Wednesday, 27 January 2016

Computing edit distance between strings in Scala

Question: Assuming you have a set of n letter dictionary words ($S$), a start word ($s_0 \in S$) and a target word ($s_t \in S$) find the shortest chain of words ($C$) where one letter differs between consecutive words in the chain and the first and last words in the chain are $s_0$ and $s_t$ respectively.

Eg $S = [cat, cab, car, tat, tar, bar], s_0 = cat, s_t = tar $
                      $ => C = [cat -> car -> bar -> tar]$

Top level solution


To solve this programatically, we can see that this is a graph programming problem - we can build a graph representing the data then find the shortest path within the graph from $s_0$ to $s_t$.


Scala has no graph implementation in the standard collections, but there is a neat little project called scala-graph which you can use.

def computeEditDistance(start: String, end: String, dictionary: Set[String]): Option[Int] = {
  val graph = Graph.from(dictionary, computeEdges(dictionary))
  val path = graph.get(start).shortestPathTo(graph.get(end))
  path.map(_.edges.size)
}

How to compute the edges


The interesting part is the computeEdges function. Here is a naive implementation:

private def computeEdgesSlowly(dictionary: Set[String]): List[UnDiEdge[String]] = {
  val (_, edges) 
     = dictionary.foldLeft (dictionary.drop(1), List.empty[UnDiEdge[String]]) {
        case (acc, word) =>
          val (unvisitedWords, edgeList) = acc
          val edgesForThisWord = unvisitedWords.collect { 
            case (w) if differsByOneLetter(w, word) 
               => word ~ w 
          }
          (unvisitedWords.drop(1), edgeList ++ edgesForThisWord)
      }
  edges
}

For every word in the list we check it against every other unvisisted word in the list meaning that this algorithm is O(n^2) where n = the number of words in the dictionary.

Eg

cat is compared to  [ cab, car, tat, tar, bar ]
cab is compared to [ car, tat, tar, bar ]
car is compared to  [ tat, tar, bar ]
tat is compared to   [ tar, bar ]
tar is compared to   [ bar ]

Cleverer algorithm for computing edges

It is possible to compute the edges in O(n) which is a huge improvement on the implementation above. :D :D

Example:  [ cat, cab, car, tat, tar, bar ]

Step 1: 

Firstly we compute 'subs' for each word in the dictionary

    'cat' ->  [ '?at'  , 'c?t' ,  'ca?' ]
    'cab' -> [ '?ab' , 'c?b' ,  'ca?' ]
    etc...

val mapOfSubs = dictionary.map { word =>
  val subs = (0 until word.length) map { case i =>
    word.patch(i, Seq('?'), 1)
  }
  word -> subs
}.toMap

Step 2:

We build a hashmap from the subs

    '?at' ->  [ 'cat', 'tat' ]
    'c?t' ->  [ 'cat' ]
    'ca?' -> [ 'cat', 'cab', 'car' ]
    '?ab' -> [ 'cab' ]
    'c?b' -> [ 'cab' ]
    etc...

val cachedMap = mutable.HashMap.empty[String, Set[String]]
dictionary.foreach { word =>
  val subs = mapOfSubs.get(word)
  subs.foreach(o => o.foreach {
    sub => 
      cachedMap(sub) 
        = Set(word) ++ cachedMap.getOrElse(sub, Set.empty[String])
  })
}

Step 3:

Now we build up the list of edges. For each word, we get 3 subs for each word. Then we look up each of these 3 subs in the above hashmap to get all the nodes for which the current words has an edge with. We keep a list of nodes we have visited so that we don't add each node twice (the edges are undirected).

val visited = mutable.Seq.empty[String]
dictionary.flatMap { word =>
  word :+ visited
  val subs = mapOfSubs(word)
  subs.flatMap { s =>
    val links = cachedMap(s)
    links.collect { 
      case l if !visited.contains(l) 
        => word ~ l // ~ = undirected edge    }
  }
}.toSeq


Full code is here.


No comments:

Post a Comment

Scala with Cats: Answers to revision questions

I'm studying the 'Scala with Cats' book. I want the information to stick so I am applying a technique from 'Ultralearning...