SBT: Group annotated tests to run in forked JVMs

by
Tags: , , ,
Category:

SbtTestGrouping

Running tests that use a HiveContext

On our current project, we utilize Spark SQL and have several ScalaTest based suites which require a SparkContext and HiveContext. These are started before a suite runs and shut down after it completes via the BeforeAfterAll mixin trait. Unfortunately due to this bug (also see this related pull request), if a HiveContext is recreated in the same JVM, this exception is thrown: "java.sql.SQLException: Another instance of Derby may have already booted the database”. The same issue is also caused by Play hot reloading. During each reload, a HiveContext is recreated and the exception is thrown, forcing us to kill SBT and restart the Play app. The workaround for this is to ditch Derby and use a separate database (such as Postgres, MySQL, Oracle, etc.) for the Hive metastore. That solution works without any hitches. However we currently have no control over the client's CI environment, making it harder to provision a database there.

The obvious workaround: Forked JVMs… Should be easy right?

The obvious workaround is to run the tests in forked JVMs. Recall that in SBT, setting fork in test := true (since 0.12.0) runs all of the tests in a single, external JVM. Ummm that won’t help. What I discovered is SBT also allows groups of tests to be assigned to JVMs (along with what options to pass to a forked JVM), using the testGrouping key. If only I could group together all the suites that can run in a single JVM, and then fork a JVM for each of the suites that use Spark… The example in the SBT docs groups tests by a naming convention, which is not ideal and easy to screw up.

Let’s start with a simple example that runs every single test in its own JVM (this can go in your build.sbt file):

// Returns a sequence of test groups, one per test
def forkedJvmPerTest(testDefs: Seq[TestDefinition]) = {
  val testGroups = testDefs map { test =>
    new Group(
      name = test.name,
      tests = Seq(test),
      runPolicy = SubProcess(javaOptions = Seq.empty[String]))
  }
  testGroups
}
//definedTests in Test returns all of the tests (that are by default under src/test/scala).
testGrouping in Test <<= (definedTests in Test) map forkedJvmPerTest

This example maps over all of the tests under src/test/scala (returned by definedTests in Test), creating a Group for each one and setting the tests property to a single element Seq with that test. No extra JVM options are defined. The groups are returned and added to the testGrouping setting.

So now I have a handle on how that works, but I still need a way to mark tests so they can be grouped accordingly. This is where I took a detour: ScalaTest provides the ability to “tag” individual tests or entire suites and then use tags to filter which tests to run (or not run). If you want to tag an entire suite, you must first create a Java annotation like this:

@org.scalatest.TagAnnotation
@Retention(RUNTIME)
@Target({METHOD, TYPE})
public @interface RequiresSpark {
}

You can then annotate an entire suite:

@RequiresSpark
class SomeSparkSQLSpec extends WordSpec with Matchers {
// tests
}

Tags + Test Grouping: No dice

I thought I could be slick and create a custom SBT configuration that would only run tests tagged with @RequiresSpark and use the above testGrouping example that runs each test in its own JVM.

  // Create a custom test config that passes the “-n” argument to ScalaTest to only run tests tagged with @RequiredSpark
  lazy val SparkTest = config(“sparkTest”) extend (Test)
  lazy val root = Project(id = “myproject", base = file("."), settings = Seq(
    libraryDependencies ++= Seq(...)
  )).configs(SparkTest)
    .settings(inConfig(SparkTest)(Defaults.testSettings))
    .settings(testOptions in SparkTest := Seq(Tests.Argument(TestFrameworks.ScalaTest, “-n”, “com.skapadia.scalatest.tags.RequiresSpark”)))
    .settings(testOptions in Test := Seq(Tests.Argument(TestFrameworks.ScalaTest, “-l”, “com.skapadia.scalatest.tags.RequiresSpark”)))

This defines a new configuration called sparkTest that extends Test and passes the -n argument, telling ScalaTest to run only tests tagged with @RequiresSpark. It then tells the normal test configuration to exclude such tests, via the -l argument. The plan was to run both test and sparkTest:test.

