As the "variables" are f, x, and v in Jf(x)⊤v, we can define a new operation vjp(f,x)(v).1We could use vjp(f,x,v), but currying
v out simplifies later calculations.
Note that vjp:(Rn→Rm,Rn)→Rm→Rn.
vjp(f,x)(v):=Jf(x)⊤v
But that still requires us to calculate ∂f/∂xk, which can be hard depending on f.
This is especially the case if f is a composition of multiple functions.2fk was used as the kth component of f before, but we use it to denote different functions from here on.
The chain rule for Jacobians is Jfk+1∘fk(x)=Jfk+1(fk(x))Jfk(x).
We can define xk=(fk∘fk−1∘⋯∘f1)(x) as the kth intermediate function value.
Reconciling this with VJPs:
The value of this representation of VJPs lies in the fact that it enables reverse mode automatic differentiation.
Automatic differentiation
Automatic differentiation (autodiff) follows from the fact that we can obtain the VJP of a composite function algorithmically by evaluating the VJPs of its constituent functions.
Since even the most complicated of functions are made up of a composition of elementary functions, and the VJP of elementary functions is trivial, we can build up VJPs for "complicated" (i.e. lots of variables, deep composition, etc.) functions step-by-step.
We require numeric values for v and x which sets this method apart from symbol differentiation.
Reverse mode autodiff refers to the fact that two passes -- forward and backward -- are required to calculate VJPs.
Calculating the VJP of fk requires us to have calculated the VJP of fk+1 as well as xk−1.
The forward pass as we evaluate f from f1 to fi gives us the intermediate value xk−1.
The backwards pass as we build up the VJPs from fi to f1 calculates the VJP of fk+1.
Wait, are VJPs actually doing differentiation though?
Sure, we can calculate the product of a vector and Jacobian, but differentiation would imply something like a gradient or the entire Jacobian.
Turns out VJPs allow us to calculate both.
Using f:Rn→R and v=11,1, we obtain the VJP:
So the VJP of a scalar function can be used to calculate its gradient at a point.
Similarly, we can calculate the entire Jacobian of a function row-by-row.
Choosing v as one-hot encoded vectors results in the VJP being the row of the Jacobian encoded by v.
Jf(x)=vjp(f,x)(δi=0)⊤⋮vjp(f,x)(δi=m)⊤
It is worth noting that building the entire Jacobian with VJPs requires m passes for f:Rn→Rm with an m×n Jacobian.
We can assume that this method is more efficent the smaller m is compared to n.
In a machine learning context, m=1≪n is common as the outputs are scalar loss values and the inputs are model weights.
This makes VJPs suitable for machine learning backpropagation.
Vector-Jacobian products can be built up from VJPs and produce gradients and Jacobians.
This enables reverse mode automatic differentiation, which is efficient for machine learning. Everything falls into place.