pytorch RuntimeError: Expected object of scalar type Double but got scalar type Float
Now that I have more experience with pytorch, I think I can explain the error message. It seems that the line
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat2' in call to _th_mm
is actually refering to the weights of the linear layer when the matrix multiplication is called. Since the input is double
while the weights are float
, it makes sense for the line
output = input.matmul(weight.t())
to expect the weights to be double
.