Scalaz StateT Monad Transformer

In a previous blog post, we talked about how to handle the situation when the outer levels of your code need to run in one monad, but the inner levels need to run in an Option monad. In the blog post examples, the outer code levels run in the State monad, while the inner layers use Option. In these situations, we use the OptionT monad transformer to blend the inner Option with the outer State so that we’re effectively working with both monads at the same time. While the example uses State as the outer monad, the outer monad could in fact be any arbitrary monad. That is, OptionT can blend an Option with any monad, not just a State.

In this blog post, we’ll see how to use the StateT monad transformer to blend an inner State monad with an arbitrary outer monad. We’ll use IO as our outer monad in the examples.

A Brief Overview of the IO Monad

Since we’ll be using the IO monad in our examples, let’s start with a very brief overview of it. We’ll examine this monad in more detail in a future blog post.

The IO monad works a bit like a simplified State monad. Like State, IO wraps an arbitrary function. Binding two IO monads together using flatMap() creates a bigger function that chains together the two smaller functions. The big difference between IO and State is the parameters and return types of the wrapped functions. For State, the functions take as an input a state value of some arbitrary type. The functions return a tuple holding a new state value as well as a value of some other arbitrary type. For IO, the functions take no input values and simply return a value of an arbitrary type.

Since the functions wrapped in IO take no input parameters, they cannot be purely functional. Pure functions operate only on their inputs, produce no side effects, and return only their output value. But if a function takes no inputs, it must be operating on some external resource, perhaps a resource stored in a global mutable variable or perhaps a resource accessed via I/O. The purpose of the IO monad is to document that a function is not pure but rather produces side effects. While Scala doesn’t require you to put impure functions in the IO monad, it’s a good practice to enclose any methods that perform I/O or operate on mutable variables in the IO monad. (In fact, other languages like Haskell require you to use the IO monad.)

To use the IO monad, you first need to include the following imports:

  import scalaz._, Scalaz._, effect._, IO._

To wrap a function in the IO monad, simply use

  val m = IO { your function }

To call the function wrapped in an IO monad, use method unsafePerformIO().

  val result = m.unsafePerformIO

Here’s a brief example.

  val m: IO[Employee] = for {
    _        <- putStrLn("before reading the database")
    employee <- IO { readEmployeeFromDatabase(employeeId) }
    _        <- putStrLn("after reading the database")
  } yield employee

  val employee: Employee = m.unsafePerformIO

Method putStrLn() is a lot like println(). The only difference is that println() returns a Unit while putStrLn() returns an IO[Unit]. That is, putStrLn() runs in the IO monad, therefore documenting that it is impure.

Note that in the example, nothing is printed to the screen and nothing is read from the database until we call unsafePerformIO(). The for() comprehension simply builds a function from the three smaller functions. The constructed function isn’t actually called until unsafePerformIO() is called.

The Problem

Suppose we’re writing an application that stores and retrieves actors and movies in some sort of NoSQL database. Suppose we have the following class to represent actors as well as the following two services to retrieve actors and movies from the database:

  case class Actor(id: String, name: String, ...)

  trait MovieService {
    def getActorIdsForMovieId(movieId: String): IO[List[String]]
    ...
  }

  trait ActorService {
    def getActorById(actorId: String): IO[Option[Actor]]
    ...
  }

Notice how all of the service methods run in the IO monad. Since these methods access the database, they’re not pure.

Now lets write a function that takes a list of movie ID’s and returns the actors for those movies. The return value will be a Map from movie ID’s to the list of actors in each movie.

  def getActorsForMovieIds(
      movieIds: List[String],
      movieService: MovieService,
      actorService: ActorService): IO[Map[String, List[Actor]]] = {

    val movieIdAndActors = movieIds traverse { movieId =>
      for {
        actorIds <- movieService.getActorIdsForMovieId(movieId)

        actors <- actorIds traverse { actorId =>
          actorService.getActorById(actorId)
        }
      } yield (movieId, actors.flatten)
    }

    movieIdAndActors map { _.toMap }
  }

  ...

  val movieToActors: Map[String, List[Actor]] =
    getActorsForMovieIds(
      List(movieId1, movieId2, ...),
      movieService,
      actorService).unsafePerformIO

