import torch
import numpy as np
import math
[docs]
def _flatten(sequence):
flat = [p.contiguous().view(-1) for p in sequence]
return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
[docs]
def compute_cc_weights(nb_steps):
lam = np.arange(0, nb_steps + 1, 1).reshape(-1, 1)
lam = np.cos((lam @ lam.T) * math.pi / nb_steps)
lam[:, 0] = 0.5
lam[:, -1] = 0.5 * lam[:, -1]
lam = lam * 2 / nb_steps
W = np.arange(0, nb_steps + 1, 1).reshape(-1, 1)
W[np.arange(1, nb_steps + 1, 2)] = 0
W = 2 / (1 - W**2)
W[0] = 1
W[np.arange(1, nb_steps + 1, 2)] = 0
cc_weights = torch.tensor(lam.T @ W).float()
steps = torch.tensor(np.cos(np.arange(0, nb_steps + 1, 1).reshape(-1, 1) * math.pi / nb_steps)).float()
return cc_weights, steps
[docs]
def integrate(x0, nb_steps, step_sizes, integrand, h, compute_grad=False, x_tot=None):
# Clenshaw-Curtis Quadrature Method
cc_weights, steps = compute_cc_weights(nb_steps)
cc_weights, steps = cc_weights.to(x0), steps.to(x0)
if compute_grad:
g_param = 0.0
g_h = 0.0
else:
z = 0.0
xT = x0 + nb_steps * step_sizes
for i in range(nb_steps + 1):
x = x0 + (xT - x0) * (steps[i] + 1) / 2
if compute_grad:
dg_param, dg_h = computeIntegrand(x, h, integrand, x_tot * (xT - x0) / 2)
g_param += cc_weights[i] * dg_param
g_h += cc_weights[i] * dg_h
else:
dz = integrand(x, h)
z = z + cc_weights[i] * dz
if compute_grad:
return g_param, g_h
return z * (xT - x0) / 2
[docs]
def computeIntegrand(x, h, integrand, x_tot):
with torch.enable_grad():
f = integrand.forward(x, h)
g_param = _flatten(
torch.autograd.grad(f, integrand.parameters(), x_tot, create_graph=True, retain_graph=True)
)
g_h = _flatten(torch.autograd.grad(f, h, x_tot))
return g_param, g_h
[docs]
class NeuralIntegral(torch.autograd.Function):
@staticmethod
[docs]
def forward(ctx, x0, x, integrand, flat_params, h, nb_steps=20):
with torch.no_grad():
x_tot = integrate(x0, nb_steps, (x - x0) / nb_steps, integrand, h, False)
# Save for backward
ctx.integrand = integrand
ctx.nb_steps = nb_steps
ctx.save_for_backward(x0.clone(), x.clone(), h)
return x_tot
@staticmethod
[docs]
def backward(ctx, grad_output):
x0, x, h = ctx.saved_tensors
integrand = ctx.integrand
nb_steps = ctx.nb_steps
integrand_grad, h_grad = integrate(x0, nb_steps, x / nb_steps, integrand, h, True, grad_output)
x_grad = integrand(x, h)
x0_grad = integrand(x0, h)
# Leibniz formula
return -x0_grad * grad_output, x_grad * grad_output, None, integrand_grad, h_grad.view(h.shape), None