Memoization in Haskell?

Angel de Vicente picture Angel de Vicente · Jul 8, 2010 · Viewed 22.1k times · Source

Any pointers on how to solve efficiently the following function in Haskell, for large numbers (n > 108)

f(n) = max(n, f(n/2) + f(n/3) + f(n/4))

I've seen examples of memoization in Haskell to solve fibonacci numbers, which involved computing (lazily) all the fibonacci numbers up to the required n. But in this case, for a given n, we only need to compute very few intermediate results.

Thanks

Answer

Edward KMETT picture Edward KMETT · Jul 9, 2010

We can do this very efficiently by making a structure that we can index in sub-linear time.

But first,

{-# LANGUAGE BangPatterns #-}

import Data.Function (fix)

Let's define f, but make it use 'open recursion' rather than call itself directly.

f :: (Int -> Int) -> Int -> Int
f mf 0 = 0
f mf n = max n $ mf (n `div` 2) +
                 mf (n `div` 3) +
                 mf (n `div` 4)

You can get an unmemoized f by using fix f

This will let you test that f does what you mean for small values of f by calling, for example: fix f 123 = 144

We could memoize this by defining:

f_list :: [Int]
f_list = map (f faster_f) [0..]

faster_f :: Int -> Int
faster_f n = f_list !! n

That performs passably well, and replaces what was going to take O(n^3) time with something that memoizes the intermediate results.

But it still takes linear time just to index to find the memoized answer for mf. This means that results like:

*Main Data.List> faster_f 123801
248604

are tolerable, but the result doesn't scale much better than that. We can do better!

First, let's define an infinite tree:

data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
    fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

And then we'll define a way to index into it, so we can find a node with index n in O(log n) time instead:

index :: Tree a -> Int -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
    (q,0) -> index l q
    (q,1) -> index r q

... and we may find a tree full of natural numbers to be convenient so we don't have to fiddle around with those indices:

nats :: Tree Int
nats = go 0 1
    where
        go !n !s = Tree (go l s') n (go r s')
            where
                l = n + s
                r = l + s
                s' = s * 2

Since we can index, you can just convert a tree into a list:

toList :: Tree a -> [a]
toList as = map (index as) [0..]

You can check the work so far by verifying that toList nats gives you [0..]

Now,

f_tree :: Tree Int
f_tree = fmap (f fastest_f) nats

fastest_f :: Int -> Int
fastest_f = index f_tree

works just like with list above, but instead of taking linear time to find each node, can chase it down in logarithmic time.

The result is considerably faster:

*Main> fastest_f 12380192300
67652175206

*Main> fastest_f 12793129379123
120695231674999

In fact it is so much faster that you can go through and replace Int with Integer above and get ridiculously large answers almost instantaneously

*Main> fastest_f' 1230891823091823018203123
93721573993600178112200489

*Main> fastest_f' 12308918230918230182031231231293810923
11097012733777002208302545289166620866358