Performing Convolution (NOT cross-correlation) in pytorch
TLDR Use the convolution from the functional toolbox, torch.nn.fuctional.conv2d
, not torch.nn.Conv2d
, and flip your filter around the vertical and horizontal axis.
torch.nn.Conv2d
is a convolutional layer for a network. Because weights are learned, it does not matter if it is implemented using cross-correlation, because the network will simply learn a mirrored version of the kernel (Thanks @etarion for this clarification).
torch.nn.fuctional.conv2d
performs convolution with the inputs and weights provided as arguments, similar to the tensorflow function in your example. I wrote a simple test to determine whether, like the tensorflow function, it is actually performing cross-correlation and it is necessary to flip the filter for correct convolutional results.
import torch
import torch.nn.functional as F
import torch.autograd as autograd
import numpy as np
#A vertical edge detection filter.
#Because this filter is not symmetric, for correct convolution the filter must be flipped before element-wise multiplication
filters = autograd.Variable(torch.FloatTensor([[[[-1, 1]]]]))
#A test image of a square
inputs = autograd.Variable(torch.FloatTensor([[[[0,0,0,0,0,0,0], [0, 0, 1, 1, 1, 0, 0],
[0, 0, 1, 1, 1, 0, 0], [0, 0, 1, 1, 1, 0, 0],
[0,0,0,0,0,0,0]]]]))
print(F.conv2d(inputs, filters))
This outputs
Variable containing:
(0 ,0 ,.,.) =
0 0 0 0 0 0
0 1 0 0 -1 0
0 1 0 0 -1 0
0 1 0 0 -1 0
0 0 0 0 0 0
[torch.FloatTensor of size 1x1x5x6]
This output is the result for cross-correlation. Therefore, we need to flip the filter
def flip_tensor(t):
flipped = t.numpy().copy()
for i in range(len(filters.size())):
flipped = np.flip(flipped,i) #Reverse given tensor on dimention i
return torch.from_numpy(flipped.copy())
print(F.conv2d(inputs, autograd.Variable(flip_tensor(filters.data))))
The new output is the correct result for convolution.
Variable containing:
(0 ,0 ,.,.) =
0 0 0 0 0 0
0 -1 0 0 1 0
0 -1 0 0 1 0
0 -1 0 0 1 0
0 0 0 0 0 0
[torch.FloatTensor of size 1x1x5x6]
Nothing too different from the answer above, but torch
can do flip(i)
natively (and I guess you only wanted to flip(2)
and flip(3)
):
def convolution(A, B):
return F.conv2d(A, B.flip(2).flip(3), padding=padding)