티스토리 뷰
출처
https://tutorials.pytorch.kr/recipes/recipes/what_is_state_dict.html
https://tutorials.pytorch.kr/recipes/recipes/saving_and_loading_a_general_checkpoint.html
PyTorch에서 torch.nn.Module 모델의 학습 가능한 매개변수(예. 가중치와 편향)들은 모델의 매개변수에 포함되어 있습니다. (model.parameters()로 접근합니다)
state_dict는 간단히 말해 각 계층을 매개변수 텐서로 매핑되는 Python 사전(dict) 객체입니다. state_dict 는 PyTorch에서 모델을 저장하거나 불러오는 데 관심이 있다면 필수적인 항목입니다. state_dict 객체는 Python 사전이기 때문에 쉽게 저장, 업데이트, 변경 및 복원할 수 있으며, 이는 PyTorch 모델과 옵티마이저에 엄청난 모듈성(modularity)을 제공합니다. 이때, 학습 가능한 매개변수를 갖는 계층(합성곱 계층, 선형 계층 등) 및 등록된 버퍼들(batchnorm의 running_mean)만 모델의 state_dict 항목을 갖습니다.
( torch.optim ) 또한 옵티마이저의 상태 뿐만 아니라 사용된 하이퍼 매개변수 (Hyperparameter) 정보가 포함된 state_dict 을 갖습니다.
단계
1. 데이터를 불러올 때 필요한 모든 라이브러리 불러오기
2. 신경망을 구성하고 초기화하기
3. 옵티마이저 초기화하기
4. 모델과 옵티마이저의 state_dict 접근하기
1. 모든 라이브러리 불러오기
import torch
import torch.nn as nn
import torch.optim as optim
2. 신경망을 구성하고 초기화하기
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
3. 옵티마이저 초기화하기
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4. 모델과 옵티마이저의 state_dict 접근하기
# 모델의 state_dict 출력
print("Model's state_dict:")
for param_tensor in net.state_dict():
print(param_tensor, "\t", net.state_dict()[param_tensor].size())
print()
# 옵티마이저의 state_dict 출력
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
이 정보는 향후 모델 및 옵티마이저를 저장하고 불러오는 것과 관련이 있습니다.
5. 체크포인트 저장하기
# 추가 정보
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4
torch.save({
'epoch': EPOCH,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
6. 체크포인트 불러오기
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - 또는 -
model.train()
추론(inference)을 실행하기 전에 model.eval() 을 호출하여 드롭아웃(dropout)과 배치 정규화 층(batch normalization layer)을 평가(evaluation) 모드로 바꿔야합니다.
state_dict 같은 경우는 모델의 매개변수를 저장하는 방법 중 하나이고, 모델을 불러올 때 저장한 모델에 load_state_dict()를 사용해서 매개변수를 적용할 수 있다. 설명을 보면 각 계층마다 매개변수를 텐서 형태로 dict 타입으로 저장하는 것 같다.
.pt 또는 .pth 확장자를 사용해 모델을 저장하고, 그 외 save하여 pickle로 가져오는 방법이 있다.
'Skills > PyTorch' 카테고리의 다른 글
[PyTorch] ADJUST_CONTRAST 이미지 대비 (torchvision.transforms.functional) (14) | 2023.07.27 |
---|---|
[PyTorch] Optimizer(옵티마이저) 최적화의 단계와 역할 (2) | 2023.07.26 |
[PyTorch] vutils.save_image 텐서 형태 이미지 저장하는 법 (0) | 2023.07.26 |
[PyTorch] nn.Parameter() / nn.Variable() (0) | 2023.07.10 |
[PyTorch] learning rate scheduler - StepLR에 대해 (0) | 2023.06.22 |
- Total
- Today
- Yesterday
- 파이썬 딕셔너리
- 프롬프트
- style transfer
- prompt learning
- stylegan
- 구글드라이브서버다운
- python
- 파이썬 클래스 계층 구조
- Unsupervised learning
- few-shot learning
- CNN
- Prompt
- 딥러닝
- support set
- 서버에다운
- NLP
- vscode 자동 저장
- clip
- docker
- 서버구글드라이브연동
- 퓨샷러닝
- 도커 컨테이너
- 구글드라이브연동
- 구글드라이브다운
- 파이썬 클래스 다형성
- 도커
- 데이터셋다운로드
- cs231n
- 구글드라이브서버연동
- 파이썬
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |