Scalaz State Monad

Suppose you are given a function words() that takes a string as a parameter and returns a list of all the words within that string. Now suppose you want to write a new function that counts the number of times each word within a string is used. Here’s a first-pass at the function:

def wordCounts(str: String): Map[String, Int] = {
  words(str).foldLeft(Map.empty[String, Int]) { (map, word) =>
    val count = map.getOrElse(word, 0) + 1
    map + (word -> count)
  }
}

I’ll use immutable data structures throughout this article. (Explaining the benefits of immutability is beyond the scope of this blog post.) The above function creates a map of words to their counts by folding an empty map across the words. As each word is encountered, a new map is created with an updated word-to-count mapping.

So far, so good. But suppose we have an Article object containing separate fields for the headline, abstract, and body. What if we want to get a word count that includes words from all of these fields? We could call our wordCounts() function for each of the three fields, but then we would have three separate maps that would need to be merged. To avoid having to write that merging code, let’s change our wordCounts() function so that it takes an input map:

def wordCounts(str: String, currMap: Map[String, Int]): Map[String, Int] = {
  words(str).foldLeft(currMap) { (map, word) =>
    val count = map.getOrElse(word, 0) + 1
    map + (word -> count)
  }
}

We barely had to change the function at all. Now, it takes an input map and transforms it to produce a new map. You can think of our new wordCounts() function as taking in the current state as a parameter and transforming it to produce a new state. We can now build the word counts for all of the article fields as follows:

val map0 = Map.empty[String, Int]
val map1 = wordCounts(article.headline, map0)
val map2 = wordCounts(article.abstract, map1)
val map3 = wordCounts(article.body, map2)
doSomethingWith(map3)

That works. At the end of that block of code, map3 contains the combined word counts across all three article fields.

But, that code is really ugly. We have to introduce a bunch of variables to hold intermediate states. It’s not terrible since we only have a few intermediate states, but it’s not hard to imagine problems where we could have 10, 20, or more intermediate states.

Also, what happens later if we need to add another field to the word counts, perhaps a sub-headline field? We’ll have to introduce a map4 variable and change any downstream references to map3 to use map4. That’s very error-prone. Similarly, if we decide we don’t want to include the article.abstract in the word counts, things get messy.

Philosophically, there’s a deeper problem here. Our state (the map we’re building) is changing over time. At any point in time, there’s exactly one correct state. But the code above forces us to keep variables around holding old, obsolete versions of the state. For example, even though map3 holds the only valid state at the end of the code, the previous states in map0, map1, and map2 are still available. That’s philosophically wrong. We’ve moved on from those states, and our code should no longer be allowed to access them.

So, our basic problem is as follows: How do we chain together arbitrary numbers of operations that immutably manipulate some state so that

  1. any operation can read the current state
  2. any operation can replace the current state with a new state
  3. no operation has access to prior obsolete states

We’ll see below how the state monad allows us to solve this problem.

The State Monad

Don’t worry about what a monad is. We won’t talk in this blog post about that. Instead, we’re only going to focus on how to use the state monad. So, don’t get hung up on the funny name.

State monads are part of the scalaz library. To use them, you’ll need to change your build.sbt to include the following:

libraryDependencies ++= Seq(
 "org.scalaz" %% "scalaz-core" % "7.0.0",
 "org.scalaz" %% "scalaz-effect" % "7.0.0",
 "org.scalaz" %% "scalaz-typelevel" % "7.0.0",
 "org.scalaz" %% "scalaz-scalacheck-binding" % "7.0.0" % "test"
)

You’ll also need to add the following to the top of your .scala files:

import scalaz._
import Scalaz._

The state monad is built around the State[S, A] class. State takes 2 type parameters S and A, where S is the type of the state, and A is the type of some extra value. Don’t worry too much about what S and A mean just yet. They’re just arbitrary types as far as we’re concerned right now.

