TLDR: I translated some of the code and ideas from Scrap Your Boilerplate: A Practical Design Pattern for Generic Programming by Lämmel and Peyton Jones to Rust and it’s available as the scrapmetal crate.


Say we work on some software that models companies, their departments, sub-departments, employees, and salaries. We might have some type definitions similar to this:

pub struct Company(pub Vec<Department>);

pub struct Department(pub Name, pub Manager, pub Vec<SubUnit>);

pub enum SubUnit {
    Person(Employee),
    Department(Box<Department>),
}

pub struct Employee(pub Person, pub Salary);

pub struct Person(pub Name, pub Address);

pub struct Salary(pub f64);

pub type Manager = Employee;
pub type Name = &'static str;
pub type Address = &'static str;

One of our companies has had a morale problem lately, and we want to transform it into a new company where everyone is excited to come in every Monday through Friday morning. But we can’t really change the nature of the work, so we figure we can just give the whole company a 10% raise and call it close enough. This requires writing a bunch of functions with type signatures like fn(self, k: f64) -> Self for every type that makes up a Company, and since we recognize the pattern, we should be good Rustaceans and formalize it with a trait:

pub trait Increase: Sized {
    fn increase(self, k: f64) -> Self;
}

A company with increased employee salaries is made by increasing the salaries of each of its departments’ employees:

impl Increase for Company {
    fn increase(self, k: f64) -> Company {
        Company(
            self.0
                .into_iter()
                .map(|d| d.increase(k))
                .collect()
        )
    }
}

A department with increased employee salaries is made by increasing its manager’s salary and the salary of every employee in its sub-units:

impl Increase for Department {
    fn increase(self, k: f64) -> Department {
        Department(
            self.0,
            self.1.increase(k),
            self.2
                .into_iter()
                .map(|s| s.increase(k))
                .collect(),
        )
    }
}

A sub-unit is either a single employee or a sub-department, so either increase the employee’s salary, or increase the salaries of all the people in the sub-department respectively:

impl Increase for SubUnit {
    fn increase(self, k: f64) -> SubUnit {
        match self {
            SubUnit::Person(e) => {
                SubUnit::Person(e.increase(k))
            }
            SubUnit::Department(d) => {
                SubUnit::Department(Box::new(d.increase(k)))
            }
        }
    }
}

An employee with an increased salary, is that same employee with the salary increased:

impl Increase for Employee {
    fn increase(self, k: f64) -> Employee {
        Employee(self.0, self.1.increase(k))
    }
}

And finally, a lone salary can be increased:

impl Increase for Salary {
    fn increase(self, k: f64) -> Salary {
        Salary(self.0 * (1.0 + k))
    }
}

Pretty straightforward.

But at the same time, that’s a whole lot of boilerplate. The only interesting part that has anything to do with actually increasing salaries is the impl Increase for Salary. The rest of the code is just traversal of the data structures. If we were to write a function to rename all the employees in a company, most of this code would remain the same. Surely there’s a way to factor all this boilerplate out so we don’t have to manually write it all the time?

In the paper Scrap Your Boilerplate: A Practical Design Pattern for Generic Programming, Lämmel and Peyton Jones show us a way to do just that in Haskell. And it turns out the ideas mostly translate into Rust pretty well, too. This blog post explores that translation, following much the same outline from the original paper.

When we’re done, we’ll be able to write the exact same salary increasing functionality with just a couple lines:

// Definition
let increase = |s: Salary| Salary(s.0 * 1.1);
let mut increase = Everywhere::new(Transformation::new(increase));

// Usage
let new_company = increase.transform(old_company);

We have a few different moving parts involved here:

  • A function that transforms a specific type: FnMut(T) -> T. In the increase example this is the closure |s: Salary| Salary(s.0 * 1.1).

  • We have Transformation::new, which lifts the transformation function from transforming a single, specific type (FnMut(T) -> T) into transforming all types (for<U> FnMut(U) -> U). If we call this new transformation with a value of type T, then it will apply our T-specific transformation function. If we call it with a value of any other type, it simply returns the given value.

    Of course, Rust doesn’t actually support rank-2 types, but we can work around this by passing a trait with a generic method, anywhere we wanted to pass for<U> FnMut(U) -> U as a parameter. This trait gets implemented by Transformation:

// Essentially, for<T> FnMut(T) -> T
pub trait GenericTransform {
    fn transform<U>(&mut self, t: U) -> U;
}
  • Next is Everywhere::new, whose result is also a for<U> FnMut(U) -> U (aka implements the GenericTransform trait). This is a combinator that takes a generic transformation function, and traverses a tree of values, applying the generic transformation function to each value along the way.

  • Finally, behind the scenes there are two traits: Term and Cast. The former provides enumeration of a value’s immediate edges in the value tree. The latter enables us to ask some generic U if it is a specific T. These traits completely encapsulate the boilerplate we’ve been trying to rid ourselves of, and neither require any implementation on our part. Term can be generated mechanically with a custom derive, and Cast can be implemented (in nightly Rust) with specialization.

