[PyTorch] nn.Parameter() / nn.Variable()

    728x90

     

    nn.Variable()은 이제 더 이상 사용하지 않는다고 한다. Tensor를 사용하기 위해서는 nn.parameters의 매개변수를 사용해야 하며, nn.Module의 매개변수로 표시된 특정 Tensor일 때, 모듈을 호출하면 반환된다 .

     

    class MyModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.param = nn.Parameter(torch.randn(1, 1))
            
        def forward(self, x):
            x = x * self.param
            return x
    
    model = MyModel()
    print(dict(model.named_parameters()))
    # {'param': Parameter containing:
    # tensor([[0.6077]], requires_grad=True)}
    
    out = model(torch.randn(1, 1))
    loss = out.mean()
    loss.backward()
    
    print(model.param.grad)
    # tensor([[-1.3033]])

     

    모듈의 예시 코드이다.

     

    출처: https://discuss.pytorch.org/t/practical-usage-of-nn-variable-and-nn-parameter/148644

    728x90

    댓글