1. Introduction

In this tutorial, we’re going to provide a solution to the problem of flattening arbitrarily nested collections.

Sometimes the information we have to deal with is not directly accessible in a collection but in nested collections. This may be due to the diversity in the sources of information or to the complexity of the algorithms involved. Regardless of the reason, it may be desirable to have a flat collection containing the aggregation of all the nested ones.

2. Description of the Solution

2.1. The API Collection

In modern versions of Scala (starting in 2.9), the collections API includes a method called flatten(), provided by the class Seq, that almost does the job for us, but not quite.

For simple cases, we can do things like this:

scala> val list = List(List(1), List(2, 3))
list: List[List[Int]] = List(List(1), List(2, 3))

scala> list.flatten
res2: List[Int] = List(1, 2, 3)

The problem with this approach is that it’s limited to collections that are uniform in their structure, i.e., all the elements must be collections for the method to work. Let’s see what happens when we try to apply the method flatten() to a collection that is not uniform:

scala> val list = List(1, List(2))
list: List[Any] = List(1, List(2))

scala> list.flatten
<console>:13: error: No implicit view available from Any => scala.collection.GenTraversableOnce[B].
       list.flatten
            ^

2.2. The Generic Approach

We want to extend the capabilities of the Seq trait, and therefore of its derived classes, to consider non-uniform collections.

The error message we saw before gives us a hint of the nature of the problem and of the solution as well. Let’s read the documentation of the flatten() method. We’ll see that it accepts as an argument an implicit conversion that asserts that the element type of this sequence is an Iterable.

But we don’t want to pollute the code of the developers, so let’s explicitly pass a partial function to the method that handles differently the cases where an element is a collection and where it’s not. For collections, it calls itself recursively. For elements, it creates single-element collections. That’s how it achieves structural uniformity.

Finally, let’s create an implicit method that wraps the Seq with an enriched class, increasing its functionality. This technique is called “Pimp my Library because it adds functionality to library classes for which we possibly don’t even have the source code:

object Flattener {

  /** This wrapper "pimps" the Seq type, adding the `fullFlat` method.
    *
    * @param seq
    *   sequence whose functionality will be expanded
    * @return
    *   a wrapper object that implements the `fullFlat` method
    */
  implicit def sequenceFlattener(seq: Seq[Any]): FullFlat =
    new FullFlat(seq)

  class FullFlat(seq: Seq[Any]) {
    def fullFlatten: Seq[Any] = seq flatten {
      case seq: Seq[Any] => seq.fullFlatten
      case nonSeq        => Seq(nonSeq)
    }
  }
}

All we have to do now is import the sequenceFlattener() method and call fullFlatten() directly on our collections!

3. Testing the Solution

There are a few cases that we want to verify using ScalaTest, but before that, some boilerplate:

class FlattenerSpec extends AnyWordSpec { 

import com.baeldung.scala.flattening.Flattener.sequenceFlattener

  "A full flattener" should {
     // our actual tests go here

What are the cases that we want to test? Let’s see:

  • What happens if the collection we flatten is already flat?
    "respect the contents of an already flat sequence" in {
      val flatList = List(3, 7, 2, 7, 1, 3, 4)
      assertResult(flatList)(flatList.fullFlatten)
    }
  • What happens if it’s empty?
    "flatten a nested empty list to an empty list" in {
      val list = List(List(List()))
      assertResult(List.empty)(list.fullFlatten)
    }
  • Does it behave as the normal flatten() in the happy case?
    "flatten several lists of the same type, one level deep" in {
      val list = List(
        List(1, 2, 3),
        List(4, 5),
        List(6)
      )
      assertResult(List(1, 2, 3, 4, 5, 6))(list.fullFlatten)
    }
  • Does it work correctly on collections with non-uniform structure?
    "flatten several lists of the same type, diverse levels deep" in {
      val list = List(
        List(1, List(2, 3)),
        List(4, List(List(5))),
        List(6)
      )
      assertResult(List(1, 2, 3, 4, 5, 6))(list.fullFlatten)
    }
  • Does it accept nested collections of different types (e.g., Vector inside List inside Queue)?
    "flatten diverse types of collections" in {
      val list = List(
        Vector(1, Queue(2, 3))
      )
      assertResult(List(1, 2, 3))(list.fullFlatten)
    }
  • Does it flatten collections of different element types (e.g., List[Int], List[String], List[List[Float]])?
    "flatten several lists of the diverse types, diverse levels deep" in {
      val list = List(
        List(1, List("b", 'c')),
        List(4.4, List(List(5)))
      )
      assertResult(List(1, "b", 'c', 4.4, 5))(list.fullFlatten)
    }

4. Conclusion

This is an interesting case of library extension. It’s an interesting case for two reasons:

  1. The flattening process becomes a recursive call, wrapping the elements that are not collections in a single-element collection.
  2. The solution imposes very few requirements on the developers using it. That’s always a desirable feature of SDKs.

As usual, the full code for this article is available over on GitHub.