티스토리 뷰

728x90

 

nn.Variable()은 이제 더 이상 사용하지 않는다고 한다. Tensor를 사용하기 위해서는 nn.parameters의 매개변수를 사용해야 하며, nn.Module의 매개변수로 표시된 특정 Tensor일 때, 모듈을 호출하면 반환된다 .

 

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.param = nn.Parameter(torch.randn(1, 1))
        
    def forward(self, x):
        x = x * self.param
        return x

model = MyModel()
print(dict(model.named_parameters()))
# {'param': Parameter containing:
# tensor([[0.6077]], requires_grad=True)}

out = model(torch.randn(1, 1))
loss = out.mean()
loss.backward()

print(model.param.grad)
# tensor([[-1.3033]])

 

모듈의 예시 코드이다.

 

출처: https://discuss.pytorch.org/t/practical-usage-of-nn-variable-and-nn-parameter/148644

728x90
댓글