R expand.grid() function in Python

The pandas documentation defines an expand_grid function:

def expand_grid(data_dict):
    """Create a dataframe from every combination of given values."""
    rows = itertools.product(*data_dict.values())
    return pd.DataFrame.from_records(rows, columns=data_dict.keys())

For this code to work, you will need the following two imports:

import itertools
import pandas as pd

The output is a pandas.DataFrame which is the most comparable object in Python to an R data.frame.


product from itertools is the key to your solution. It produces a cartesian product of the inputs.

from itertools import product

def expand_grid(dictionary):
   return pd.DataFrame([row for row in product(*dictionary.values())], 
                       columns=dictionary.keys())

dictionary = {'color': ['red', 'green', 'blue'], 
              'vehicle': ['car', 'van', 'truck'], 
              'cylinders': [6, 8]}

>>> expand_grid(dictionary)
    color  cylinders vehicle
0     red          6     car
1     red          6     van
2     red          6   truck
3     red          8     car
4     red          8     van
5     red          8   truck
6   green          6     car
7   green          6     van
8   green          6   truck
9   green          8     car
10  green          8     van
11  green          8   truck
12   blue          6     car
13   blue          6     van
14   blue          6   truck
15   blue          8     car
16   blue          8     van
17   blue          8   truck

Here's an example that gives output similar to what you need:

import itertools
def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}

>>> a = [1,2,3]
>>> b = [5,7,9]
>>> expandgrid(a, b)
{'Var1': [1, 1, 1, 2, 2, 2, 3, 3, 3], 'Var2': [5, 7, 9, 5, 7, 9, 5, 7, 9]}

The difference is related to the fact that in itertools.product the rightmost element advances on every iteration. You can tweak the function by sorting the product list smartly if it's important.


EDIT (by S. Laurent)

To have the same as R:

def expandgrid(*itrs): # https://stackoverflow.com/a/12131385/1100107
    """
    Cartesian product. Reversion is for compatibility with R.
    
    """
    product = list(itertools.product(*reversed(itrs)))
    return [[x[i] for x in product] for i in range(len(itrs))][::-1]

Just use list comprehensions:

>>> [(x, y) for x in range(5) for y in range(5)]

[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (3, 0), (3, 1), (3, 2), (3, 3), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3), (4, 4)]

convert to numpy array if desired:

>>> import numpy as np
>>> x = np.array([(x, y) for x in range(5) for y in range(5)])
>>> x.shape
(25, 2)

I have tested for up to 10000 x 10000 and performance of python is comparable to that of expand.grid in R. Using a tuple (x, y) is about 40% faster than using a list [x, y] in the comprehension.

OR...

Around 3x faster with np.meshgrid and much less memory intensive.

%timeit np.array(np.meshgrid(range(10000), range(10000))).reshape(2, 100000000).T
1 loops, best of 3: 736 ms per loop

in R:

> system.time(expand.grid(1:10000, 1:10000))
   user  system elapsed 
  1.991   0.416   2.424 

Keep in mind that R has 1-based arrays whereas Python is 0-based.

Tags:

Python

R