티스토리 뷰

728x90

 

출처

https://tutorials.pytorch.kr/recipes/recipes/what_is_state_dict.html

https://tutorials.pytorch.kr/recipes/recipes/saving_and_loading_a_general_checkpoint.html

 

 

PyTorch에서 torch.nn.Module 모델의 학습 가능한 매개변수(예. 가중치와 편향)들은 모델의 매개변수에 포함되어 있습니다. (model.parameters()로 접근합니다)

 

state_dict는 간단히 말해 각 계층을 매개변수 텐서로 매핑되는 Python 사전(dict) 객체입니다. state_dict 는 PyTorch에서 모델을 저장하거나 불러오는 데 관심이 있다면 필수적인 항목입니다. state_dict 객체는 Python 사전이기 때문에 쉽게 저장, 업데이트, 변경 및 복원할 수 있으며, 이는 PyTorch 모델과 옵티마이저에 엄청난 모듈성(modularity)을 제공합니다. 이때, 학습 가능한 매개변수를 갖는 계층(합성곱 계층, 선형 계층 등) 및 등록된 버퍼들(batchnorm의 running_mean)만 모델의 state_dict 항목을 갖습니다. 

 

( torch.optim ) 또한 옵티마이저의 상태 뿐만 아니라 사용된 하이퍼 매개변수 (Hyperparameter) 정보가 포함된 state_dict 을 갖습니다.

 

 

단계

1. 데이터를 불러올 때 필요한 모든 라이브러리 불러오기

2. 신경망을 구성하고 초기화하기

3. 옵티마이저 초기화하기

4. 모델과 옵티마이저의 state_dict 접근하기

 

 

1. 모든 라이브러리 불러오기

import torch
import torch.nn as nn
import torch.optim as optim

 

2. 신경망을 구성하고 초기화하기

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

 

3. 옵티마이저 초기화하기

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

 

 

4. 모델과 옵티마이저의 state_dict 접근하기

# 모델의 state_dict 출력
print("Model's state_dict:")
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())

print()

# 옵티마이저의 state_dict 출력
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

이 정보는 향후 모델 및 옵티마이저를 저장하고 불러오는 것과 관련이 있습니다.

 

5. 체크포인트 저장하기

# 추가 정보
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

 

6. 체크포인트 불러오기

model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - 또는 -
model.train()

 

추론(inference)을 실행하기 전에 model.eval() 을 호출하여 드롭아웃(dropout)과 배치 정규화 층(batch normalization layer)을 평가(evaluation) 모드로 바꿔야합니다.

 

state_dict 같은 경우는 모델의 매개변수를 저장하는 방법 중 하나이고, 모델을 불러올 때 저장한 모델에 load_state_dict()를 사용해서 매개변수를 적용할 수 있다. 설명을 보면 각 계층마다 매개변수를 텐서 형태로 dict 타입으로 저장하는 것 같다.

 

.pt 또는 .pth 확장자를 사용해 모델을 저장하고, 그 외 save하여 pickle로 가져오는 방법이 있다.

 

728x90
댓글