Next, we’ll walk through the implementation of each of these bits.

Implementing Cast

The Cast trait is defined like so:

trait Cast<T>: Sized {
    fn cast(self) -> Result<T, Self>;
}

Given some value, we can try and cast it to a T or if that fails, get the original value back. You can think of it like instanceof in JavaScript, but without walking some prototype or inheritance chain. In the original Haskell, cast returns the equivalent of Option<T>, but we need to get the original value back if we ever want to use it again because of Rust’s ownership system.

To implement Cast requires specialization, which is a nightly Rust feature. We start with a default blanket implementation of Cast that fails to perform the conversion:

impl<T, U> Cast<T> for U {
    default fn cast(self) -> Result<T, Self> {
        Err(self)
    }
}

Then we define a specialization for when Self is T that allows the cast to succeed:

impl<T> Cast<T> for T {
    fn cast(self) -> Result<T, Self> {
        Ok(self)
    }
}

That’s it!

Here is Cast in action:

assert_eq!(Cast::<bool>::cast(1), Err(1));
assert_eq!(Cast::<bool>::cast(true), Ok(true));

Implementing Transformation

Once we have Cast, implementing generic transformations is easy. If we can cast the value to our underlying non-generic transformation function’s input type, then we call it. If we can’t, then we return the given value:

pub struct Transformation<F, U>
where
    F: FnMut(U) -> U,
{
    f: F,
}

impl<F, U> GenericTransform for Transformation<F, U>
where
    F: FnMut(U) -> U,
{
    fn transform<T>(&mut self, t: T) -> T {
        // Try to cast the T into a U.
        match Cast::<U>::cast(t) {
            // Call the transformation function and then cast
            // the resulting U back into a T.
            Ok(u) => match Cast::<T>::cast((self.f)(u)) {
                Ok(t) => t,
                Err(_) => unreachable!("If T=U, then U=T."),
            },
            // Not a U, return unchanged.
            Err(t) => t,
        }
    }
}

For example, we can lift the logical negation function into a generic transformer. For booleans, it will return the complement of the value, for other values, it leaves them unchanged:

let mut not = Transformation::new(|b: bool| !b);
assert_eq!(not.transform(true), false);
assert_eq!(not.transform("str"), "str");

Implementing Term

The next piece of the puzzle is Term, which enumerates the direct children of a value. It is defined as follows:

pub trait Term: Sized {
    fn map_one_transform<F>(self, f: &mut F) -> Self
    where
        F: GenericTransform;
}

In the original Haskell, map_one_transform is called gmapT for “generic map transform”, and as mentioned earlier GenericTransform is a workaround for the lack of rank-2 types, and would otherwise be for<U> FnMut(U) -> U.

It is important that map_one_transform does not recursively call its children’s map_one_transform methods. We want a building block for making all different kinds of traversals, not one specific traversal hard coded.

If we were to implement Term for Employee, we would write this:

impl Term for Employee {
    fn map_one_transform<F>(self, f: &mut F) -> Self
    where
        F: GenericTransform,
    {
        Employee(f.transform(self.0), f.transform(self.1))
    }
}

And for SubUnit, it would look like this:

impl Term for SubUnit {
    fn map_one_transform<F>(self, f: &mut F) -> Self
    where
        F: GenericTransform,
    {
        match self {
            SubUnit::Person(e) => SubUnit::Person(f.transform(e)),
            SubUnit::Department(d) => SubUnit::Department(f.transform(d)),
        }
    }
}

On the other hand, a floating point number has no children to speak of, and so it would do less:

impl Term for f64 {
    fn map_one_transform<F>(self, _: &mut F) -> Self
    where
        F: GenericTransform,
    {
        self
    }
}

Note that each of these implementations are driven purely by the structure of the implementation’s type. enums transform whichever variant they are, structs and tuples transfrom each of their fields, etc. It’s 100% mechanical and 100% uninteresting.

It’s easy to write a custom derive for implementing Term. After that’s done, we just add #[derive(Term)] to our type definitions:

#[derive(Term)]
pub struct Employee(pub Person, pub Salary);
// Etc...

Implementing Everywhere

Everywhere takes a generic transformation and then uses Term::map_one_transform to recursively apply it to the whole tree. It does so in a bottom up, left to right order.

Its definition and constructor are trivial:

pub struct Everywhere<F>
where
    F: GenericTransform,
{
    f: F,
}

impl<F> Everywhere<F>
where
    F: GenericTransform,
{
    pub fn new(f: F) -> Everywhere<F> {
        Everywhere { f }
    }
}

