arg max in Java 8 streams?
Use
String longestName = names.stream().max(Comparator.comparing(String::length)).get();
to compare elements on some property (can be more complex than that, but doesn't have to).
As Brian suggests in the comments, using Optional#get()
like this is unsafe if there's a possibility that the Stream
is empty. You'd be better suited to use one of the safer retrieval methods, such as Optional#orElse(Object)
which will give you some default value if there is no max.
I think one should consider that while max
/min
are unique, this is of course not guaranteed for the argMax
/argMin
; this in particular implies that the type of the reduction should be a collection, such as for instance a List
. This requires a bit more work than suggested above.
The following ArgMaxCollector<T>
class provides a simple implementation of such a reduction. The main
shows an application of such class to compute the argMax
/argMin
of the set of strings
one two three four five six seven
ordered by their length. The output (reporting the result of the argMax
and argMin
collectors respectively) should be
[three, seven]
[one, two, six]
that are the two longest and the three shortest strings respectively.
This is my first attempt at using the new Java 8 stream APIs, so any comment will be more than welcome!
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collector;
class ArgMaxCollector<T> {
private T max = null;
private ArrayList<T> argMax = new ArrayList<T>();
private Comparator<? super T> comparator;
private ArgMaxCollector( Comparator<? super T> comparator ) {
this.comparator = comparator;
}
public void accept( T element ) {
int cmp = max == null ? -1 : comparator.compare( max, element );
if ( cmp < 0 ) {
max = element;
argMax.clear();
argMax.add( element );
} else if ( cmp == 0 )
argMax.add( element );
}
public void combine( ArgMaxCollector<T> other ) {
int cmp = comparator.compare( max, other.max );
if ( cmp < 0 ) {
max = other.max;
argMax = other.argMax;
} else if ( cmp == 0 ) {
argMax.addAll( other.argMax );
}
}
public List<T> get() {
return argMax;
}
public static <T> Collector<T, ArgMaxCollector<T>, List<T>> collector( Comparator<? super T> comparator ) {
return Collector.of(
() -> new ArgMaxCollector<T>( comparator ),
( a, b ) -> a.accept( b ),
( a, b ) ->{ a.combine(b); return a; },
a -> a.get()
);
}
}
public class ArgMax {
public static void main( String[] args ) {
List<String> names = Arrays.asList( new String[] { "one", "two", "three", "four", "five", "six", "seven" } );
Collector<String, ArgMaxCollector<String>, List<String>> argMax = ArgMaxCollector.collector( Comparator.comparing( String::length ) );
Collector<String, ArgMaxCollector<String>, List<String>> argMin = ArgMaxCollector.collector( Comparator.comparing( String::length ).reversed() );
System.out.println( names.stream().collect( argMax ) );
System.out.println( names.stream().collect( argMin ) );
}
}