윤제로의 제로베이스

토치텍스트(Torchtext)의 batch_first 본문

Background/Pytorch 기초

토치텍스트(Torchtext)의 batch_first

윤_제로 2022. 1. 19. 19:48

https://wikidocs.net/65794

 

04. 토치텍스트(TorchText)의 batch_first

이번 챕터에서는 토치텍스트에서 배치퍼스트(batch_first)를 True로 한 경우와 False를 한 경우를 비교해보겠습니다. 이번 챕터는 토치텍스트 튜토리얼 챕터가 아니 ...

wikidocs.net

1. 훈련 데이터와 테스트 데이터로 분리하기

import urllib.request
import pandas as pd

urllib.request.urlretrieve("https://raw.githubusercontent.com/LawrenceDuan/IMDb-Review-Analysis/master/IMDb_Reviews.csv", filename="IMDb_Reviews.csv")

df = pd.read_csv('IMDb_Reviews.csv', encoding='latin1') # 50000

train_df = df[:25000]
test_df = df[25000:]

train_df.to_csv("train_data.csv", index=False)
test_df.to_csv("test_data.csv", index=False)

2. 필드 정의하기(torchtext.data)

from torchtext import data # torchtext.data 임포트

# 필드 정의
TEXT = data.Field(sequential=True,
                  use_vocab=True,
                  tokenize=str.split,
                  lower=True,
                  batch_first=True, # <== 이 부분을 True로 합니다.
                  fix_length=20)

LABEL = data.Field(sequential=False,
                   use_vocab=False,
                   batch_first=False,
                   is_target=True)
  • batch_first : 미니 배치 차원을 맨 앞으로 하여 데이터를 불러올 것인지 여부(False가 기본값)

3. 데이터셋/단어 집합/데이터로더 만들기

from torchtext.data import TabularDataset
from torchtext.data import Iterator

# TabularDataset은 데이터를 불러오면서 필드에서 정의했던 토큰화 방법으로 토큰화를 수행합니다.
train_data, test_data = TabularDataset.splits(
        path='.', train='train_data.csv', test='test_data.csv', format='csv',
        fields=[('text', TEXT), ('label', LABEL)], skip_header=True)

# 정의한 필드에 .build_vocab() 도구를 사용하면 단어 집합을 생성합니다.
TEXT.build_vocab(train_data, min_freq=10, max_size=10000) # 10,000개의 단어를 가진 단어 집합 생성

# 배치 크기를 정하고 첫번째 배치를 출력해보겠습니다.
batch_size = 5
train_loader = Iterator(dataset=train_data, batch_size = batch_size)
batch = next(iter(train_loader)) # 첫번째 미니배치

 

배치 크기가 5이기 때문에 각 샘플의 길이가 20인 샘플 5개가 들어있다.

앞서 필드를 정의할 때 fix_length를 20으로 정의했기 때문이다.

다시 말해 하나의 미니 배치의 크기는 (배치 크기 * fix_length)이다.

4. 필드 재정의하기(torchtext.data)

# 필드 정의
TEXT = data.Field(sequential=True,
                  use_vocab=True,
                  tokenize=str.split,
                  lower=True,
                  fix_length=20)

LABEL = data.Field(sequential=False,
                   use_vocab=False,
                   batch_first=False,
                   is_target=True)

 

TEXT 필드에서 batch_first = True 인자를 제거한다.

기본값이 False이기 때문에 batch_first값이 False가 된다.

5. batch_first = False로 하였을 때 텐서 크기

하나의 미니 배치 크기가 (fix_length * 배치 크기) 로 변화한다.