Plotting two distributions in seaborn.jointplot

A better solution, in my opinion, is to use the axes handles for the joint and marginal distributions that sns.joinplot returns. Using those (the names are ax_joint, ax_marg_x and ax_marg_y) is also possible to draw on the marginal distributions plots.

import seaborn as sns
import numpy as np

data1 = np.random.randn(100)
data2 = np.random.randn(100)
data3 = np.random.randn(100)
data4 = np.random.randn(100)

df1 = pd.DataFrame({'col1': data1, 'col2':data2})
df2 = pd.DataFrame({'col1': data3, 'col2':data4})

axs = sns.jointplot('col1', 'col2', data=df1)
axs.ax_joint.scatter('col1', 'col2', data=df2, c='r', marker='x')

# drawing pdf instead of histograms on the marginal axes
axs.ax_marg_x.cla()
axs.ax_marg_y.cla()
sns.distplot(df1.col1, ax=axs.ax_marg_x)
sns.distplot(df1.col2, ax=axs.ax_marg_y, vertical=True)

distplots on marginal axes


Here is one way to do it by modifying the underlying data of sns.JointGrid.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# simulate some artificial data
# ========================================
np.random.seed(0)
data1 = np.random.multivariate_normal([0,0], [[1,0.5],[0.5,1]], size=200)
data2 = np.random.multivariate_normal([0,0], [[1,-0.8],[-0.8,1]], size=100)

# both df1 and df2 have bivaraite normals, df1.size=200, df2.size=100
df1 = pd.DataFrame(data1, columns=['x1', 'y1'])
df2 = pd.DataFrame(data2, columns=['x2', 'y2'])


# plot
# ========================================   
graph = sns.jointplot(x=df1.x1, y=df1.y1, color='r')

graph.x = df2.x2
graph.y = df2.y2
graph.plot_joint(plt.scatter, marker='x', c='b', s=50)

enter image description here


It might be easier after drawing the jointplot, change to the axis on which you want to draw something and use then normal pyplot or axis based seaborn plots:

g=sns.jointplot(...)
plt.sca("axis_name")
plt.plot/plt.scatter/.../sns.kde(ax="axis_name")

The axis name is either ax_joint for the 2d-Plot or ax_marg_x or ax_marg_y for the 1d Plots on the side.

Furthermore, if you want to use the jointplot structure but plot all plots by pyplot, use the cla function, e.g. for clearing the 2d-Plot:

g.ax_joint.cla()