plot different color for different categorical levels using matplotlib
Here's a succinct and generic solution to use a seaborn color palette.
First find a color palette you like and optionally visualize it:
sns.palplot(sns.color_palette("Set2", 8))
Then you can use it with matplotlib
doing this:
# Unique category labels: 'D', 'F', 'G', ...
color_labels = df['color'].unique()
# List of RGB triplets
rgb_values = sns.color_palette("Set2", 8)
# Map label to RGB
color_map = dict(zip(color_labels, rgb_values))
# Finally use the mapped values
plt.scatter(df['carat'], df['price'], c=df['color'].map(color_map))
Imports and Sample DataFrame
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns # for sample data
from matplotlib.lines import Line2D # for legend handle
# DataFrame used for all options
df = sns.load_dataset('diamonds')
carat cut color clarity depth table price x y z
0 0.23 Ideal E SI2 61.5 55.0 326 3.95 3.98 2.43
1 0.21 Premium E SI1 59.8 61.0 326 3.89 3.84 2.31
2 0.23 Good E VS1 56.9 65.0 327 4.05 4.07 2.31
With matplotlib
You can pass plt.scatter
a c
argument, which allows you to select the colors. The following code defines a colors
dictionary to map the diamond colors to the plotting colors.
fig, ax = plt.subplots(figsize=(6, 6))
colors = {'D':'tab:blue', 'E':'tab:orange', 'F':'tab:green', 'G':'tab:red', 'H':'tab:purple', 'I':'tab:brown', 'J':'tab:pink'}
ax.scatter(df['carat'], df['price'], c=df['color'].map(colors))
# add a legend
handles = [Line2D([0], [0], marker='o', color='w', markerfacecolor=v, label=k, markersize=8) for k, v in colors.items()]
ax.legend(title='color', handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left')
plt.show()
df['color'].map(colors)
effectively maps the colors from "diamond" to "plotting".
(Forgive me for not putting another example image up, I think 2 is enough :P)
With seaborn
You can use seaborn
which is a wrapper around matplotlib
that makes it look prettier by default (rather opinion-based, I know :P) but also adds some plotting functions.
For this you could use seaborn.lmplot
with fit_reg=False
(which prevents it from automatically doing some regression).
sns.scatterplot(x='carat', y='price', data=df, hue='color', ec=None)
also does the same thing.
Selecting hue='color'
tells seaborn to split and plot the data based on the unique values in the 'color'
column.
sns.lmplot(x='carat', y='price', data=df, hue='color', fit_reg=False)
With pandas.DataFrame.groupby
& pandas.DataFrame.plot
If you don't want to use seaborn, use pandas.groupby
to get the colors alone, and then plot them using just matplotlib, but you'll have to manually assign colors as you go, I've added an example below:
fig, ax = plt.subplots(figsize=(6, 6))
grouped = df.groupby('color')
for key, group in grouped:
group.plot(ax=ax, kind='scatter', x='carat', y='price', label=key, color=colors[key])
plt.show()
This code assumes the same DataFrame as above, and then groups it based on color
. It then iterates over these groups, plotting for each one. To select a color, I've created a colors
dictionary, which can map the diamond color (for instance D
) to a real color (for instance tab:blue
).