In this post, approximation methods for Hessian matrix are summarized (See also PRML chapter 5)
Let us consider a function \(f(\mathbf{x})\in\mathbb{R}\), where \(\mathbf{x}\in\mathbb{R}^n\). Then, the Hessian matrix follows this definition:
\[ \mathbf{H} = \nabla^2f(\mathbf{w}), \]
where \(\mathbf{H}_{ij} = \frac{\partial^2 f}{\partial w_1\partial w_2}\).
When is the Hessian matrix used? - Influence function - fast-retraining - pruning (least significant weights) {: .notice–success}
Main Challenge: Computational complexity {: .notice–success}
0. Incremental Inverse Hessian
In many studies, inverse of Hessian matrix is used more than the original Hessian vector. (e.g., Quasi-Newton Method).
By adopting Barlett’s second identity, we have
\[\mathbb{E}\left[\nabla_\theta^2\ell(x;\theta)\right] + \mathbb{E}\left[\nabla \ell(x;\theta)\nabla\ell(x;\theta)^T\right]=0.\]
Under assumption that the trained model well estimates the target distribution, we have an approximation of Hessian as follows:
\[ H \approx \mathbb{E}_{x\sim q_x}\left[\mathbb{E}_{y\sim p_{y|x}} \nabla_\theta\log p(y|x;\theta)\nabla_\theta\log p(y|x;\theta)^T\right]. \] Then, we can represent the Hessian as a summation of outer products.
\[ \mathbf{H}_n = \sum_{i=1}^n \mathbf{v}_n\mathbf{v}_n^T \]
From the above equation, we can progressively obtain the inversion Hessian matrix: \(\mathbf{H}_{n} = \mathbf{H}_{n-1} + \mathbf{v}_n\mathbf{v}_n^T\). Now, by applying matrix inversion in the both side of the equation, we have the following equality:
Sherman–Morrison-Woodbury formula: \[ (A+BD^{-1}C)^{-1} = A^{-1} - A^{-1}B(D+CA^{-1}B)^{-1}CA^{-1} \] {: .notice–success}
\[ (\mathbf{H}_{n-1} + \mathbf{v}_n\mathbf{v}_n^T)^{-1} = \mathbf{H}_{n-1}^{-1} - \frac{\mathbf{H}_{n-1}^{-1}\mathbf{v}_n\mathbf{v}_n^T\mathbf{H}_{n-1}^{-1}}{1+\mathbf{v}_n^T\mathbf{H}_{n-1}^{-1}\mathbf{v}_n} \]
Unfortunately, this can work if once inversion hessian is obtained.. (full rank of \(\mathbf{H}_{n-1}\).)
1. Hessian vector product
In this post, we focus on obtaining a Hessian-vector product without directly obtaining the Hessian matrix. The motivation is simple: Hessian vector is computationally expensive.
Let us define a function \(f(\cdot):\mathbb{R}^n\rightarrow \mathbb{R}^1\) and a variable \(\mathbf{x}\) and vector \(\mathbf{v}\). The hessian vector product is formulated as
\[ \mathbf{H}\mathbf{v}=J(\nabla f(\mathbf{x}))^T \cdot \mathbf{v} \]
Let’s start with an approximation of \(f(\mathbf{x} + \Delta\mathbf{x})\): \[
f(\mathbf{x}+\Delta \mathbf{x}) \approx f(\mathbf{x}) + \nabla f(\mathbf{x})^T\mathbf{\Delta x} + \mathbf{\Delta x}^T\mathbf{H}\mathbf{\Delta x}+ \cdots .
\] Then, we can replace \(\Delta \mathbf{x} = r \mathbf{v}\) and \(r\) is small, thereby deriving to \[
f(\mathbf{x}+r\mathbf{v}) \approx f(\mathbf{x}) + r \nabla f(\mathbf{x})^T \mathbf{v} + O(r^2)
\] To approximate the Hessian-vector product we apply \(\nabla\) on both side of the above equation. \[
\mathbf{H}\mathbf{v} = \frac{\nabla f(\mathbf{x}+r\mathbf{v}) - \nabla f(\mathbf{x})}{r} + O(r).
\] To make the \(O(r)\approx 0\) , we add \(\lim_{r\rightarrow 0}\)
\[
\mathbf{Hv} = \nabla (\nabla f(\mathbf{x})^T \cdot \mathbf{v}).
\]
2. Inverse-Hessian-vector product
By using a property \(\frac{1}{x} = \frac{1}{1-(1-x)} \approx 1 + (1-x) + (1-x)^2 + \cdots\). We can obtain inverse-hessian vector product by an iterative way \[ \mathrm{IHv}_n \leftarrow \mathbf{v} + (\mathbf{I}-\mathbf{H})\cdot \mathrm{IHV}_{n-1}, \] where \(\mathrm{IHV}_{0}=\mathbf{v}\).
3. Examples
3.1. Hessian-vector product
import torch
from torch import nn
# Define parameters
= 3
param_dim
= torch.randn(param_dim, param_dim)
A = nn.Parameter(torch.randn(param_dim))
x = torch.randn(param_dim)
v
# Compute a loss function: f = (1/2) x^T A x
= 1/2 * x.T @ A @ x
f
# 1. Direct computation: From the above equation, we have the Hessian matrix as (1/2) (A+A^T).
print(f'Real hessian-vector product is {(1/2)*(A+A.T)@v}')
# 2.Hessian-Vector Product Method:
= torch.autograd.grad(f, x, create_graph=True)[0]
x_grad = torch.autograd.grad(x_grad.T @ v, x)[0]
Hv print(f'Estimated hessian-vector product is {Hv}')
#output
Real hessian-vector product is tensor([-0.1761, 0.4509, 0.1387])
Estimated hessian-vector product is tensor([-0.1761, 0.4509, 0.1387])
3.2. Inverse-Hessian-vector product
# Define parameters
= 3
param_dim = torch.randn(param_dim, param_dim)
A = A @ A.T / 10
A = nn.Parameter(torch.randn(param_dim))
x = torch.randn(param_dim)
v
# Compute a loss function: f = (1/2) x^T A x
= 1/2 * x.T @ A @ x
f
# 1. Direct Inverse-Hessian-vector product
= 1/2 * (A+A.T)
HV_org = torch.linalg.inv(HV_org)
IHV print(f'Real inverse-hessian-vector product is {IHV@v}')
# 2. Estimated Inverse-Hessian-vector product
"""
IHV_n = v + (I-H)IHV_{n-1}
IHV_0 = v
"""
=500
iteration= v
IHV_n = torch.autograd.grad(f, x, create_graph=True)[0]
x_grad
for i in range(iteration):
= torch.autograd.grad(x_grad.T @ IHV_n, x, retain_graph=True)[0]
HV_org_IHV_n = v + IHV_n - HV_org_IHV_n
IHV_n
print(f'Estimated inverse-hessian-vector product is {IHV_n}')
#output
Real inverse-hessian-vector product is tensor([-2.6567, -2.1620, 9.9240])
Estimated inverse-hessian-vector product is tensor([-2.6567, -2.1620, 9.9240])
Source: 1. https://velog.io/@veglog/Hessian-vector-product 2. https://github.com/howonlee/bobdobbshess/blob/master/writeup.pdf