728x90

nn.Variable()은 이제 더 이상 사용하지 않는다고 한다. Tensor를 사용하기 위해서는 nn.parameters의 매개변수를 사용해야 하며, nn.Module의 매개변수로 표시된 특정 Tensor일 때, 모듈을 호출하면 반환된다 .
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn(1, 1))
def forward(self, x):
x = x * self.param
return x
model = MyModel()
print(dict(model.named_parameters()))
# {'param': Parameter containing:
# tensor([[0.6077]], requires_grad=True)}
out = model(torch.randn(1, 1))
loss = out.mean()
loss.backward()
print(model.param.grad)
# tensor([[-1.3033]])
모듈의 예시 코드이다.
출처: https://discuss.pytorch.org/t/practical-usage-of-nn-variable-and-nn-parameter/148644
728x90
'Skills > PyTorch' 카테고리의 다른 글
| [PyTorch] ADJUST_CONTRAST 이미지 대비 (torchvision.transforms.functional) (14) | 2023.07.27 |
|---|---|
| [PyTorch] Optimizer(옵티마이저) 최적화의 단계와 역할 (2) | 2023.07.26 |
| [PyTorch] vutils.save_image 텐서 형태 이미지 저장하는 법 (0) | 2023.07.26 |
| [PyTorch] STATE_DICT, CHECKPOINT 에 대해 (0) | 2023.07.25 |
| [PyTorch] learning rate scheduler - StepLR에 대해 (0) | 2023.06.22 |
댓글