Hessian Matrix in ML

ML

In this post, approximation methods for Hessian matrix are summarized (See also PRML chapter 5)

Let us consider a function \(f(\mathbf{x})\in\mathbb{R}\), where \(\mathbf{x}\in\mathbb{R}^n\). Then, the Hessian matrix follows this definition:

\[ \mathbf{H} = \nabla^2f(\mathbf{w}), \]

where \(\mathbf{H}_{ij} = \frac{\partial^2 f}{\partial w_1\partial w_2}\).

When is the Hessian matrix used? - Influence function - fast-retraining - pruning (least significant weights) {: .notice–success}

Main Challenge: Computational complexity {: .notice–success}

0. Incremental Inverse Hessian

In many studies, inverse of Hessian matrix is used more than the original Hessian vector. (e.g., Quasi-Newton Method).

By adopting Barlett’s second identity, we have

\[\mathbb{E}\left[\nabla_\theta^2\ell(x;\theta)\right] + \mathbb{E}\left[\nabla \ell(x;\theta)\nabla\ell(x;\theta)^T\right]=0.\]

Under assumption that the trained model well estimates the target distribution, we have an approximation of Hessian as follows:

\[ H \approx \mathbb{E}_{x\sim q_x}\left[\mathbb{E}_{y\sim p_{y|x}} \nabla_\theta\log p(y|x;\theta)\nabla_\theta\log p(y|x;\theta)^T\right]. \] Then, we can represent the Hessian as a summation of outer products.

\[ \mathbf{H}_n = \sum_{i=1}^n \mathbf{v}_n\mathbf{v}_n^T \]

From the above equation, we can progressively obtain the inversion Hessian matrix: \(\mathbf{H}_{n} = \mathbf{H}_{n-1} + \mathbf{v}_n\mathbf{v}_n^T\). Now, by applying matrix inversion in the both side of the equation, we have the following equality:

Sherman–Morrison-Woodbury formula: \[ (A+BD^{-1}C)^{-1} = A^{-1} - A^{-1}B(D+CA^{-1}B)^{-1}CA^{-1} \] {: .notice–success}

\[ (\mathbf{H}_{n-1} + \mathbf{v}_n\mathbf{v}_n^T)^{-1} = \mathbf{H}_{n-1}^{-1} - \frac{\mathbf{H}_{n-1}^{-1}\mathbf{v}_n\mathbf{v}_n^T\mathbf{H}_{n-1}^{-1}}{1+\mathbf{v}_n^T\mathbf{H}_{n-1}^{-1}\mathbf{v}_n} \]

Unfortunately, this can work if once inversion hessian is obtained.. (full rank of \(\mathbf{H}_{n-1}\).)

1. Hessian vector product

In this post, we focus on obtaining a Hessian-vector product without directly obtaining the Hessian matrix. The motivation is simple: Hessian vector is computationally expensive.

Let us define a function \(f(\cdot):\mathbb{R}^n\rightarrow \mathbb{R}^1\) and a variable \(\mathbf{x}\) and vector \(\mathbf{v}\). The hessian vector product is formulated as

\[ \mathbf{H}\mathbf{v}=J(\nabla f(\mathbf{x}))^T \cdot \mathbf{v} \]

Let’s start with an approximation of \(f(\mathbf{x} + \Delta\mathbf{x})\): \[ f(\mathbf{x}+\Delta \mathbf{x}) \approx f(\mathbf{x}) + \nabla f(\mathbf{x})^T\mathbf{\Delta x} + \mathbf{\Delta x}^T\mathbf{H}\mathbf{\Delta x}+ \cdots . \] Then, we can replace \(\Delta \mathbf{x} = r \mathbf{v}\) and \(r\) is small, thereby deriving to \[ f(\mathbf{x}+r\mathbf{v}) \approx f(\mathbf{x}) + r \nabla f(\mathbf{x})^T \mathbf{v} + O(r^2) \] To approximate the Hessian-vector product we apply \(\nabla\) on both side of the above equation. \[ \mathbf{H}\mathbf{v} = \frac{\nabla f(\mathbf{x}+r\mathbf{v}) - \nabla f(\mathbf{x})}{r} + O(r). \] To make the \(O(r)\approx 0\) , we add \(\lim_{r\rightarrow 0}\)
\[ \mathbf{Hv} = \nabla (\nabla f(\mathbf{x})^T \cdot \mathbf{v}). \]

2. Inverse-Hessian-vector product

By using a property \(\frac{1}{x} = \frac{1}{1-(1-x)} \approx 1 + (1-x) + (1-x)^2 + \cdots\). We can obtain inverse-hessian vector product by an iterative way \[ \mathrm{IHv}_n \leftarrow \mathbf{v} + (\mathbf{I}-\mathbf{H})\cdot \mathrm{IHV}_{n-1}, \] where \(\mathrm{IHV}_{0}=\mathbf{v}\).

3. Examples

3.1. Hessian-vector product

import torch
from torch import nn
# Define parameters

param_dim = 3

A = torch.randn(param_dim, param_dim)
x = nn.Parameter(torch.randn(param_dim))
v = torch.randn(param_dim)

# Compute a loss function: f = (1/2) x^T A x  
f = 1/2 * x.T @ A @ x

# 1. Direct computation: From the above equation, we have the Hessian matrix as (1/2) (A+A^T).  
print(f'Real hessian-vector product is {(1/2)*(A+A.T)@v}')

# 2.Hessian-Vector Product Method:  
x_grad = torch.autograd.grad(f, x, create_graph=True)[0]
Hv = torch.autograd.grad(x_grad.T @ v, x)[0]
print(f'Estimated hessian-vector product is {Hv}')
#output 
Real hessian-vector product is tensor([-0.1761, 0.4509, 0.1387]) 
Estimated hessian-vector product is tensor([-0.1761, 0.4509, 0.1387])

3.2. Inverse-Hessian-vector product

# Define parameters
param_dim = 3
A = torch.randn(param_dim, param_dim)
A = A @ A.T / 10
x = nn.Parameter(torch.randn(param_dim))
v = torch.randn(param_dim)

# Compute a loss function: f = (1/2) x^T A x
f = 1/2 * x.T @ A @ x

# 1. Direct Inverse-Hessian-vector product

HV_org = 1/2 * (A+A.T)
IHV = torch.linalg.inv(HV_org)
print(f'Real inverse-hessian-vector product is {IHV@v}')

# 2. Estimated Inverse-Hessian-vector product

"""
IHV_n = v + (I-H)IHV_{n-1}
IHV_0 = v
"""

iteration=500
IHV_n = v
x_grad = torch.autograd.grad(f, x, create_graph=True)[0]

for i in range(iteration):
    HV_org_IHV_n = torch.autograd.grad(x_grad.T @ IHV_n, x, retain_graph=True)[0]
    IHV_n = v + IHV_n - HV_org_IHV_n

print(f'Estimated inverse-hessian-vector product is {IHV_n}')
#output 
Real inverse-hessian-vector product is tensor([-2.6567, -2.1620, 9.9240]) 
Estimated inverse-hessian-vector product is tensor([-2.6567, -2.1620, 9.9240])

Source: 1. https://velog.io/@veglog/Hessian-vector-product 2. https://github.com/howonlee/bobdobbshess/blob/master/writeup.pdf