Tuesday, January 31, 2012

Encoding algebraic data types in C#

Algebraic data types are generally not directly expressible in C#, but they're a tool far too useful to be left unused. ADTs are a very precise modeling tool, helping making illegal states unrepresentable. The Base Class library already includes a Tuple type representing the anonymous product type, and C# also has anonymous types to represent labeled product types (isn't that confusing). And of course you could even consider simple classes or structs to be product types.

But the BCL doesn't include any anonymous sum type. We can use F#'s Choice type in C#, for example this discriminated union type in F# (borrowed from MSDN):

type Shape =
  // The value here is the radius.
| Circle of float
  // The values here are the height and width.
| Rectangle of double * double

Could be represented in C# as FSharpChoice<float, Tuple<double, double>> . But this obviously loses the labels (Circle, Rectangle). These "labels" are usually called "constructors".

A reasonable approach to encode ADTs in C# would be using ILSpy to reverse engineer the code the F# compiler generates from the discriminated union above: (this might hurt a bit, but don't get scared!)

using Microsoft.FSharp.Core;
using System;
using System.Collections;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
[DebuggerDisplay("{__DebugDisplay(),nq}"), CompilationMapping(SourceConstructFlags.SumType)]
[Serializable]
[StructLayout(LayoutKind.Auto, CharSet = CharSet.Auto)]
public abstract class Shape : IEquatable<Program.Shape>, IStructuralEquatable, IComparable<Program.Shape>, IComparable, IStructuralComparable
{
    public static class Tags
    {
        public const int Circle = 0;
        public const int Rectangle = 1;
    }
    [DebuggerTypeProxy(typeof(Program.Shape.Circle@DebugTypeProxy)), DebuggerDisplay("{__DebugDisplay(),nq}")]
    [Serializable]
    public class Circle : Program.Shape
    {
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal readonly double item;
        [CompilationMapping(SourceConstructFlags.Field, 0, 0), CompilerGenerated, DebuggerNonUserCode]
        public double Item
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this.item;
            }
        }
        [CompilerGenerated, DebuggerNonUserCode]
        internal Circle(double item)
        {
            this.item = item;
        }
    }
    [DebuggerTypeProxy(typeof(Program.Shape.Rectangle@DebugTypeProxy)), DebuggerDisplay("{__DebugDisplay(),nq}")]
    [Serializable]
    public class Rectangle : Program.Shape
    {
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal readonly double item1;
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal readonly double item2;
        [CompilationMapping(SourceConstructFlags.Field, 1, 0), CompilerGenerated, DebuggerNonUserCode]
        public double Item1
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this.item1;
            }
        }
        [CompilationMapping(SourceConstructFlags.Field, 1, 1), CompilerGenerated, DebuggerNonUserCode]
        public double Item2
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this.item2;
            }
        }
        [CompilerGenerated, DebuggerNonUserCode]
        internal Rectangle(double item1, double item2)
        {
            this.item1 = item1;
            this.item2 = item2;
        }
    }
    internal class Circle@DebugTypeProxy
    {
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal Program.Shape.Circle _obj;
        [CompilationMapping(SourceConstructFlags.Field, 0, 0), CompilerGenerated, DebuggerNonUserCode]
        public double Item
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this._obj.item;
            }
        }
        [CompilerGenerated, DebuggerNonUserCode]
        public Circle@DebugTypeProxy(Program.Shape.Circle obj)
        {
            this._obj = obj;
        }
    }
    internal class Rectangle@DebugTypeProxy
    {
        [DebuggerBrowsable(DebuggerBrowsableState.Never), CompilerGenerated, DebuggerNonUserCode]
        internal Program.Shape.Rectangle _obj;
        [CompilationMapping(SourceConstructFlags.Field, 1, 0), CompilerGenerated, DebuggerNonUserCode]
        public double Item1
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this._obj.item1;
            }
        }
        [CompilationMapping(SourceConstructFlags.Field, 1, 1), CompilerGenerated, DebuggerNonUserCode]
        public double Item2
        {
            [CompilerGenerated, DebuggerNonUserCode]
            get
            {
                return this._obj.item2;
            }
        }
        [CompilerGenerated, DebuggerNonUserCode]
        public Rectangle@DebugTypeProxy(Program.Shape.Rectangle obj)
        {
            this._obj = obj;
        }
    }
    [CompilerGenerated, DebuggerNonUserCode, DebuggerBrowsable(DebuggerBrowsableState.Never)]
    public int Tag
    {
        [CompilerGenerated, DebuggerNonUserCode]
        get
        {
            return (!(this is Program.Shape.Rectangle)) ? 0 : 1;
        }
    }
    [CompilerGenerated, DebuggerNonUserCode, DebuggerBrowsable(DebuggerBrowsableState.Never)]
    public bool IsRectangle
    {
        [CompilerGenerated, DebuggerNonUserCode]
        get
        {
            return this is Program.Shape.Rectangle;
        }
    }
    [CompilerGenerated, DebuggerNonUserCode, DebuggerBrowsable(DebuggerBrowsableState.Never)]
    public bool IsCircle
    {
        [CompilerGenerated, DebuggerNonUserCode]
        get
        {
            return this is Program.Shape.Circle;
        }
    }
    [CompilerGenerated, DebuggerNonUserCode]
    internal Shape()
    {
    }
    [CompilationMapping(SourceConstructFlags.UnionCase, 1)]
    public static Program.Shape NewRectangle(double item1, double item2)
    {
        return new Program.Shape.Rectangle(item1, item2);
    }
    [CompilationMapping(SourceConstructFlags.UnionCase, 0)]
    public static Program.Shape NewCircle(double item)
    {
        return new Program.Shape.Circle(item);
    }
    [CompilerGenerated, DebuggerNonUserCode]
    internal object __DebugDisplay()
    {
        return ExtraTopLevelOperators.PrintFormatToString<FSharpFunc<Program.Shape, string>>(new PrintfFormat<FSharpFunc<Program.Shape, string>, Unit, string, string, string>("%+0.8A")).Invoke(this);
    }
    [CompilerGenerated]
    public sealed override int CompareTo(Program.Shape obj)
    {
        if (this != null)
        {
            if (obj == null)
            {
                return 1;
            }
            int num = (!(this is Program.Shape.Rectangle)) ? 0 : 1;
            int num2 = (!(obj is Program.Shape.Rectangle)) ? 0 : 1;
            if (num != num2)
            {
                return num - num2;
            }
            if (this is Program.Shape.Circle)
            {
                Program.Shape.Circle circle = (Program.Shape.Circle)this;
                Program.Shape.Circle circle2 = (Program.Shape.Circle)obj;
                IComparer genericComparer = LanguagePrimitives.GenericComparer;
                double item = circle.item;
                double item2 = circle2.item;
                if (item < item2)
                {
                    return -1;
                }
                if (item > item2)
                {
                    return 1;
                }
                if (item == item2)
                {
                    return 0;
                }
                return LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(genericComparer, item, item2);
            }
            else
            {
                Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
                Program.Shape.Rectangle rectangle2 = (Program.Shape.Rectangle)obj;
                IComparer genericComparer2 = LanguagePrimitives.GenericComparer;
                double item3 = rectangle.item1;
                double item4 = rectangle2.item1;
                int num3 = (item3 >= item4) ? ((item3 <= item4) ? ((item3 != item4) ? LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(genericComparer2, item3, item4) : 0) : 1) : -1;
                if (num3 < 0)
                {
                    return num3;
                }
                if (num3 > 0)
                {
                    return num3;
                }
                IComparer genericComparer3 = LanguagePrimitives.GenericComparer;
                double item5 = rectangle.item2;
                double item6 = rectangle2.item2;
                if (item5 < item6)
                {
                    return -1;
                }
                if (item5 > item6)
                {
                    return 1;
                }
                if (item5 == item6)
                {
                    return 0;
                }
                return LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(genericComparer3, item5, item6);
            }
        }
        else
        {
            if (obj != null)
            {
                return -1;
            }
            return 0;
        }
    }
    [CompilerGenerated]
    public sealed override int CompareTo(object obj)
    {
        return this.CompareTo((Program.Shape)obj);
    }
    [CompilerGenerated]
    public sealed override int CompareTo(object obj, IComparer comp)
    {
        Program.Shape shape = (Program.Shape)obj;
        if (this != null)
        {
            if ((Program.Shape)obj == null)
            {
                return 1;
            }
            int num = (!(this is Program.Shape.Rectangle)) ? 0 : 1;
            Program.Shape shape2 = shape;
            int num2 = (!(shape2 is Program.Shape.Rectangle)) ? 0 : 1;
            if (num != num2)
            {
                return num - num2;
            }
            if (this is Program.Shape.Circle)
            {
                Program.Shape.Circle circle = (Program.Shape.Circle)this;
                Program.Shape.Circle circle2 = (Program.Shape.Circle)shape;
                double item = circle.item;
                double item2 = circle2.item;
                if (item < item2)
                {
                    return -1;
                }
                if (item > item2)
                {
                    return 1;
                }
                if (item == item2)
                {
                    return 0;
                }
                return LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(comp, item, item2);
            }
            else
            {
                Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
                Program.Shape.Rectangle rectangle2 = (Program.Shape.Rectangle)shape;
                double item3 = rectangle.item1;
                double item4 = rectangle2.item1;
                int num3 = (item3 >= item4) ? ((item3 <= item4) ? ((item3 != item4) ? LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(comp, item3, item4) : 0) : 1) : -1;
                if (num3 < 0)
                {
                    return num3;
                }
                if (num3 > 0)
                {
                    return num3;
                }
                double item5 = rectangle.item2;
                double item6 = rectangle2.item2;
                if (item5 < item6)
                {
                    return -1;
                }
                if (item5 > item6)
                {
                    return 1;
                }
                if (item5 == item6)
                {
                    return 0;
                }
                return LanguagePrimitives.HashCompare.GenericComparisonWithComparerIntrinsic<double>(comp, item5, item6);
            }
        }
        else
        {
            if ((Program.Shape)obj != null)
            {
                return -1;
            }
            return 0;
        }
    }
    [CompilerGenerated]
    public sealed override int GetHashCode(IEqualityComparer comp)
    {
        if (this == null)
        {
            return 0;
        }
        int num;
        if (this is Program.Shape.Circle)
        {
            Program.Shape.Circle circle = (Program.Shape.Circle)this;
            num = 0;
            return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic<double>(comp, circle.item) + ((num << 6) + (num >> 2)));
        }
        Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
        num = 1;
        num = -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic<double>(comp, rectangle.item2) + ((num << 6) + (num >> 2)));
        return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic<double>(comp, rectangle.item1) + ((num << 6) + (num >> 2)));
    }
    [CompilerGenerated]
    public sealed override int GetHashCode()
    {
        return this.GetHashCode(LanguagePrimitives.GenericEqualityComparer);
    }
    [CompilerGenerated]
    public sealed override bool Equals(object obj, IEqualityComparer comp)
    {
        if (this == null)
        {
            return obj == null;
        }
        Program.Shape shape = obj as Program.Shape;
        if (shape == null)
        {
            return false;
        }
        Program.Shape shape2 = shape;
        int num = (!(this is Program.Shape.Rectangle)) ? 0 : 1;
        Program.Shape shape3 = shape2;
        int num2 = (!(shape3 is Program.Shape.Rectangle)) ? 0 : 1;
        if (num != num2)
        {
            return false;
        }
        if (this is Program.Shape.Circle)
        {
            Program.Shape.Circle circle = (Program.Shape.Circle)this;
            Program.Shape.Circle circle2 = (Program.Shape.Circle)shape2;
            return circle.item == circle2.item;
        }
        Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
        Program.Shape.Rectangle rectangle2 = (Program.Shape.Rectangle)shape2;
        return rectangle.item1 == rectangle2.item1 && rectangle.item2 == rectangle2.item2;
    }
    [CompilerGenerated]
    public sealed override bool Equals(Program.Shape obj)
    {
        if (this == null)
        {
            return obj == null;
        }
        if (obj == null)
        {
            return false;
        }
        int num = (!(this is Program.Shape.Rectangle)) ? 0 : 1;
        int num2 = (!(obj is Program.Shape.Rectangle)) ? 0 : 1;
        if (num != num2)
        {
            return false;
        }
        if (this is Program.Shape.Circle)
        {
            Program.Shape.Circle circle = (Program.Shape.Circle)this;
            Program.Shape.Circle circle2 = (Program.Shape.Circle)obj;
            double item = circle.item;
            double item2 = circle2.item;
            return (item != item && item2 != item2) || item == item2;
        }
        Program.Shape.Rectangle rectangle = (Program.Shape.Rectangle)this;
        Program.Shape.Rectangle rectangle2 = (Program.Shape.Rectangle)obj;
        double item3 = rectangle.item1;
        double item4 = rectangle2.item1;
        if ((item3 != item3 && item4 != item4) || item3 == item4)
        {
            double item5 = rectangle.item2;
            double item6 = rectangle2.item2;
            return (item5 != item5 && item6 != item6) || item5 == item6;
        }
        return false;
    }
    [CompilerGenerated]
    public sealed override bool Equals(object obj)
    {
        Program.Shape shape = obj as Program.Shape;
        return shape != null && this.Equals(shape);
    }

