Haskell - Reproduce numpy's reshape
There are two details here that are qualitatively different from Python, ultimately stemming from dynamic vs. static typing.
The first one you have noticed yourself: at each chunking step the resulting type is different from the input type. This means you cannot use foldr
, because it expects a function of one specific type. You could do it via recursion though.
The second problem is a bit less obvious: the return type of your reshape
function depends on what the first argument is. Like, if the first argument is [2]
, the return type is [[a]]
, but if the first argument is [2, 3]
, then the return type is [[[a]]]
. In Haskell, all types must be known at compile time. And this means that your reshape
function cannot take the first argument that is defined at runtime. In other words, the first argument must be at the type level.
Type-level values may be computed via type functions (aka "type families"), but because it's not just the type (i.e. you also have a value to compute), the natural (or the only?) mechanism for that is a type class.
So, first let's define our type class:
class Reshape (dimensions :: [Nat]) from to | dimensions from -> to where
reshape :: from -> to
The class has three parameters: dimensions
of kind [Nat]
is a type-level array of numbers, representing the desired dimensions. from
is the argument type, and to
is the result type. Note that, even though it is known that the argument type is always [a]
, we have to have it as a type variable here, because otherwise our class instances won't be able to correctly match the same a
between argument and result.
Plus, the class has a functional dependency dimensions from -> to
to indicate that if I know both dimensions
and from
, I can unambiguously determine to
.
Next, the base case: when dimentions
is an empty list, the function just degrades to id
:
instance Reshape '[] [a] [a] where
reshape = id
And now the meat: the recursive case.
instance (KnownNat n, Reshape tail [a] [b]) => Reshape (n:tail) [a] [[b]] where
reshape = chunksOf n . reshape @tail
where n = fromInteger . natVal $ Proxy @n
First it makes the recursive call reshape @tail
to chunk out the previous dimension, and then it chunks out the result of that using the value of the current dimension as chunk size.
Note also that I'm using the chunksOf
function from the library split
. No need to redefine it yourself.
Let's test it out:
λ reshape @ '[1] [1,2,3]
[[1],[2],[3]]
λ reshape @ '[1,2] [1,2,3,4]
[[[1,2]],[[3,4]]]
λ reshape @ '[2,3] [1..12]
[[[1,2,3],[4,5,6]],[[7,8,9],[10,11,12]]]
λ reshape @ '[2,3,4] [1..24]
[[[[1,2,3,4],[5,6,7,8],[9,10,11,12]],[[13,14,15,16],[17,18,19,20],[21,22,23,24]]]]
For reference, here's the full program with all imports and extensions:
{-# LANGUAGE
MultiParamTypeClasses, FunctionalDependencies, TypeApplications,
ScopedTypeVariables, DataKinds, TypeOperators, KindSignatures,
FlexibleInstances, FlexibleContexts, UndecidableInstances,
AllowAmbiguousTypes
#-}
import Data.Proxy (Proxy(..))
import Data.List.Split (chunksOf)
import GHC.TypeLits (Nat, KnownNat, natVal)
class Reshape (dimensions :: [Nat]) from to | dimensions from -> to where
reshape :: from -> to
instance Reshape '[] [a] [a] where
reshape = id
instance (KnownNat n, Reshape tail [a] [b]) => Reshape (n:tail) [a] [[b]] where
reshape = chunksOf n . reshape @tail
where n = fromInteger . natVal $ Proxy @n
@Fyodor Soikin's answer is perfect with respect to the actual question. Except there is a bit of a problem with the question itself. Lists of lists is not the same thing as an array. It is a common misconception that Haskell doesn't have arrays and you are forced to deal with lists, which could not be further from the truth.
Because the question is tagged with array
and there is comparison to numpy
, I would like to add a proper answer that handles this situation for multidimensional arrays. There are a couple of array libraries in Haskell ecosystem, one of which is massiv
A reshape
like functionality from numpy
can be achieved by resize'
function:
λ> 1 ... (18 :: Int)
Array D Seq (Sz1 18)
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18 ]
λ> resize' (Sz (3 :> 2 :. 3)) (1 ... (18 :: Int))
Array D Seq (Sz (3 :> 2 :. 3))
[ [ [ 1, 2, 3 ]
, [ 4, 5, 6 ]
]
, [ [ 7, 8, 9 ]
, [ 10, 11, 12 ]
]
, [ [ 13, 14, 15 ]
, [ 16, 17, 18 ]
]
]