~yklcsblogvector-jacobian-products

Vector-Jacobian products and automatic differentiation

Consider f:RnRmf: \R^n \to \R^m. The Jacobian of ff, denoted Jf\mathbf{J}_f, is an m×nm \times n matrix of all the partial derivatives. Pretty basic:

Jf=fx=[f1x1f1xnfmx1fmxn]\textbf{J}_f = \frac{\partial{f}}{\partial{\mathbf{x}}} = \begin{bmatrix} \dfrac{\partial{f_1}}{\partial{x_1}} & \cdots & \dfrac{\partial{f_1}}{\partial{x_n}} \\ \vdots & \ddots & \vdots \\ \dfrac{\partial{f_m}}{\partial{x_1}} & \cdots & \dfrac{\partial{f_m}}{\partial{x_n}} \end{bmatrix}

For vRm\mathbf{v} \in \R^m and xRn\mathbf{x} \in \R^n, let's look at the product of the row vector v\mathbf{v}^\top and Jf(x)\mathbf{J}_f (\mathbf{x}). This is known as the vector-Jacobian product (VJP).

vJf(x)=[v1v2vm][f1x1(x1)f1xn(xn)fmx1(x1)fmxn(xn)]=[v1f1x1(x1)++vmfmx1(x1)v1f1xn(xn)++vmfmxn(xn)]\begin{aligned} \mathbf{v}^\top \mathbf{J}_f (\mathbf{x}) &= \begin{bmatrix} v_1 & v_2 & \cdots & v_m \end{bmatrix}\begin{bmatrix} \dfrac{\partial{f_1}}{\partial{x_1}} (x_1) & \cdots & \dfrac{\partial{f_1}}{\partial{x_n}} (x_n) \\ \vdots & \ddots & \vdots \\ \dfrac{\partial{f_m}}{\partial{x_1}} (x_1) & \cdots & \dfrac{\partial{f_m}}{\partial{x_n}} (x_n) \end{bmatrix} \\ &= \begin{bmatrix} v_1 \dfrac{\partial{f_1}}{\partial{x_1}} (x_1) + \cdots + v_m \dfrac{\partial{f_m}}{\partial{x_1}} (x_1) \\ \vdots \\ v_1 \dfrac{\partial{f_1}}{\partial{x_n}} (x_n) + \cdots + v_m \dfrac{\partial{f_m}}{\partial{x_n}} (x_n) \end{bmatrix}^\top \end{aligned}

The row vectors are pesky, so transposes are taken.

Jf(x)v=[v1f1x1(x1)++vmfmx1(x1)v1f1xn(xn)++vmfmxn(xn)]{\mathbf{J}_f} (\mathbf{x})^\top\mathbf{v} = \begin{bmatrix} v_1 \dfrac{\partial{f_1}}{\partial{x_1}} (x_1) + \cdots + v_m \dfrac{\partial{f_m}}{\partial{x_1}} (x_1) \\ \vdots \\ v_1 \dfrac{\partial{f_1}}{\partial{x_n}} (x_n) + \cdots + v_m \dfrac{\partial{f_m}}{\partial{x_n}} (x_n) \end{bmatrix}

As the "variables" are ff, x\mathbf{x}, and v\mathbf{v} in Jf(x)v{\mathbf{J}_f} (\mathbf{x})^\top\mathbf{v}, we can define a new operation vjp(f,x)(v)\operatorname{vjp}(f, \mathbf{x})(\mathbf{v}).1We could use vjp(f,x,v)\operatorname{vjp}(f, \mathbf{x}, \mathbf{v}), but currying v\mathbf{v} out simplifies later calculations.

Note that vjp:(RnRm,Rn)RmRn\operatorname{vjp} : (\R^n \to \R^m, \R^n) \to \R^m \to \R^n.

vjp(f,x)(v)Jf(x)v\operatorname{vjp}(f, \mathbf{x})(\mathbf{v}) \coloneqq {\mathbf{J}_f} (\mathbf{x})^\top \mathbf{v}