Whew! That's a lot of code! Let's break down some of this code:

  • The DebuggerDisplay, DebuggerTypeProxy, DebuggerBrowsable, DebuggerNonUserCode attributes and DebugTypeProxy classes enhance the debugging experience.
  • Discriminated union types are marked as Serializable.
  • Discriminated union types implement equality and comparison (IEquatable, IComparable, IStructuralEquatable, IStructuralComparable, Equals(), GetHashCode())
  • Tags (simple integer constants) are used to optimize the implementation of equality and comparison.

It seems inviable to write such an amount of code in C# every time we want an ADT. However, underneath all the attributes and noise, the gist of it is quite simple: a class hierarchy starting with an abstract class plus a concrete subclass for each case:

abstract class Shape {
    class Circle: Shape {
        public readonly float Radius;

        public Circle(float radius) {
            Radius = radius;
        }
    }

    class Rectangle: Shape {
        public readonly double Height;
        public readonly double Width;

        public Rectangle(double height, double width) {
            Height = height;
            Width = width;
        }
    }
}

Shape is abstract because the only valid cases are Circle and Rectangle. Instantiating a Shape doesn't make sense!

Now, there's a detail we have to take care of: ADTs are closed, which means that we can't add new shapes to the type without changing the Shape type itself. This is in contradiction with the general practice of OOP: if we wanted a Square we could just create a new subclass of Shape. That, however, complicates things since it makes writing total functions over Shapes harder (impossible?). For more information about this and a comparison of OO (classes/subtyping) vs FP (closed ADTs) see this Stackoverflow question. OCaml supports "open ADTs" (actually called "open variants" or "polymorphic variants") which are powerful but have their cons too.

