GA-CCRi Analytical Development Services

Typeclasses for Flexible API Development

Here at GA-CCRi, we do a lot of machine learning; going from data to knowledge is kind of our thing. We’ve got a library of machine learning tools in-house, and we even use other algorithms from third-party developers sometimes.

To see how different methods compare, or to see how one method behaves under a range of tweakable parameters, we build a testbed. But different libraries expose different interfaces, and even the ones we control might grow apart over time. And what if one of our major projects decides it needs to use a different algorithm? Digging into the code to find all the uses of one library and translating them into the interface of the new library could take a long time.

What we need is a simple, universal API that’s general enough to handle whatever machine learning tasks we might need from it. Something like this:

[code language=”scala”] trait MachineLearningModel[OutputType] extends (ModelInput => OutputType) {
def apply(input: ModelInput): OutputType =
Await.result(classify(input), Duration.Inf)
def classify(input: ModelInput): Future[OutputType] def classifyProbabilities(input: ModelInput): Future[Map[OutputType, Double]] def train(input: ModelInput, output: OutputType)
def prepareForClassify() {}
def initialized: Boolean
def destroy()
}
[/code]

where ModelInput is a uniform input format we use. To use a given modeling library, we provide a simple shim trait to translate the native interface into our own:

[code language=”scala”] trait RFModel extends MachineLearningModel[ModelOutput] {
protected def model: RandomForest

def classify(input: ModelInput) =
// implement classify in terms of the RandomForest API
def train(input: ModelInput, mo: ModelOutput) {
// implement train in terms of the RandomForest API
}

}
[/code]

“But wait!” you complain. “That may be universal, but it’s not very flexible! What if I need to serialize and deserialize my models? Won’t that still require messy access to the underlying API?”

This is an excellent point, and there is a very good reason for our choice: these are core functions required of almost any modeling algorithm. We want to keep specialized functionalities as extensions to the basic API.

You might expect at this point that we’re going to introduce descendant traits, like SerializableModel extends MachineLearningModel, and we could well do it like that. But there are some big problems with this approach. Most obvious is the combinatorial explosion as we add more specialized APIs; a new trait for every combination of special functionalities gets out of hand fast. Scala traits can mix in implementations, unlike Java interfaces, but the result is still unwieldy. Less obvious is that unless we know in advance every possible future machine learning method, this approach cannot possibly be flexible enough to cover everything we might want to do with it.

Luckily, Scala provides another option: typeclasses. Instead of using inheritance to implement an interface, typeclasses allow us to implement interfaces by providing “evidence objects”, and Scala’s implicit mechanism allows us to sweep most of the details under a syntactic sugar rug in the client code.

Here’s our typeclass for serializing models:

[code language=”scala”] trait SerializableModel[M <: MachineLearningModel[_]] {
def serialize(m: M): Future[Array[Byte]] def deserialize(bytes: Array[Byte]): M
}

object SerializableModel {
implicit def serializeOps[M <: MachineLearningModel[_] : SerializableModel](m: M) = new {
val witness = implicitly[SerializableModel[M]] def serialize: Future[Array[Byte]] = witness.serialize(m)
}
}
[/code]

and here’s an implementation corresponding to our RandomForest model:

[code language=”scala”] object RFModel {
implicit def isSerializable(implicit as: ActorSystem): SerializableModel[RFModel] =
new SerializableModel[RFModel] {
def serialize(m: RFModel): Future[Array[Byte]] =
m.model.serialize
def deserialize(bytes: Array[Byte]): RFModel =
new RFModel { val model = RandomForest.fromSerializedForm(bytes) }
}
}
[/code]

The SerializableModel typeclass tells us that evidence of serializability consists of an implementation of serialize and deserialize methods with the given signatures. How do we know that RFModel belongs to the SerializableModel typeclass? Because the evidence object RFModel.isSerializable is an instance of SerializableModel[RFModel] that provides the implementations of those methods to use to serialize and deserialize RFModel instances.

The method SerializableModel.serializeOps method provides an example of this typeclasses in use, as well as making our syntax more flexible. The type annotation [M <: MachineLearningModel[_] : SerializableModel] tells us two things about the type M of the object we pass into this method. First, it must be a MachineLearningModel[A] for some output type A. Second, the compiler must be able to find an implicit instance of SerializableModel[M]. If you don’t provide one yourself, the compiler will look in the companion object to M; if you’re using an RFModel, it will automatically find RFModel.isSerializable for you! Saying a type belongs to a typeclass is often just that simple from the client’s end.

What does SerializableModel.serializeOps give us? It turns our M into a new object that has a method called serializable. And since serializeOps is an implicit method, the compiler will automatically use this serializeOps to convert and use this one if our object doesn’t have a serialize method already. That is, we didn’t write a method called serialize in the RFModel trait, but if we’ve got an rfModel: RFModel we can say rfModel.serialize and everything works exactly as we’d expect. Under the compiler’s covers, it recognizes that rfModel can use SerializableModel.serializeOps to implicitly convert into an object that has a serialize method and then calls that method.

The example of SerializeModel also shows another way in which typeclasses are more flexible than inheritance. We might be able to declare an abstract serialize member method in a trait or interface, saying that it produces the Array[Byte] representation of the model object, but how can you write deserialize in the same way? Deserialization doesn’t take a model input — the implicit this argument of a member method — it produces a model. It needs to be static — a method on the companion object, in Scala’s idiom — and static methods do not play well with generics. But typeclasses don’t care whether the interface consists of member methods or not; they handle all kinds of signatures with ease.

And there’s another benefit: we might predict when writing our machine learning library that serialization and feature importance methods are useful typeclasses, but we cannot predict everything in advance. It’s difficult to add a new specialized interface to our API after the fact, but with typeclasses it couldn’t be simpler.

Let’s say in the testbed we’re studying models that consist of ensembles of classifiers, and we want to probe various statistics about the classifiers. We don’t need to rewrite our machine learning library to add this functionality, we can just define

[code language=”scala”] trait EnsembleStatistics[M <: MachineLearningModel[_]] {
def accuracies(m: M): Array[Double] def depths(m: M): Array[Double] …
}

object EnsembleStatistics {
implicit def hasEnsembleStatistics[M <: MachineLearningModel[_] : EnsembleStatistics](m: M) =
new { … }

implicit val RFHasEnsembleStatistics: EnsembleStatistics[RFModel] =
new EnsembleStatistics[RFModel] {
def accuracies(m: RFModel) = m.model.reportAccuracies
def depths(m: RFModel) = m.model.reportDepths

}
}
[/code]

The EnsembleStatistics trait defines what methods must be implemented to belong to this typeclass. The companion object provides the implicit promotion hasEnsembleStatistics to add member method syntax to our models. It also provides an evidence object for RFModel, so we don’t have to provide one by hand. Not only have we added new functionality to our API, we have added an implementation of the new interface to an existing class without modifying any of the original source code! And if we find that EnsembleStatistics is useful enough across other projects, we can cut and paste it into our existing machine learning library later with no further code changes necessary.

We find that, when working with such a diverse and constantly changing category as machine learning algorithms, Scala’s typeclass pattern provides us with all the flexibility we need to write an API that will serve our needs well into the future.

Go Back