Skip to content

Latest commit

 

History

History
314 lines (285 loc) · 13.3 KB

SymbolicDiff.org

File metadata and controls

314 lines (285 loc) · 13.3 KB

Symbolic Differentiation

Have you ever thought about how to write a program that differentiates any expression any expression? Now, it’s easy to write a program that computes a numerical approximation for the derivative of a given function. The question is, can you write a program that gives the exact general form for the derivative of any expression? It’s not as difficult as it seems. In fact a good way to check whether you understand differentiation and recursion is to try writing the program yourself. If you are stuck, read the first section and try again.

How to represent an expression

The first problem we need to tackle is that many programming languages don’t have a data structure that directly represents expressions. Programming languages like Python support arrays, dictionaries (a.k.a. hash tables), strings and tuples but not expressions. One way to start approaching this problem is to write down a formal definition of what a valid expression is. Formally defining any notion makes it easier to write a precise computer program based on that notion. So here is a formal definition:

Numbers
Any numeric symbol such as 5, or 3.14 is a valid expression.
Variables
Any string of characters such as "x", "y" or "abcd" is a valid expression.
Functions with one argument
Given that e is a valid expression, the following are also valid expressions:
  • sin(e)
  • cos(e)
  • exp(e)
Arithmetic operations
Given that e1 and e2 are valid expressions, the following are also valid expressions:
  • e1 + e2
  • e1 * e2

Of course, we can name many more functions than the three listed above but, as we’ll see, it’s quite easy to add any more function we would like.

We then say that valid expressions are “freely generated,” by the above rules. What this means, intuitively speaking, is that any meaningful way of combining the rules results in a valid expression, and that combining these rules is the only way one can create a valid expression. For example x is a valid expression by the Variables rule and so is 5 by the Numbers rule. Therefore, by the second case of the Arithmetic Operations rule, so is 5 * x, and by the Functions with one argument rule so is sin(5 * x). Now explaining precisely what “freely generating,” something by listing ways of building it means mathematically in the general case means is quite difficult actually. In fact, doing this in a completely general context gives one a foundation of mathematics, where mathematical structures are described by listing the ways of building them and asking that the structure in question is freely generated. But that has to wait for another time.

So, the next question is, how do we translate the above definition into Python. Let’s start with the first case, which is pretty easy:

Numbers
We can represent numbers directly in Python.
Variables
Strings of characters can also be directly represented.
Functions with one argument
These are all created by starting with an expressions e, and putting some function symbol in front of it. So given that e is a valid expression, then we’ll say that the following are also valid expressions:
  • ("sin", e)
  • ("cos", e)
  • ("exp", e)
Arithmetic operations
In a similar way, given that e1 and e2 are valid expressions the following are also valid expressions:
  • ("+", e1, e2)
  • ("*", e1, e2).

That’s it, we now have a way of representing, in Python, any expression that follows our rules that defined “valid” expressions above.

Programming differentiation

So we want to write a function diff which takes a valid expression e as input and the variable that we are differentiating with respect to x and returns a valid expression that represents the derivative of e with respect to x. We will assume that all variables except for x are constant, i.e. they don’t depend on x. Let’s begin by determining which of the above four kinds of expression e is:

def diff(e, x):
    if <e is a number>:
        ...
    if <e is a variable>:
        ...
    elif <e is a function with one argument>:
        ...
    elif <e is an arithmetic operation>:
        ...
    else:
    ...

Now how would we know whether e is a number. There are two kinds of numbers in Python, ints and floats. We can use isinstance(e, int), which is True when e is an int and False otherwise, the same applies for floats. The next question is, how do we check whether e is a variable? Variables are represented with strings, so if e is a string then it represents a variable we can use isinstance(e, str). On to the next case, how do we check if e is a function with one argument. Well, e would have to be a tuple with only two elements, so we may use isinstance(e, tuple) and len(e) == 2. We may check whether e is an arithmetic operation in a similar way. Here’s what we have so far:

def diff(e, x):
    if isinstance(e, int) or isinstance(e, float):
        ...
    if isinstance(e, str):
        ...
    elif isinstance(e, tuple) and len(e) == 1:
        ...
    elif isinstance(e, tuple) and len(e) == 2:
        ...
    else:
        ...

Let’s tackle the first branch of the if statement. So know that e is a number. What is the derivative of a constant number with respect to any number? Obviously zero, so we may return 0 in the first case. On to the next branch, what is the derivative of some variable e with respect to another variable x. Well, if e is the same variable as x, we are calculating $\frac{dx}{dx}~$ which is one. If e is a different variable, we’ll assume that its value is independent of x, again the derivative of a constant is zero so we return zero in this case. Here’s what we get:

def diff(e, x):
    if isinstance(e, int) or isinstance(e, float):
        return 0
    if isinstance(e, str):
        if e == x:
            return 1
        else:
            return 0
    elif isinstance(e, tuple) and len(e) == 1:
        ...
    elif isinstance(e, tuple) and len(e) == 2:
        ...
    else:
        ...

Now on to the first interesting case. Recalling our definition, when e is an expression of two variables it is of the form (f, u) where f is some string that tells you the name of the function and u is the argument to the function. So, how do we differentiate $f(u)$? Of course we use the chain rule. It tells us that $f(u)’ = f’(u)u’$. Notice how the chain rule gives the derivative of a composite $f(u)$ in terms of the derivative of $f$, and the derivative of $u$. This means that we have to use recursion to compute the derivative. So we first have to find out what the derivative of f is, which isn’t difficult since according to our definition there are only three possible functions we can use here, so all we need to do is have one case for each possible value of f. We then recursively compute the derivative of u, and multiply the resulting expressions.

def diff(e, x):
if isinstance(e, int) or isinstance(e, float):
    return 0