Objects of type State hold a single field which is a function that takes an object of type S as a parameter and returns a tuple holding an object of type S and an object of type A. It’s important to realize that this function is not a method on State but rather a data field within the State object. That is, different State objects can have different functions inside them. Any function you can dream up that takes an S object as a parameter and returns an (S, A) tuple can be stored in a State. Here’s an example:

val m1 = State { s: String => (s, s.size) }

In this case, we’re constructing a State with a function that takes a string as a parameter (here S is a String) and returns a tuple holding the exact same string and its size (so A is an Int in this example).

Here’s a slightly more complex example:

def repeat(num: Int): State[String, Unit] = State { s: String => (s * num, ()) }

Function repeat() returns a State object. Nothing special there. We’ve all written thousands of functions that return objects. The returned State object holds a function that takes a string as a parameter and returns a tuple holding the string repeated over and over num times and a value of type Unit. In this example, S is still a String. But now A is the Unit type.

Take a look at the (S, A) tuples produced in the two examples. In the first, the S value is unchanged and a meaningful A value is produced. In the second example, though, the S value is changed and the A value is a fairly uninteresting Unit. We are free to do whatever we want in our functions and do something interesting with the S value, the A value, or both the S and A values.

What can we do once we have a State object? The simplest thing is fetching the function held within the State so that we can call it. We fetch the function using the State’s run() method. Since the function takes an object of type S as a parameter, we can call it by passing it an object of type S. The result is a tuple of type (S, A). Here are some examples:

m1.run("hello")        ====>   ("hello", 5)
repeat(3).run("abc")   ====>   ("abcabcabc", ())

No magic here. We’re just extracting out the functions we created above and calling them by passing in values for their string parameters.

flatMap

We can use the State object’s flatMap() method to combine together two individual State objects into a bigger, unified object. For a state of type State[S, A], the flatMap() method has the following signature:

flatMap[B](f: A => State[S, B]): State[S, B]

That is, flatMap() takes a function as a parameter that transforms a value of type A into a State[S, B]. You can think of flatMap() as pulling the A value out of the initial State[S, A] object and passing it to function f(). Function f() then transforms the A value to produce a new State[S, B] object. The flatMap() method then transforms that resulting State[S, B] object to produce yet another State[S, B] object that is then returned.

Remember that at its core, every State[S, B] object simply holds a function that takes an object of type S as a parameter and returns a tuple of type (S, B). So, flatMap() is just returning us a new function wrapped in a State object. But where does this function come from? The flatMap() method actually constructs a brand new function that chains together the smaller functions. An example will make this more clear:

m1.flatMap(repeat).run("hello")   ====> ("hellohellohellohellohello", ())

Let’s step through this carefully. Object m1 holds a function that takes a String and returns a (String, Int) where the Int is the length of the input string. The flatMap() method pulls out the Int value and passes it to the repeat() function. The repeat() function returns a function that takes a String and returns a (String, Unit), where the resulting string is the input string repeated over and over.

The flatMap() method then builds a new function that takes a String as a parameter, invokes the function inside of m1, takes the resulting String and passes it into the function in the State object returned by repeat(), and then returns the final String returned by that function. I’ll show some diagrams in a bit that will explain this better. But for now, just realize that flatMap() builds a new function that chains together the function in the m1 State[String, Int] object and the function in the repeat() State[String, Unit] object.

Once we have this new chained together function, we can call it by passing it a string (“hello” in the example above) and the sub-functions are called and chained together. The first sub-function gets the length of the string. The second sub-function repeats the string by the “length” number of times. Since “hello” has 5 characters, the final string has “hello” repeated 5 times.

Note that we chain together as many flatMap() calls as we like, for example:

m1.flatMap(repeat).flatMap({ _ => m1 }).run("hello")
     ====>   ("hellohellohellohellohello", 25)

Here we’ve added another function to the chain. This new function takes a parameter (which we ignore) and returns the m1 State[String, Int] object. Remember, this object holds a function that takes a string and returns the length of that string.

