Implementing Text Justification with Dynamic Programming

bertzzie picture bertzzie · Aug 13, 2013 · Viewed 18.6k times · Source

I'm trying to understand the concept of Dynamic Programming, via the course on MIT OCW here. The explanation on OCW video is great and all, but I feel like I don't really understand it until I implemented the explanation into code. While implementiing, I refer to some notes from the lecture note here, particularly page 3 of the note.

The problem is, I have no idea how to translate some of the mathematical notation to code. Here's some part of the solution I've implemented (and think it's implemented right):

import math

paragraph = "Some long lorem ipsum text."
words = paragraph.split(" ")

# Count total length for all strings in a list of strings.
# This function will be used by the badness function below.
def total_length(str_arr):
    total = 0

    for string in str_arr:
        total = total + len(string)

    total = total + len(str_arr) # spaces
    return total

# Calculate the badness score for a word.
# str_arr is assumed be send as word[i:j] as in the notes
# we don't make i and j as argument since it will require
# global vars then.
def badness(str_arr, page_width):
    line_len = total_length(str_arr)
    if line_len > page_width:
        return float('nan') 
    else:
        return math.pow(page_width - line_len, 3)

Now the part I don't understand is on point 3 to 5 in the lecture notes. I literally don't understand and don't know where to start implementing those. So far, I've tried iterating the list of words, and counting the badness of each allegedly end of line, like this:

def justifier(str_arr, page_width):
    paragraph = str_arr
    par_len = len(paragraph)
    result = [] # stores each line as list of strings
    for i in range(0, par_len):
        if i == (par_len - 1):
            result.append(paragraph)
        else:
            dag = [badness(paragraph[i:j], page_width) + justifier(paragraph[j:], page_width) for j in range(i + 1, par_len + 1)] 
            # Should I do a min(dag), get the index, and declares it as end of line?

But then, I don't know how I can continue the function, and to be honest, I don't understand this line:

dag = [badness(paragraph[i:j], page_width) + justifier(paragraph[j:], page_width) for j in range(i + 1, par_len + 1)] 

and how I'll return justifier as an int (since I already decided to store return value in result, which is a list. Should I make another function and recurse from there? Should there be any recursion at all?

Could you please show me what to do next, and explain how this is dynamic programming? I really can't see where the recursion is, and what the subproblem is.

Thanks before.

Answer

Joohwan picture Joohwan · Aug 13, 2013

In case you have trouble understanding the core idea of dynamic programming itself here is my take on it:

Dynamic programming is essentially sacrificing space complexity for time complexity (but the extra space you use is usually very little compared to the time you save, making dynamic programming totally worth it if implemented correctly). You store the values from each recursive call as you go (e.g. in an array or a dictionary) so you can avoid computing for the second time when you run into the same recursive call in another branch of the recursion tree.

And no you do not have to use recursion. Here is my implementation of the question you were working on using just loops. I followed the TextAlignment.pdf linked by AlexSilva very closely. Hopefully you find this helpful.

def length(wordLengths, i, j):
    return sum(wordLengths[i- 1:j]) + j - i + 1


def breakLine(text, L):
    # wl = lengths of words
    wl = [len(word) for word in text.split()]

    # n = number of words in the text
    n = len(wl)    

    # total badness of a text l1 ... li
    m = dict()
    # initialization
    m[0] = 0    

    # auxiliary array
    s = dict()

    # the actual algorithm
    for i in range(1, n + 1):
        sums = dict()
        k = i
        while (length(wl, k, i) <= L and k > 0):
            sums[(L - length(wl, k, i))**3 + m[k - 1]] = k
            k -= 1
        m[i] = min(sums)
        s[i] = sums[min(sums)]

    # actually do the splitting by working backwords
    line = 1
    while n > 1:
        print("line " + str(line) + ": " + str(s[n]) + "->" + str(n))
        n = s[n] - 1
        line += 1