Much to my surprise, ScalaTest's tag exclusion behavior works differently than I expected. Although it doesn't run the tests within a suite tagged with @RequiresSpark, it still runs beforeAll / afterAll (and beforeEach/ afterEach), causing the SparkContext and HiveContext to start up. That defeats the whole purpose. Epic fail!!!1

Take a swig from SBT's Analysis API

At the end of the day, what I wanted was a way to determine which suites were annotated with @RequiresSpark so I could pass that to a modified forkedJvmPerTest function. As a last ditch effort, I decided to investigate whether SBT provided a way to inspect how a class is annotated.

After some exploring, I find out that it does, via the Analysis API. This can be retrieved via the value of the compile TaskKey (in this case we're interested in the test classes, so compile in Test). This value is an instance of sbt.inc.Analyis. At first glance I see an apis method that returns an sbt.inc.API. That in turn defines a method internal: Map[File, Source] property. Hmmm maybe I'm getting warmer…

An xsbti.api.Source instance represents a single source file and provides access to the definitions within it! Getting warmer… xsbti.api.Definition provides access to an array of xsbti.api.Annotation. Hot! There's a little more to the Annotation class but I'll let the code speak for itself.

I ended up using this to define a function and custom TaskKey to return tests annotated with @RequiresSpark:

def isAnnotatedWithRequiresSpark(definition: xsbti.api.Definition): Boolean = {
  definition.annotations().exists { annotation: xsbti.api.Annotation =>
    annotation.base match {
      case proj: xsbti.api.Projection if (proj.id() == "RequiresSpark") => true
      case _ => false
    }
  }
}
// Note the type TaskKey[Seq[String]] must be explicitly specified otherwise an error occurs
lazy val testsAnnotatedWithRequiresSpark: TaskKey[Seq[String]] = taskKey[Seq[String]]("Returns list of FQCNs of tests annotated with RequiresSpark")
testsAnnotatedWithRequiresSpark := {
  val analysis = (compile in Test).value
  analysis.apis.internal.values.flatMap({ source =>
    source.api().definitions().filter(isAnnotatedWithRequiresSpark).map(_.name())
  }).toSeq
}

We iterate over the values of the Map[File, Source]: for each Source, we filter the definitions that are annotated with @RequiresSpark and get the name of the test. The isAnnotatedWithRequiresSpark method takes a Definition and returns true if any one of the annotations is @RequiresSpark (admittedly this does not check the fully qualified name of the annotation). This took a little trial and error of course.

Finally we can modify forkedJvmPerTest to also take in a sequence of test names as returned by testsAnnotatedWithRequiresSpark and partition the tests appropriately:

def forkedJvmPerTest(testDefs: Seq[TestDefinition], testsToFork: Seq[String]) = {
  val (forkedTests, otherTests) = testDefs.partition { testDef => testsToFork.contains(testDef.name) }
  val otherTestsGroup = new Group(name = "Single JVM tests", tests = otherTests, runPolicy = SubProcess(javaOptions = Seq.empty[String]))
  val forkedTestGroups = forkedTests map { test =>
    new Group(
      name = test.name,
      tests = Seq(test),
      runPolicy = SubProcess(javaOptions = Seq.empty[String]))
  }
  Seq(otherTestsGroup) ++ forkedTestGroups
}

The final piece is applying this to the testGrouping setting:

testGrouping in Test <<= (definedTests in Test, testsAnnotatedWithRequiresSpark) map forkedJvmPerTest

This works as expected and the CI build no longer fails with "java.sql.SQLException: Another instance of Derby may have already booted the database”.

Post-mortem

It turned out I had to dig into the SBT internals a little bit to achieve what I initially thought would be really simple 🙂 It really shouldn't have to be this hard, should it? I would definitely classify this solution as “hacky” and I hope there's a better way that’s escaping me.2 If you have a better solution, please let me know! I also feel ScalaTest's behavior with excluding tagged tests is non-intuitive and a deal-breaker for me in certain cases.


  1. While SBT test filtering properly excludes tests (that is they don't run at all), a filter function only accepts the test name. I wanted to avoid using naming conventions to affect build behavior. 

  2. We could put all suites that require Spark to run in one suite of suites. The top level suite could start Spark at the beginning and stop it at the end. This approach didn’t appeal to me as much, but admittedly its far easier.