윤제로의 제로베이스

소프트맥스 회귀로 MNIST 데이터 분류하기 본문

Background/Pytorch 기초

소프트맥스 회귀로 MNIST 데이터 분류하기

윤_제로 2022. 1. 16. 21:59

https://wikidocs.net/60324

 

05. 소프트맥스 회귀로 MNIST 데이터 분류하기

이번 챕터에서는 MNIST 데이터에 대해서 이해하고, 파이토치(PyTorch)로 소프트맥스 회귀를 구현하여 MNIST 데이터를 분류하는 실습을 진행해봅시다. MNIST ...

wikidocs.net

for X, Y in data_loader:
  # 입력 이미지를 [batch_size × 784]의 크기로 reshape
  # 레이블은 원-핫 인코딩
  X = X.view(-1, 28*28)
   
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
import random

USE_CUDA = torch.cuda.is_available() # GPU를 사용가능하면 True, 아니라면 False를 리턴
device = torch.device("cuda" if USE_CUDA else "cpu") # GPU 사용 가능하면 사용하고 아니면 CPU 사용
print("다음 기기로 학습합니다:", device)

# for reproducibility
random.seed(777)
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed_all(777)
    
# hyperparameters
training_epochs = 15
batch_size = 100

분류기 구현

# MNIST dataset
mnist_train = dsets.MNIST(root='MNIST_data/',
                          train=True,
                          transform=transforms.ToTensor(),
                          download=True)

mnist_test = dsets.MNIST(root='MNIST_data/',
                         train=False,
                         transform=transforms.ToTensor(),
                         download=True)

 

첫번째 인자 root는 MNIST 데이터를 받을 경로이다.

두번째 인자 train은 인자로 True를 주면 MNIST의 훈련 데이터를 리턴 받으며 False를 주면 테스트 데이터를 리턴 받는다.

세번째 인자 transform은 현재 데이터를 파이토치 텐서로 변환해준다.

네번째 인자 download는 해당 경로에 MNIST 데이터가 없다면 다운받겠다는 의미이다.

 

# dataset loader
data_loader = DataLoader(dataset=mnist_train,
                                          batch_size=batch_size, # 배치 크기는 100
                                          shuffle=True,
                                          drop_last=True)

 

첫번째 인자인 dataset은 로드할 대상을 의미한다.

두번째 인자인 batch_size는 배치 크기, shuffle은 매 에포크마다 미니 배치를 셔플할 것인지의 여부, drop_last는 마지막 배치를 버릴 것인지를 의미한다.

drop_last를 하는 이유는 미니배치를 가득 채우지 못한 마지막 배치를 버릴지를 설정해주는 부분이다. 이는 다른 미니배치보다 개수가 적은 마지막 배치를 경사 하가업ㅂ에 사용하여 마지막 배치가 상대적으로 과대 평가되는 현상을 막아준다.

 

# MNIST data image of shape 28 * 28 = 784
linear = nn.Linear(784, 10, bias=True).to(device)

 

to() 함수는 연산을 어디서 수행할지를 정한다. 

to()함수는 모델의 매개변수를 지정한 장치의 메모리로 보낸다.

CPU를 사용할 경우 필요 없지만, GPU를 사용하려면 to('cuda')를 해주어야 한다.

 

bias는 편향 b를 사용할 것인지를 나타낸다. 

기본값이 True이므로 굳이 할 필요는 없다.

 

# 비용 함수와 옵티마이저 정의
criterion = nn.CrossEntropyLoss().to(device) # 내부적으로 소프트맥스 함수를 포함하고 있음.
optimizer = torch.optim.SGD(linear.parameters(), lr=0.1)

 

여기서는 torch.nn.CrossEntropyLoss()를 사용하고 있으나 앞서 사용했던 torch.nn.functional.cross_entropy()와 동일하게 소프트맥스 함수를 포함하고 있는 크로스 엔트로피 함수이다.

 

for epoch in range(training_epochs): # 앞서 training_epochs의 값은 15로 지정함.
    avg_cost = 0
    total_batch = len(data_loader)

    for X, Y in data_loader:
        # 배치 크기가 100이므로 아래의 연산에서 X는 (100, 784)의 텐서가 된다.
        X = X.view(-1, 28 * 28).to(device)
        # 레이블은 원-핫 인코딩이 된 상태가 아니라 0 ~ 9의 정수.
        Y = Y.to(device)

        optimizer.zero_grad()
        hypothesis = linear(X)
        cost = criterion(hypothesis, Y)
        cost.backward()
        optimizer.step()

        avg_cost += cost / total_batch

    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_cost))

print('Learning finished')

 

# 테스트 데이터를 사용하여 모델을 테스트한다.
with torch.no_grad(): # torch.no_grad()를 하면 gradient 계산을 수행하지 않는다.
    X_test = mnist_test.test_data.view(-1, 28 * 28).float().to(device)
    Y_test = mnist_test.test_labels.to(device)

    prediction = linear(X_test)
    correct_prediction = torch.argmax(prediction, 1) == Y_test
    accuracy = correct_prediction.float().mean()
    print('Accuracy:', accuracy.item())

    # MNIST 테스트 데이터에서 무작위로 하나를 뽑아서 예측을 해본다
    r = random.randint(0, len(mnist_test) - 1)
    X_single_data = mnist_test.test_data[r:r + 1].view(-1, 28 * 28).float().to(device)
    Y_single_data = mnist_test.test_labels[r:r + 1].to(device)

    print('Label: ', Y_single_data.item())
    single_prediction = linear(X_single_data)
    print('Prediction: ', torch.argmax(single_prediction, 1).item())

    plt.imshow(mnist_test.test_data[r:r + 1].view(28, 28), cmap='Greys', interpolation='nearest')
    plt.show()