The previous State in the chain (returned by the repeat() function) holds a function that manipulates its string parameter and returns a result of type (String, Unit). The flatMap() method extracts the Unit value and passes it to our new { _ => m1 } function. Since the new function just takes a Unit as a parameter, you can see why we’re ignoring it with an _ since Unit isn’t particularly interesting.

But, what is interesting is how the functions are all chained together. We start by passing in “hello” as the string. That’s passed into the first sub-function, which computes its size and returns the “hello” string unchanged. The second sub-function is then called. It changes the string to be “hello” repeated 5 times. The third sub-function is then called with “hello” repeated 5 times as a parameter. It then computes the length of this string as 25 and returns that value (along with the unchanged long string).

What the hell?

Ok. So the flatMap() method is a bit magical. Here’s a diagram that may explain better how things are chained together in the previous example:

state_chaining

Here’s the sequence of events:

  1. The string “hello” is passed into the function { s: String => (s, s.size) }.
  2. The function returns (“hello”, 5).
  3. 5 is passed to the repeat() function.
  4. The repeat() function returns the function { s: String => (s * num, ()) }.
  5. “hello” is passed into the function returned in step 4.
  6. The result is (“hellohellohellohellohello”, ()).
  7. () is passed into function { _ => m1 }.
  8. That function returns the function { s: String => (s, s.size) }.
  9. “hellohellohellohellohello” is passed into the function from step 8.
  10. The result is (“hellohellohellohellohello”, 25).

Maybe a bit less magical now. You should be able to see the pattern and how it can be extended to chain together any number of State objects. The key is that State’s hold functions, flatMap() builds a new function that chains together the State functions, and you use the run() method to execute the final function, passing it an initial S (String in our example) to process.

The following diagram shows a more general case:

state_chaining2

Reading and changing the state

If we think of the String flowing through the above example as being “the state”, there’s something very interesting going on. The function in m1 reads the current state (so it can compute the string length). Furthermore, depending on where we use m1 in the chain, it gets a different value when it reads the state. That is, the first time we use it, the state is “hello”. The second time, the state is “hellohellohellohellohello”. Similarly, the function in the State returned by repeat() reads the current state. It also modifies the state by changing the String that it returns.

So, we can always access the current state anywhere in the chain even though we never store it off in variables like we did in the word counting example. We never have access to old, obsolete states. We only ever have access to the current state.

The state monad provides two functions get() and put() that make it easy to read or modify the state at any point during the chain. Let’s first look at get():

def get[S] = State { s: S => (s, s) }

The get() function takes no parameters. It returns a State[S, S] object. That is, both of the type parameters in the resulting object are of type S. The resulting State[S, S] object contains a function that returns the input S value unchanged, and it puts the exact same value into the second field of the resulting tuple.

Remember that flatMap() pulls the second value out of the resulting tuple and passes it to the function passed into flatMap(). That means you can call get(), then flatMap() the result, and your function you pass to flatMap() will have access to the current state.

Before showing this in action, let’s first introduce flatMap()’s closely related cousin map(). Here are the signatures for map() and flatMap() side-by-side:

map[B](f: A => B): State[S, B]
flatMap[B](f: A => State[S, B]): State[S, B]

You can see that the methods are almost identical. The only difference is the return value of the parameter function f. For flatMap(), f returns a State[S, B]. For map(), however, f just returns a B. The map() method must wrap this B up into a new State[S, B] object.

Another way of looking at this is that the function passed into map() can read the A value, but it can neither read nor change the S value. Otherwise, map() is similar to flatMap().

Adding in get() and map(), we can now write our “hello”-repeating example like this:

get[String]
  .flatMap({ s0 => repeat(s0.size) })
  .flatMap({ _ => get[String]})
  .map({ s1 => s1.size })
  .run("hello")

    ====>    ("hellohellohellohellohello", 25)

