Enumerable/Stream with look ahead

As discussed in the comments, my first attempt had some performance problems and didn't work with streams that have side-effects, such as IO streams. I took the time to dig deeper into the stream library and finally came up with this solution:

defmodule MyStream
  def lookahead(enum, n) do
    step = fn val, _acc -> {:suspend, val} end
    next = &Enumerable.reduce(enum, &1, step)
    &do_lookahead(n, :buffer, [], next, &1, &2)
  end

  # stream suspended
  defp do_lookahead(n, state, buf, next, {:suspend, acc}, fun) do
    {:suspended, acc, &do_lookahead(n, state, buf, next, &1, fun)}
  end

  # stream halted
  defp do_lookahead(_n, _state, _buf, _next, {:halt, acc}, _fun) do
    {:halted, acc}
  end

  # initial buffering
  defp do_lookahead(n, :buffer, buf, next, {:cont, acc}, fun) do
    case next.({:cont, []}) do
      {:suspended, val, next} ->
        new_state = if length(buf) < n, do: :buffer, else: :emit
        do_lookahead(n, new_state, buf ++ [val], next, {:cont, acc}, fun)
      {_, _} ->
        do_lookahead(n, :emit, buf, next, {:cont, acc}, fun)
    end
  end

  # emitting
  defp do_lookahead(n, :emit, [_|rest] = buf, next, {:cont, acc}, fun) do
    case next.({:cont, []}) do
      {:suspended, val, next} ->
        do_lookahead(n, :emit, rest ++ [val], next, fun.(buf, acc), fun)
      {_, _} ->
        do_lookahead(n, :emit, rest, next, fun.(buf, acc), fun)
    end
  end

  # buffer empty, halting
  defp do_lookahead(_n, :emit, [], _next, {:cont, acc}, _fun) do
    {:halted, acc}
  end
end

This may look daunting at first, but actually it's not that hard. I will try to break it down for you, but that's hard with a full-fledged example like this.

Let's start with a simpler example instead: A stream that endlessly repeats the value given to it. In order to emit a stream, we can return a function that takes an accumulator and a function as argument. To emit a value, we call the function with two arguments: the value to emit and the accumulator. acc The accumulator is a tuple that consists of a command (:cont, :suspend or :halt) and tells us what the consumer wants us to do; the result we need to return depends on the operation. If the stream should be suspended, we return a three-element tuple of the atom :suspended, the accumulator and a function that will be called when the enumeration continues (sometimes called "continuation"). For the :halt command, we simply return {:halted, acc} and for the :cont we emit a value by performing the recursive step as described above. The whole thing then looks like this:

defmodule MyStream do
  def repeat(val) do
    &do_repeat(val, &1, &2)
  end

  defp do_repeat(val, {:suspend, acc}, fun) do
    {:suspended, acc, &do_repeat(val, &1, fun)}
  end

  defp do_repeat(_val, {:halt, acc}, _fun) do
    {:halted, acc}
  end

  defp do_repeat(val, {:cont, acc}, fun) do
    do_repeat(val, fun.(val, acc), fun)
  end
end

Now this is only one part of the puzzle. We can emit a stream, but we don't process an incoming stream yet. Again, to explain how that works it makes sense to construct a simpler example. Here, I will build a function that takes an enumerable and just suspends and re-emits for every value.

defmodule MyStream do
  def passthrough(enum) do
    step = fn val, _acc -> {:suspend, val} end
    next = &Enumerable.reduce(enum, &1, step)
    &do_passthrough(next, &1, &2)
  end

  defp do_passthrough(next, {:suspend, acc}, fun) do
    {:suspended, acc, &do_passthrough(next, &1, fun)}
  end

  defp do_passthrough(_next, {:halt, acc}, _fun) do
    {:halted, acc}
  end

  defp do_passthrough(next, {:cont, acc}, fun) do
    case next.({:cont, []}) do
      {:suspended, val, next} ->
        do_passthrough(next, fun.(val, acc), fun)
      {_, _} ->
        {:halted, acc}
    end
  end
end

The first clause sets up the next function that gets passed down to the do_passthrough function. It serves the purpose of getting the next value from the incoming stream. The step function that is internally used defines that we suspend for every item in the stream. The rest is pretty similar except for the last clause. Here, we call the next function with {:cont, []} to get a new value and process the result by means of a case statement. If there is a value, we get back {:suspended, val, next}, if not, the stream is halted and we pass that through to the consumer.