Then, we implement GenericTransform for Everywhere. First we recursively map across the value’s children, then we transform the given value. This transforming of children first is what causes the traversal to be bottom up.

impl<F> GenericTransform for Everywhere<F>
where
    F: GenericTransform,
{
    fn transform<T>(&mut self, t: T) -> T
    where
        T: Term,
    {
        let t = t.map_one_transform(self);
        self.f.transform(t)
    }
}

If instead we wanted to perform a top down traversal, our choice to implement mapping non-recursively for Term enables us to do so:

impl<F> GenericTransform for EverywhereTopDown<F>
where
    F: GenericTransform,
{
    fn transform<T>(&mut self, t: T) -> T
    where
        T: Term,
    {
        // Calling `transform` before `map_one_transform` now.
        let t = self.f.transform(t);
        t.map_one_transform(self)
    }
}

So What?

At this point, you might be throwing up your hands and complaining about all the infrastructure we had to write in order to get to the two line solution for increasing salaries in a company. Surely all this infrastructure is at least as much code as the original boilerplate? Yes, but this infrastructure can be shared for all the transformations we ever write, and not even just for companies, but values of all types!

For example, if we wanted to make sure every employee in the company was a good culture fit, we might want to rename them all to “Juan Offus”. This is all the code we’d have to write:

// Definition
let rename = |p: Person| Person("Juan Offus", p.1);
let mut rename = Everywhere::new(Transformation::new(rename));

// Usage
let new_company = rename.transform(old_company);

Finally, the paper notes that this technique is more future proof than writing out the boilerplate:

Furthermore, if the data types change – for example, a new form of SubUnit is added – then the per-data-type boilerplate code must be re-generated, but the code for increase [..] is unchanged.

Queries

What if instead of consuming a T and transforming it into a new T, we wanted to non-destructively produce some other kind of result type R? In the Haskell code, generic queries have this type signature:

forall a. Term a => a -> R

Translating this into Rust, thinking about ownership and borrowing semantics, and using a trait with a generic method to avoid rank-2 function types, we get this:

// Essentially, for<T> FnMut(&T) -> R
pub trait GenericQuery<R> {
    fn query<T>(&mut self, t: &T) -> R
    where
        T: Term;
}

Similar to the Transformation type, we have a Query type, which lifts a query function for a particular U type (FnMut(&U) -> R) into a generic query over all types (for<T> FnMut(&T) -> R aka GenericQuery). The catch is that we need some way to create a default instance of R for the cases where our generic query function is invoked on a value that isn’t of type &U. This is what the D: FnMut() -> R is for.

pub struct Query<Q, U, D, R>
where
    Q: FnMut(&U) -> R,
    D: FnMut() -> R,
{
    make_default: D,
    query: Q,
}

When constructing a Query, and our result type R implements the Default trait, we can use Default::default as D:

impl<Q, U, R> Query<Q, U, fn() -> R, R>
where
    Q: FnMut(&U) -> R,
    R: Default,
{
    pub fn new(query: Q) -> Query<Q, U, fn() -> R, R> {
        Query {
            make_default: Default::default,
            query,
        }
    }
}

Otherwise, we require a function that we can invoke to give us a default value when we need one:

impl<Q, U, D, R> Query<Q, U, D, R>
where
    Q: FnMut(&U) -> R,
    D: FnMut() -> R,
{
    pub fn or_else(make_default: D, query: Q) -> Query<Q, U, D, R> {
        Query {
            make_default,
            query,
        }
    }
}

Here we can see Query in action:

let mut char_to_u32 = Query::or_else(|| 42, |c: &char| *c as u32);
assert_eq!(char_to_u32.query(&'a'), 97);
assert_eq!(char_to_u32.query(&'b'), 98);
assert_eq!(char_to_u32.query("str is not a char"), 42);

Next, we extend the Term trait with a map_one_query method, similar to map_one_transform, that applies the generic query to each of self’s direct children.

Note that this produces zero or more R values, not a single R! The original Haskell code returns a list of R values, and its laziness allows one to only actually compute as many as end up getting used. But Rust is not lazy, and is much more explicit about things like physical layout and storage of values. We don’t want to allocate a (generally small) vector on the heap for every single map_one_query call. Instead, we use a callback interface, so that callers can decide if and when to heap allocate the results.

pub trait Term: Sized {
    // ...

    fn map_one_query<Q, R, F>(&self, query: &mut Q, each: F)
    where
        Q: GenericQuery<R>,
        F: FnMut(&mut Q, R);
}

Implementing map_one_query for Employee would look like this:

impl Term for Employee {
    // ...