Here, weve gotten rid of our m1 object. Instead, we use get() to fetch the current state string. So, the s0 parameter in the function passed to the first flatMap() holds the string “hello”.

Let’s introduce one last function, put(), and then I’ll show another diagram of how this is all chained together. Function put() takes a new S value as a parameter, and it changes the current state to be that value. Here’s the definition of put():

put[S](newState: S) = State { s: S => (newState, ()) }

The result is a State[S, Unit] object holding a function that ignores its input state parameter and instead returns the tuple (newState, ()). So, if we flatMap put() with another function, when the other function reads the state, it will now read the new state.

Here’s one final modification to our example that gets rid of our repeat() function and uses put() instead:

get[String]
  .flatMap({ s0 => put(s0 * s0.size) })
  .flatMap({ _ => get[String]})
  .map({ s1 => s1.size })
  .run("hello")

    ====>    ("hellohellohellohellohello", 25)

And here’s the diagram that shows it all in action:

state_chaining3

Using “for comprehensions”

It turns out that Scala’s “for comprehensions” are just syntactic sugar for flatMap() and map(). Basically, if you have some container M[A] (like List, Option, or even State!) that defines flatMap() and map(), then the following

m.flatMap({ a => f(a) }).flatMap({ b => g(b) }).flatMap({ c => h(c) }).map({ d => i(d) })

is equivalent to

for {
  a <- m
  b <- f(a)
  c <- g(b)
  d <- h(c)
} yield i(d)

Here, m is an instance of M[A], f() returns an M[B], g() returns an M[C], h() returns an M[D], and i() returns an E. The result of the entire expression in both cases is M[E].

The “for comprehension” form is much easier to read.

When used with a State object, the arrows in the “for comprehension” can be thought of almost like assignment operators. In the expression

a <- foo(...)

function foo() must return a State[S, A] where A is some arbitrary type. You can think of the arrow as pulling out the A value from the resulting State[S, A] object and assigning it to the variable a.

The yield expression at the end of the “for comprehension” returns some arbitrary type A. The resulting value will be wrapped up in a State[S, A] object.

All State objects on the right-hand side of the arrows will be chained together using flatMap() as will the State object generated by the yield (though it will be added to the chain via map()). The result is a single State object containing a function that calls the whole chain. This is exactly the same as explicitly using flatMap() and map(), just easier to read.

So, here’s our example using a “for comprehension”:

val m = for {
  s0 <- get[String]
  _  <- put(s0 * s0.size)
  s1 <- get[String]
} yield s1.size

m.run("hello")    ====>    ("hellohellohellohellohello", 25)

Better ways of reading and changing the state

Sadly, we’ve done a lot of work, wracked our brains, and done a lot of strange tricks only to end up with a solution that suffers from the same basic problem as our original word counting code. If you look at our “for comprehension” solution, it fetches the current state and stores it in s0. Then, it changes the state with the call to put(), meaning that the state in s0 is now obsolete. But s0 is still in scope. So, we can continue to use the state in s0 even though it’s obsolete–a recipe for bugs! For example, it would be easy to accidentally yield s0.size instead of s1.size, and that’s not at all what we want!

Luckily, there are alternatives to get() and put() that fix this problem. Let’s start with function modify():

modify[S](f: S => S) = State { s: S => (f(s), ()) }

Function modify() takes a function f that transforms the state. Function f is passed the current state, and it transforms that state somehow to produce a new state. Using modify(), our example now becomes

val m = for {
  _  <- modify { s: String => s * s.size }
  s1 <- get[String]
} yield s1.size

m.run("hello")    ====>    ("hellohellohellohellohello", 25)

The s0 variable is gone. So, we no longer have to worry about continuing to hold a reference to that state after it becomes obsolete. However, we still are storing the final state in variable s1.

We can get rid of s1 using function gets():

gets[S, A](f: S => A) = State { s: S => (s, f(s)) }

