Source code for calpit.nn.umnn.MonotonicNN

import torch
import torch.nn as nn
from .NeuralIntegral import NeuralIntegral  # noqa
from .ParallelNeuralIntegral import ParallelNeuralIntegral


[docs] def _flatten(sequence): flat = [p.contiguous().view(-1) for p in sequence] return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
[docs] def clipped_relu(x): return torch.minimum(torch.maximum(torch.Tensor([0]), x), torch.Tensor([1]))
[docs] class IntegrandNN(nn.Module): def __init__(self, in_d, hidden_layers): super(IntegrandNN, self).__init__() self.net = [] hs = [in_d] + hidden_layers + [1] for h0, h1 in zip(hs, hs[1:]): self.net.extend( [ nn.Linear(h0, h1), nn.ReLU(), ] ) self.net.pop() # pop the last ReLU for the output layer self.net.append(nn.ELU()) self.net = nn.Sequential(*self.net)
[docs] def forward(self, x, h): return self.net(torch.cat((x, h), 1)) + 1.0
[docs] class MonotonicNN(nn.Module): def __init__(self, in_d, hidden_layers, nb_steps=50, sigmoid=False): super(MonotonicNN, self).__init__() self.integrand = IntegrandNN(in_d, hidden_layers) self.net = [] hs = [in_d - 1] + hidden_layers + [2] for h0, h1 in zip(hs, hs[1:]): self.net.extend( [ nn.Linear(h0, h1), nn.ReLU(), ] ) self.net.pop() # pop the last ReLU for the output layer # It will output the scaling and offset factors. self.net = nn.Sequential(*self.net) self.nb_steps = nb_steps self.sigmoid = sigmoid """ The forward procedure takes as input x which is the variable for which the integration must be made, h is just other conditionning variables. """
[docs] def forward(self, x_input): x = x_input[:, 0][:, None] h = x_input[:, 1:] x0 = torch.zeros(x.shape).to(x_input) out = self.net(h) offset = out[:, [0]] scaling = torch.exp(out[:, [1]]) if self.sigmoid: return torch.sigmoid( scaling * ParallelNeuralIntegral.apply( x0, x, self.integrand, _flatten(self.integrand.parameters()), h, self.nb_steps ) + offset ) else: return torch.squeeze( scaling * ParallelNeuralIntegral.apply( x0, x, self.integrand, _flatten(self.integrand.parameters()), h, self.nb_steps ) + offset )