What is the fastest way to sum a collection in Scala

Tala picture Tala · Jun 23, 2010 · Viewed 38.8k times · Source

I've tried different collections in Scala to sum it's elements and they are much slower than Java sums it's arrays (with for cycle). Is there a way for Scala to be as fast as Java arrays?

I've heard that in scala 2.8 arrays will be same as in java, but they are much slower in practice

Answer

Rex Kerr picture Rex Kerr · Jun 23, 2010

Indexing into arrays in a while loop is as fast in Scala as in Java. (Scala's "for" loop is not the low-level construct that Java's is, so that won't work the way you want.)

Thus if in Java you see

for (int i=0 ; i < array.length ; i++) sum += array(i)

in Scala you should write

var i=0
while (i < array.length) {
  sum += array(i)
  i += 1
}

and if you do your benchmarks appropriately, you'll find no difference in speed.

If you have iterators anyway, then Scala is as fast as Java in most things. For example, if you have an ArrayList of doubles and in Java you add them using

for (double d : arraylist) { sum += d }

then in Scala you'll be approximately as fast--if using an equivalent data structure like ArrayBuffer--with

arraybuffer.foreach( sum += _ )

and not too far off the mark with either of

sum = (0 /: arraybuffer)(_ + _)
sum = arraybuffer.sum  // 2.8 only

Keep in mind, though, that there's a penalty to mixing high-level and low-level constructs. For example, if you decide to start with an array but then use "foreach" on it instead of indexing into it, Scala has to wrap it in a collection (ArrayOps in 2.8) to get it to work, and often will have to box the primitives as well.

Anyway, for benchmark testing, these two functions are your friends:

def time[F](f: => F) = {
  val t0 = System.nanoTime
  val ans = f
  printf("Elapsed: %.3f\n",1e-9*(System.nanoTime-t0))
  ans
}

def lots[F](n: Int, f: => F): F = if (n <= 1) f else { f; lots(n-1,f) }

For example:

val a = Array.tabulate(1000000)(_.toDouble)
val ab = new collection.mutable.ArrayBuffer[Double] ++ a
def adSum(ad: Array[Double]) = {
  var sum = 0.0
  var i = 0
  while (i<ad.length) { sum += ad(i); i += 1 }
  sum
}

// Mixed array + high-level; convenient, not so fast
scala> lots(3, time( lots(100,(0.0 /: a)(_ + _)) ) )
Elapsed: 2.434
Elapsed: 2.085
Elapsed: 2.081
res4: Double = 4.999995E11

// High-level container and operations, somewhat better
scala> lots(3, time( lots(100,(0.0 /: ab)(_ + _)) ) )    
Elapsed: 1.694
Elapsed: 1.679
Elapsed: 1.635
res5: Double = 4.999995E11

// High-level collection with simpler operation
scala> lots(3, time( lots(100,{var s=0.0;ab.foreach(s += _);s}) ) )
Elapsed: 1.171
Elapsed: 1.166
Elapsed: 1.162
res7: Double = 4.999995E11

// All low level operations with primitives, no boxing, fast!
scala> lots(3, time( lots(100,adSum(a)) ) )              
Elapsed: 0.185
Elapsed: 0.183
Elapsed: 0.186
res6: Double = 4.999995E11