I hope that clarifies a few things about how to build streams in Elixir manually. Unfortunately, there's an awful lot of boilerplate required to work with streams. If you go back to the lookahead implementation now, you will see that there are only tiny differences, which are the actually interesting parts. There are two additional parameters: state, which serves to differentiate between the :buffer and :emit steps, and buffer which is pre-filled with n+1 items in the initial buffering step. In the emit phase, the current buffer is emitted and then shifted to the left on each iteration. We're done when the input stream halts or our stream is halted directly.


I am leaving my original answer here for reference:

Here's a solution that uses Stream.unfold/2 to emit a true stream of values according to your specification. This means you need to add Enum.to_list to the end of your first two examples to obtain the actual values.

defmodule MyStream do
  def lookahead(stream, n) do
    Stream.unfold split(stream, n+1), fn
      {[], stream} ->
        nil
      {[_ | buf] = current, stream} ->
        {value, stream} = split(stream, 1)
        {current, {buf ++ value, stream}}
    end
  end

  defp split(stream, n) do
    {Enum.take(stream, n), Stream.drop(stream, n)}
  end
end

The general idea is that we keep a buf of the previous iterations around. On each iteration, we emit the current buf, take one value from the stream and append it to the end of the buf. This repeats until the buf is empty.

Example:

iex> MyStream.lookahead(1..6, 1) |> Enum.to_list
[[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6]]

iex> MyStream.lookahead(1..4, 2) |> Enum.to_list
[[1, 2, 3], [2, 3, 4], [3, 4], [4]]

iex> Stream.cycle(1..3) |> MyStream.lookahead(2) |> Enum.take(5)
[[1, 2, 3], [2, 3, 1], [3, 1, 2], [1, 2, 3], [2, 3, 1]]

Here is an inefficient implementation of such function:

defmodule Lookahead do
  def lookahead(enumerable, n) when n > 0 do
    enumerable
    |> Stream.chunk(n + 1, 1, [])
    |> Stream.flat_map(fn list ->
        length = length(list)
        if length < n + 1 do
          [list|Enum.scan(1..n-1, list, fn _, acc -> Enum.drop(acc, 1) end)]
        else
          [list]
        end
      end)
  end
end

It builds on top of @hahuang65 implementation, except that we use a Stream.flat_map/2 to check the length of each emitted item, adding the missing ones as soon as we detect the emitted item got shorter.

A hand-written implementation from scratch would be faster because we would not need to call length(list) on every iteration. The implementation above may be fine though if n is small. If n is fixed, you could even pattern match on the generated list explicitly.


I had started a discussion about my proposed Stream.mutate method on the elixir core mailing list, where Peter Hamilton suggested another way of solving this problem. By using make_ref to create a globally unique reference, we can create a padding stream and concatenate it with the original enumerable to continue emitting after the original stream has halted. This can then either be used in conjunction with Stream.chunk, which means we need to remove the unwanted references in a last step:

def lookahead(enum, n) do
  stop = make_ref
  enum
  |> Stream.concat(List.duplicate(stop, n))
  |> Stream.chunk(n+1, 1)
  |> Stream.map(&Enum.reject(&1, fn x -> x == stop end))
end

I think this is the prettiest solution yet, from a syntactical point of view. Alternatively, we can use Stream.transform to build the buffer manually, which is quite similar to the manual solution I proposed earlier:

def lookahead(enum, n) do
  stop = make_ref
  enum
  |> Stream.concat(List.duplicate(stop, n+1))
  |> Stream.transform([], fn val, acc ->
    case {val, acc} do
      {^stop, []}                         -> {[]   , []           }
      {^stop, [_|rest] = buf}             -> {[buf], rest         }
      {val  , buf} when length(buf) < n+1 -> {[]   , buf ++ [val] }
      {val  , [_|rest] = buf}             -> {[buf], rest ++ [val]}
    end
  end)
end

I haven't benchmarked these solutions but I suppose the second one, although slightly clunkier, should perform a little bit better because it does not have to iterate over each chunk.

By the way, the second solution can be written without the case statement once Elixir allows to use the pin operator in function heads (probably in v1.1.0):

def lookahead(enum, n) do
  stop = make_ref
  enum
  |> Stream.concat(List.duplicate(stop, n+1))
  |> Stream.transform([], fn
    ^stop, []                         -> {[]   , []           }
    ^stop, [_|rest] = buf             -> {[buf], rest         }
    val  , buf when length(buf) < n+1 -> {[]   , buf ++ [val] }
    val  , [_|rest] = buf             -> {[buf], rest ++ [val]}
  end)
end

Tags:

Elixir