But I digress. The goal is to prevent subclasses of Rectangle, Circle and further subclasses of Shape. We can do this by making the constructor of Shape private, and sealing Rectangle and Shape:

public abstract class Shape {
    private Shape() {}

    public sealed class Circle : Shape {
        public readonly float Radius;

        public Circle(float radius) {
            Radius = radius;
        }
    }

    public sealed class Rectangle : Shape {
        public readonly double Height;
        public readonly double Width;

        public Rectangle(double height, double width) {
            Height = height;
            Width = width;
        }
    }
}

This is also why I chose to place Circle and Rectangle as nested classes of Shape.

So far we have a general structure and constructors. But we're not done yet! Given a Shape how do we know if it's a Circle or a Rectangle? How do we get to the data (radius, height, width)?

Any red-blooded object-oriented programmer would at this point be yelling "implement a Visitor!". Languages with first-class support for ADTs like ML dialects use pattern matching instead. Continuing with the same MSDN sample, we calculate the area for a Shape:

let shape = Circle 2.0

let area =
    match shape with
    | Circle radius -> System.Math.PI * radius * radius
    | Rectangle (h, w) -> h * w

System.Console.WriteLine area

Again we open this with ILSpy and see the following C# code (I edited it a bit to remove name mangling):

Shape shape = new Circle(2.0f);
double arg_67_0;
if (shape is Shape.Rectangle)
{
    var rectangle = (Shape.Rectangle)shape;
    double w = rectangle.Width;
    double h = rectangle.Height;
    arg_67_0 = h * w;
}
else
{
    var circle = (Shape.Circle)shape;
    double radius = circle.Radius;
    arg_67_0 = 3.1415926535897931 * radius * radius;
}
Console.WriteLine(arg_67_0);

