Enumerable/Stream with look ahead
As discussed in the comments, my first attempt had some performance problems and didn't work with streams that have side-effects, such as IO streams. I took the time to dig deeper into the stream library and finally came up with this solution:
defmodule MyStream
def lookahead(enum, n) do
step = fn val, _acc -> {:suspend, val} end
next = &Enumerable.reduce(enum, &1, step)
&do_lookahead(n, :buffer, [], next, &1, &2)
end
# stream suspended
defp do_lookahead(n, state, buf, next, {:suspend, acc}, fun) do
{:suspended, acc, &do_lookahead(n, state, buf, next, &1, fun)}
end
# stream halted
defp do_lookahead(_n, _state, _buf, _next, {:halt, acc}, _fun) do
{:halted, acc}
end
# initial buffering
defp do_lookahead(n, :buffer, buf, next, {:cont, acc}, fun) do
case next.({:cont, []}) do
{:suspended, val, next} ->
new_state = if length(buf) < n, do: :buffer, else: :emit
do_lookahead(n, new_state, buf ++ [val], next, {:cont, acc}, fun)
{_, _} ->
do_lookahead(n, :emit, buf, next, {:cont, acc}, fun)
end
end
# emitting
defp do_lookahead(n, :emit, [_|rest] = buf, next, {:cont, acc}, fun) do
case next.({:cont, []}) do
{:suspended, val, next} ->
do_lookahead(n, :emit, rest ++ [val], next, fun.(buf, acc), fun)
{_, _} ->
do_lookahead(n, :emit, rest, next, fun.(buf, acc), fun)
end
end
# buffer empty, halting
defp do_lookahead(_n, :emit, [], _next, {:cont, acc}, _fun) do
{:halted, acc}
end
end
This may look daunting at first, but actually it's not that hard. I will try to break it down for you, but that's hard with a full-fledged example like this.
Let's start with a simpler example instead: A stream that endlessly repeats the value given to it. In order to emit a stream, we can return a function that takes an accumulator and a function as argument. To emit a value, we call the function with two arguments: the value to emit and the accumulator. acc
The accumulator is a tuple that consists of a command (:cont
, :suspend
or :halt
) and tells us what the consumer wants us to do; the result we need to return depends on the operation. If the stream should be suspended, we return a three-element tuple of the atom :suspended
, the accumulator and a function that will be called when the enumeration continues (sometimes called "continuation"). For the :halt
command, we simply return {:halted, acc}
and for the :cont
we emit a value by performing the recursive step as described above. The whole thing then looks like this:
defmodule MyStream do
def repeat(val) do
&do_repeat(val, &1, &2)
end
defp do_repeat(val, {:suspend, acc}, fun) do
{:suspended, acc, &do_repeat(val, &1, fun)}
end
defp do_repeat(_val, {:halt, acc}, _fun) do
{:halted, acc}
end
defp do_repeat(val, {:cont, acc}, fun) do
do_repeat(val, fun.(val, acc), fun)
end
end
Now this is only one part of the puzzle. We can emit a stream, but we don't process an incoming stream yet. Again, to explain how that works it makes sense to construct a simpler example. Here, I will build a function that takes an enumerable and just suspends and re-emits for every value.
defmodule MyStream do
def passthrough(enum) do
step = fn val, _acc -> {:suspend, val} end
next = &Enumerable.reduce(enum, &1, step)
&do_passthrough(next, &1, &2)
end
defp do_passthrough(next, {:suspend, acc}, fun) do
{:suspended, acc, &do_passthrough(next, &1, fun)}
end
defp do_passthrough(_next, {:halt, acc}, _fun) do
{:halted, acc}
end
defp do_passthrough(next, {:cont, acc}, fun) do
case next.({:cont, []}) do
{:suspended, val, next} ->
do_passthrough(next, fun.(val, acc), fun)
{_, _} ->
{:halted, acc}
end
end
end
The first clause sets up the next
function that gets passed down to the do_passthrough
function. It serves the purpose of getting the next value from the incoming stream. The step function that is internally used defines that we suspend for every item in the stream. The rest is pretty similar except for the last clause. Here, we call the next function with {:cont, []}
to get a new value and process the result by means of a case statement. If there is a value, we get back {:suspended, val, next}
, if not, the stream is halted and we pass that through to the consumer.
I hope that clarifies a few things about how to build streams in Elixir manually. Unfortunately, there's an awful lot of boilerplate required to work with streams. If you go back to the lookahead
implementation now, you will see that there are only tiny differences, which are the actually interesting parts. There are two additional parameters: state
, which serves to differentiate between the :buffer
and :emit
steps, and buffer
which is pre-filled with n+1
items in the initial buffering step. In the emit phase, the current buffer is emitted and then shifted to the left on each iteration. We're done when the input stream halts or our stream is halted directly.
I am leaving my original answer here for reference:
Here's a solution that uses Stream.unfold/2
to emit a true stream of values
according to your specification. This means you need to add Enum.to_list
to
the end of your first two examples to obtain the actual values.
defmodule MyStream do
def lookahead(stream, n) do
Stream.unfold split(stream, n+1), fn
{[], stream} ->
nil
{[_ | buf] = current, stream} ->
{value, stream} = split(stream, 1)
{current, {buf ++ value, stream}}
end
end
defp split(stream, n) do
{Enum.take(stream, n), Stream.drop(stream, n)}
end
end
The general idea is that we keep a buf of the previous iterations around. On each iteration, we emit the current buf, take one value from the stream and append it to the end of the buf. This repeats until the buf is empty.
Example:
iex> MyStream.lookahead(1..6, 1) |> Enum.to_list
[[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6]]
iex> MyStream.lookahead(1..4, 2) |> Enum.to_list
[[1, 2, 3], [2, 3, 4], [3, 4], [4]]
iex> Stream.cycle(1..3) |> MyStream.lookahead(2) |> Enum.take(5)
[[1, 2, 3], [2, 3, 1], [3, 1, 2], [1, 2, 3], [2, 3, 1]]
Here is an inefficient implementation of such function:
defmodule Lookahead do
def lookahead(enumerable, n) when n > 0 do
enumerable
|> Stream.chunk(n + 1, 1, [])
|> Stream.flat_map(fn list ->
length = length(list)
if length < n + 1 do
[list|Enum.scan(1..n-1, list, fn _, acc -> Enum.drop(acc, 1) end)]
else
[list]
end
end)
end
end
It builds on top of @hahuang65 implementation, except that we use a Stream.flat_map/2
to check the length of each emitted item, adding the missing ones as soon as we detect the emitted item got shorter.
A hand-written implementation from scratch would be faster because we would not need to call length(list)
on every iteration. The implementation above may be fine though if n
is small. If n is fixed, you could even pattern match on the generated list explicitly.
I had started a discussion about my proposed Stream.mutate
method on the elixir core mailing list, where Peter Hamilton suggested another way of solving this problem. By using make_ref
to create a globally unique reference, we can create a padding stream and concatenate it with the original enumerable to continue emitting after the original stream has halted. This can then either be used in conjunction with Stream.chunk
, which means we need to remove the unwanted references in a last step:
def lookahead(enum, n) do
stop = make_ref
enum
|> Stream.concat(List.duplicate(stop, n))
|> Stream.chunk(n+1, 1)
|> Stream.map(&Enum.reject(&1, fn x -> x == stop end))
end
I think this is the prettiest solution yet, from a syntactical point of view. Alternatively, we can use Stream.transform
to build the buffer manually, which is quite similar to the manual solution I proposed earlier:
def lookahead(enum, n) do
stop = make_ref
enum
|> Stream.concat(List.duplicate(stop, n+1))
|> Stream.transform([], fn val, acc ->
case {val, acc} do
{^stop, []} -> {[] , [] }
{^stop, [_|rest] = buf} -> {[buf], rest }
{val , buf} when length(buf) < n+1 -> {[] , buf ++ [val] }
{val , [_|rest] = buf} -> {[buf], rest ++ [val]}
end
end)
end
I haven't benchmarked these solutions but I suppose the second one, although slightly clunkier, should perform a little bit better because it does not have to iterate over each chunk.
By the way, the second solution can be written without the case statement once Elixir allows to use the pin operator in function heads (probably in v1.1.0):
def lookahead(enum, n) do
stop = make_ref
enum
|> Stream.concat(List.duplicate(stop, n+1))
|> Stream.transform([], fn
^stop, [] -> {[] , [] }
^stop, [_|rest] = buf -> {[buf], rest }
val , buf when length(buf) < n+1 -> {[] , buf ++ [val] }
val , [_|rest] = buf -> {[buf], rest ++ [val]}
end)
end