Differentiating Types in Haskell

April 26th, 2009

In the last article, we explored the concept of data-types as algebraic expressions, and the idea that the derivative of these expressions produces their type of one-hole contexts.  This “leap of logic” gives us a firm foundation for purely-functional incremental mutation, and (more generally) proof that there are interesting things happening at the level of types!

Still, many good ideas are destroyed (or exalted) when the time comes to make a real program out of them.  So in this article, we will write that program, and in the process shore up this notion of a derivative by defining it over any recursive type.

Defining the Derivative

In his paper first exploring this topic, Conor McBride determines the six laws for differentiating data types, roughly as follows:

1. ∂x(x) 1
2. ∂x(T) where x Ï T 0
3. ∂x(S + T) ∂x(S) + ∂x(T)
4. ∂x(S * T) ∂x(S) * T + S * ∂x(T)
5. ∂x(μy.F) μz.((∂x(F)|y=μ y.F)+(∂y(F)|y=μy.F)*z)
6. ∂x(F|y=S) (∂x(F)|y=S) + (∂y(F)|y=S) * ∂x(S)

This table is slightly more complicated than the rosy picture painted by the last article.  Before getting into the details of this definition, it’s worth pointing out one of the important things done here with recursive types.  We’ll need to distinguish between a recursive type, μ y.F , and its one-step unrolling, F|y=S (see the left column of rules 5 and 6).  Our previous definition of recursive types still applies, but instead of informally saying that a recursive type defines an infinite expansion by repeated substitution, we capture a single-step of substitution in this new form.

Also, this point was slightly glossed over in the last article, it should be pointed out that for typical data-types, we’re actually interested in the partial derivative.  In the definition above, ∂x(T) means the partial derivative of T with respect to x, and ∂y(T) means the partial derivative with respect to y.  Likewise, when we use ∂x(T) we are poking an x-hole in T (i.e.: we’re finding one x in T to single out).

The first four rules are just the same rules as in Calculus, but it’s instructive to informally consider what they mean in the context of types.  The first rule says that poking an x-hole in just the single variable x produces the unit type (this also tells us that holes have the unit type, perhaps a subtle point that was not made in the last article).  The second rule says that poking an x-hole in any type where x is not mentioned produces the void type (i.e.: you can’t poke a hole in an x that isn’t there).  The third rule says that if you have a type that could either be an S or a T, then poking an x-hole in this type produces a type that’s either an S with an x-hole or a T with an x-hole.  And the fourth rule says that poking an x-hole in a pair of an S and a T is either an S with an x-hole and a T (left intact), or it’s an S (left intact) and a T with an x-hole — in other words, in a pair, the hole is either in the first or the second value.

In other words, rules 3 and 4 are just the normal sum and product rules of Differential Calculus (respectively) and rules 1 and 2 are the regular base cases of differentiation (for the differentiated variable, or a constant — respectively).

So what explains rules 5 and 6?  These are two sides of the plain old chain rule.  It’s (arguably) easiest to make this case for rule 6, so let’s repeat it for reference:

∂x(F|y=S) : (∂x(F)|y=S) + (∂y(F)|y=S) * ∂x(S)

Now consider the definition of the chain rule for partial derivatives:

For z = f(u, v)
∂z/x = (f/u)*(u/x) + (f/v)*(v/x)

This definition looks almost identical to rule 6, except that in this case the left side of the sum (∂x(F)|y=S) is just ∂f/∂x.

The connection is a little more obscure in rule 5:

∂x(μ y.F) : μ z.((∂x(F)|y=μ y.F) + (∂y(F)|y=μ y.F) * z)

Here we actually have the same structure as in rule 6, except that we see “z” where the chain rule tells us to expect ∂v/∂x.  Informally, this anomaly is resolved by recognizing that “z” (representing the infinite expansion of the entire recursive type) is in fact ∂v/∂x.

Let’s try a non-trivial example to develop some intuition for this method before we set out to write the differentiation program.  Consider the recursive type of lists of integers:

μ X.(1+int*X)

And we will want to poke an int-hole, so we’re going to find:

∂int(μ X.(1+int*X))

By rule #5 (and given that “y” from rule #5 matches “X” in our expression, and “F” from rule #5 matches “1+int*X” in our expression), this becomes:

μz.((∂int(1+int*X)|X=μX.1+int*X)+(∂X(1+int*X)|X=μX.1+int*X)*z)

Yikes!  There’s a lot going on there, but perhaps we can break this problem into easier pieces by a couple of choice substitutions.  Since this is just a recursive type with a sum in it, let’s choose a couple of variables to represent both sides of the sum:

μz.(U+V*z)
where
   U = ∂int(1+int*X)|X=(μX.1+int*X)
   V = ∂X(1+int*X)|X=(μX.1+int*X)

Now we can evaluate U and V independently.  First let’s consider U:

∂int(1+int*X)|X=(μX.1+int*X)

By rule #3 (the sum rule) this becomes:

(∂int(1)+∂int(int*X))|X=(μX.1+int*X)

Now we have two derivatives to evaluate.  First the one on the left becomes 0 (by rule #2) and we can eliminate it:

∂int(int*X)|X=(μX.1+int*X)

Then by rule #4 (the product rule) this last derivative becomes:

(∂int(int) * X + int * ∂int(X))|X=(μX.1+int*X)

Now we’re almost done.  Here we have one more sum, and the derivative on the left becomes 1 (by rule #1), leaving a lone X:

(X + int * ∂int(X))|X=(μX.1+int*X)

And the derivative on the right becomes 0 (by rule #2), which means that the right side of the sum can be eliminated.  Therefore we have solved “U”:

U = X|X=(μX.1+int*X)

Now, since we’re done, we could substitute for X here to eliminate the redundant context:

U = μX.(1+int*X)

In other words, U is a list of ints.  Now we must solve V (the trickier side of the chain rule, where we switch to poking an int-list hole rather than an int-hole):

∂X(1+int*X)|X=(μX.1+int*X)

By rule #3, this becomes:

∂X(1)+∂X(int*X)|X=(μX.1+int*X)

Here the left side of the sum is eliminated (by rule #2) and the right side is elaborated (by rule #4) becoming:

(∂X(int) * X + int * ∂X(X))|X=(μX.1+int*X)

Now the left side of the sum is eliminated (by rule #2), and the right side just simplifies to int (by rule #1), so that we’ve also now solved for V:

V = int|X=(μX.1+int*X)

In fact, we don’t need this trailing context information anymore (since “X” doesn’t appear in the simple type variable “int”, we don’t need to remember how to substitute “X”), so we can simply say:

V = int

Now, substituting U and V back into our original solution for the derivative of a list of ints, this becomes:

μz.((μX.(1+int*X))+int*z)

To help us make some sense of this answer, it’s useful to observe that the following identity holds:

μ a.(T+S*a) = list<S> * T

We know this to be true in the normal case of lists — because in that case T = 1 (and list<S>*1 = list<S>).  But if we apply this identity to the answer we’ve derived, we can conclude that the derivative of a list of ints is a pair of lists of ints (which makes sense — the first list is the prefix up to the hole, and the second list is the suffix following the hole).  In fact, it’s a good thing that we derived this answer, because we’ve already made use of it!

Now, assuming that we’re satisfied with this differentiation procedure, all that’s left is to actually implement it in a real program.

A Brief Haskell Program

The first thing that we need to do is set down a data structure that represents types.  The circularity of this definition may seem odd (we’re defining a data type for structures that represent data types), but because we’ve already covered in detail the forms that types can take, it shouldn’t be very surprising:

data Type =
   Ty String
| Sum (Type, Type)
| Prod (Type, Type)
| Fix (String, Type)
| FApp (Type, String, Type) deriving Eq

Stated plainly, a type is a type variable (e.g.: int, string, date, unit, void, etc), or the sum of two types, or the product of two types, or a recursive type, or the one-step unrolling of a recursive type.

For convenience, we can define a function to display these type structures conveniently (and consistent with the notation we’ve developed so far):

instance Show Type where
show (Ty “unit”)     =”1″
show (Ty “void”)     =”0″
show (Ty v)          =v
show (Sum  (t1, t2)) =(show t1)++”+”++(show t2)
show (Prod (t1, t2)) =(show t1)++”*”++(show t2)
show (Fix  (v,  ty)) =”mu “++v++”.”++(show ty)
show (FApp (t, x, s))=”["++(show t)++"|"++x++"="++(show s)++"]“

This way, if we enter an interesting type at the Haskell prompt (like the recursive type of lists of integers), we should see it in the shorthand we’ve developed:

Prelude> Fix (“X”, Sum (Ty “unit”, Prod (Ty “int”, Ty “X”)))
mu X.1+int*X

Finally, we can define a “derivative” function on types with respect to a particular type-variable by explicitly translating each of the 6 rules we’ve considered above.  First rule #1:

derivative (Ty v) x | x == v = Ty “unit”

Then rule #2 (a type other than x):

derivative (Ty v) x = Ty “void”

Then rule #3 (the rule for sums):

derivative (Sum  (t1, t2)) x = Sum (derivative t1 x, derivative t2 x)

And rule #4, for products:

derivative (Prod (t1, t2))   x = Sum (Prod (derivative t1 x, t2), Prod (t1, derivative t2 x))

Rule #5 (assuming for the moment that we have a function unique_name to choose a unique type-variable name that does not appear in a particular type expression):

derivative rty@(Fix (v, ty)) x =
   let z = unique_name rty in
   let dt_dx = derivative ty x in
   let dt_dv = derivative ty v in
     Fix (z,
      Sum (
        FApp (dt_dx, v, rty),
        Prod (FApp (dt_dv, v, rty), Ty z)))

And finally rule #6:

derivative (FApp (t, y, s))  x =
    let dt_dx = derivative t x in
    let dt_dy = derivative t y in
    let ds_dx = derivative s x in
       Sum (FApp(dt_dx, y, s), Prod (FApp (dt_dy, y, s), ds_dx))

Now, as a question of correctness, we have finished this problem — we have a Haskell program that can compute derivatives of types (modulo the definition of this unique_name function).  However, as in the normal differential calculus, the terms we produce can be a little messy — with lots of 1*X and X+0 hanging around.  There can also be unnecessary recursive types (where the recursion variable is not mentioned in the body of the recursive type), and those should be eliminated as well.

The simplest approach is probably to define the single-step simplification of a type expression:

simplify (Sum  (Ty “void”, x))= simplify x
simplify (Sum  (x, Ty “void”))= simplify x
simplify (Sum  (x, y))        = Sum (simplify x, simplify y)
simplify (Prod (Ty “void”, x))= Ty “void”
simplify (Prod (x, Ty “void”))= Ty “void”
simplify (Prod (Ty “unit”, x))= simplify x
simplify (Prod (x, Ty “unit”))= simplify x
simplify (Prod (x, y))        = Prod (simplify x, simplify y)
simplify (Fix  (v, ty))       | not (v `member` names ty) = ty
simplify (Fix  (v, ty))       = Fix (v, simplify ty)
simplify (FApp (t, y, s))     | not (y `member` names t) = t
simplify (FApp (t, y, s))     = FApp (simplify t, y, simplify s)
simplify x                    = x

Then, given that we can perform a single step of simplification to a term (keep in mind, one 0*X becoming 0 could propagate into a new 0*Y, as in the expression Y*(0*X)), we can completely simplify a term by the fixed-point of this process:

fixed_point f x | x == (f x) = x
fixed_point f x              = fixed_point f (f x)

csimplify x = fixed_point simplify x

Also, because the way we display these types loses some grouping information (the type “int*(string+date)” is actually printed “int*string+date”, which is actually a different type!) we should probably flatten types (e.g.: converting “int*(string+date)” to “int*string + int*date”):

flatten ty = distribute Nothing ty
where
  distribute Nothing t@(Ty _) = t
  distribute (Just lty) t@(Ty _) = Prod (lty, t)
  distribute mlty (Sum (t1, t2)) = Sum (distribute mlty t1, distribute mlty t2)
  distribute mlty (Prod (t1, t2)) = distribute (Just (distribute mlty t1)) t2
  distribute Nothing (Fix (v, ty)) = Fix (v, distribute Nothing ty)
  distribute (Just lty) t@(Fix (v, ty)) = Prod (lty, distribute Nothing t)
  distribute Nothing (FApp (t, x, s)) = FApp (distribute Nothing t, x, distribute Nothing s)
  distribute (Just lty) t@(FApp (f, x, s)) = Prod (lty, distribute Nothing t)

Finally we can package all of these niceties on top of the “derivative” function into a single definition:

d ty x = flatten (csimplify (derivative ty x))

For convenience, I’ve prepared a script — please rename it to dtype.hs, somehow our weblog software thinks that *.hs files are insecure — which contains all of these definitions (along with several utility functions, like unique_name, which are necessary but not vital to the understanding of this differentiation procedure).  Provided that you have GHC installed, you can load the script like this:

> ghci dtype.hs

We can then test it out on two types whose one-hole contexts we’ve already made use of.  First the derivative of a list:

Prelude> d (Fix (“X”, Sum (Ty “unit”, Prod (Ty “int”, Ty “X”)))) “int”
mu a.[X|X=mu X.1+int*X]+int*a

Also we can take the derivative of a binary tree:

Prelude> d (Fix (“X”, Sum (Ty “int”, Prod (Ty “X”, Ty “X”)))) “int”
mu a.1+[X+X|X=mu X.int+X*X]*a

Again, using the list identity and the observation that X is the type of binary trees of ints, this tells us that the derivative of a binary tree is a list of a binary tree or a binary tree (either the left or right branch passed on the path from the hole in the tree to the root of the tree).

Comments are closed.