Conditional state monad expressions
Doh! I don't know why this didn't occur to me sooner. Sometimes just explaining your problem in simpler terms forces you to look at it afresh, I guess...
One possibility is to handle sequences of transitions, so that the next task is only undertaken if the current task succeeds.
// Run a sequence of transitions, until one fails.
def untilFailure[M](ts: List[Transition[M]]): Transition[M] = State {s =>
ts match {
// If we have an empty list, that's an error. (Cannot report a success value.)
case Nil => (s, Failure(new RuntimeException("Empty transition sequence")))
// If there's only one transition left, perform it and return the result.
case t :: Nil => t.run(s).value
// Otherwise, we have more than one transition remaining.
//
// Run the next transition. If it fails, report the failure, otherwise repeat
// for the tail.
case t :: tt => {
val r = t.run(s).value
if(r._2.isFailure) r
else untilFailure(tt).run(r._1).value
}
}
}
We can then implement counterManip
as a sequence.
val counterManip: Transition[Unit] = for {
r <- untilFailure(List(decrement, increment, increment, increment))
} yield r
which gives the correct results:
scala> counterManip.run(Counter(0)).value
res0: (Counter, scala.util.Try[Unit]) = (Counter(0),Failure(java.lang.ArithmeticException: Attempt to make count negative failed))
scala> counterManip.run(Counter(1)).value
res1: (Counter, scala.util.Try[Unit]) = (Counter(3),Success(()))
scala> counterManip.run(Counter(Int.MaxValue - 2)).value
res2: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Success(()))
scala> counterManip.run(Counter(Int.MaxValue - 1)).value
res3: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Failure(java.lang.ArithmeticException: Attempt to overflow counter failed))
scala> counterManip.run(Counter(Int.MaxValue)).value
res4: (Counter, scala.util.Try[Unit]) = (Counter(2147483647),Failure(java.lang.ArithmeticException: Attempt to overflow counter failed))
The downside is that all of the transitions need to have a return value in common (unless you're OK with Any
result).
From what I understand, your computation has two states, which you can define as an ADT
sealed trait CompState[A]
case class Ok[A](value: A) extends CompState[A]
case class Err[A](lastValue: A, cause: Exception) extends CompState[A]
The next step you can take is to define an update
method for CompState
, to encapsulate your logic of what should happen when chaining the computations.
def update(f: A => A): CompState[A] = this match {
case Ok(a) =>
try Ok(f(a))
catch { case e: Exception => Err(a, e) }
case Err(a, e) => Err(a, e)
}
From there, redefine
type Transition[M] = State[CompState[Counter], M]
// Operation to increment a counter.
// note: using `State.modify` instead of `.apply`
val increment: Transition[Unit] = State.modify { cs =>
// use the new `update` method to take advantage of your chaining semantics
cs update{ c =>
// If the count is at its maximum, incrementing it must fail.
if(c.count == Int.MaxValue) {
throw new ArithmeticException("Attempt to overflow counter failed")
}
// Otherwise, increment the count and indicate success.
else c.copy(count = c.count + 1)
}
}
// Operation to decrement a counter.
val decrement: Transition[Unit] = State.modify { cs =>
cs update { c =>
// If the count is zero, decrementing it must fail.
if(c.count == 0) {
throw new ArithmeticException("Attempt to make count negative failed")
}
// Otherwise, decrement the count and indicate success.
else c.copy(count = c.count - 1)
}
}
Note that in the updated increment/decrement transitions above, I used State.modify
, which changes the state, but does not generate a result. It looks like the "idiomatic" way to obtain the current state at the end of your transitions is to use State.get
, i.e.
val counterManip: State[CompState[Counter], CompState[Counter]] = for {
_ <- decrement
_ <- increment
_ <- increment
_ <- increment
r <- State.get
} yield r
And you can run this and discard the final state using the runA
helper, i.e.
counterManip.runA(Ok(Counter(0))).value
// Err(Counter(0),java.lang.ArithmeticException: Attempt to make count negative failed)