윤제로의 제로베이스

소프트맥스 회귀의 비용 함수 구현하기 본문

Background/Pytorch 기초

소프트맥스 회귀의 비용 함수 구현하기

윤_제로 2022. 1. 16. 21:32

https://wikidocs.net/60572

 

03. 소프트맥스 회귀의 비용 함수 구현하기

이번 챕터에서는 소프트맥스 회귀의 비용 함수를 구현해봅시다. 앞으로의 모든 실습은 아래의 코드가 이미 진행되었다고 가정합니다. ``` import torch impor ...

wikidocs.net

1. 파이토치로 소프트맥스의 비용함수 구현하기(로우 레벨 버전)

import torch
import torch.nn.functional as F
torch.manual_seed(1)

z = torch.rand(3, 5, requires_grad=True) # 소프트맥스 함수의 입력

hypothesis = F.softmax(z, dim=1)

y = torch.randint(5, (3,)).long()

# 모든 원소가 0의 값을 가진 3 × 5 텐서 생성
y_one_hot = torch.zeros_like(hypothesis) 
y_one_hot.scatter_(1, y.unsqueeze(1), 1)

cost = (y_one_hot * -torch.log(hypothesis)).sum(dim=1).mean()

2. 파이토치로 소프트맥스의 비용함수 구현하기(하이 레벨 버전)

1) F.softmax() + torch.log() = F.log_softmax()

# Low level
torch.log(F.softmax(z, dim=1))

# High level
F.log_softmax(z, dim=1)

2) F.log_softmax() + F.nll_loss() = F.cross_entropy()

# Low level
# 첫번째 수식
(y_one_hot * -torch.log(F.softmax(z, dim=1))).sum(dim=1).mean()

# 두번째 수식
(y_one_hot * - F.log_softmax(z, dim=1)).sum(dim=1).mean()

# High level
# 세번째 수식
F.nll_loss(F.log_softmax(z, dim=1), y)

# 네번째 수식
F.cross_entropy(z, y)

nll이란 Negative Log Likelihodd의 약자이다. 


  • F.cross_entropy는 비용 함수에 소프트맥스 함수까지 포함하고 있다.