Automatic Differentiation#
using Plots
How to find derivatives of functions#
Suppose we want to find the slope of a function
f(x)=(x+1)*x*(x-2)
plot(f,-1.5,2.2)
In this particular case, we can do it analytically.
dfdx(x)=3*x^2-2*x-2
plot(f,-1.5,2.2,label="function")
plot!(dfdx,-1,2,label="slope")
An obvious numerical approach, is to use finite differences
One just makes \(\delta\) small enough and one should approximate the derivative. A better approximation is
approxder1(func,x,delta)=(func(x+delta)-func(x))/(delta)
approxder2(func,x,delta)=(func(x+delta)-func(x-delta))/(2*delta)
deltalist=[(10.)^-s for s in 1:20]
derlist1=[abs(approxder1(f,1.,delta)-dfdx(1)) for delta in deltalist]
derlist2=[abs(approxder2(f,1.,delta)-dfdx(1)) for delta in deltalist]
scatter(deltalist,derlist1,xscale=:log10,yscale=:log10,xlabel="δ",ylabel="error in derivative",label="forward")
scatter!(deltalist,derlist2,label="symmetric")
Indeed, the symmetric derivative formula is orders of magnitude better than the forward formula, but the error is actually non-monotonic in the difference \(\delta\). The reason for this is round-off error. There are tricks to mitigate this error, for example
function approxder3(func,x,delta)
x2=x+delta
x1=x-delta
return (func(x2)-func(x1))/(x2-x1)
end
derlist3=[abs(approxder3(f,1.,delta)-dfdx(1)) for delta in deltalist]
derlist3
20-element Vector{Float64}:
0.009999999999999454
9.999999999177334e-5
9.999999162069173e-7
1.000088900582341e-8
1.0547118733938987e-10
1.6653345369377348e-10
1.6653345369377348e-9
5.551115123125783e-9
5.551114967694559e-8
0.0
0.0
5.5509297807398994e-5
0.0005552470849528035
0.011111111111111072
0.052631578947368474
1.0
NaN
NaN
NaN
NaN
As you can see, we actually got the exact result (to machine precision) with that trick, with the appropriate \(\delta\). The reason this trick helped is that with round-off error \((x+\delta)-(x-\delta)\) is not exactly equal to \(2\delta\). The NaN
’s in the end mean “Not a Number”. That came from dividing by zero – for those really small \(\delta\)’s \(x+\delta=x-\delta\).
The problem with using finite diferences for derivatives is that it is expensive. We needed 2 function evaluations to approximate the derivative. If we have a function of \(N\) variables, we would need \(2*N\) evaluations to find the gradient.
You will find code out there that uses finite differences to approximate derivatives. The modern solution, however, is Automatic Differentiation, where we algorithmically step through a program, calculating the derivatives as we go.
There are two flavors of Automatic Differentiation: Forward and backward. I’ll mainly focus on the Forward method in this lecture.
Dispatch, methods, types and overloading#
Every object in any programming language has a type
@show typeof(1.2)
@show typeof(1)
@show typeof(sin)
@show typeof([1,2,3])
@show typeof(1:5);
typeof(1.2) = Float64
typeof(1) = Int64
typeof(sin) = typeof(sin)
typeof([1, 2, 3]) = Vector{Int64}
typeof(1:5) = UnitRange{Int64}
Julia combines features of dynamically typed languages (like Python) and statically typed languages (like C). So far we have essentially been treating it as if it is solely dynamically typed. We are, now, going to take advantage of the fact that we can make functions (called methods
) which behave differently, depending on the type of the object which they are called with. This programming strategy is called multiple dispatch. For example, here is a function which returns 1
when called with an integer, and 0
if called by a floating point number.
real_or_int(x::Int64)=1
real_or_int(x::Float64)=0
real_or_int (generic function with 2 methods)
@show real_or_int(1.2)
@show real_or_int(3);
real_or_int(1.2) = 0
real_or_int(3) = 1
This is particularly useful when defining our own data types. It allows us to overload the standard arithmetic operations:
struct quaternion
a
i
j
k
end
q=quaternion(1,2,3,4)
quaternion(1, 2, 3, 4)
q.k
4
struct quaternion
a
i
j
k
end
Base.:+(a::quaternion,b::quaternion)=quaternion(a.a+b.a,a.i+b.i,a.j+b.j,a.k+b.k)
Base.:-(a::quaternion,b::quaternion)=quaternion(a.a-b.a,a.i-b.i,a.j-b.j,a.k-b.k)
Base.:*(a::quaternion,b::quaternion)=quaternion(
a.a*b.a-a.i*b.i-a.k*b.k,
a.a*b.i+a.i*b.a+a.j*b.k-a.k*b.j,
a.a*b.j+a.j*b.a+ a.k*b.i-a.i*b.k,
a.a*b.k+a.k*b.a+a.i*b.j-a.j*b.i)
function Base.show(io::IO,::MIME"text/latex",x::quaternion)
outputstring="\$("*string(x.a)*") + ("*string(x.i)*")i +("*string(x.j)*")j +("*string(x.k)*")k\$"
print(io,outputstring)
end
quaternion(1,1,0,0)*quaternion(1,1,0,0)
quaternion(1,1,0,0)+quaternion(1,1,0,2)
(1+1im)*(1+1im)
0 + 2im
Forward Automatic Differentiation#
We now introduce dual
numbers, which are going to keep track of both their value, and their derivative with respect to x
.
struct dual
val
deriv
end
We then teach the dual numbers how to update non only the value, but also the derivative. For example,
Base.:*(a::dual,b)=dual(a.val*b,a.deriv*b)
Base.:*(b,a::dual)=dual(b*a.val,b*a.deriv)
x=dual(2.,1.)
2*x
dual(4.0, 2.0)
We then introduce the product rule
Base.:*(a::dual,b::dual)=dual(a.val*b.val,a.deriv*b.val+a.val*b.deriv)
x*x
dual(4.0, 4.0)
Finally we need to deal with addition and subtraction
Base.:+(a::dual,b)=dual(a.val+b,a.deriv)
Base.:+(b,a::dual)=dual(b+a.val,a.deriv)
Base.:+(a::dual,b::dual)=dual(a.val+b.val,a.deriv+b.deriv)
Base.:-(a::dual,b)=dual(a.val-b,a.deriv)
Base.:-(b,a::dual)=dual(b-a.val,-a.deriv)
Base.:-(a::dual,b::dual)=dual(a.val-b.val,a.deriv-b.deriv)
f(x)=(x+1)*x*(x-2)
f (generic function with 1 method)
for x in -1:0.1:1
@show f(dual(x,1))
print(" ")
@show f(x)
print(" ")
@show dfdx(x)
println("***************************************************************")
end
f(dual(x, 1)) = dual(0.0, 3.0)
f(x) = 0.0
dfdx(x) = 3.0
***************************************************************
f(dual(x, 1)) = dual(0.26099999999999995, 2.23)
f(x) = 0.26099999999999995
dfdx(x) = 2.2300000000000004
***************************************************************
f(dual(x, 1)) = dual(0.4479999999999999, 1.5200000000000002)
f(x) = 0.4479999999999999
dfdx(x) = 1.5200000000000005
***************************************************************
f(dual(x, 1)) = dual(0.5670000000000001, 0.8699999999999999)
f(x) = 0.5670000000000001
dfdx(x) = 0.8699999999999997
***************************************************************
f(dual(x, 1)) = dual(0.624, 0.2799999999999999)
f(x) = 0.624
dfdx(x) = 0.28000000000000025
***************************************************************
f(dual(x, 1)) = dual(0.625, -0.25)
f(x) = 0.625
dfdx(x) = -0.25
***************************************************************
f(dual(x, 1)) = dual(0.576, -0.7199999999999999)
f(x) = 0.576
dfdx(x) = -0.7199999999999998
***************************************************************
f(dual(x, 1)) = dual(0.48299999999999993, -1.13)
f(x) = 0.48299999999999993
dfdx(x) = -1.13
***************************************************************
f(dual(x, 1)) = dual(0.3520000000000001, -1.4800000000000004)
f(x) = 0.3520000000000001
dfdx(x) = -1.48
***************************************************************
f(dual(x, 1)) = dual(0.18900000000000003, -1.7700000000000002)
f(x) = 0.18900000000000003
dfdx(x) = -1.77
***************************************************************
f(dual(x, 1)) = dual(-0.0, -2.0)
f(x) = -0.0
dfdx(x) = -2.0
***************************************************************
f(dual(x, 1)) = dual(-0.20900000000000002, -2.1700000000000004)
f(x) = -0.20900000000000002
dfdx(x) = -2.17
***************************************************************
f(dual(x, 1)) = dual(-0.432, -2.2800000000000002)
f(x) = -0.432
dfdx(x) = -2.2800000000000002
***************************************************************
f(dual(x, 1)) = dual(-0.663, -2.33)
f(x) = -0.663
dfdx(x) = -2.33
***************************************************************
f(dual(x, 1)) = dual(-0.8959999999999999, -2.32)
f(x) = -0.8959999999999999
dfdx(x) = -2.32
***************************************************************
f(dual(x, 1)) = dual(-1.125, -2.25)
f(x) = -1.125
dfdx(x) = -2.25
***************************************************************
f(dual(x, 1)) = dual(-1.3439999999999999, -2.12)
f(x) = -1.3439999999999999
dfdx(x) = -2.12
***************************************************************
f(dual(x, 1)) = dual(-1.547, -1.9300000000000002)
f(x) = -1.547
dfdx(x) = -1.9300000000000002
***************************************************************
f(dual(x, 1)) = dual(-1.7280000000000002, -1.68)
f(x) = -1.7280000000000002
dfdx(x) = -1.6799999999999997
***************************************************************
f(dual(x, 1)) = dual(-1.881, -1.37)
f(x) = -1.881
dfdx(x) = -1.3699999999999999
***************************************************************
f(dual(x, 1)) = dual(-2.0, -1.0)
f(x) = -2.0
dfdx(x) = -1.0
***************************************************************
To better see what is going on under the hood, we can add some logging
using Logging
default_logger=global_logger()
debug_logger=ConsoleLogger(stderr, Logging.Debug); # set up logger
global_logger(debug_logger)
ConsoleLogger(IJulia.IJuliaStdio{Base.PipeEndpoint}(IOContext(Base.PipeEndpoint(RawFD(44) open, 0 bytes waiting))), Info, Logging.default_metafmt, true, 0, Dict{Any, Int64}())
function Base.:*(a::dual,b)
result=dual(a.val*b,a.deriv*b)
@debug "a*b" a b result
return result
end
function Base.:*(b,a::dual)
result=dual(b*a.val,b*a.deriv)
@debug "b*a" b a result
return result
end
function Base.:*(a::dual,b::dual)
result=dual(a.val*b.val,a.deriv*b.val+a.val*b.deriv)
@debug "a*b" a b result
return result
end
function Base.:+(a::dual,b)
result=dual(a.val+b,a.deriv)
@debug "a+b" a b result
return result
end
function Base.:+(b,a::dual)
result=dual(b+a.val,a.deriv)
@debug "b+a" b a result
return result
end
function Base.:+(a::dual,b::dual)
result=dual(a.val+b.val,a.deriv+b.deriv)
@debug "a+b" a b result
return result
end
function Base.:-(a::dual,b)
result=dual(a.val-b,a.deriv)
@debug "a-b" a b result
return result
end
function Base.:-(b,a::dual)
result=dual(b-a.val,-a.deriv)
@debug "b-a" b a result
return result
end
function Base.:-(a::dual,b::dual)
result=dual(a.val-b.val,a.deriv-b.deriv)
@debug "a-b" a b result
return result
end
f(x)=(x+1)*x*(x-2)
f(dual(1,1))
┌ Debug: a+b
│ a = dual(1, 1)
│ b = 1
│ result = dual(2, 1)
└ @ Main In[25]:21
┌ Debug: a-b
│ a = dual(1, 1)
│ b = 2
│ result = dual(-1, 1)
└ @ Main In[25]:39
┌ Debug: a*b
│ a = dual(2, 1)
│ b = dual(1, 1)
│ result = dual(2, 3)
└ @ Main In[25]:15
┌ Debug: a*b
│ a = dual(2, 3)
│ b = dual(-1, 1)
│ result = dual(-2, -1)
└ @ Main In[25]:15
dual(-2, -1)
#turn off the debug logger
global_logger(default_logger)
ConsoleLogger(IJulia.IJuliaStdio{Base.PipeEndpoint}(IOContext(Base.PipeEndpoint(RawFD(44) open, 0 bytes waiting))), Debug, Logging.default_metafmt, true, 0, Dict{Any, Int64}())
Application: What is the derivative of the range of a projectile with respect to the launch angle?#
In your lab you created a function which gives the distance that a projectile travels. We can use our Automatic Differentiation scheme to calculate the derivative of the distance with respect to a parameter
include("Projectilelab.jl")
distance (generic function with 2 methods)
thetas=collect(pi/50:pi/50:(pi/2-pi/50))
ds=[distance(initialv=(cos(theta),sin(theta)),m=1,g=1,gamma=10.,delta=2,dt=0.1) for theta in thetas]
scatter(thetas,ds,xlabel="θ",ylabel="d",label="γ=0")
struct dual
val
deriv
end
Base.:*(a::dual,b)=dual(a.val*b,a.deriv*b)
Base.:*(b,a::dual)=dual(b*a.val,b*a.deriv)
Base.:*(a::dual,b::dual)=dual(a.val*b.val,a.deriv*b.val+a.val*b.deriv)
Base.:+(a::dual,b)=dual(a.val+b,a.deriv)
Base.:+(b,a::dual)=dual(b+a.val,a.deriv)
Base.:+(a::dual,b::dual)=dual(a.val+b.val,a.deriv+b.deriv)
Base.:-(a::dual,b)=dual(a.val-b,a.deriv)
Base.:-(b,a::dual)=dual(b-a.val,-a.deriv)
Base.:-(a::dual,b::dual)=dual(a.val-b.val,a.deriv-b.deriv)
Base.:>(a::dual,b)=a.val>b
Base.:<(a::dual,b)=a.val<b
Base.:>(b,a::dual)=b>a.val
Base.:<(b,a::dual)=b<a.val
Base.cos(a::dual)=dual(cos(a.val),-sin(a.val)*a.deriv)
Base.sin(a::dual)=dual(sin(a.val),cos(a.val)*a.deriv)
Base.:^(a::dual,b)=dual(a.val^b,b*a.val^(b-1)*a.deriv)
Base.:-(a::dual)=dual(-a.val,-a.deriv)
Base.:/(a::dual,b::dual)=dual(a.val/b.val,a.deriv/b.val-a.val/(b.val)^2*b.deriv)
Base.:/(a::dual,b)=dual(a.val/b,a.deriv/b)
function dist1(theta)
initialv=(cos(theta),sin(theta))
distance(initialv=initialv,m=1,g=1,gamma=0,delta=2,dt=0.01)
end
@time dist1(dual(0.5,1))
0.224447 seconds (656.41 k allocations: 43.057 MiB, 99.44% compilation time)
dual(0.841461602137972, 1.093065387505064)
Lets compare with our finite differences
@time approxder3(dist1,0.5,1e-4)
0.000393 seconds (14.05 k allocations: 477.094 KiB)
1.0930654018138735
We could also look at the derivitive with respect to some other quantity – say the strength of gravity
function dist2(g)
initialv=(dual(1.,0.),dual(1.,0.))
distance(initialv=initialv,m=1,g=g,gamma=0,delta=2,dt=0.1)
end
dist2(dual(9.8,1.))
dual(0.20275862068965514, -0.014268727705112963)
Derivatives with respect to multiple parameters#
distance(initialv=[dual(1.,[1.,0.]),dual(1.,[0.,1.])],m=1,g=1,gamma=0,delta=2,dt=0.01)
dual(2.0000000000000004, [2.0000000000000004, 2.010050251256236])
Second derivatives#
f(x)=x*x*x
f (generic function with 1 method)
f(dual(dual(1,1),dual(1,0)))
dual(dual(1, 3), dual(3, 6))
Packages for Automatic Differentiation#
In practice one would probably never write one own’s automatic differentiation code. The one we wrote worked fine for our purposes, but once you understand the principle it is nice to let someone else do the hard work
using ForwardDiff
fd=ForwardDiff
ForwardDiff
@time fd.derivative(dist1,0.5)
0.283059 seconds (1.20 M allocations: 81.469 MiB, 4.12% gc time, 99.79% compilation time)
1.093065387505064
@time dist1(0.5)
0.000148 seconds (7.02 k allocations: 238.047 KiB)
0.841461602137972
@time dist1(dual(0.5,1))
0.000520 seconds (29.63 k allocations: 671.500 KiB)
dual(0.841461602137972, 1.093065387505064)
Under the hood, ForwardDiff.jl
just uses the same dual number structure – but with some extra tricks to make things run faster. (About a factor of 2 faster than our code)
d1=fd.Dual{Float64}(0.5,fd.Partials((1.,)))
@time dist1(d1)
0.158059 seconds (644.37 k allocations: 43.868 MiB, 2.75% gc time, 99.64% compilation time)
Dual{Float64}(0.841461602137972,1.093065387505064)
Back Propegation / Backward AD#
The algorithm that we developed is known as forward pass automatic differentiation. There is an alternative algorithm known as backward pass automatic differentiation.
I may not have time to go through all of this section. The punch line is that forward-pass is good for the case where you have a small number of inputs, and an arbitrary number of outputs. It then efficiently calculates the derivatives of each of the outputs with respect to the input. Backward-pass is good when you have many inputs and a small number of outputs.
To elaborate on this, lets think about chained unary operations \(x\rightarrow y\rightarrow z\rightarrow f\).
In forward pass we first calculate \(dy/dx\) as a number, and then \(dz/dx=dz/dy*dy/dx\), and then \(df/dx=df/dz*dz/dx\). We work from left to right.
In backward pass we first calculate \(df/dz\) as a number. Then we get the numerical value of \(df/dy=df/dz*dz/dy\). Finally we calculate \(df/dx=df/dy*dy/dx\).
The end result is the same:
The only difference is the order in which the numbers are calculated. As you will see, the book-keeping is easier for the forward algorithm – thats what the dual number construction does.
Suppose \(x,y,z,f\) are vectors of length \(N_x,N_y,N_z,N_f\). In the forward algorithm we will be storing
\(dy/dx\) – a matrix of size \(N_y\times N_x\)
\(dz/dx\) – a matrix of size \(N_z\times N_x\)
\(df/dx\) – a matrix of size \(N_f\times N_x\)
Conversely, in the backward algorithm we need
\(df/dz\) – a matrix of size \(N_f\times N_z\)
\(df/dy\) – a matrix of size \(N_f\times N_y\)
\(df/dx\) – a matrix of size \(N_f\times N_x\)
There are therefore likely to be efficiencies in the forward method if \(N_x\ll N_f\). Conversely, backward is probably better if \(N_f\ll N_x\).
Our case is a little more complicated, as we are not just doing a simple chain of unary operations – rather we are doing binary operations \(+,-,*\). Here is a sketch of what the Forward algorithm was doing when we calculated the derivative of \(f(x)=(x+1)x(x-2)\) at \(x=1\) – where I name all the intermediate results
In that diagram, each operation is a box. We track the derivative with respect to x of each intermediate result.
We want to calculate
To do so we first calculate:
which leaves us with \(f=y_1*x*(x-2)\). Next we calculate
which leads to \(f=y_1*x*y_2\). We then calculate
Leaving us with
The magic of dual numbers is that all of the book-keeping is automatically taken care of.
Reverse auto-differentiation is similar, but we need to do more of the book-keeping by hand. It is a two-pass algorithm. In the forward pass we construct a flowchart, similar to the one I drew for the forward differentiation case, but storing slightly different information. We then do a backward pass in which we use the stored information to extract the derivitive.
One advantage of the reverse method is that it automatically gives us not just the derivative with respect to the input – but also the derivative with respect to all of the parameters in the function: The ‘1’ and ‘-2’ in \(f(x)=(x+{\color{red} 1})x(x{\color{red}-2})\). This is important in applications like machine learning.
For many physics problems, forward mode AD works better than backwards mode.
For our problem, we will think of \(f(x)=(x+1)x(x-2)\) as \(f=(x_1+c_1)x_2(x_3-c_2)\), with \(x_1=x_2=x_3=1\) and \(c_1=1\), \(c_2=2\). In our forward pass we will generate a flowchart, keeping track of the derivatives of each operation with respect to its inputs:
Thus we store:
Next, we do a backward pass to generate \(df/dz\) on the bond where \(z\) is created. This is a simple application of the chain rule. Starting from the bottom, we know
We then go backwards to the node that generated it and use
In that way we generate the following graph:
We can then just read off \(df/dx\) from the leaves at the top of the graph by adding up the contributions from all of the red arrows leading to it. As promissed, we also, for free, get \(df/c_1\) and \(df/dc_2\).
Coding backward AD is a little more involved than coding forward AD. As with ForwardDiff.jl
there are some great packages which do it for us. Nonetheless, here is a primitive implementation.
The first thing is that we need a data structure which stores our graph. We will call each box in our flowchart a node. It knows about the inputs – which we will call parents
. It will know about the derivitive of the output with respect to the input – which we will call local_derivatives
. It will know the value of the output, value
, and the derivative of f
with respect to the output – derivatives
.
Note: I borrowed this code from a forum post on discourse.julia.org
mutable struct node
value::Float64 # Stores the value of the variable
derivative::Float64 # Stores the value of derivative: df/dnode
parents::Vector{node} # Stores the input variables
local_derivatives::Vector{Float64} # Local derivatives of outputs with respect to input variables
function node(value) # Constructor
x = new()
x.value = value
x.derivative = 0.0 # Needs to be set during back-propegation
x.parents = []
x.local_derivatives = []
return x
end
end
function Base.:+(a::node, b::node)
value = a.value + b.value
C = node(value)
C.parents = [a, b]
C.local_derivatives = [1.0, 1.0] # [dC/da,dC/db]
return C
end
function Base.:-(a::node, b::node)
value = a.value - b.value
C = node(value)
C.parents = [a, b]
C.local_derivatives = [1.0, -1.0] # [dC/da,dC/db]
return C
end
function Base.:*(a::node, b::node)
value = a.value * b.value
C = node(value)
C.parents = [a, b]
C.local_derivatives = [b.value, a.value] # [dC/da,dC/db]
return C
end
"""
set_derivatives_to_zero!(C::node)
recursively follows the flowchart backwards, setting the the derivatives at all nodes to zero. (When created nodes
have their derivative set to zero, but we may want to be able to reuse variables. Certainly causes no harm.
"""
function set_derivatives_to_zero!(C::node)
for i = 1:length(C.parents)
C.parents[i].derivative = 0.0
set_derivatives_to_zero!(C.parents[i])
end
return nothing
end
"""
back_propegate!(C::node)
Propegates the derivitive backward through the flowchart
"""
function back_propegate!(C::node)
for i = 1:length(C.parents)
C.parents[i].derivative += C.derivative * C.local_derivatives[i] #df/di=df/dC dC/di
back_propegate!(C.parents[i])
end
return nothing
end
"""
calc_derivative!(f::node)
starts the back-propegation from the last node
"""
function calc_derivative!(f::node)
set_derivatives_to_zero!(f)
f.derivative=1. #df/df=1
back_propegate!(f)
return nothing
end
calc_derivative!
xn=node(1)
c1n=node(1)
c2n=node(2)
fn=(xn+c1n)*xn*(xn-c2n);
fn.value
-2.0
calc_derivative!(fn)
xn.derivative # should give df/dx
-1.0
c1n.derivative # should give df/dc1
-1.0
c2n.derivative # should give df/dc2
-2.0