The code uses the traverse() method a couple times. This method works a lot like the standard map() method with the only difference being that the function you give to traverse() must return a monad. The traverse() method binds all the resulting monads together into a single monad using flatMap().

The getActorsForMovieIds() function loops through each movie ID. It first fetches the list of actor ID’s for each movie. Then for each actor ID, it calls the actor service to load the actor object. Since the actor service returns an Option, our resulting list of actors is really a list of Option[Actor] objects. We’ll call flatten() on it to get rid of the Option wrappers and throw away any None’s from actors that couldn’t be loaded. The end result is a list of whatever Actor objects could be loaded for the current movie ID.

We’ll build a list of tuples, where each tuple holds a movie ID and a list of actors in that movie. Finally, we convert this list of tuples into a map from movie ID’s to lists of actors.

Since getActorsForMovieIds() wraps its result in an IO method, we have to call unsafePerformIO to get the movie-to-actor map.

This code works, but it’s inefficient. Since actors can be in multiple movies, it’s very possible that the function’s input movies could have overlapping actors. We’ll end up fetching these actors multiple times from the actor service. It would be better to have some sort of cache of actors we’ve already loaded. That way, we don’t have to call out to the actor service if we can find an actor in the cache. The cache will simply be a Map from actor ID’s to Actor objects. Let’s use the State monad to represent this cache. We’ll start with a couple of type definitions:

  type ActorCache = Map[String, Actor]
  type ActorCacheMonad[+A] = State[ActorCache, A]

Now let’s try to modify our ActorService so that it has a getActorByIdWithCache() method.

  THIS CODE WON'T COMPILE!!!

  trait ActorServiceBad1 {
    def getActorById(actorId: String): IO[Option[Actor]]

    def getActorByIdWithCache(actorId: String):
        ActorCacheMonad[Option[Actor]] = {

      for {
        maybeCachedActor <- gets { cache: ActorCache =>
          cache.get(actorId)
        }

        maybeActor <- maybeCachedActor match {
          case Some(cachedActor) =>
            cachedActor.some.point[ActorCacheMonad]

          case None =>
            loadActorInCache(actorId)
        }
      } yield maybeActor
    }

    protected def loadActorInCache(actorId: String):
        ActorCacheMonad[Option[Actor]] = {

      for {
        maybeActor <- getActorById(actorId)
      
        _ <- maybeActor match {
          case Some(actor) =>
            modify { cache: ActorCache => cache + (actorId -> actor) }

          case None =>
            ().point[ActorCacheMonad]
        }
      } yield maybeActor
    }
  }

This code doesn’t work. What it’s trying to do is simple enough. It first looks up the actor in the cache. If the actor is found, it is returned (wrapped in the State monad). If the actor isn’t found in the cache, we call loadActorInCache(), which calls getActorById() to load the actor from the database. If the actor is found, loadActorInCache() modifies the cache to include the newly loaded actor before returning the actor.

The problem is that we’ve created a mess out of our monads. Both getActorByIdWithCache() and loadActorInCache() run in the State monad (using our ActorCacheMonad type). But we’re also trying to use the IO monad inside loadActorInCache(). Recall that in a for() comprehension, each arrow must have the same type of monad on the right side. In loadActorInCache()’s for() comprehension, getActorById() uses the IO monad, while the subsequent match() statement uses the State monad. That’s not allowed. So, this code won’t compile.

