Seaborn Jointplot add colors for each class
The obvious solution is to let the regplot
only draw the regression line, but not the points and add those via a usual scatter plot, which has the color c
argument.
g = sns.jointplot(X, y, kind='reg', scatter = False )
g.ax_joint.scatter(X,y, c=classes)
I managed to find a solution that is exactly what I need. Thank to @ImportanceOfBeingErnest that gave me the idea to let the regplot
only draw the regression line.
Solution:
import pandas as pd
classes = np.array([1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2.,
2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
3., 3., 3., 3., 3., 3., 3.])
df = pd.DataFrame(map(list, zip(*[X.T, y.ravel().T])))
df = df.reset_index()
df['index'] = classes[:]
g = sns.jointplot(X, y, kind='reg', scatter = False )
for i, subdata in df.groupby("index"):
sns.kdeplot(subdata.iloc[:,1], ax=g.ax_marg_x, legend=False)
sns.kdeplot(subdata.iloc[:,2], ax=g.ax_marg_y, vertical=True, legend=False)
g.ax_joint.plot(subdata.iloc[:,1], subdata.iloc[:,2], "o", ms = 8)
plt.tight_layout()
plt.show()