Bland-Altman plot in Python

I modified a bit the excellent code of @sodd to add a few more labels and text so that it would maybe be more publication ready

enter image description here

import matplotlib.pyplot as plt
import numpy as np
import pdb
from numpy.random import random

def bland_altman_plot(data1, data2, *args, **kwargs):
    data1     = np.asarray(data1)
    data2     = np.asarray(data2)
    mean      = np.mean([data1, data2], axis=0)
    diff      = data1 - data2                   # Difference between data1 and data2
    md        = np.mean(diff)                   # Mean of the difference
    sd        = np.std(diff, axis=0)            # Standard deviation of the difference
    CI_low    = md - 1.96*sd
    CI_high   = md + 1.96*sd

    plt.scatter(mean, diff, *args, **kwargs)
    plt.axhline(md,           color='black', linestyle='-')
    plt.axhline(md + 1.96*sd, color='gray', linestyle='--')
    plt.axhline(md - 1.96*sd, color='gray', linestyle='--')
    return md, sd, mean, CI_low, CI_high


md, sd, mean, CI_low, CI_high = bland_altman_plot(random(10), random(10))
plt.title(r"$\mathbf{Bland-Altman}$" + " " + r"$\mathbf{Plot}$")
plt.xlabel("Means")
plt.ylabel("Difference")
plt.ylim(md - 3.5*sd, md + 3.5*sd)

xOutPlot = np.min(mean) + (np.max(mean)-np.min(mean))*1.14

plt.text(xOutPlot, md - 1.96*sd, 
    r'-1.96SD:' + "\n" + "%.2f" % CI_low, 
    ha = "center",
    va = "center",
    )
plt.text(xOutPlot, md + 1.96*sd, 
    r'+1.96SD:' + "\n" + "%.2f" % CI_high, 
    ha = "center",
    va = "center",
    )
plt.text(xOutPlot, md, 
    r'Mean:' + "\n" + "%.2f" % md, 
    ha = "center",
    va = "center",
    )
plt.subplots_adjust(right=0.85)
plt.show()


This is now implemented in statsmodels: https://www.statsmodels.org/devel/generated/statsmodels.graphics.agreement.mean_diff_plot.html

Here is their example:

import statsmodels.api as sm
import numpy as np
import matplotlib.pyplot as plt

# Seed the random number generator.
# This ensures that the results below are reproducible.
np.random.seed(9999)
m1 = np.random.random(20)
m2 = np.random.random(20)

f, ax = plt.subplots(1, figsize = (8,5))
sm.graphics.mean_diff_plot(m1, m2, ax = ax)

plt.show()

which produces this:

enter image description here


If I have understood the theory behind the plot correctly, this code should provide the basic plotting, whereas you can configure it to your own particular needs.

import matplotlib.pyplot as plt
import numpy as np

def bland_altman_plot(data1, data2, *args, **kwargs):
    data1     = np.asarray(data1)
    data2     = np.asarray(data2)
    mean      = np.mean([data1, data2], axis=0)
    diff      = data1 - data2                   # Difference between data1 and data2
    md        = np.mean(diff)                   # Mean of the difference
    sd        = np.std(diff, axis=0)            # Standard deviation of the difference

    plt.scatter(mean, diff, *args, **kwargs)
    plt.axhline(md,           color='gray', linestyle='--')
    plt.axhline(md + 1.96*sd, color='gray', linestyle='--')
    plt.axhline(md - 1.96*sd, color='gray', linestyle='--')

The corresponding elements in data1 and data2 are used to calculate the coordinates for the plotted points.

Then you can create a plot by running e.g.

from numpy.random import random

bland_altman_plot(random(10), random(10))
plt.title('Bland-Altman Plot')
plt.show()

Bland-Altman Plot


I took sodd's answer and made a plotly implementation. This seems like the best place to share it easily.