It does runtime type testing and downcasting! As it turns out, runtime type testing is faster than a vtable dispatch, and the F# compiler optimizes taking advantage of this fact. If we had more cases in the discriminated union we'd see that the F# compiler simply nests ifs to tests all cases. This is fast but it's not very practical to write such code in C#. Also, we can't statically check if the pattern match was exhaustive. What we really want is a method on Shape (we'll call it Match) that takes two functions as parameters: one to handle the Circle case, another to handle the Rectangle case.

At this point we have several alternatives. We can implement a full visitor pattern as I said earlier, or we can take advantage of the fact that it's a closed type and simply encapsulate the type testing and downcasting, like this:

public T Match<T>(Func<float, T> circle, Func<double, double, T> rectangle) {
    if (this is Circle) {
        var x = (Circle)this;
        return circle(x.Radius);
    }
    var y = (Rectangle)this;
    return rectangle(y.Width, y.Height);
}

And now we can calculate the area of a Shape like this:

Shape shape = new Shape.Circle(2.0f);
var area = shape.Match<double>(circle: radius => Math.PI * radius * radius,
                               rectangle: (width, height) => width * height);
Console.WriteLine(area);

Note how named arguments in C# 4 make this a bit more readable. Also, using the Match method ensures that we always cover all cases. Alas, the C# compiler can't infer the return type, we have to type it explicitly.

