Plotting two distributions in seaborn.jointplot
A better solution, in my opinion, is to use the axes handles for the joint and marginal distributions that sns.joinplot
returns. Using those (the names are ax_joint
, ax_marg_x
and ax_marg_y
) is also possible to draw on the marginal distributions plots.
import seaborn as sns
import numpy as np
data1 = np.random.randn(100)
data2 = np.random.randn(100)
data3 = np.random.randn(100)
data4 = np.random.randn(100)
df1 = pd.DataFrame({'col1': data1, 'col2':data2})
df2 = pd.DataFrame({'col1': data3, 'col2':data4})
axs = sns.jointplot('col1', 'col2', data=df1)
axs.ax_joint.scatter('col1', 'col2', data=df2, c='r', marker='x')
# drawing pdf instead of histograms on the marginal axes
axs.ax_marg_x.cla()
axs.ax_marg_y.cla()
sns.distplot(df1.col1, ax=axs.ax_marg_x)
sns.distplot(df1.col2, ax=axs.ax_marg_y, vertical=True)
Here is one way to do it by modifying the underlying data of sns.JointGrid
.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# simulate some artificial data
# ========================================
np.random.seed(0)
data1 = np.random.multivariate_normal([0,0], [[1,0.5],[0.5,1]], size=200)
data2 = np.random.multivariate_normal([0,0], [[1,-0.8],[-0.8,1]], size=100)
# both df1 and df2 have bivaraite normals, df1.size=200, df2.size=100
df1 = pd.DataFrame(data1, columns=['x1', 'y1'])
df2 = pd.DataFrame(data2, columns=['x2', 'y2'])
# plot
# ========================================
graph = sns.jointplot(x=df1.x1, y=df1.y1, color='r')
graph.x = df2.x2
graph.y = df2.y2
graph.plot_joint(plt.scatter, marker='x', c='b', s=50)
It might be easier after drawing the jointplot, change to the axis on which you want to draw something and use then normal pyplot or axis based seaborn plots:
g=sns.jointplot(...)
plt.sca("axis_name")
plt.plot/plt.scatter/.../sns.kde(ax="axis_name")
The axis name is either ax_joint
for the 2d-Plot or ax_marg_x
or ax_marg_y
for the 1d Plots on the side.
Furthermore, if you want to use the jointplot structure but plot all plots by pyplot, use the cla
function, e.g. for clearing the 2d-Plot:
g.ax_joint.cla()