How does Pytorch's "Fold" and "Unfold" work?
unfold
and fold
are used to facilitate "sliding window" operation (like convolutions).
Suppose you want to apply a function foo
to every 5x5 window in a feature map/image:
from torch.nn import functional as f
windows = f.unfold(x, kernel_size=5)
Now windows
has size
of batch-(5*5*x.size(1)
)-num_windows, you can apply foo
on windows
:
processed = foo(windows)
Now you need to "fold" processed
back to the original size of x
:
out = f.fold(processed, x.shape[-2:], kernel_size=5)
You need to take care of padding
, and kernel_size
that may affect your ability to "fold" back processed
to the size of x
.
Moreover, fold
sums over overlapping elements, so you might want to divide the output of fold
by patch size.
One dimensional unfolding is easy:
x = torch.arange(1, 9).float()
print(x)
# dimension, size, step
print(x.unfold(0, 2, 1))
print(x.unfold(0, 3, 2))
Out:
tensor([1., 2., 3., 4., 5., 6., 7., 8.])
tensor([[1., 2.],
[2., 3.],
[3., 4.],
[4., 5.],
[5., 6.],
[6., 7.],
[7., 8.]])
tensor([[1., 2., 3.],
[3., 4., 5.],
[5., 6., 7.]])
Two dimensional unfolding (also called patching)
import torch
patch=(3,3)
x=torch.arange(16).float()
print(x, x.shape)
x2d = x.reshape(1,1,4,4)
print(x2d, x2d.shape)
h,w = patch
c=x2d.size(1)
print(c) # channels
# unfold(dimension, size, step)
r = x2d.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1, c, h, w)
print(r.shape)
print(r) # result
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15.]) torch.Size([16])
tensor([[[[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[12., 13., 14., 15.]]]]) torch.Size([1, 1, 4, 4])
1
torch.Size([4, 1, 3, 3])
tensor([[[[ 0., 1., 2.],
[ 4., 5., 6.],
[ 8., 9., 10.]]],
[[[ 4., 5., 6.],
[ 8., 9., 10.],
[12., 13., 14.]]],
[[[ 1., 2., 3.],
[ 5., 6., 7.],
[ 9., 10., 11.]]],
[[[ 5., 6., 7.],
[ 9., 10., 11.],
[13., 14., 15.]]]])