Learn through the super-clean Baeldung Pro experience:
>> Membership and Baeldung Pro.
No ads, dark-mode and 6 months free of IntelliJ Idea Ultimate to start with.
Last updated: March 18, 2024
In this tutorial, we’re going to provide a solution to the problem of flattening arbitrarily nested collections.
Sometimes the information we have to deal with is not directly accessible in a collection but in nested collections. This may be due to the diversity in the sources of information or to the complexity of the algorithms involved. Regardless of the reason, it may be desirable to have a flat collection containing the aggregation of all the nested ones.
In modern versions of Scala (starting in 2.9), the collections API includes a method called flatten(), provided by the class Seq, that almost does the job for us, but not quite.
For simple cases, we can do things like this:
scala> val list = List(List(1), List(2, 3))
list: List[List[Int]] = List(List(1), List(2, 3))
scala> list.flatten
res2: List[Int] = List(1, 2, 3)
The problem with this approach is that it’s limited to collections that are uniform in their structure, i.e., all the elements must be collections for the method to work. Let’s see what happens when we try to apply the method flatten() to a collection that is not uniform:
scala> val list = List(1, List(2))
list: List[Any] = List(1, List(2))
scala> list.flatten
<console>:13: error: No implicit view available from Any => scala.collection.GenTraversableOnce[B].
list.flatten
^
We want to extend the capabilities of the Seq trait, and therefore of its derived classes, to consider non-uniform collections.
The error message we saw before gives us a hint of the nature of the problem and of the solution as well. Let’s read the documentation of the flatten() method. We’ll see that it accepts as an argument an implicit conversion that asserts that the element type of this sequence is an Iterable.
But we don’t want to pollute the code of the developers, so let’s explicitly pass a partial function to the method that handles differently the cases where an element is a collection and where it’s not. For collections, it calls itself recursively. For elements, it creates single-element collections. That’s how it achieves structural uniformity.
Finally, let’s create an implicit method that wraps the Seq with an enriched class, increasing its functionality. This technique is called “Pimp my Library“ because it adds functionality to library classes for which we possibly don’t even have the source code:
object Flattener {
/** This wrapper "pimps" the Seq type, adding the `fullFlat` method.
*
* @param seq
* sequence whose functionality will be expanded
* @return
* a wrapper object that implements the `fullFlat` method
*/
implicit def sequenceFlattener(seq: Seq[Any]): FullFlat =
new FullFlat(seq)
class FullFlat(seq: Seq[Any]) {
def fullFlatten: Seq[Any] = seq flatten {
case seq: Seq[Any] => seq.fullFlatten
case nonSeq => Seq(nonSeq)
}
}
}
All we have to do now is import the sequenceFlattener() method and call fullFlatten() directly on our collections!
There are a few cases that we want to verify using ScalaTest, but before that, some boilerplate:
class FlattenerSpec extends AnyWordSpec {
import com.baeldung.scala.flattening.Flattener.sequenceFlattener
"A full flattener" should {
// our actual tests go here
What are the cases that we want to test? Let’s see:
"respect the contents of an already flat sequence" in {
val flatList = List(3, 7, 2, 7, 1, 3, 4)
assertResult(flatList)(flatList.fullFlatten)
}
"flatten a nested empty list to an empty list" in {
val list = List(List(List()))
assertResult(List.empty)(list.fullFlatten)
}
"flatten several lists of the same type, one level deep" in {
val list = List(
List(1, 2, 3),
List(4, 5),
List(6)
)
assertResult(List(1, 2, 3, 4, 5, 6))(list.fullFlatten)
}
"flatten several lists of the same type, diverse levels deep" in {
val list = List(
List(1, List(2, 3)),
List(4, List(List(5))),
List(6)
)
assertResult(List(1, 2, 3, 4, 5, 6))(list.fullFlatten)
}
"flatten diverse types of collections" in {
val list = List(
Vector(1, Queue(2, 3))
)
assertResult(List(1, 2, 3))(list.fullFlatten)
}
"flatten several lists of the diverse types, diverse levels deep" in {
val list = List(
List(1, List("b", 'c')),
List(4.4, List(List(5)))
)
assertResult(List(1, "b", 'c', 4.4, 5))(list.fullFlatten)
}
This is an interesting case of library extension. It’s an interesting case for two reasons: