Draggable line with draggable points

Ok I finally found the solution. I post it here for those who might need it. This code basically allow to have 2 draggable points linked by a line. If you move one of the points, the line follows. Very useful to make a baseline in scientific applications.

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D


class DraggablePoint:

    # http://stackoverflow.com/questions/21654008/matplotlib-drag-overlapping-points-interactively

    lock = None #  only one can be animated at a time

    def __init__(self, parent, x=0.1, y=0.1, size=0.1):

        self.parent = parent
        self.point = patches.Ellipse((x, y), size, size * 3, fc='r', alpha=0.5, edgecolor='r')
        self.x = x
        self.y = y
        parent.fig.axes[0].add_patch(self.point)
        self.press = None
        self.background = None
        self.connect()

        if self.parent.list_points:
            line_x = [self.parent.list_points[0].x, self.x]
            line_y = [self.parent.list_points[0].y, self.y]

            self.line = Line2D(line_x, line_y, color='r', alpha=0.5)
            parent.fig.axes[0].add_line(self.line)


    def connect(self):

        'connect to all the events we need'

        self.cidpress = self.point.figure.canvas.mpl_connect('button_press_event', self.on_press)
        self.cidrelease = self.point.figure.canvas.mpl_connect('button_release_event', self.on_release)
        self.cidmotion = self.point.figure.canvas.mpl_connect('motion_notify_event', self.on_motion)


    def on_press(self, event):

        if event.inaxes != self.point.axes: return
        if DraggablePoint.lock is not None: return
        contains, attrd = self.point.contains(event)
        if not contains: return
        self.press = (self.point.center), event.xdata, event.ydata
        DraggablePoint.lock = self

        # draw everything but the selected rectangle and store the pixel buffer
        canvas = self.point.figure.canvas
        axes = self.point.axes
        self.point.set_animated(True)
        if self == self.parent.list_points[1]:
            self.line.set_animated(True)
        else:
            self.parent.list_points[1].line.set_animated(True)
        canvas.draw()
        self.background = canvas.copy_from_bbox(self.point.axes.bbox)

        # now redraw just the rectangle
        axes.draw_artist(self.point)

        # and blit just the redrawn area
        canvas.blit(axes.bbox)


    def on_motion(self, event):

        if DraggablePoint.lock is not self:
            return
        if event.inaxes != self.point.axes: return
        self.point.center, xpress, ypress = self.press
        dx = event.xdata - xpress
        dy = event.ydata - ypress
        self.point.center = (self.point.center[0]+dx, self.point.center[1]+dy)

        canvas = self.point.figure.canvas
        axes = self.point.axes
        # restore the background region
        canvas.restore_region(self.background)

        # redraw just the current rectangle
        axes.draw_artist(self.point)

        if self == self.parent.list_points[1]:
            axes.draw_artist(self.line)
        else:
            self.parent.list_points[1].line.set_animated(True)
            axes.draw_artist(self.parent.list_points[1].line)

        self.x = self.point.center[0]
        self.y = self.point.center[1]

        if self == self.parent.list_points[1]:
            line_x = [self.parent.list_points[0].x, self.x]
            line_y = [self.parent.list_points[0].y, self.y]
            self.line.set_data(line_x, line_y)
        else:
            line_x = [self.x, self.parent.list_points[1].x]
            line_y = [self.y, self.parent.list_points[1].y]

            self.parent.list_points[1].line.set_data(line_x, line_y)

        # blit just the redrawn area
        canvas.blit(axes.bbox)


    def on_release(self, event):

        'on release we reset the press data'
        if DraggablePoint.lock is not self:
            return

        self.press = None
        DraggablePoint.lock = None

        # turn off the rect animation property and reset the background
        self.point.set_animated(False)
        if self == self.parent.list_points[1]:
            self.line.set_animated(False)
        else:
            self.parent.list_points[1].line.set_animated(False)

        self.background = None

        # redraw the full figure
        self.point.figure.canvas.draw()

        self.x = self.point.center[0]
        self.y = self.point.center[1]

    def disconnect(self):

        'disconnect all the stored connection ids'

        self.point.figure.canvas.mpl_disconnect(self.cidpress)
        self.point.figure.canvas.mpl_disconnect(self.cidrelease)
        self.point.figure.canvas.mpl_disconnect(self.cidmotion)

