Finding non-intersection of two pytorch tensors
if you don't want to leave cuda, a workaround could be:
t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
t2 = torch.tensor([1, 24], device = 'cuda')
indices = torch.ones_like(t1, dtype = torch.uint8, device = 'cuda')
for elem in t2:
indices = indices & (t1 != elem)
intersection = t1[indices]
If you don't want a for loop this can compare all values in one go.
Also you can get the non intersection easily too
t1 = torch.tensor([1, 9, 12, 5, 24])
t2 = torch.tensor([1, 24])
# Create a tensor to compare all values at once
compareview = t2.repeat(t1.shape[0],1).T
# Intersection
print(t1[(compareview == t1).T.sum(1)==1])
# Non Intersection
print(t1[(compareview != t1).T.prod(1)==1])
tensor([ 1, 24])
tensor([ 9, 12, 5])
I came across the same problem but the proposed solutions were far too slow when using larger arrays. The following simple solution works on CPU and GPU and is significantly faster than the other proposed solutions:
combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]