Sub-quadratic algorithm for fitting a curve with two lines
Disclaimer: I don't feel like figuring out how to do this in C++, so I will use Python (numpy) notation. The concepts are completely transferable, so you should have no trouble translating back to the language of your choice.
Let's say that you have a pair of arrays, x
and y
, containing the data points, and that x
is monotonically increasing. Let's also say that you will always select a partition point that leaves at least two elements in each partition, so the equations are solvable.
Now you can compute some relevant quantities:
N = len(x)
sum_x_left = x[0]
sum_x2_left = x[0] * x[0]
sum_y_left = y[0]
sum_y2_left = y[0] * y[0]
sum_xy_left = x[0] * y[0]
sum_x_right = x[1:].sum()
sum_x2_right = (x[1:] * x[1:]).sum()
sum_y_right = y[1:].sum()
sum_y2_right = (y[1:] * y[1:]).sum()
sum_xy_right = (x[1:] * y[1:]).sum()
The reason that we need these quantities (which are O(N)
to initialize) is that you can use them directly to compute some well known formulae for the parameters of a linear regression. For example, the optimal m
and b
for y = m * x + b
is given by
μx = Σxi/N μy = Σyi/N m = Σ(xi - μx)(yi - μy) / Σ(xi - μx)2 b = μy - m * μx
The sum of squared errors is given by
e = Σ(yi - m * xi - b)2
These can be expanded using simple algebra into the following:
m = (Σxiyi - ΣxiΣyi/N) / (Σxi2 - (Σxi)2/N) b = Σyi/N - m * Σxi/N e = Σyi2 + m2 * Σxi2 + N * b2 - m * Σxiyi - b * Σyi + m * b * Σxi
You can therefore loop over all the possibilities and record the minimal e
:
for p in range(1, N - 3):
# shift sums: O(1)
sum_x_left += x[p]
sum_x2_left += x[p] * x[p]
sum_y_left += y[p]
sum_y2_left += y[p] * y[p]
sum_xy_left += x[p] * y[p]
sum_x_right -= x[p]
sum_x2_right -= x[p] * x[p]
sum_y_right -= y[p]
sum_y2_right -= y[p] * y[p]
sum_xy_right -= x[p] * y[p]
# compute err: O(1)
n_left = p + 1
slope_left = (sum_xy_left - sum_x_left * sum_y_left * n_left) / (sum_x2_left - sum_x_left * sum_x_left / n_left)
intercept_left = sum_y_left / n_left - slope_left * sum_x_left / n_left
err_left = sum_y2_left + slope_left * slope_left * sum_x2_left + n_left * intercept_left * intercept_left - slope_left * sum_xy_left - intercept_left * sum_y_left + slope_left * intercept_left * sum_x_left
n_right = N - n_left
slope_right = (sum_xy_right - sum_x_right * sum_y_right * n_right) / (sum_x2_right - sum_x_right * sum_x_right / n_right)
intercept_right = sum_y_right / n_right - slope_right * sum_x_right / n_right
err_right = sum_y2_right + slope_right * slope_right * sum_x2_right + n_right * intercept_right * intercept_right - slope_right * sum_xy_right - intercept_right * sum_y_right + slope_right * intercept_right * sum_x_right
err = err_left + err_right
if p == 1 || err < err_min
err_min = err
n_min_left = n_left
n_min_right = n_right
slope_min_left = slope_left
slope_min_right = slope_right
intercept_min_left = intercept_left
intercept_min_right = intercept_right
There are probably other simplifications you can make, but this is sufficient to have an O(n)
algorithm.