But that still requires us to calculate f/xk\partial{f}/\partial{x_k}, which can be hard depending on ff. This is especially the case if ff is a composition of multiple functions.2fkf_k was used as the kkth component of ff before, but we use it to denote different functions from here on.

The chain rule for Jacobians is Jfk+1fk(x)=Jfk+1(fk(x))Jfk(x)\mathbf{J}_{f_{k+1} \circ f_k} (\mathbf{x}) = \mathbf{J}_{f_{k+1}} (f_k (\mathbf{x})) \mathbf{J}_{f_k}(\mathbf{x}). We can define xk=(fkfk1f1)(x)\mathbf{x}_k = (f_k \circ f_{k-1} \circ \cdots \circ f_1) (\mathbf{x}) as the kkth intermediate function value. Reconciling this with VJPs:

vjp(f,x)(v)=Jf(x)v=(Jfif3f2f1(x))v=(Jfif3f2(f1(x))Jf1(x))v=(Jfif3(f2(f1(x)))Jf2(x1)Jf1(x))v=(Jfi(xi1)Jf2(x1)Jf1(x))v=Jf1(x)Jf2(x1)Jfi(xi1)v\begin{aligned}\operatorname{vjp}(f,\mathbf{x})(\mathbf{v}) &= {\mathbf{J}_f} (\mathbf{x})^\top \mathbf{v} \\ &= (\mathbf{J}_{f_i \circ \cdots \circ f_3 \circ f_2 \circ f_1}(\mathbf{x})) ^\top \mathbf{v} \\ &= (\mathbf{J}_{f_i \circ \cdots \circ f_3 \circ f_2} (f_1(\mathbf{x})) \mathbf{J}_{f_1} (\mathbf{x})) ^\top \mathbf{v} \\ &= (\mathbf{J}_{f_i \circ \cdots \circ f_3} (f_2(f_1(\mathbf{x}))) \mathbf{J}_{f_2} (\mathbf{x}_1) \mathbf{J}_{f_1} (\mathbf{x})) ^\top \mathbf{v} \\ &= (\mathbf{J}_{f_i} (\mathbf{x}_{i-1}) \cdots \mathbf{J}_{f_2} (\mathbf{x}_1) \mathbf{J}_{f_1} (\mathbf{x})) ^\top \mathbf{v} \\ &= \mathbf{J}_{f_1} (\mathbf{x})^\top \mathbf{J}_{f_2} (\mathbf{x}_1)^\top \cdots \mathbf{J}_{f_i} (\mathbf{x}_{i-1})^\top \mathbf{v} \end{aligned}

Noting that Jfi(x)v=vjp(fi,x)(v)\mathbf{J}_{f_i}(\mathbf{x})^\top\mathbf{v} = \operatorname{vjp}(f_i, \mathbf{x})(\mathbf{v}), we can express vjp(f,x)(v)\operatorname{vjp}(f, \mathbf{x})(\mathbf{v}) with a composition of vjp\operatorname{vjp}s.

vjp(f,x)(v)=Jf1(x)Jfi1(xi2)Jfi(xi1)v=Jf1(x)Jfi1(xi2)vjp(fi,xi1)(v)=Jf1(x)vjp(fi1,xi2)(vjp(fi,xi1)(v))=(vjp(f1,x)vjp(fi,xi1))(v)\begin{aligned} \operatorname{vjp}(f,\mathbf{x})(\mathbf{v}) &= \mathbf{J}_{f_1} (\mathbf{x})^\top \cdots \mathbf{J}_{f_{i-1}} (\mathbf{x}_{i-2})^\top \mathbf{J}_{f_i} (\mathbf{x}_{i-1})^\top \mathbf{v} \\ &= \mathbf{J}_{f_1} (\mathbf{x})^\top \cdots \mathbf{J}_{f_{i-1}} (\mathbf{x}_{i-2})^\top \operatorname{vjp}(f_i,\mathbf{x}_{i-1})(\mathbf{v}) \\ &= \mathbf{J}_{f_1} (\mathbf{x})^\top \cdots \operatorname{vjp}(f_{i-1}, \mathbf{x}_{i-2})(\operatorname{vjp}(f_i,\mathbf{x}_{i-1})(\mathbf{v})) \\ &= (\operatorname{vjp} (f_1, \mathbf{x}) \circ \cdots \circ \operatorname{vjp}(f_i,\mathbf{x}_{i-1})) (\mathbf{v}) \end{aligned}