Like modify(), gets() also is passed a function f. But in this case, function f doesn’t transform the input state into another state. Instead, it transforms the state into a value of an arbitrary type A. The result of f is not put into the S field of the resulting tuple; rather, it goes in the A field. That is, we take the state, transform it to produce a value of an arbitrary type, and then put this in the second tuple field so that it is available in the next flatMap() call in the chain.

Here’s the final revision of the example. We no longer store any states that could be in danger of being obsolete.

val m = for {
  _    <- modify { s: String => s * s.size }
  size <- gets { s: String => s.size }
} yield size

m.run("hello")    ====>    ("hellohellohellohellohello", 25)

Our code is nice and clean. We’ve finally uncovered a pattern that solves the problem posed at the beginning of this blog post.

Re-examining our word count example

Let’s re-write our word count example using state monads. Our type S is Map[String, Int], meaning our states will be of type State[Map[String, Int], A], for some arbitrary type A.

def wordCounts(str: String) = modify { currMap: Map[String, Int] =>
  words(str).foldLeft(currMap) { (map, word) =>
    val count = map.getOrElse(word, 0) + 1
    map + (word -> count)
  }
}

The new wordCounts() function returns an object of type State[Map[String, Int], Unit]. It replaces the map in the current state with a new map reflecting the word counts for the input string.

Our code to add the various article fields to the map now looks like this:

val m = for {
  _ <- wordCounts(article.headline)
  _ <- wordCounts(article.abstract)
  _ <- wordCounts(article.body)
} yield ()

val (wordMap, _) = m.run(Map.empty[String, Int])
doSomethingWith(wordMap)

Throughout this code, we’re not really using the A value in our State[Map[String, Int], A] objects. This is because wordCounts() always uses Unit for the A value. In other words, wordCounts() isn’t really returning a value. Instead, it’s just being called for its side effects on the state. But, our “for comprehension” must yield some value. So, we just yield a Unit value.

This means that the resulting value in m is of type State[Map[String, Int], Unit]. When we call the run() method, we’ll get back a tuple of (State[Map[String, Int]], Unit). We don’t really care about the Unit part of that value. So we’ll ignore it and just keep the final resulting wordMap.

What if we’re not chaining together a fixed set of operations?

We have one last trick to go learn, and then we’ll basically know how to use the state monad in our code. In the previous examples, we’ve always had a fixed set of operations to chain together. But what if we have an unbounded set?

For example, instead of building the word-count map for a single article, suppose we have a list of articles and want to build a single map with the counts of all words in all the articles.

Let’s first define the following function to get the word counts for a single article:

def wordCountsForArticle(article: Article) = for {
  _ <- wordCounts(article.headline)
  _ <- wordCounts(article.abstract)
  _ <- wordCounts(article.body)
} yield ()

This method returns a State[Map[String, Int], Unit], and it manipulates the map embedded in the state to include all words in the current article. That should be clear enough.

We could try processing the list of articles like this:

articles map wordCountsForArticle

That works, but the result has the following type:

List[State[Map[String, Int], Unit]]

That is, we have a list of State objects instead of just a single State object. We can no longer call them by simply invoking one run() method. We have a whole list of run() methods to call! That’s no good.

To get this down to a single State object, we have to bind all of the State objects in our list together using flatMap(). By calling flatMap() enough times, we can build a function that chains all the State’s together into a single resulting State. Then, we just call the run() method on that resulting State.

We can flatMap() all the State’s together as follows:

val ms = articles map wordCountsForArticle
val m = ms.foldLeft(State { s: Map[String, Int] => (s, ()) }) { (resultM, currM) =>
  resultM flatMap { _: Unit => currM }
}
val (wordMap, _) = m.run(Map.empty[String, Int])
doSomethingWith(wordMap)

Nothing too fancy going on there. We’re just using a foldLeft() operation to flatMap() across the State’s in the list. Since foldLeft() requires us to seed it with a State, we just seed it with a State containing a do-nothing function.