from scipy.stats import linregress
import numpy as np
import plotly.graph_objects as go
def bland_altman_plot(data1, data2, data1_name='A', data2_name='B', subgroups=None, plotly_template='none', annotation_offset=0.05, plot_trendline=True, n_sd=1.96,*args, **kwargs):
    data1 = np.asarray( data1 )
    data2 = np.asarray( data2 )
    mean = np.mean( [data1, data2], axis=0 )
    diff = data1 - data2  # Difference between data1 and data2
    md = np.mean( diff )  # Mean of the difference
    sd = np.std( diff, axis=0 )  # Standard deviation of the difference


    fig = go.Figure()

    if plot_trendline:
        slope, intercept, r_value, p_value, std_err = linregress(mean, diff)
        trendline_x = np.linspace(mean.min(), mean.max(), 10)
        fig.add_trace(go.Scatter(x=trendline_x, y=slope*trendline_x + intercept,
                                 name='Trendline',
                                 mode='lines',
                                 line=dict(
                                        width=4,
                                        dash='dot')))
    if subgroups is None:
        fig.add_trace( go.Scatter( x=mean, y=diff, mode='markers', **kwargs))
    else:
        for group_name in np.unique(subgroups):
            group_mask = np.where(np.array(subgroups) == group_name)
            fig.add_trace( go.Scatter(x=mean[group_mask], y=diff[group_mask], mode='markers', name=str(group_name), **kwargs))



    fig.add_shape(
        # Line Horizontal
        type="line",
        xref="paper",
        x0=0,
        y0=md,
        x1=1,
        y1=md,
        line=dict(
            # color="Black",
            width=6,
            dash="dashdot",
        ),
        name=f'Mean {round( md, 2 )}',
    )
    fig.add_shape(
        # borderless Rectangle
        type="rect",
        xref="paper",
        x0=0,
        y0=md - n_sd * sd,
        x1=1,
        y1=md + n_sd * sd,
        line=dict(
            color="SeaGreen",
            width=2,
        ),
        fillcolor="LightSkyBlue",
        opacity=0.4,
        name=f'±{n_sd} Standard Deviations'
    )

    # Edit the layout
    fig.update_layout( title=f'Bland-Altman Plot for {data1_name} and {data2_name}',
                       xaxis_title=f'Average of {data1_name} and {data2_name}',
                       yaxis_title=f'{data1_name} Minus {data2_name}',
                       template=plotly_template,
                       annotations=[dict(
                                        x=1,
                                        y=md,
                                        xref="paper",
                                        yref="y",
                                        text=f"Mean {round(md,2)}",
                                        showarrow=True,
                                        arrowhead=7,
                                        ax=50,
                                        ay=0
                                    ),
                                   dict(
                                       x=1,
                                       y=n_sd*sd + md + annotation_offset,
                                       xref="paper",
                                       yref="y",
                                       text=f"+{n_sd} SD",
                                       showarrow=False,
                                       arrowhead=0,
                                       ax=0,
                                       ay=-20
                                   ),
                                   dict(
                                       x=1,
                                       y=md - n_sd *sd + annotation_offset,
                                       xref="paper",
                                       yref="y",
                                       text=f"-{n_sd} SD",
                                       showarrow=False,
                                       arrowhead=0,
                                       ax=0,
                                       ay=20
                                   ),
                                   dict(
                                       x=1,
                                       y=md + n_sd * sd - annotation_offset,
                                       xref="paper",
                                       yref="y",
                                       text=f"{round(md + n_sd*sd, 2)}",
                                       showarrow=False,
                                       arrowhead=0,
                                       ax=0,
                                       ay=20
                                   ),
                                   dict(
                                       x=1,
                                       y=md - n_sd * sd - annotation_offset,
                                       xref="paper",
                                       yref="y",
                                       text=f"{round(md - n_sd*sd, 2)}",
                                       showarrow=False,
                                       arrowhead=0,
                                       ax=0,
                                       ay=20
                                   )
                               ])
    return fig