Another alternative is to pass the whole object to the handling functions instead of just its data, e.g.

public T Match<T>(Func<Circle, T> circle, Func<Rectangle, T> rectangle)

Now we have a usable algebraic data type. When compared to F#, the boilerplate required in C# is considerable and tedious, but still worth it in my opinion. However, if you also need equality and comparison you might as well just use F# instead ;-)

7 comments:

Matías Giovannini said...

Make your Match abstract in Shape, implement concrete versions in each of Circle and Rectangle, and you have your canonical visitor with an unrolled list of visiting functions. Ig you need to keep state between visits you'd be better off with an explicit visitor anyway.

Mauricio Scheffer said...

@Matías : thanks for your comments. Yes, in fact I started using that variant myself, but in the CLR runtime type testing is faster than a vtable dispatch, and double dispatch is considerably slower. Not that it would make a significant difference in a large number of cases probably, but it's simple enough to use it instead of the alternatives.
As for stateful visitors, I don't think I've ever needed them... though I agree that sometimes it may be simpler to just stick the state in the visitor instead of externally.

Matías Giovannini said...

For an example of the need for stateful visitors, see my article on implementing Brainfuck in Java, in particular the Evaluator visitor: http://alaska-kamtchatka.blogspot.com/2011/11/brainfuck-in-java.html

Anonymous said...

Nice article. I really enjoying reading your blogand good links.
Thanks.

In regards to this particular post - IMO using dynamic dispatch - unsafe and compiler unverifiable, and therfore its mutes the point of using static typed system and event mutes the point of ADT, dont you think?

In regards vtable dispatch versus dynamic cast performance - sounds like apples vs oranges, in a sence its depends a lot on usage scenario.
Its seems that test actually testing virtual dispatch versus staticly inlined generic type (compiler optimization). I can be wrong, this topic had very long heated threads in c# newsgroups couple years ago. Easiest way is to check asm code in SOS for each case. (and btw vtable dispatch changed with .net 4)

Mauricio Scheffer said...

@Anonymous: I disagree that using runtime type testing mutes the point of ADTs, because it's localized. It's like mutability: many functional operations in LINQ and F# are implemented with mutable variables under the hood for performance, but it's *local* mutability with a perfectly pure interface.

About performance, I refer you to Sandro Magi's benchmark.
About .NET 4 changes: F# 3.0 still compiles to runtime type testing, so I don't think it's changed that much.

Anonymous said...

Sorry Mauricio, but i don't see how this localization helps you to reason about your code statically.

You still get run-time error in case you provide wrong type. That's already pushes this type definition into dynamic typing.

Imagine situation where you would use such types in many-core enviroment, and in this case you would already need to know upfront that your ADT doing run time checks on provided types and make sure to lock data which could affect such checks.
In such case IMO, it should state right in type declaration, that this type is unsafe.

Mauricio Scheffer said...

@Anonymous: again, just as with mutability, what matters is that client code is easy to reason about. How the innards of a fold or an ADT are implemented does not affect your ability to reason about your code, as long as the "ugliness" is localized. Client code (i.e. users of an ADT as defined here) never has to do any runtime type testing or casting, and can't "provide a wrong type". That's entirely contained in the "Match" method. Whether "Match" is implemented with a visitor or runtime type testing doesn't make any difference to client code.

There is nothing thread-unsafe about the technique or the types I described in this post. You don't need any kind of locks here.

Maybe if you could come up with some concrete code I could address your concerns more precisely.