Random numbers generation in PySpark
For my use case, most of the solution was buried in an edit at the bottom of the question. However, there is another complication: I wanted to use the same function to generate multiple (different) random columns. It turns out that Spark has an assumption that the output of a UDF is deterministic, which means that it can skip later calls to the same function with the same inputs. For functions that return random output this is obviously not what you want.
To work around this, I generated a separate seed column for every random column that I wanted using the built-in PySpark rand
function:
import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType
import numpy as np
@F.udf(IntegerType())
def my_rand(seed):
rs = np.random.RandomState(seed)
return rs.randint(1000)
seed_expr = (F.rand()*F.lit(4294967295).astype('double')).astype('bigint')
my_df = (
my_df
.withColumn('seed_0', seed_expr)
.withColumn('seed_1', seed_expr)
.withColumn('myrand_0', my_rand(F.col('seed_0')))
.withColumn('myrand_1', my_rand(F.col('seed_1')))
.drop('seed_0', 'seed_1')
)
I'm using the DataFrame API rather than the RDD of the original problem because that's what I'm more familiar with, but the same concepts presumably apply.
NB: apparently it is possible to disable the assumption of determinism for Scala Spark UDFs since v2.3: https://jira.apache.org/jira/browse/SPARK-20586.
This seems to be a bug (or feature) of randint
. I see the same behavior, but as soon as I change the f
, the values do indeed change. So, I'm not sure of the actual randomness of this method....I can't find any documentation, but it seems to be using some deterministic math algorithm instead of using more variable features of the running machine. Even if I go back and forth, the numbers seem to be the same upon returning to the original value...
So the actual problem here is relatively simple. Each subprocess in Python inherits its state from its parent:
len(set(sc.parallelize(range(4), 4).map(lambda _: random.getstate()).collect()))
# 1
Since parent state has no reason to change in this particular scenario and workers have a limited lifespan, state of every child will be exactly the same on each run.