How to plot 2 seaborn lmplots side-by-side?
If the intention of using lmplot
is to use hue
for two different sets of variables, regplot
may not be sufficient without some tweaks.
In order to use of seaborn's lmplot
hue
argument in two side-by-side plots, one possible solution is:
def hue_regplot(data, x, y, hue, palette=None, **kwargs):
from matplotlib.cm import get_cmap
regplots = []
levels = data[hue].unique()
if palette is None:
default_colors = get_cmap('tab10')
palette = {k: default_colors(i) for i, k in enumerate(levels)}
for key in levels:
regplots.append(
sns.regplot(
x=x,
y=y,
data=data[data[hue] == key],
color=palette[key],
**kwargs
)
)
return regplots
This function give result similar to lmplot
(with hue
option), but accepts the ax
argument, necessary for creating a composite figure.
An example of usage is
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
%matplotlib inline
rnd = np.random.default_rng(1234567890)
# create df
x = np.linspace(0, 2 * np.pi, 400)
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2),
'color1': rnd.integers(0,2, size=400), 'color2': rnd.integers(0,3, size=400)}) # color for exemplification
# Two subplots
f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
# ax1.plot(df.x, df.y)
ax1.set_title('Sharing Y axis')
# ax2.scatter(df.x, df.y)
hue_regplot(data=df, x='x', y='y', hue='color1', ax=ax1)
hue_regplot(data=df, x='x', y='y', hue='color2', ax=ax2)
plt.show()
You get that error because matplotlib and its objects are completely unaware of seaborn functions.
Pass your axes objects (i.e., ax1
and ax2
) to seaborn.regplot
or you can skip defining those and use the col
kwarg of seaborn.lmplot
With your same imports, pre-defining your axes and using regplot
looks like this:
# create df
x = np.linspace(0, 2 * np.pi, 400)
df = pd.DataFrame({'x': x, 'y': np.sin(x ** 2)})
df.index.names = ['obs']
df.columns.names = ['vars']
idx = np.array(df.index.tolist(), dtype='float') # make an array of x-values
# call regplot on each axes
fig, (ax1, ax2) = plt.subplots(ncols=2, sharey=True)
sns.regplot(x=idx, y=df['x'], ax=ax1)
sns.regplot(x=idx, y=df['y'], ax=ax2)
Using lmplot requires your dataframe to be tidy. Continuing from the code above:
tidy = (
df.stack() # pull the columns into row variables
.to_frame() # convert the resulting Series to a DataFrame
.reset_index() # pull the resulting MultiIndex into the columns
.rename(columns={0: 'val'}) # rename the unnamed column
)
sns.lmplot(x='obs', y='val', col='vars', hue='vars', data=tidy)