Ok, take 2. Let’s try a slightly different approach:

  THIS CODE STILL WON'T COMPILE!!!

  trait ActorServiceBad1 {
    def getActorById(actorId: String): IO[Option[Actor]]

    def getActorByIdWithCache(actorId: String):
        ActorCacheMonad[Option[Actor]] = {

      for {
        maybeCachedActor <- gets { cache: ActorCache =>
          cache.get(actorId)
        }

        maybeActor <- maybeCachedActor match {
          case Some(cachedActor) =>
            cachedActor.some.point[ActorCacheMonad]

          case None =>
            loadActorInCache(actorId)
        }
      } yield maybeActor
    }

    protected def loadActorInCache(actorId: String):
        ActorCacheMonad[Option[Actor]] = {

      for {
        maybeActor <- getActorById(actorId).point[ActorCacheMonad]
      
        _ <- maybeActor match {
          case Some(actor) =>
            modify { cache: ActorCache => cache + (actorId -> actor) }

          case None =>
            ().point[ActorCacheMonad]
        }
      } yield maybeActor
    }
  }

The only change is in loadActorInCache(). We now take the result of getActorById() and wrap it in the State monad. So, the right side of the first arrow is now of type ActorCacheMonad[IO[Option[Actor]]]. All the monads on the right side of the arrows are now ActorCacheMonad’s. Everything should be happy, but the code still doesn’t compile. What gives?

The problem is that we want the maybeActor variable in loadActorInCache() to be of type Option[Actor]. But it’s really of type IO[Option[Actor]] since the arrow operator only strips off the outer ActorCacheMonad leaving everything else. That is, our value is still wrapped up in the IO monad. But then it doesn’t make any sense to try to match an IO[Option[Actor]] against a Some and a None. So the compiler gets angry, and the code doesn’t work.

Damn. Let’s try take 3:

  THIS CODE COMPILES, BUT IT SUCKS!!!

  trait ActorServiceBad1 {
    def getActorById(actorId: String): IO[Option[Actor]]

    def getActorByIdWithCache(actorId: String):
        ActorCacheMonad[Option[Actor]] = {

      for {
        maybeCachedActor <- gets { cache: ActorCache =>
          cache.get(actorId)
        }

        maybeActor <- maybeCachedActor match {
          case Some(cachedActor) =>
            cachedActor.some.point[ActorCacheMonad]

          case None =>
            loadActorInCache(actorId)
        }
      } yield maybeActor
    }

    protected def loadActorInCache(actorId: String):
        ActorCacheMonad[Option[Actor]] = {

      for {
        maybeActor <- getActorById(actorId)
          .unsafePerformIO
          .point[ActorCacheMonad]
      
        _ <- maybeActor match {
          case Some(actor) =>
            modify { cache: ActorCache => cache + (actorId -> actor) }

          case None =>
            ().point[ActorCacheMonad]
        }
      } yield maybeActor
    }
  }

The only difference is in that same stupid call to getActorById() in loadActorInCache(). Now we’re calling unsafePerformIO() to get the value out of the IO monad before wrapping it in the ActorCacheMonad. That is, getActorById() returns an IO[Option[Actor]]. Once we call unsafePerformIO(), we just have an Option[Actor]. Then we wrap this in the ActorCacheMonad to get an ActorCacheMonad[Option[Actor]]. The arrow strips off the ActorCacheMonad part, meaning maybeActor is just an Option[Actor]. That’s exactly what we want! The compiler is happy! The code works!

Unfortunately, the code is constructed pretty poorly. The problem is that loadActorInCache() is not a pure function. It has side effects since it’s calling out to the database. While it’s fine that it’s not pure, nothing about the signature of loadActorInCache() lets us know that it’s not pure. That is, if loadActorInCache() is performing IO, its return value should be wrapped in the IO monad. We’re hiding the fact that it does IO by calling unsafePerformIO() within it. Essentially, we’ve defeated the entire purpose of using the IO monad! If any function can just hide the fact that it does IO by calling unsafePerformIO(), then we might as well not use the IO monad at all. In short, this code works, but it’s a hack.

(It’s worth noting that pure functional languages like Haskell have no equivalent of unsafePerformIO(). So this hack wouldn’t even work in those languages.)

StateT to the Rescue

We need a monad that blends together IO and ActorCacheMonad, allowing us to work with both at the same time. That’s exactly what the StateT monad transformer does. It allows you to take code that’s running in some arbitrary monad (not just an IO monad) and layer a State monad on top of it. That way, you can have some operations that return the outer monad, some operations that return the State monad, and you can mix these operations all together.

