Can you split a stream into two streams?

user1148758 picture user1148758 · Nov 12, 2013 · Viewed 108k times · Source

I have a data set represented by a Java 8 stream:

Stream<T> stream = ...;

I can see how to filter it to get a random subset - for example

Random r = new Random();
PrimitiveIterator.OfInt coin = r.ints(0, 2).iterator();   
Stream<T> heads = stream.filter((x) -> (coin.nextInt() == 0));

I can also see how I could reduce this stream to get, for example, two lists representing two random halves of the data set, and then turn those back into streams. But, is there a direct way to generate two streams from the initial one? Something like

(heads, tails) = stream.[some kind of split based on filter]

Thanks for any insight.

Answer

Mark Jeronimus picture Mark Jeronimus · May 7, 2015

A collector can be used for this.

  • For two categories, use Collectors.partitioningBy() factory.

This will create a Map from Boolean to List, and put items in one or the other list based on a Predicate.

Note: Since the stream needs to be consumed whole, this can't work on infinite streams. And because the stream is consumed anyway, this method simply puts them in Lists instead of making a new stream-with-memory. You can always stream those lists if you require streams as output.

Also, no need for the iterator, not even in the heads-only example you provided.

  • Binary splitting looks like this:
Random r = new Random();

Map<Boolean, List<String>> groups = stream
    .collect(Collectors.partitioningBy(x -> r.nextBoolean()));

System.out.println(groups.get(false).size());
System.out.println(groups.get(true).size());
  • For more categories, use a Collectors.groupingBy() factory.
Map<Object, List<String>> groups = stream
    .collect(Collectors.groupingBy(x -> r.nextInt(3)));
System.out.println(groups.get(0).size());
System.out.println(groups.get(1).size());
System.out.println(groups.get(2).size());

In case the streams are not Stream, but one of the primitive streams like IntStream, then this .collect(Collectors) method is not available. You'll have to do it the manual way without a collector factory. It's implementation looks like this:

[Example 2.0 since 2020-04-16]

    IntStream    intStream = IntStream.iterate(0, i -> i + 1).limit(100000).parallel();
    IntPredicate predicate = ignored -> r.nextBoolean();

    Map<Boolean, List<Integer>> groups = intStream.collect(
            () -> Map.of(false, new ArrayList<>(100000),
                         true , new ArrayList<>(100000)),
            (map, value) -> map.get(predicate.test(value)).add(value),
            (map1, map2) -> {
                map1.get(false).addAll(map2.get(false));
                map1.get(true ).addAll(map2.get(true ));
            });

In this example I initialize the ArrayLists with the full size of the initial collection (if this is known at all). This prevents resize events even in the worst-case scenario, but can potentially gobble up 2*N*T space (N = initial number of elements, T = number of threads). To trade-off space for speed, you can leave it out or use your best educated guess, like the expected highest number of elements in one partition (typically just over N/2 for a balanced split).

I hope I don't offend anyone by using a Java 9 method. For the Java 8 version, look at the edit history.