    fn map_one_query<Q, R, F>(&self, q: &mut Q, mut f: F)
    where
        Q: QueryAll<R>,
        F: FnMut(&mut Q, R),
    {
        let r = q.query(&self.0);
        f(q, r);
        let r = q.query(&self.1);
        f(q, r);
    }
}

And implementing it for SubUnit like this:

impl Term for SubUnit {
    // ...

    fn map_one_query<Q, R, F>(&self, q: &mut Q, mut f: F)
    where
        Q: QueryAll<R>,
        F: FnMut(&mut Q, R),
    {
        match *self {
            SubUnit::Person(ref p) => {
                let r = q.query(p);
                f(q, r);
            }
            SubUnit::Department(ref d) => {
                let r = q.query(d);
                f(q, r);
            }
        }
    }
}

Once again, map_one_query’s implementation directly falls out of the structure of the type: querying each field of a struct, matching on a variant and querying each of the matched variant’s children. It is also mechanically implemented inside #[derive(Term)].

The final querying puzzle piece is a combinator putting the one-layer querying traversal together with generic query functions into recursive querying traversal. This is very similar to the Everywhere combinator, but now we also need a folding function to reduce the multiple R values we get from map_one_query into a single resulting R value.

Here is its definition and constructor:

pub struct Everything<Q, R, F>
where
    Q: GenericQuery<R>,
    F: FnMut(R, R) -> R,
{
    q: Q,
    fold: F,
}

impl<Q, R, F> Everything<Q, R, F>
where
    Q: GenericQuery<R>,
    F: FnMut(R, R) -> R,
{
    pub fn new(q: Q, fold: F) -> Everything<Q, R, F> {
        Everything {
            q,
            fold,
        }
    }
}

We implement the Everything query traversal top down by querying the given value before mapping the query across its children and folding their results together. The wrapping into and unwrapping out of Options allow fold and the closure to take r by value; Option is essentially acting as a “move cell”.

impl<Q, R, F> GenericQuery<R> for Everything<Q, R, F>
where
    Q: GenericQuery<R>,
    F: FnMut(R, R) -> R,
{
    fn query<T>(&mut self, t: &T) -> R
    where
        T: Term,
    {
        let mut r = Some(self.q.query(t));
        t.map_one_query(
            self,
            |me, rr| { r = Some((me.fold)(r.take().unwrap(), rr)); },
        );
        r.unwrap()
    }
}

With Everything defined, we can perform generic queries! For example, to find the highest salary paid out in a company, we can query by grabbing an Employee’s salary (wrapped in an Option because we could have a shell company with no employees), and folding all the results together with std::cmp::max:

use std::cmp::max;

// Definition
let get_salary = |e: &Employee| Some(e.1.clone());
let mut query_max_salary = Everything::new(Query::new(get_salary), max);

// Usage
let max_salary = query_max_salary.query(&some_company);

If we were only querying for a single value, for example a Department with a particular name, the Haskell paper shows how we could leverage laziness to avoid traversing the whole search tree once we’ve found an acceptable answer. This is not an option for Rust. To have equivalent functionality, we would need to thread a break-or-continue control value from the query function through to map_one_query implementations. I haven’t implemented this, but if you want to, send me a pull request ;-)

However, we can prune subtrees from the search/traversal with the building blocks we’ve defined so far. For example, EverythingBut is a generic transformer combinator that only transforms the subtrees for which its predicate returns true, and leaves other subtrees as they are:

pub struct EverywhereBut<F, P>
where
    F: GenericTransform,
    P: GenericQuery<bool>,
{
    f: F,
    predicate: P,
}

impl<F, P> GenericTransform for EverywhereBut<F, P>
where
    F: GenericTransform,
    P: GenericQuery<bool>,
{
    fn transform<T>(&mut self, t: T) -> T
    where
        T: Term,
    {
        if self.predicate.query(&t) {
            let t = t.map_one_transform(self);
            self.f.transform(t)
        } else {
            t
        }
    }
}

What’s Next?

The paper continues by generalizing transforms, queries, and monadic transformations into brain-twisting generic folds over the value tree. Unfortunately, I don’t think that this can be ported to Rust, but maybe you can prove me wrong. I don’t fully grok it yet :)

If the generic folds can’t be expressed in Rust, that means that for every new kind of generic operation we might want to perform (eg add a generic cloning operation for<T> FnMut(&T) -> T) we would need to extend the Term trait and its custom derive. The consequences are that downstream crates are constrained to only use the operations predefined by scrapmetal, and can’t define their own arbitrary new operations.

The paper is a fun read — go read it!

Finally, check out the scrapmetal crate, play with it, and send me pull requests. I still need to implement Term for all the types that are exported in the standard library, and would love some help in this department. I’d also like to figure out what kinds of operations should come prepackaged, what kinds of traversals and combinators should be built in, and of course some help implementing them.