To avoid compiler warnings when using the StateT monad, you’ll need to add the following include statement to your code:

  import scala.language.higherKinds

You’ll also need a couple of boilerplate type definitions:

  type ActorCacheMonadT[M[+_], +A] = StateT[M, ActorCache, A]
  type ActorCacheMonadIO[+A] = ActorCacheMonadT[IO, A]

The StateT type has three type parameters: 1) an arbitrary monad M, 2) the type of the state (ActorCache in our example), and 3) an arbitrary type A representing the value that will be wrapped up inside the monad. Our first type definition above locks down the state type while leaving the monad M as arbitrary. Our second type definition locks down the monad as being an IO. Again, it’s worth noting that we could use any monad here, not just an IO. In fact, it’s common to use StateT to combine two different State monads. Perhaps your outer layers of code are running in one State monad and your inner layers need to run in a different State monad.

If we have a value wrapped up in an IO monad, we can easily convert it to a StateT using the liftM() method:

  val m1: IO[Int] = 53.point[IO]
  val m2: ActorCacheMonadIO[Int] = m1.liftM[ActorCacheMonadT]

Similarly, if we have a value wrapped up in a State monad, we can easily convert it to a StateT using the lift() method:

  val m3: ActorCacheMonad[String] = "hello".point[ActorCacheMonad]
  val m4: ActorCacheMonadIO[String] = m3.lift[IO]

In these examples, m1 and m3 are two different monad types (IO and ActorCacheMonad). That means we can’t use them together with flatMap() or a for() comprehension. But once we’ve used liftM() and lift(), m2 and m4 do have the same monad type (StateT). Therefore, we can use m2 and m4 together with flatMap() or a for() comprehension. That’s the power of StateT.

What if you have a raw value not wrapped in a monad? How do you get it into a StateT?

  val m5: ActorCacheMonadIO[Int] = 10.point[ActorCacheMonadIO]

You can just use point() to put the raw value directly into the StateT monad.

Let’s see how we can use StateT to come up with a better version of our loadActorInCache() method:

  protected def loadActorInCache(actorId: String):
      ActorCacheMonadIO[Option[Actor]] = {

    for {

      maybeActor <-
        getActorById(actorId).liftM[ActorCacheMonadT]
    
      _ <- maybeActor match {
        case Some(actor) => (
          modify { cache: ActorCache =>
            cache + (actorId -> actor)
          }
        ).lift[IO]

        case None => ().point[ActorCacheMonadIO]
      }
    } yield maybeActor
  }

The method is now running in the StateT monad instead of the State monad. That is, the return value is ActorCacheMonadIO[Option[Actor]] rather than ActorCacheMonad[Option[Actor]]. Within the for() comprehension, the call to getActorById() returns a value wrapped in IO. So we have to use liftM to get the value into the StateT monad.

Within the match() clause, the Some case calls the modify() method, which returns a value wrapped in the State monad. Its value is ActorCacheMonad[Unit]. We have to use the lift() method to get the value into the StateT monad. The None case simply wraps a Unit value directly into the StateT monad. Thus, both the Some and None cases return ActorCacheMonadIO[Unit].

The loadActorInCache() method now not only compiles, but its return type indicates that it runs in the IO monad and is not a pure function. We’re no longer cheating by calling unsafePerformIO() to hide the fact that it’s doing I/O.

Since loadActorInCache() now runs in ActorCacheMonadIO, we’ll have to change getActorByIdWithCache() to also run in ActorCacheMonadIO rather than ActorCacheMonad. Given what we’ve already covered, the changes are pretty straightforward.

  def getActorByIdWithCache(actorId: String):
      ActorCacheMonadIO[Option[Actor]] = {

    for {
      maybeCachedActor <-
        (gets { cache: ActorCache => cache.get(actorId) }).lift[IO]

      maybeActor <- maybeCachedActor match {
        case Some(cachedActor) =>
          cachedActor.some.point[ActorCacheMonadIO]

        case None =>
          loadActorInCache(actorId)
      }
    } yield maybeActor
  }

