How to get the device type of a pytorch module conveniently?
My solution, works in 99% of cases.
class Net(nn.Module):
def __init__()
super().__init__()
self.dummy_param = nn.Parameter(torch.empty(0))
def forward(x):
device = self.dummy_param.device
... etc
Thereafter, the dummy_param will always have the same device as the module Net, so you can get it anytime you want. eg:
net = Net()
net.dummy_param.device
'cpu'
net = net.to('cuda')
net.dummy_param.device
'cuda:0'
@Duane's answer creates a parameter in the model (despite being a small tensor).
I think this answer is slightly more pythonic and elegant:
class Model(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.device = torch.device('cpu') # device parameter not defined by default for modules
def _apply(self, fn):
# https://stackoverflow.com/questions/54706146/moving-member-tensors-with-module-to-in-pytorch
# override apply by moving the attribute device of the class object as well.
# This allows to directly know where the class is when creating new attribute for the class object.
super()._apply(fn)
self.device = fn(self.device)
return self
net.cuda()
, net.float()
, etc will all work as well, since those all call _apply
rather than to
(as can be seen in the source).
An alternative solution from the comment of @Kani (accepted answer) is also very elegant:
class Model(nn.Module):
def __init__(self, *args, **kwargs):
"""
Constructor for Neural Network.
"""
super().__init__()
@property
def device(self):
return next(self.parameters()).device
You access the device through model.device
as for parameters. This solution does not work when you have no parameter inside the model.
This question has been asked many times (1, 2). Quoting the reply from a PyTorch developer:
That’s not possible. Modules can hold parameters of different types on different devices, and so it’s not always possible to unambiguously determine the device.
The recommended workflow (as described on PyTorch blog) is to create the device
object separately and use that everywhere. Copy-pasting the example from the blog here:
# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
...
# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)
Do note that there is nothing stopping you from adding a .device
property to the models.
As mentioned by Kani (in the comments), if the all the parameters in the model are on the same device, one could use next(model.parameters()).device
.