It would be a real drag if we had to rewrite essentially this same code every time we had an unbounded list of operations to chain together. Luckily, there’s a method traverseS() that combines the work of the map and the fold. Method traverseS() works kind of like map in that it applies a function to each item in a collection. The function that it applies, though, must return a State. Method traverseS() then takes the resulting State’s and automatically chains them together using flatMap(). Moreover, it collects up the A values inside each State and packages them into a list.

So, if the function you pass to traverseS() returns objects of type State[S, A], the result of traverseS() will be State[S, List[A]]. In our word counting example, we’ll end up with State[S, List[Unit]]. Obviously, a list of Unit’s is kind of useless (and a bit humorous). But collecting the values into a list can be very useful in other problems.

Using traverseS(), here’s our final solution to word counting:

val m = articles traverseS wordCountsForArticle
val (wordMap, _) = m.run(Map.empty[String, Int])
doSomethingWith(wordMap)

Conclusion

We’ve seen how to use scalaz’s state monad to write code that can read and modify the current state without ever having to store off in local variables intermediate states that may become obsolete. Whenever you find yourself writing lots of state-transforming functions that take input states and return output states, consider using the state monad to simplify your code.

Admittedly, the state monad has a tough learning curve; it takes a bit of time to get your head around it. You could make an argument that it’s not worth the complexity, and you should just write the word counting function the way I did at the start of this blog post.

But I think that complexity argument is a false argument. The state monad is a well-established pattern in functional programming. Yes there’s a learning curve, but once you’re over the the learning curve, the state monad is quite easy to use. It’s something that experienced functional programmers will recognize quite easily in your code. It will actually make your code more readable, not less (to an experienced functional programmer). And, it will decrease the likelihood of bugs in your code.

If you’ve gone down the Scala route in your coding (as opposed to something more mainstream like Java or PHP), you’ve already embraced the idea that investing in getting over the learning curve now pays off big-time in increased productivity down the road. Adding the state monad to your tool belt will increase your productivity even more.

Advertisements

8 Comments

