Grouping and Aggregations With Java Streams
Learn the straightforward path to solving problems using Java Streams, a framework that allows us to process large amounts of data quickly and efficiently.
Join the DZone community and get the full member experience.
Join For FreeWhen we group elements from a list, we can subsequently aggregate the fields of the grouped elements to perform meaningful operations that help us analyze the data. Some examples are addition, averages, or max/min values. These aggregations of single fields can be easily done with Java Streams and Collectors. The documentation provides simple examples of how to do these types of calculations.
However, there are more sophisticated aggregations like weighted averages, geometric means. Additionally, there might be the need to do simultaneous aggregations of several fields. In this article, we are going to show a straightforward path to solve these kinds of problems using Java Streams. Using this framework allows us to process large amounts of data quickly and efficiently.
We'll assume that the reader has a basic understanding of Java Streams and the utility Collectors class.
The Ultimate Java Expert Certification Bundle.*
*Affiliate link. See Terms of Use.
Problem Layout
Let's consider a simple example to showcase the type of issues that we want to solve. We'll make it very generic so we can easily generalize it. Let's consider a list of TaxEntry
entities that it's defined by the following code:
public class TaxEntry {
private String state;
private String city;
private int numEntries;
private double price;
//Constructors, getters, hashCode, equals etc
}
It is very simple to compute the total number of entries for a given city:
Map<String, Integer> totalNumEntriesByCity =
taxes.stream().collect(Collectors.groupingBy(TaxEntry::getCity,
Collectors.summingInt(TaxEntry::getNumEntries)));
Collectors.groupingBy
takes two parameters: a classifier function to do the grouping and a Collector that does the downstream aggregation for all the elements that belong to a given group. We use TaxEntry::getCity
as the classifier function. For the downstream, we use Collectors::summingInt
which returns a Collector
that sums the number of tax entries that we get for each grouped element.
Things are a little more complicated if we try to find compound groupings. For example, with the previous example, the total number of entries for a given state and city. There are several ways to do this, but a very straightforward approach is first to define:
record StateCityGroup(String state, String city) {}
Notice that we’re using a Java record
, which is a concise way to define an immutable class. Furthermore, the Java compiler generates for us field accessor methods, hashCode
, equals, and toString
implementations. With this in hand, the solution now is simple:
Map<StateCityGroup, Integer> totalNumEntriesForStateCity =
taxes.stream().collect(groupingBy(p -> new StateCityGroup(p.getState(), p.getCity()),
Collectors.summingInt(TaxEntrySimple::getNumEntries))
);
For Collectors::groupingBy
we set the classifier function using a lambda expression that creates a new StateCityGroup
record that encapsulates each state-city. The downstream Collector is the same as before.
Note: for the sake of conciseness, in the code samples, we are going to assume static imports for all the methods of the Collectors class, so we don't have to show their class qualification.
Where things start to get more complicated is if we want to do several aggregations simultaneously. For example, find the sum of the number of entries and the average price for a given state and city. The library does not provide a simple solution to this problem.
To begin untangling this issue, we take a cue from the previous aggregation and define a record that encapsulates all the fields that need to be aggregated:
record TaxEntryAggregation (int totalNumEntries, double averagePrice ) {}
Now, how do we do the aggregation simultaneously for the two fields? There is always the possibility of doing the stream collection twice to find each of the aggregations separately, as it's suggested in the following code:
Map<StateCityGroup, TaxEntryAggregation> aggregationByStateCity = taxes.stream().collect(
groupingBy(p -> new StateCityGroup(p.getState(), p.getCity()),
collectingAndThen(Collectors.toList(),
list -> {int entries = list.stream().collect(
summingInt(TaxEntrySimple::getNumEntries));
double priceAverage = list.stream().collect(
averagingDouble(TaxEntrySimple::getPrice));
return new TaxEntryAggregation(entries, priceAverage);})));
The grouping is done as before, but for the downstream, we do the aggregation using Collectors::collectingAndThen
(line 3). This function takes two parameters:
- The download stream from the initial grouping that we convert into a list (using
Collectors::toList()
in line 3) - Finisher function (lines 4–9) where we use a lambda expression to create two different streams from the previous list to do the aggregations and return them combined in a new
TaxEntryAggregation
record
Imagine that we wanted to do more field aggregations simultaneously. We will need to increase accordingly the number of streams from the downstream list. The code becomes, inefficient, very repetitive, and less than desirable. We should look for better alternatives.
Also, the problems don’t end here, and in general, we're constrained on the types of aggregations that we can do with the Collectors helper class. Their methods, summing*, averaging*, and summarizing*, provide support only for integer, long, and double native types. What do we do if we have more sophisticated types like BigInteger
or BigDecimal
?
To add insult to injury the summarizing* methods only provide summary statistics for, min, max, count, sum, and average. What if we want to perform more sophisticated calculations such as weighted averages or geometric means?
Some people will argue that we can always write custom Collectors, but this requires knowing the Collector interface and a good understanding of the stream collector flow. It’s more straightforward to use built-in collectors provided with the utility methods in the Collectors class. In the next section, we’ll show a couple of strategies on how to accomplish this.
Complex Multiple Aggregations: A Resolution Path
Let’s consider a simple example that will highlight the challenges that we have mentioned in the previous section. Suppose that we have the following entity:
public class TaxEntry {
private String state;
private String city;
private BigDecimal rate;
private BigDecimal price;
record StateCityGroup(String state, String city) {
}
//Constructors, getters, hashCode/equals etc
}
We start by asking how for each distinct state-city pair, we can find the total count of entries and the total sum of the product of rate
and price
(∑(rate * price)). Notice that we are doing a multifield aggregation using BigDecimal
.
As we did in the previous section, we define a class that encapsulates the aggregation:
record RatePriceAggregation(int count, BigDecimal ratePrice) {}
It might seem surprising at first, but a straightforward solution to groupings that are followed by simple aggregations is to use Collectors::toMap
. Let’s see how we would do it:
Map<StateCityGroup, RatePriceAggregation> mapAggregation = taxes.stream().collect(
toMap(p -> new StateCityGroup(p.getState(), p.getCity()),
p -> new RatePriceAggregation(1, p.getRate().multiply(p.getPrice())),
(u1,u2) -> new RatePriceAggregation( u1.count() + u2.count(), u1.ratePrice().add(u2.ratePrice()))
));
The Collectors::toMap
(line 2) takes three parameters, we do the following implementation:
- The first parameter is a lambda expression to generate the keys of the map. This function creates
StateCityGroup
as keys to the map. This will group the elements by state and city (line 2). - The second parameter produces the values of the map. In our case, we create a
RatePriceAggregation
initialized with a count of 1 and the product of rate and price (line 3). - Finally, the last parameter is a
BinaryOperator
to merge cases where multiple elements map to the same state-city key. We sum the counts and prices to do our aggregation (line 4).
Let’s demonstrate how this will work setting up some sample data:
List<TaxEntry> taxes = Arrays.asList(
new TaxEntry("New York", "NYC", BigDecimal.valueOf(0.2), BigDecimal.valueOf(20.0)),
new TaxEntry("New York", "NYC", BigDecimal.valueOf(0.4), BigDecimal.valueOf(10.0)),
new TaxEntry("New York", "NYC", BigDecimal.valueOf(0.6), BigDecimal.valueOf(10.0)),
new TaxEntry("Florida", "Orlando", BigDecimal.valueOf(0.3), BigDecimal.valueOf(13.0)));
To get the results for New York from the previous code sample is trivial:
System.out.println("New York: " + mapAggregation.get(new StateCityGroup("New York", "NYC")));
This prints:
New York: RatePriceAggregation[count=3, ratePrice=14.00]
This is a straightforward implementation that determines the grouping and aggregation of multiple fields and non-primitive data types (BigDecimal
in our case). However, it has the drawback that it does not have any finalizers that allow you to perform extra operations. For example, you can’t do averages of any kind.
To showcase this issue, let’s consider a more complex problem. Suppose that we want to find the weighted average of the rate-price, and the sum of all the prices for each state and city pair. In particular, to find the weighted average, we need to calculate the sum of the product of the rate and price for all the entries that belong to each state-city pair, and then divide by the total number of entries n for each case: 1/n ∑(rate * price).
To tackle this problem we start defining a record that comprises the aggregation:
record TaxEntryAggregation(int count, BigDecimal weightedAveragePrice, BigDecimal totalPrice) {}
With this in hand, we can do the following implementation:
Map<StateCityGroup, TaxEntryAggregation> groupByAggregation = taxes.stream().collect(
groupingBy(p -> new StateCityGroup(p.getState(), p.getCity()),
mapping(p -> new TaxEntryAggregation(1, p.getRate().multiply(p.getPrice()), p.getPrice()),
collectingAndThen(reducing(new TaxEntryAggregation(0, BigDecimal.ZERO, BigDecimal.ZERO),
(u1,u2) -> new TaxEntryAggregation(u1.count() + u2.count(),
u1.weightedAveragePrice().add(u2.weightedAveragePrice()),
u1.totalPrice().add(u2.totalPrice()))
),
u -> new TaxEntryAggregation(u.count(),
u.weightedAveragePrice().divide(BigDecimal.valueOf(u.count()),
2, RoundingMode.HALF_DOWN),
u.totalPrice())
)
)
));
We can see that the code is somewhat more complicated, but allows us to get the solution we are looking for. We'll follow it more in detail:
Collectors::groupingBy
(line 2):- For the classification function, we create a
StateCityGroup
record - For the downstream, we invoke
Collectors::mapping
(line 3):- For the first parameter, the mapper that we apply to the input elements transforms the grouped state-city tax records to new
TaxEntryAggregation
entries that assign the initial count to 1, multiply the rate with price, and set the price (line 3). - For the downstream, we invoke
Collectors::collectingAndThen
(line 4), and as we’ll see, this will allow us to apply to the downstream collector a finishing transformation.- Invoke
Collectors::reducing
(line 4)- Create a default
TaxEntryAggregation
to cover the cases where there are no downstream elements (line 4). - Lambda expression to do the reduction and return a new
TaxEntryAggregation
that has the aggregations of the fields (line 5, 6 7)
- Create a default
- Perform the finishing transformation calculating the averages using the count calculated in the previous reduction and returning the final
TaxEntryAggregation
(lines 9, 10, 11).
- Invoke
- For the first parameter, the mapper that we apply to the input elements transforms the grouped state-city tax records to new
- For the classification function, we create a
We see that this implementation not only allows us to do multiple field aggregations simultaneously but can also perform complex calculations in several stages.
This can be easily generalized to solve more complex problems. The path is straightforward: define a record that encapsulates all the fields that need to be aggregated, use Collectors::mapping
to initialize the records, and then apply Collectors::collectingAndThen
to do the reduction and final aggregation.
As before we can get the aggregations for New York:
System.out.println("Finished aggregation: " + groupByAggregation.get(new StateCityGroup("New York", "NYC")));
We get the results:
Finished aggregation: TaxEntryAggregation[count=3, weightedAveragePrice=4.67, totalPrice=40.0]
It is also worth pointing out that because TaxEntryAggregation
is a Java record
, it’s immutable, so the calculation can be parallelized using the support provided by the stream collector’s library.
Conclusion
We have shown a couple of strategies to do complex multi-field groupings with aggregations that include non-primitive data types with multi and cross-field calculations. This is for a list of records using Java streams and the Collectors API, so it provides us the ability to process huge amounts of data quickly and efficiently.
Opinions expressed by DZone contributors are their own.
Comments