torch detect device code example
Example 1: check if pytorch is using gpu minimal example
import torch
import torch.nn as nn
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
t1 = torch.randn(1,2)
t2 = torch.randn(1,2).to(dev)
print(t1)
print(t2)
t1.to(dev)
print(t1)
print(t1.is_cuda)
t1 = t1.to(dev)
print(t1)
print(t1.is_cuda)
class M(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(1,2)
def forward(self, x):
x = self.l1(x)
return x
model = M()
model.to(dev)
print(next(model.parameters()).is_cuda)
Example 2: check cuda version pytorch
torch.version.cuda