Finally, let’s rewrite our getActorsForMovieIds() function so that it uses getActorByIdWithCache() instead of getActorById(). There are a couple of new things to introduce here. First, when you’re using a StateT monad, use method traverseU() instead of traverse(). Why? I have no idea. You just have to remember that with StateT and OptionT, you use traverseU(); with State, you use traverseS(); with most other monads, you use traverse().

Second, we’ll introduce the StateT run() method. First, let’s recall what the State run() method does:

  val m: ActorCacheMonad[Int] = ...

  val (finalState: ActorCache, i: Int) =
    m.run(Map.empty[String, Actor])

For State, we pass the run() method an initial state (an empty cache in this case), and the run() method returns a tuple holding the final state and the final value wrapped up in the monad (an Int in this case). For StateT, the run() method is similar. The only difference is that the returned tuple is wrapped up in the outer monad (an IO in our example). Basically, run() peels off the State layer leaving us with only our outer monad. So we start with only our outer monad, use StateT to layer a State monad on top of it, do some operations involving both the State and the outer monad, and finally call run() to strip off the State layer leaving us right back where we started with just the outer monad.

  val m: ActorCacheMonadIO[Int] = ...

  val IO[(finalState: ActorCache, i: Int)] =
    m.run(Map.empty[String, Actor])

Here’s the code for getting the actors for a list of movies using the cache:

  def getActorsForMovieIdsWithCache(
      movieIds: List[String],
      movieService: MovieService,
      actorService: ActorService):
      IO[Map[String, List[Actor]]] = {

    val movieIdAndActors = movieIds traverseU { movieId =>
      for {
        actorIds <-
          movieService.getActorIdsForMovieId(movieId)
            .liftM[ActorCacheMonadT]

        actors <- actorIds traverseU { actorId =>
          actorService.getActorByIdWithCache(actorId)
        }
      } yield (movieId, actors.flatten)
    }

    val movieIdToActorsMap = movieIdAndActors map { _.toMap }
    movieIdToActorsMap.run(Map.empty[String, Actor]) map { _._2 }
  }

Mostly pretty straightforward. The movieIdAndActors variable holds an ActorCacheMonadIO[List[(String, List[Actor])]]. In movieIdToActorsMap, we convert that wrapped list into a Map so that we have ActorCacheMonadIO[Map[String, List[Actor]]]. We then call the run() method to strip that down to just an IO monad. But run() returns a tuple with two values, the final cache and the movie-to-actor Map. We don’t really care about the final cache. So we’ll transform the result of run() to throw out the final cache value and just return the movie-to-actor Map wrapped up in the IO monad.

That’s all there is to it. We’re now running in the IO monad so that we correctly declare what is impure, and we’re using a cache so that we don’t load actors multiple times.

Summary

The StateT monad allows you to layer a State monad on top of some other arbitrary monad, allowing you to use all the capabilities of both monads at the same time. If you have some state of type SomeState and some monad of type SomeM, here are the rules for using StateT with these types:

Be sure to import scala.language.higherKinds.

Create a type definition for the State monad:

  type SomeStateMonad[+A] = State[SomeState, A]

Create two type definitions for the StateT monad:

  type SomeStateMonadT[M[+_], +A] = StateT[M, SomeState, A]
  type SomeStateMonadSomeM[+A] = SomeStateMonadT[SomeM, A]

If you have a value of type SomeM[A], you convert it into a StateT by calling liftM[SomeStateMonadT] on it.

If you have a value of type SomeStateMonad[A], you convert it into a StateT by calling lift[SomeM] on it.

If you have a raw unwrapped value, you convert it into a StateT by calling point[SomeStateMonadSomeM] on it.

If you have a value of type SomeStateMonadSomeM[A], use the run() method to convert it to a SomeM[(SomeState, A)].

If you need to apply a function that returns a StateT to each item in a list, use traverseU to invoke the function on each list item and bind the resulting StateT’s into a single StateT.

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: