Saturday, October 27, 2012

Monads in C++

In my last two articles, I discussed fmap in C++. At the end, I implemented it with a solution of tag dispatch and a type class, calling it type-class dispatch. Today I want to talk about the next step: monads. Familiarity with fmap is required, but not monads or Haskell. I will be using the same type-class dispatch code here as in the last post, but without explanation.

struct sequence_tag {};
struct pointer_tag {};

template< class X >
X category( ... );

template< class S >
auto category( const S& s ) -> decltype( std::begin(s), sequence_tag() );

template< class Ptr >
auto category( const Ptr& p ) -> decltype( *p, p==nullptr, pointer_tag() );

Monads are scary. Or at least they seem scary. People talk about them like they are. In reality, they are not much more complicated than Functors, being very similar. Previously, the problem was that we have a function f and a Functor, F(x). fmap simply allowed us to apply f to the data inside the Functor. Monads do the same thing, except that f is monad-aware and returns a monad of the correct type. For example, with fmap we might write

std::unique_ptr<int> p( new int(5) );
auto f = []( int x ) { return -x; };
std::unique_ptr<int> q = fmap( f, p );

and we know that (*q) = -(*p). What if f knew that we wanted to have a unique_ptr returned? Well, then we could use the monad's version of fmap, which I'll refer to as mbind (monad bind).

std::unique_ptr<int> p( new int(5) );
auto f = []( int x ) { 
  return std::unique_ptr<int>( new int(-x) ); 
};
std::unique_ptr<int> q = mbind( f, p );

So what is a Monad? Almost the same as what a Functor is!

    If fmap(f,F(x)) = F(f(x)),
        then mbind(f,M(x)) = f(x), or something like that.

Monads can be std::vectors or std::unique_ptrs, std::pairs; the limit is your imagination.

Monads have one more ability: to construct a type, M(x), given an x, with a function called return. But return means two different things in Haskell and C++, so I'll use the term mreturn. This is a pretty simple concept--we can rewrite the above example like so:

std::unique_ptr<int> p( new int(5) );
auto f = []( int x ) { return mreturn(-x); };
std::unique_ptr<int> q = mbind( f, p );

In this example, mreturn is a function that takes an int and returns an std::unique_ptr<int>.

And so we have two basic operations:

    auto m = mreturn<M>(x); // creates an M<X>
    mbind( f, m ); // Applies f to x.

And we know

    auto p = mreturn<unique_ptr>(3); // will create a unique_ptr<int>.
    mbind( f, p ); // is equivalent to f(*p)

And we'll consider, for the moment, that a Monad is a type for which this operation is defined.

I'll implement this much the same way I did fmap. We start with a free function, mbind, which maps to the static member function Monad::mbind, and mreturn which maps to Monad::mreturn.

template< class ... > struct Monad;

template< class F, class M, class Mo=Monad<Cat<M>> >
auto mbind( F&& f, M&& m )
    -> decltype( Mo::mbind(std::declval<F>(),std::declval<M>()) )
{
    return Mo::mbind( std::forward<F>(f), std::forward<M>(m) );
}

// The first template argument must be explicit!
template< class M, class X, class Mo = Monad<Cat<M>> >
M mreturn( X&& x ) {
    // We have to forward the monad type, too.
    return Mo::template mreturn<M>( std::forward<X>(x) );
}

One might notice that the above example using mreturn and this definition don't match. Instead of calling mreturn(-x), it should call mreturn<std::unique_ptr<int>>(-x). However, the int part is redundant, so let's overload mreturn using a template template parameter so we only have to supply std::unique_ptr.

template< template<class...>class M, class X, 
          class Mo = Monad<Cat<M<X>>> >
M<X> mreturn( const X& x ) {
    return Mo::template mreturn<M<X>>( x );
}

Now, we can write that example like so:

std::unique_ptr<int> p( new int(5) );
auto f = []( int x ) { 
  return mreturn<std::unique_ptr>(-x); 
};
std::unique_ptr<int> q = mbind( f, p );

The pointer monad.

template< > struct Monad< pointer_tag > {
    template< class F, template<class...>class Ptr, class X,
              class R = typename std::result_of<F(X)>::type >
    static R mbind( F&& f, const Ptr<X>& p ) {
        // Just like fmap, but without needing to explicitly return the correct type.
        return p ? std::forward<F>(f)( *p ) : nullptr;
    }


    template< class M, class X >
    static M mreturn( X&& x ) {
        // All smart pointers define element_type.
        using Y = typename M::element_type; 
        return M( new Y(std::forward<X>(x)) );
    }
};

This may not be the most exciting code, but we can use it to translate a small Haskell function into C++.

    -- Haskell
    addM a b = do
        x <- a -- Extract x from a
        y <- b -- and y from b.
        return (x+y) -- Return a new monad with the value (x+y)

If we supplied two unique_ptrs, we'd get one back holding the value x+y. The first line, x <- a, syntactically means "what fallows is a function of x." This is addM with do notation; another way to write it:

    addM a b = a >>= (\x -> b >>= (\y -> return (x+y)) )

Here, >>= denotes a bind and (\x->...) denotes a lambda that takes x. The inner-most function, (\y -> return (x+y)) returns the actual value as a monad. It gets called when we extract the value from b with (\x -> b >>= ... ). The x came from a >>= (\x -> ... ). So it extracts x from a, then y from b, and constructs a new monad with the value x+y.

// C++
template< class M >
M addM( const M& a, const M& b ) {
    return mbind (
        [&]( int x ) {
            return mbind ( 
                [=]( int y ) { return mreturn<M>(x+y); },
                b
            );
        }, a

    );
}

Yuck! This is a literal translation, but Haskell handles scope automatically with do notation and it implicitly returns the last statement, while we write return mreturn<M>.  We can rewrite this to use fmap([=](int y){return x+y;},b) and that solves the return problem, but not the scoping one. We can alleviate that by defining an operator overload for mbind, and why not use the very same operator as in Haskell?

template< class M, class F >
auto operator >>= ( M&& m, F&& f )
    -> decltype( mbind(std::declval<F>(),std::declval<M>()) )
{
    return mbind( std::forward<F>(f), std::forward<M>(m) );
}

template< class M >
M addM( const M& a, const M& b ) {
    return a >>= []( int x ) {
        return fmap( [=]( int y ){ return x+y }, b );
    };
}

It's hard to justify the use of operator overloads in C++, but this one rarely gets any use. It won't change the behavior of basic types; given some int x, x >>= 2, this still means you with to shift the bits by two. If this gives one an uncomfortable feeling, it can be put in its own namespace so that in order to make use of the operator overload, the user would have to write using namespace monad; or whatever before writing >>=.

Monadic sequences.

Remember, fmap(f,seq) took a regular function and made a new sequence by applying f to seq. What will mbind(f,seq) do? This time, f is monad-aware, so it already returns a sequence. Does mbind return a sequence of sequences? That would be very confusing. It actually returns the concatenation of every sequence produced by f(x). So, if f(x)={-x,x}, then mbind(f,{1,2}) = {-1,1,-2,2}.

template< > struct Monad< sequence_tag > {
    template< class F, template<class...>class S, class X,
              class R = typename std::result_of<F(X)>::type >
    static R mbind( F&& f, const S<X>& xs ) {
        R r;
        for( const X& x : xs ) {
            auto ys = std::forward<F>(f)( x );
            std::move( std::begin(ys), std::end(ys), std::back_inserter(r) );
        }
        return r;
    }

    template< class S, class X >
    static S mreturn( X&& x ) {
        return S{ std::forward<X>(x) }; // Construct an S of one element, x.
    }
};
std::move from <algorithm>

What implications does this have on our addM function? If v={1,2} and w={3,4}, what does addM(v,w) return? Try it!

int main() {
    std::vector<int> v={1,2}, w={3,4};
    auto vw = addM(v,w);

    std::cout << "v+w = { ";
    std::copy (
        std::begin(vw), std::end(vw),
        std::ostream_iterator<int>(std::cout, " ")
    );
    std::cout << '}' << std::endl;
}

Just in case you didn't actually run the code, it prints { 4 5 5 6 }. Does this sequence seem odd? It's { 1+3 1+4 2+3 2+4 }. Basically, it applied the addition function on every pair of elements from v and w. That means add(v[0],w[0]) then add(v[0],w[1]) then add(v[1],w[0]) then add(v[1],w[1]).

This is the magic of monads. The functionality of addM changed appropriately to how its arguments changed. It did so without us even thinking about how it might. And now, every type that can hold an int that one specialized mbind for works with addM, too!


In conclusion:

Monads are often talked about as mysterious, tricky, and hard to understand. They are none of these. It is of little importance to know concretely what a monad is. mbind is a simple function that applies some function, f, to some object M(x), where f returns M(y). mreturn is a simple function that constructs an object of type, M(x), given an x.

Note that Haskell also has a Monad function, >>, or mdo as I call it (though I can't remember why). mdo is not always as obvious as mbind, however I did implement it in the gist (see below).

In full, the monadic operations are:

    a >> b ; //  see the gist
    a >>= f ; // Apply the value(s) in a to f.
    mreturn<M>(x) ; // Create an M<X>.

There are a few helpful properties of this:

    mreturn<M>(x) >>= f == f(x)
    m >>= mreturn<M> == m 
    m >>= (\x -> k x >>= h) == (m >>= k) >>= h


Here's the code I wrote for this article: https://gist.github.com/3965514 (It contains a few extra examples.)
Monads in Haskel: http://www.haskell.org/ghc/docs/latest/html/libraries/base/Control-Monad.html#t:Monad