윤제로의 제로베이스

클래스로 파이토치 선형회귀 모델 구현하기 본문

Background/Pytorch 기초

클래스로 파이토치 선형회귀 모델 구현하기

윤_제로 2022. 1. 14. 00:54

https://wikidocs.net/60036

 

05. 클래스로 파이토치 모델 구현하기

파이토치의 대부분의 구현체들은 대부분 모델을 생성할 때 클래스(Class)를 사용하고 있습니다. 앞서 배운 선형 회귀를 클래스로 구현해보겠습니다. 앞서 구현한 코드와 다른 ...

wikidocs.net

1. 모델을 클래스로 구현하기

# 모델을 선언 및 초기화. 단순 선형 회귀이므로 input_dim=1, output_dim=1.
model = nn.Linear(1,1)

 

이를 클래스로 구현하면 아래와 같다.

 

class LinearRegressionModel(nn.Module): # torch.nn.Module을 상속받는 파이썬 클래스
    def __init__(self): #
        super().__init__()
        self.linear = nn.Linear(1, 1) # 단순 선형 회귀이므로 input_dim=1, output_dim=1.

    def forward(self, x):
        return self.linear(x)
        
model = LinearRegressionModel()

 

클래스 형태의 모델은 nn.Module을 상속받는다.

그리고 __init__()에서 모델의 구조와 동작을 정의하는 생성자를 정의한다. 

이는 파이썬에서 객체가 갖는 속성값을 초기화하는 역할로, 객체가 생성될 때 자동으로 호출된다.

super() 함수를 부르면 여기서 만든 클래스는 nn.Module 클래스의 속성을 가지고 초기화 된다.

forward() 함수는 모델이 학습 데이터를 입력 받아서 forward 연산을 진행시키는 함수이다.

이 foward() 함수는 model 객체를 데이터와 함께 호출하면 자동으로 실행된다.

 

# 다중 선형 회귀 모델 클래스
class MultivariateLinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 1) # 다중 선형 회귀이므로 input_dim=3, output_dim=1.

    def forward(self, x):
        return self.linear(x)