Tracing back deprecated warning in pytorch
Found the problem.
line : loss_x = self.mse_loss(x[mask], tx[mask])
the mask
variable was a ByteTensor
which is deprecated . Just replaced it with a BoolTensor
Found the problem.
line : loss_x = self.mse_loss(x[mask], tx[mask])
the mask
variable was a ByteTensor
which is deprecated . Just replaced it with a BoolTensor