Discriminated union in C#
I don't really like the type-checking and type-casting solutions provided above, so here's 100% type-safe union which will throw compilation errors if you attempt to use the wrong datatype:
using System;
namespace Juliet
{
class Program
{
static void Main(string[] args)
{
Union3<int, char, string>[] unions = new Union3<int,char,string>[]
{
new Union3<int, char, string>.Case1(5),
new Union3<int, char, string>.Case2('x'),
new Union3<int, char, string>.Case3("Juliet")
};
foreach (Union3<int, char, string> union in unions)
{
string value = union.Match(
num => num.ToString(),
character => new string(new char[] { character }),
word => word);
Console.WriteLine("Matched union with value '{0}'", value);
}
Console.ReadLine();
}
}
public abstract class Union3<A, B, C>
{
public abstract T Match<T>(Func<A, T> f, Func<B, T> g, Func<C, T> h);
// private ctor ensures no external classes can inherit
private Union3() { }
public sealed class Case1 : Union3<A, B, C>
{
public readonly A Item;
public Case1(A item) : base() { this.Item = item; }
public override T Match<T>(Func<A, T> f, Func<B, T> g, Func<C, T> h)
{
return f(Item);
}
}
public sealed class Case2 : Union3<A, B, C>
{
public readonly B Item;
public Case2(B item) { this.Item = item; }
public override T Match<T>(Func<A, T> f, Func<B, T> g, Func<C, T> h)
{
return g(Item);
}
}
public sealed class Case3 : Union3<A, B, C>
{
public readonly C Item;
public Case3(C item) { this.Item = item; }
public override T Match<T>(Func<A, T> f, Func<B, T> g, Func<C, T> h)
{
return h(Item);
}
}
}
}
I like the direction of the accepted solution but it doesn't scale well for unions of more than three items (e.g. a union of 9 items would require 9 class definitions).
Here is another approach that is also 100% type-safe at compile-time, but that is easy to grow to large unions.
public class UnionBase<A>
{
dynamic value;
public UnionBase(A a) { value = a; }
protected UnionBase(object x) { value = x; }
protected T InternalMatch<T>(params Delegate[] ds)
{
var vt = value.GetType();
foreach (var d in ds)
{
var mi = d.Method;
// These are always true if InternalMatch is used correctly.
Debug.Assert(mi.GetParameters().Length == 1);
Debug.Assert(typeof(T).IsAssignableFrom(mi.ReturnType));
var pt = mi.GetParameters()[0].ParameterType;
if (pt.IsAssignableFrom(vt))
return (T)mi.Invoke(null, new object[] { value });
}
throw new Exception("No appropriate matching function was provided");
}
public T Match<T>(Func<A, T> fa) { return InternalMatch<T>(fa); }
}
public class Union<A, B> : UnionBase<A>
{
public Union(A a) : base(a) { }
public Union(B b) : base(b) { }
protected Union(object x) : base(x) { }
public T Match<T>(Func<A, T> fa, Func<B, T> fb) { return InternalMatch<T>(fa, fb); }
}
public class Union<A, B, C> : Union<A, B>
{
public Union(A a) : base(a) { }
public Union(B b) : base(b) { }
public Union(C c) : base(c) { }
protected Union(object x) : base(x) { }
public T Match<T>(Func<A, T> fa, Func<B, T> fb, Func<C, T> fc) { return InternalMatch<T>(fa, fb, fc); }
}
public class Union<A, B, C, D> : Union<A, B, C>
{
public Union(A a) : base(a) { }
public Union(B b) : base(b) { }
public Union(C c) : base(c) { }
public Union(D d) : base(d) { }
protected Union(object x) : base(x) { }
public T Match<T>(Func<A, T> fa, Func<B, T> fb, Func<C, T> fc, Func<D, T> fd) { return InternalMatch<T>(fa, fb, fc, fd); }
}
public class Union<A, B, C, D, E> : Union<A, B, C, D>
{
public Union(A a) : base(a) { }
public Union(B b) : base(b) { }
public Union(C c) : base(c) { }
public Union(D d) : base(d) { }
public Union(E e) : base(e) { }
protected Union(object x) : base(x) { }
public T Match<T>(Func<A, T> fa, Func<B, T> fb, Func<C, T> fc, Func<D, T> fd, Func<E, T> fe) { return InternalMatch<T>(fa, fb, fc, fd, fe); }
}
public class DiscriminatedUnionTest : IExample
{
public Union<int, bool, string, int[]> MakeUnion(int n)
{
return new Union<int, bool, string, int[]>(n);
}
public Union<int, bool, string, int[]> MakeUnion(bool b)
{
return new Union<int, bool, string, int[]>(b);
}
public Union<int, bool, string, int[]> MakeUnion(string s)
{
return new Union<int, bool, string, int[]>(s);
}
public Union<int, bool, string, int[]> MakeUnion(params int[] xs)
{
return new Union<int, bool, string, int[]>(xs);
}
public void Print(Union<int, bool, string, int[]> union)
{
var text = union.Match(
n => "This is an int " + n.ToString(),
b => "This is a boolean " + b.ToString(),
s => "This is a string" + s,
xs => "This is an array of ints " + String.Join(", ", xs));
Console.WriteLine(text);
}
public void Run()
{
Print(MakeUnion(1));
Print(MakeUnion(true));
Print(MakeUnion("forty-two"));
Print(MakeUnion(0, 1, 1, 2, 3, 5, 8));
}
}
I wrote some blog posts on this subject that might be useful:
- Union Types in C#
- Implementing Tic-Tac-Toe Using State Classes
Let's say you have a shopping cart scenario with three states: "Empty", "Active" and "Paid", each with different behavior.
- You create have a
ICartState
interface that all states have in common (and it could just be an empty marker interface) - You create three classes that implement that interface. (The classes do not have to be in an inheritance relationship)
- The interface contains a "fold" method, whereby you pass a lambda in for each state or case that you need to handle.
You could use the F# runtime from C# but as a lighter weight alternative, I have written a little T4 template for generating code like this.
Here's the interface:
partial interface ICartState
{
ICartState Transition(
Func<CartStateEmpty, ICartState> cartStateEmpty,
Func<CartStateActive, ICartState> cartStateActive,
Func<CartStatePaid, ICartState> cartStatePaid
);
}
And here's the implementation:
class CartStateEmpty : ICartState
{
ICartState ICartState.Transition(
Func<CartStateEmpty, ICartState> cartStateEmpty,
Func<CartStateActive, ICartState> cartStateActive,
Func<CartStatePaid, ICartState> cartStatePaid
)
{
// I'm the empty state, so invoke cartStateEmpty
return cartStateEmpty(this);
}
}
class CartStateActive : ICartState
{
ICartState ICartState.Transition(
Func<CartStateEmpty, ICartState> cartStateEmpty,
Func<CartStateActive, ICartState> cartStateActive,
Func<CartStatePaid, ICartState> cartStatePaid
)
{
// I'm the active state, so invoke cartStateActive
return cartStateActive(this);
}
}
class CartStatePaid : ICartState
{
ICartState ICartState.Transition(
Func<CartStateEmpty, ICartState> cartStateEmpty,
Func<CartStateActive, ICartState> cartStateActive,
Func<CartStatePaid, ICartState> cartStatePaid
)
{
// I'm the paid state, so invoke cartStatePaid
return cartStatePaid(this);
}
}
Now let's say you extend the CartStateEmpty
and CartStateActive
with an AddItem
method which is not implemented by CartStatePaid
.
And also let's say that CartStateActive
has a Pay
method that the other states dont have.
Then here's some code that shows it in use -- adding two items and then paying for the cart:
public ICartState AddProduct(ICartState currentState, Product product)
{
return currentState.Transition(
cartStateEmpty => cartStateEmpty.AddItem(product),
cartStateActive => cartStateActive.AddItem(product),
cartStatePaid => cartStatePaid // not allowed in this case
);
}
public void Example()
{
var currentState = new CartStateEmpty() as ICartState;
//add some products
currentState = AddProduct(currentState, Product.ProductX);
currentState = AddProduct(currentState, Product.ProductY);
//pay
const decimal paidAmount = 12.34m;
currentState = currentState.Transition(
cartStateEmpty => cartStateEmpty, // not allowed in this case
cartStateActive => cartStateActive.Pay(paidAmount),
cartStatePaid => cartStatePaid // not allowed in this case
);
}
Note that this code is completely typesafe -- no casting or conditionals anywhere, and compiler errors if you try to pay for an empty cart, say.