UPDATE:

How to use the DraggablePoint class, with PyQt5:

#!/usr/bin/python
# -*-coding:Utf-8 -*

import sys
import matplotlib
matplotlib.use("Qt5Agg")
from PyQt5 import QtWidgets, QtGui

from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.figure import Figure

# Personnal modules
from drag import DraggablePoint


class MyGraph(FigureCanvas):

    """A canvas that updates itself every second with a new plot."""

    def __init__(self, parent=None, width=5, height=4, dpi=100):

        self.fig = Figure(figsize=(width, height), dpi=dpi)
        self.axes = self.fig.add_subplot(111)

        self.axes.grid(True)

        FigureCanvas.__init__(self, self.fig)
        self.setParent(parent)

        FigureCanvas.setSizePolicy(self,
                                   QtWidgets.QSizePolicy.Expanding,
                                   QtWidgets.QSizePolicy.Expanding)
        FigureCanvas.updateGeometry(self)

        # To store the 2 draggable points
        self.list_points = []


        self.show()
        self.plotDraggablePoints([0.1, 0.1], [0.2, 0.2], [0.1, 0.1])


    def plotDraggablePoints(self, xy1, xy2, size=None):

        """Plot and define the 2 draggable points of the baseline"""

        # del(self.list_points[:])
        self.list_points.append(DraggablePoint(self, xy1[0], xy1[1], size))
        self.list_points.append(DraggablePoint(self, xy2[0], xy2[1], size))
        self.updateFigure()


    def clearFigure(self):

        """Clear the graph"""

        self.axes.clear()
        self.axes.grid(True)
        del(self.list_points[:])
        self.updateFigure()


    def updateFigure(self):

        """Update the graph. Necessary, to call after each plot"""

        self.draw()

if __name__ == '__main__':

    app = QtWidgets.QApplication(sys.argv)
    ex = MyGraph()
    sys.exit(app.exec_())

Here is my simple solution with the additionnal feature to add or remove points. You have then a draggable segmented line with controls on points.

The code is simple despite events handling. Improvements are welcome.

enter image description here

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D

#------------------------------------------------
listLabelPoints = []
point_alpha_default = 0.8
mousepress = None
currently_dragging = False
current_artist = None
offset = [0,0]
n = 0
line_object = None

#------------------------------------------------
def on_press(event):
    global currently_dragging
    global mousepress
    currently_dragging = True
    if event.button == 3:
        mousepress = "right"
    elif event.button == 1:
        mousepress = "left"

#------------------------------------------------
def on_release(event):
    global current_artist, currently_dragging
    current_artist = None
    currently_dragging = False

#------------------------------------------------
def on_pick(event):
    global current_artist, offset, n
    global listLabelPoints
    if current_artist is None:
        current_artist = event.artist
        #print("pick ", current_artist)
        if isinstance(event.artist, patches.Circle):
            if event.mouseevent.dblclick:
                if mousepress == "right":
                    #print("double click right")
                    if len(ax.patches) > 2:
                        #print("\ndelete", event.artist.get_label())
                        event.artist.remove()
                        xdata = list(line_object[0].get_xdata())
                        ydata = list(line_object[0].get_ydata())
                        for i in range(0,len(xdata)):
                            if event.artist.get_label() == listLabelPoints[i]:
                                xdata.pop(i) 
                                ydata.pop(i) 
                                listLabelPoints.pop(i)
                                break
                        #print('--->', listLabelPoints)
                        line_object[0].set_data(xdata, ydata)
                        plt.draw()
            else:
                x0, y0 = current_artist.center
                x1, y1 = event.mouseevent.xdata, event.mouseevent.ydata
                offset = [(x0 - x1), (y0 - y1)]
        elif isinstance(event.artist, Line2D):
            if event.mouseevent.dblclick:
                if mousepress == "left":
                    #print("double click left")
                    n = n+1
                    x, y = event.mouseevent.xdata, event.mouseevent.ydata
                    newPointLabel = "point"+str(n)
                    point_object = patches.Circle([x, y], radius=50, color='r', fill=False, lw=2,
                            alpha=point_alpha_default, transform=ax.transData, label=newPointLabel)
                    point_object.set_picker(5)
                    ax.add_patch(point_object)
                    xdata = list(line_object[0].get_xdata())
                    ydata = list(line_object[0].get_ydata())
                    #print('\ninit', listLabelPoints)
                    pointInserted = False
                    for i in range(0,len(xdata)-1):
                        #print("--> testing inclusion %s in [%s-%s]" 
                        #        %(newPointLabel, listLabelPoints[i], listLabelPoints[i+1]))
                        #print('----->', min(xdata[i],xdata[i+1]), '<', x, '<', max(xdata[i],xdata[i+1]))
                        #print('----->', min(ydata[i],ydata[i+1]), '<', y, '<', max(ydata[i],ydata[i+1]))
                        if x > min(xdata[i],xdata[i+1]) and x < max(xdata[i],xdata[i+1]) and \
                           y > min(ydata[i],ydata[i+1]) and y < max(ydata[i],ydata[i+1]) :
                            xdata.insert(i+1, x)
                            ydata.insert(i+1, y)
                            listLabelPoints.insert(i+1, newPointLabel)
                            pointInserted = True
                            #print("include", newPointLabel)
                            break
                    line_object[0].set_data(xdata, ydata)
                    #print('final', listLabelPoints)
                    plt.draw()
                    if not pointInserted:
                        print("Error: point not inserted")
            else:
                xdata = event.artist.get_xdata()
                ydata = event.artist.get_ydata()
                x1, y1 = event.mouseevent.xdata, event.mouseevent.ydata
                offset = xdata[0] - x1, ydata[0] - y1