Add yours →

  1. The clearest explanation of the State Monad in Scala I have ever read.

  2. I have a question regarding the 5th paragraph in the flatMap section. You start out by saying:

    The flatMap() method *then* builds a new function…

    I don’t understand this. Doesn’t flatMap() simply return the function that was described in the previous paragraph?

    • Consider just this part of the statement: m1.flatMap(repeat)

      m1 holds a function. The enclosed function takes a string as a parameter and returns a tuple containing the input string unchanged and the length of the string. But it’s important to realize that we’re not calling this enclosed function here. We’re just listing the m1 object that holds the function inside it.

      We are calling the repeat() function, though. But the return value of the repeat function is a wrapper around another function. That is, repeat() is not returning a normal type of value but rather is creating a new function on the fly (a closure) and returning it wrapped in another State object. The function returned by repeat() takes a string as a parameter and returns a tuple containing the string repeated over and over and a dummy Unit value. But it’s important to realize that while the repeat() function is called in the above statement, the function retruned by repeat() is not being called.

      So, we’re left with 2 different functions above, neither of which is called in the above statement. Both functions take strings as parameters and both return tuples where the first value in the tuple is a string.

      flatMap() takes those 2 functions and stitches them together to create a brand new function. This new function also takes a string as a parameter and returns a tuple. The new function created by flatMap() first calls the function enclosed inside m1 and then it takes the results and calls the function returned by repeat().

      But, the function created by flatMap() still is not being called yet. But at least we’re down from having 2 functions floating around to now just having a single function.

      So, flatMap() just glues together two smaller functions into one bigger function.

      Finally, we call run(“hello”). That invokes the function returned by flatMap(). That function in turn passes the string “hello” to the function held inside m1 and then passes the results from m1 into the function returned by repeat(). Then, the results of that are returned.

      It’s a bit hard to get your head around all this. It took me a lot of time before it made sense to me. The key thing to remember is that while you’re calling flatMap(), in a sense, your code isn’t really doing anything. Instead, it’s dynamically building up a function. Then, at the end, you call the run() method, and that runs the function you built up. In a sense, all the real work is done when you run the function.

      That’s a bit of an inversion of the way programs normally work.

      • Now I think I am more confused. You say that repeat() is called and that flatMap stitches repeat()’s returned function together with m1’s function. However, repeat() isn’t called until you invoke flatMap() in the first place! When was return() called? Where did flatMap() get the two functions to stitch together unless it was inside flatMap() calling repeat() to begin with?

        So, is the call to flatMap() like a function composition of sorts?

        Also, it is confusing that in one case flatMap() is taking the Int from the tuple and passing it into the provided function (repeat) and the next case it passes the String (the anonymous function). What is happening here? How does it decide?

        I have been working with Scala for only about 6 months and it is my first real functional language (except for some Lisp – especially with Emacs but that doesn’t really count). I have come to grips with Monads in general and have looked at them in Scala, Haskell, and F#. However, I have been struggling with the State Monad for the last several weeks and haven’t been able to really understand it, regardless of the language. Every time I think I am closing in on it it slips away. Apparently I must have something wrong with the model I am using to try and understand it.

  3. Oh, wait. Just after posting that reply I noticed something.

    Apparently, flatMap() is taking the Int in the first call and passing that to repeat() to get the function (taking a String and returning a tuple). This is stitched together with the function from m1 and returns a new State object with this function.

    Then, flatMap() is taking the Unit in the second call and passing it to the anonymous function which returns m1 which is a State. However, I assume that State must have an apply() function which allows it to simply use its function (taking a String and returning a tuple).

    This would explain the diagram ‘state_chaining1.png’. Is this close?

    • You’re pretty close. 🙂 I took a few poetic liberties in my previous reply that might have confused things. You are correct that function repeat() isn’t called right away. But still, at the end of the day, repeat() and m1 both result in functions that need to be stitched together. Here’s the weird thing, though… flatMap() isn’t what calls repeat() either. Instead, all flatMap() does is build a new function. This new function is what does all the work. But, the new function isn’t actually called until the very end when we call run(). So, really run() is what calls repeat().

      In the blog post, I tried to avoid going into the implementation details of flatMap(). That’s usually where most explanations start, and I think it’s too easy to get confused by the implementation. But, it might be helpful to go into a bit more detail about what flatMap() does.

      First, all flatMap() does is construct a new function (and wrap that function in a state monad). In the case of “m1.flatMap(repeat)”, the new function takes a String as a parameter and returns a (String, Unit) tuple. Other than function construction, no other work is actually done until that function is called (by run()).

      When the new function is executed later on, the first thing it does is extract the function out of m1 and run it. m1’s function takes a String as a parameter and returns a (String, Int) tuple. So, to call this function, we need some String. Recall that the new function takes a String as a parameter. So, we can just pass the new function’s String parameter as the String parameter to m1’s function. That gives us a (String, Int) tuple. In the following description, I’ll call that String the “new string”.

      Next, repeat() is called with the Int parameter. The result is a new function that takes a String as a parameter and returns a (String, Unit). (That new function is wrapped in a State monad.)

      That new function then has to be invoked with the “new string” from the m1 function call. The resulting (String, Unit) is the final return value from the function created by flatMap().

      • Thanks. This is making more sense, I think. I have a reasonable handle on Monads in general but every time I come to the State Monad I get very confused. Whether the discussion is in Haskell, Scala, or F#, the explanation has always left me wondering. I think this explanation may have broken through the barrier. I will have to give it some thought. 🙂

  4. I was struggling to understand this monad… and this was the post with the best explanation so far.. Thank you ! great post …

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s

%d bloggers like this: