Efficient rolling trimmed mean with Python
One observation that could come in handy is that you do not need to sort all the values at each step. Rather, if you ensure that the window is always sorted, all you need to do is insert the new value at the relevant spot, and remove the old one from where it was, both of which are operations that can be done in O(log_2(window_size)) using bisect
. In practice, this would look something like
def rolling_mean(data):
x = sorted(data[:49])
res = np.repeat(np.nan, len(data))
for i in range(49, len(data)):
if i != 49:
del x[bisect.bisect_left(x, data[i - 50])]
bisect.insort_right(x, data[i])
res[i] = np.mean(x[3:47])
return res
Now, the additional benefit in this case turns out to be less than what is gained by the vectorization that scipy.stats.trim_mean
relies on, and so in particular, this will still be slower than @ChrisA's solution, but it is a useful starting point for further performance optimization.
> data = pd.Series(np.random.randint(0, 1000, 50000))
> %timeit data.rolling(50).apply(lambda w: trim_mean(w, 0.06))
727 ms ± 34.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> %timeit rolling_mean(data.values)
812 ms ± 42.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Notably, Numba's jitter, which is often useful in situations like these, also provides no benefit:
> from numba import jit
> rolling_mean_jit = jit(rolling_mean)
> %timeit rolling_mean_jit(data.values)
1.05 s ± 183 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
The following, seemingly far-from-optimal, approach outperforms both of the other approaches considered above:
def rolling_mean_np(data):
res = np.repeat(np.nan, len(data))
for i in range(len(data)-49):
x = np.sort(data[i:i+50])
res[i+49] = x[3:47].mean()
return res
Timing:
> %timeit rolling_mean_np(data.values)
564 ms ± 4.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
What is more, this time around, JIT compilation does help:
> rolling_mean_np_jit = jit(rolling_mean_np)
> %timeit rolling_mean_np_jit(data.values)
94.9 ms ± 605 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
While we're at it, let's just quickly verify that this actually does what we expect it to:
> np.all(rolling_mean_np_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True
In fact, by helping out the sorter just a little bit, we can squeeze out another factor of 2, taking the total time down to 57 ms:
def rolling_mean_np_manual(data):
x = np.sort(data[:50])
res = np.repeat(np.nan, len(data))
for i in range(50, len(data)+1):
res[i-1] = x[3:47].mean()
if i != len(data):
idx_old = np.searchsorted(x, data[i-50])
x[idx_old] = data[i]
x.sort()
return res
> %timeit rolling_mean_np_manual(data.values)
580 ms ± 23 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> rolling_mean_np_manual_jit = jit(rolling_mean_np_manual)
> %timeit rolling_mean_np_manual_jit(data.values)
57 ms ± 5.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> np.all(rolling_mean_np_manual_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True
Now, the "sorting" that is going on in this example of course just boils down to placing the new element in the right place, while shifting everything in between by one. Doing this by hand will make the pure Python code slower, but the jitted version gains another factor of 2, taking us below 30 ms:
def rolling_mean_np_shift(data):
x = np.sort(data[:50])
res = np.repeat(np.nan, len(data))
for i in range(50, len(data)+1):
res[i-1] = x[3:47].mean()
if i != len(data):
idx_old, idx_new = np.searchsorted(x, [data[i-50], data[i]])
if idx_old < idx_new:
x[idx_old:idx_new-1] = x[idx_old+1:idx_new]
x[idx_new-1] = data[i]
elif idx_new < idx_old:
x[idx_new+1:idx_old+1] = x[idx_new:idx_old]
x[idx_new] = data[i]
else:
x[idx_new] = data[i]
return res
> %timeit rolling_mean_np_shift(data.values)
937 ms ± 97.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
> rolling_mean_np_shift_jit = jit(rolling_mean_np_shift)
> %timeit rolling_mean_np_shift_jit(data.values)
26.4 ms ± 693 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
> np.all(rolling_mean_np_shift_jit(data.values)[49:] == data.rolling(50).apply(lambda w: trim_mean(w, 0.06)).values[49:])
True
At this point, most of the time is spent in np.searchsorted
, so let us make the search itself JIT-friendly. Adopting the source code for bisect
, we let
@jit
def binary_search(a, x):
lo = 0
hi = 50
while lo < hi:
mid = (lo+hi)//2
if a[mid] < x: lo = mid+1
else: hi = mid
return lo
@jit
def rolling_mean_np_jitted_search(data):
x = np.sort(data[:50])
res = np.repeat(np.nan, len(data))
for i in range(50, len(data)+1):
res[i-1] = x[3:47].mean()
if i != len(data):
idx_old = binary_search(x, data[i-50])
idx_new = binary_search(x, data[i])
if idx_old < idx_new:
x[idx_old:idx_new-1] = x[idx_old+1:idx_new]
x[idx_new-1] = data[i]
elif idx_new < idx_old:
x[idx_new+1:idx_old+1] = x[idx_new:idx_old]
x[idx_new] = data[i]
else:
x[idx_new] = data[i]
return res
This takes us down to 12 ms, a x60 improvement over the raw pandas+SciPy approach:
> %timeit rolling_mean_np_jitted_search(data.values)
12 ms ± 210 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
You might try using scipy.stats.trim_mean
:
from scipy.stats import trim_mean
df['value'].rolling(5).apply(lambda x: trim_mean(x, 0.2))
[output]
0 NaN
1 NaN
2 NaN
3 NaN
4 10.000000
5 11.000000
6 13.000000
7 13.333333
8 14.000000
9 15.666667
Note that I had to use rolling(5)
and proportiontocut=0.2
for your toy data set.
For your real data you should use rolling(50)
and trim_mean(x, 0.06)
to remove the top and bottom 3 values from the rolling window.