#------------------------------------------------
def on_motion(event):
    global current_artist
    if not currently_dragging:
        return
    if current_artist == None:
        return
    if event.xdata == None:
        return
    dx, dy = offset
    if isinstance(current_artist, patches.Circle):
        cx, cy =  event.xdata + dx, event.ydata + dy
        current_artist.center = cx, cy
        #print("moving", current_artist.get_label())
        xdata = list(line_object[0].get_xdata())
        ydata = list(line_object[0].get_ydata())
        for i in range(0,len(xdata)): 
            if listLabelPoints[i] == current_artist.get_label():
                xdata[i] = cx
                ydata[i] = cy
                break
        line_object[0].set_data(xdata, ydata)
    elif isinstance(current_artist, Line2D):
        xdata = list(line_object[0].get_xdata())
        ydata = list(line_object[0].get_ydata())
        xdata0 = xdata[0]
        ydata0 = ydata[0]
        for i in range(0,len(xdata)): 
                xdata[i] = event.xdata + dx + xdata[i] - xdata0
                ydata[i] = event.ydata + dy + ydata[i] - ydata0 
        line_object[0].set_data(xdata, ydata)
        for p in ax.patches:
            pointLabel = p.get_label()
            i = listLabelPoints.index(pointLabel) 
            p.center = xdata[i], ydata[i]
    plt.draw()

#------------------------------------------------
def on_click(event):
    global n, line_object
    if event and event.dblclick:
        if len(listLabelPoints) < 2:
            n = n+1
            x, y = event.xdata, event.ydata
            newPointLabel = "point"+str(n)
            point_object = patches.Circle([x, y], radius=50, color='r', fill=False, lw=2,
                    alpha=point_alpha_default, transform=ax.transData, label=newPointLabel)
            point_object.set_picker(5)
            ax.add_patch(point_object)
            listLabelPoints.append(newPointLabel)
            if len(listLabelPoints) == 2:
                xdata = []
                ydata = []
                for p in ax.patches:
                    cx, cy = p.center
                    xdata.append(cx)
                    ydata.append(cy)
                line_object = ax.plot(xdata, ydata, alpha=0.5, c='r', lw=2, picker=True)
                line_object[0].set_pickradius(5)
            plt.draw()

#================================================
fig, ax = plt.subplots()

ax.set_title("Double click left button to create draggable point\nDouble click right to remove a point", loc="left")
ax.set_xlim(0, 4000)
ax.set_ylim(0, 3000)
ax.set_aspect('equal')

fig.canvas.mpl_connect('button_press_event', on_click)
fig.canvas.mpl_connect('button_press_event', on_press)
fig.canvas.mpl_connect('button_release_event', on_release)
fig.canvas.mpl_connect('pick_event', on_pick)
fig.canvas.mpl_connect('motion_notify_event', on_motion)

plt.grid(True)
plt.show()