Which brings us to our conclusion:

vjp(f,x)(v)=(vjp(f1,x)vjp(fi,xi1))(v)\operatorname{vjp}(f,\mathbf{x})(\mathbf{v}) = (\operatorname{vjp} (f_1, \mathbf{x}) \circ \cdots \circ \operatorname{vjp}(f_i,\mathbf{x}_{i-1})) (\mathbf{v})

The value of this representation of VJPs lies in the fact that it enables reverse mode automatic differentiation.

Automatic differentiation

Automatic differentiation (autodiff) follows from the fact that we can obtain the VJP of a composite function algorithmically by evaluating the VJPs of its constituent functions. Since even the most complicated of functions are made up of a composition of elementary functions, and the VJP of elementary functions is trivial, we can build up VJPs for "complicated" (i.e. lots of variables, deep composition, etc.) functions step-by-step. We require numeric values for v\mathbf{v} and x\mathbf{x} which sets this method apart from symbol differentiation.

Reverse mode autodiff refers to the fact that two passes -- forward and backward -- are required to calculate VJPs. Calculating the VJP of fkf_k requires us to have calculated the VJP of fk+1f_{k+1} as well as xk1\mathbf{x}_{k-1}. The forward pass as we evaluate ff from f1f_1 to fif_i gives us the intermediate value xk1\mathbf{x}_{k-1}. The backwards pass as we build up the VJPs from fif_i to f1f_1 calculates the VJP of fk+1f_{k+1}.

Wait, are VJPs actually doing differentiation though? Sure, we can calculate the product of a vector and Jacobian, but differentiation would imply something like a gradient or the entire Jacobian. Turns out VJPs allow us to calculate both. Using f:RnRf : \R^n \to \R and v=11,1\mathbf{v} = \mathbf{1}_{1,1}, we obtain the VJP:

vjp(f,x)(11,1)=Jf(x)11,1=[f1x1(x1)f1xn(xn)]=f(x)\operatorname{vjp}(f, \mathbf{x})(\mathbf{1}_{1,1}) = {\mathbf{J}_f}(\mathbf{x})^\top {\mathbf{1}_{1,1}} = \begin{bmatrix} \dfrac{\partial{f_1}}{\partial{x_1}} (x_1) \\ \vdots \\ \dfrac{\partial{f_1}}{\partial{x_n}} (x_n) \end{bmatrix} = \nabla f({\mathbf{x}})

So the VJP of a scalar function can be used to calculate its gradient at a point. Similarly, we can calculate the entire Jacobian of a function row-by-row. Choosing v\mathbf{v} as one-hot encoded vectors results in the VJP being the row of the Jacobian encoded by v\mathbf{v}.

Jf(x)=[vjp(f,x)(δi=0)vjp(f,x)(δi=m)]\mathbf{J}_f(\mathbf{x}) = \begin{bmatrix} \operatorname{vjp}(f, \mathbf{x})(\delta_{i=0})^\top \\ \vdots \\ \operatorname{vjp}(f, \mathbf{x})(\delta_{i=m})^\top \end{bmatrix}

It is worth noting that building the entire Jacobian with VJPs requires mm passes for f:RnRmf: \R^n \to \R^m with an m×nm \times n Jacobian. We can assume that this method is more efficent the smaller mm is compared to nn. In a machine learning context, m=1nm=1 \ll n is common as the outputs are scalar loss values and the inputs are model weights. This makes VJPs suitable for machine learning backpropagation.

Vector-Jacobian products can be built up from VJPs and produce gradients and Jacobians. This enables reverse mode automatic differentiation, which is efficient for machine learning. Everything falls into place.