How to print the decision path / rules used to predict sample of a specific row in PySpark?
I changed your dataframe just slightly so that we could ensure we could see different features in the explanations
I changed the Assembler to use a feature_list, so we have easy access to that later
changes below:
#change1: ball goes from [0,1,2,3] ->[0,1,1,3] so we can see other features in explanations
#change2: added in multiple paths to the same prediction
#change3: added in a categorical variable
#change3: feature_list so we can re-use those indicies easily later
data = pd.DataFrame({
'ball': [0, 1, 1, 3, 1, 0, 1, 3],
'keep': [4, 5, 6, 7, 7, 4, 6, 7],
'hall': [8, 9, 10, 11, 2, 6, 10, 11],
'fall': [12, 13, 14, 15, 15, 12, 14, 15],
'mall': [16, 17, 18, 10, 10, 16, 18, 10],
'wall': ['a','a','a','a','a','a','c','e'],
'label': [21, 31, 41, 51, 51, 51, 21, 31]
})
df = spark.createDataFrame(data)
df = df.withColumn("mono_ID", monotonically_increasing_id())
w = Window().orderBy("mono_ID")
df = df.select(row_number().over(w).alias("tagvalue"), col("*"))
indexer = StringIndexer(inputCol='wall', outputCol='wallIndex')
encoder = OneHotEncoder(inputCol='wallIndex', outputCol='wallVec')
#i added this line so feature replacement later is easy because of the indices
features = ['ball','keep','wallVec','hall','fall']
assembler = VectorAssembler(
inputCols=features, outputCol='features')
dtc = DecisionTreeClassifier(featuresCol='features', labelCol='label')
pipeline = Pipeline(stages=[indexer, encoder, assembler, dtc]).fit(df)
transformed_pipeline = pipeline.transform(df)
Below is a method I've found to be able to work with the decision tree itself:
#get the pipeline back out, as you've done earlier, this changed to [3] because of the categorical encoders
ml_pipeline = pipeline.stages[3]
#saves the model so we can get at the internals that the scala code keeps private
ml_pipeline.save("mymodel_test")
#read back in the model parameters
modeldf = spark.read.parquet("mymodel_test/data/*")
import networkx as nx
#select only the columns that we NEED and collect into a list
noderows = modeldf.select("id","prediction","leftChild","rightChild","split").collect()
#create a graph for the decision tree; you Could use a simpler tree structure here if you wanted instead of a 'graph'
G = nx.Graph()
#first pass to add the nodes
for rw in noderows:
if rw['leftChild'] < 0 and rw['rightChild'] < 0:
G.add_node(rw['id'], cat="Prediction", predval=rw['prediction'])
else:
G.add_node(rw['id'], cat="splitter", featureIndex=rw['split']['featureIndex'], thresh=rw['split']['leftCategoriesOrThreshold'], leftChild=rw['leftChild'], rightChild=rw['rightChild'], numCat=rw['split']['numCategories'])
#second pass to add the relationships, now with additional information
for rw in modeldf.where("leftChild > 0 and rightChild > 0").collect():
tempnode = G.nodes()[rw['id']]
G.add_edge(rw['id'], rw['leftChild'], reason="{0} less than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))
G.add_edge(rw['id'], rw['rightChild'], reason="{0} greater than {1}".format(features[tempnode['featureIndex']],tempnode['thresh']))
Now let's build a function to work with the all this stuff
Note: this could be written more cleanly
#function to parse the path based on the tagvalue and it's corresponding features
def decision_path(tag2search):
wanted_row = transformed_pipeline.where("tagvalue = "+str(tag2search)).collect()[0]
wanted_features = wanted_row['features']
start_node = G.nodes()[0]
while start_node['cat'] != 'Prediction':
#do stuff with categorical variables
if start_node['numCat'] > 0:
feature_value = wanted_features[start_node['featureIndex']:start_node['featureIndex'] + start_node['numCat']]
#this assumes that you'll name all your cat variables with the following syntax 'ball' -> 'ballVec' or 'wall' -> 'wallVec'
feature_column = features[start_node['featureIndex']]
original_column = feature_column[:-3]
valToCheck = [x[original_column] for x in transformed_pipeline.select(feature_column, original_column).distinct().collect() if np.all(x[feature_column].toArray()==feature_value)][0]
if (valToCheck == wanted_row[original_column]) :
print("'{0}' value of {1} in [{2}]; ".format(original_column, wanted_row[original_column], valToCheck))
start_node = G.nodes()[start_node['leftChild']]
else:
print("'{0}' value of {1} in [{2}]; ".format(original_column, wanted_row[original_column], valToCheck))
start_node = G.nodes()[start_node['rightChild']]
#path to do stuff with non-categorical variables
else:
feature_value = wanted_features[start_node['featureIndex']]
if feature_value > start_node['thresh'][0]:
print("'{0}' value of {1} was greater than {2}; ".format(features[start_node['featureIndex']], feature_value, start_node['thresh'][0]))
start_node = G.nodes()[start_node['rightChild']]
else:
print("'{0}' value of {1} was less than or equal to {2}; ".format(features[start_node['featureIndex']], feature_value, start_node['thresh'][0]))
start_node = G.nodes()[start_node['leftChild']]
print("leads to prediction of {0}".format(start_node['predval']))
Results take this form:
[decision_path(X) for X in range(1,8)]
'fall' value of 8.0 was greater than 6.0;
'ball' value of 0.0 was less than or equal to 1.0;
'ball' value of 0.0 was less than or equal to 0.0;
leads to prediction of 21.0
'fall' value of 9.0 was greater than 6.0;
'ball' value of 1.0 was less than or equal to 1.0;
'ball' value of 1.0 was greater than 0.0;
'keep' value of 5.0 was less than or equal to 5.0;
leads to prediction of 31.0
'fall' value of 10.0 was greater than 6.0;
'ball' value of 1.0 was less than or equal to 1.0;
'ball' value of 1.0 was greater than 0.0;
'keep' value of 6.0 was greater than 5.0;
'wall' value of a in [a];
leads to prediction of 21.0
'fall' value of 11.0 was greater than 6.0;
'ball' value of 3.0 was greater than 1.0;
'wall' value of a in [a];
leads to prediction of 31.0
'fall' value of 2.0 was less than or equal to 6.0;
leads to prediction of 51.0
'fall' value of 6.0 was less than or equal to 6.0;
leads to prediction of 51.0
'fall' value of 10.0 was greater than 6.0;
'ball' value of 1.0 was less than or equal to 1.0;
'ball' value of 1.0 was greater than 0.0;
'keep' value of 6.0 was greater than 5.0;
'wall' value of c in [c];
leads to prediction of 21.0
Notes:
- If you want to stay exclusively in Spark-world you could use GraphFrames instead of networkx (I don't have that luxury :( )
- You can modify the phrasing as you wish
- If you need the impurity, impurityStats, or gain, those are all in the model information dataframe that gets saved
- I chose to work with the tree instead of parsing the
.toDebugString
because having access to the tree sounded more foundationally important (and expandable)- On that note, just looking at the .toDebugString AND the sklearn.decision_path outputs, I feel that these are more easily understandable/readable
- if you want to visualize the tree, checkout: https://github.com/tristaneljed/Decision-Tree-Visualization-Spark/blob/master/DT.py
- I had found a pure Scala implementation at some point, but can't find that again right now :(
- I feel like I'm missing a test case with the "Not In" categorical, if someone wants to throw in what that row would look like, I can edit if I have to