if isinstance(e, str):
    if e == x:
        return 1
    else:
        return 0
elif isinstance(e, tuple) and len(e) == 1:
    (f, u) = e
    if isinstance(f, str):
        if f == "sin":
            ...
        elif f == "cos":
            ...
        elif f == "exp":
            ...
        else:
            ...
    else:
        ...
elif isinstance(e, tuple) and len(e) == 2:
    ...
else:
    ...

Starting with the case where f == "sin", we have to differentiate $sin(u)$. Since $sin(u)’ = sin’(u)u = cos(u)u’$. Translating this into Python, we have to return (*, ("cos", u), diff(u, x)). The other cases are similar. If we hit the last if statement of the branch then we know that the expression we were given is invalid (by definition), and so we return None to indicate the lack of a meaningful result. For the same reason, we return None when f is not a string.

def dif    f(e, x):
    if isinstance(e, int) or isinstance(e, float):
        return 0
    if isinstance(e, str):
        if e == x:
            return 1
        else:
            return 0
    elif isinstance(e, tuple) and len(e) == 1:
        (f, u) = e
        if isinstance(f, str):
            if f == "sin":
                return ("*", ("cos", u), diff(u, x))
            elif f == "cos":
                return ("*", ("*", -1, ("sin", u)), diff(u, x))
            elif f == "exp":
                return ("*", ("exp", u), diff(u, x))
            else:
                return None
        else:
            return None
    elif isinstance(e, tuple) and len(e) == 2:
        ...
    else:
        ...

Just like the chain rule, the product and sum rules also allow us to define the derivative of a sum or a product recursively. We determine which type of expression by a similar case analysis. Try writing the rest yourself before seeing the solution.

Final program

So we are done, here is the program (Dramatic Trumpet Music …)

def diff(e, x):
    if isinstance(e, int) or isinstance(e, float):
        return 0
    if isinstance(e, str):
        if e == x:
            return 1
        else:
            return 0
    elif isinstance(e, tuple) and len(e) == 1:
        (f, u) = e
        if isinstance(f, str):
            if f == "sin":
                return ("*", ("cos", u), diff(u, x))
            elif f == "cos":
                return ("*", ("*", -1, ("sin", u)), diff(u, x))
            elif f == "exp":
                return ("*", ("exp", u), diff(u, x))
            else:
                return None
        else:
            return None
    elif isinstance(e, tuple) and len(e) == 2:
        (op, e1, e2) = e
        if isinstance(op, str):
            if op == "+":
                return ("+", diff(e1, x), diff(e2, x))
            elif op == "*":
                return ("+", ("*", diff(e1, x), e2), 
                             ("*", e1, diff(e2, x)))
            else:
                return None 
    else:
        return None

Let’s do some test cases.

>>> diff(("sin", ("cos", "x")), "x")
('*', ('cos', ('cos', ('*', 2, 'x'))), 
      ('*', ('*', -1, ('sin', ('*', 2, 'x'))), 
            ('+', ('*', 0, 'x'), ('*', 2, 1))))

If we simplify the result manually, we get that $sin(cos(x))’ = -2cos(cos(2x))sin(2x)$, which is the right result. Let’s try another one:

>>> diff(("*", ("exp", ("*", 5, "x")), ("sin", "x"))
('+', ('*', ('*', ('exp', ('*', 5, 'x')), 
                  ('+', ('*', 0, 'x'), 
                        ('*', 5, 1))), 
            ('sin', 'x')), 
      ('*', ('exp', ('*', 5, 'x')), 
            ('*', ('cos', 'x'), 1)))

Again, after simplifying we get that $e5xsin(x)’ = 5e5xsin(x) + e5xcos(x)$, which again is the right result. So, let’s just cut to the chase. It works!

To conclude, I’ll make a few remarks. Firstly, although the way we represented expressions is quite convenient, there is a way of representing them using classes in Python that is much more elegant. Even then, in my favourite kinds of programming languages, namely functional programming languages like Haskell, there is a better way called algebraic data types where the above mess can be written simply as:

data Expr = Var String
          | Const Int
          | Plus Expr Expr
          | Times Expr Expr
          | Sin Expr
          | Cos Expr
          deriving Show

diff :: Expr -> String -> Expr
diff (Var y) x | y == x = Const 1
               | otherwise = Const 0
diff (Const _) _ = Const 0
diff (Plus e1 e2) x = Plus (diff e1 x) (diff e2 x)
diff (Times e1 e2) x = Plus (Times (diff e1 x) e2)
                            (Times (diff e2 x) e1)
diff (Sin e1) x = Times (Cos e1) (diff e1 x)
diff (Cos e1) x = Times (Times (Const (-1)) (Sin e1)) (diff e1 x)

Notice how we don’t have to make sure the given expression is valid. The reason for this is we can define what a valid expression is, that is what the data Expr is for. Once the language understands what a valid expression is it can make sure that no one can use diff with an invalid expression.

But, most languages don’t support algebraic data types directly. Also, notice how diff(u, x) is repeated in the Python code. The reason for this is that I didn’t bother defining what it means to substitute a variable for a particular expression, defining that is actually trickier than it seems. Also, the representation of expressions we used, is quite cumbersome compared to the usual notation for expressions, it would be cool if you could enter an expression like "sin(cos(x))" as a string and get its derivative "-cos(sin(x))sin(x)" out. To implement this we would need a parser that converts strings into the nested tuples that we used to represent expressions. We would also need a simplifier, and something that outputs formatted expressions e.g. as a string or even better an image.

One last point, the base cases of the recursive program correspond to the bases cases of the recursive definition of valid expressions, and a similar thing happens with the recursive cases of both. This is part of that deep connection I mentioned between recursive definitions and recursion/induction principles. Unfortunately, all these things also have to wait for another time.