Source code for calpit.nn.ispline_nn

import torch
import torch.nn as nn
from splinebasis import ISplineBasis


[docs] class ISplineLayer(nn.Module): def __init__(self, in_features, num_basis,dropout_p=0): super().__init__()
[docs] self.in_features = in_features
[docs] self.num_basis = num_basis
[docs] self.coefs = nn.Sequential(nn.Linear(in_features, num_basis), nn.Softmax(dim=-1),nn.Dropout(p=dropout_p))
[docs] self.grid = torch.linspace(0, 1, 1000)
[docs] self.basis_vectors = ISplineBasis( order=3, num_basis=num_basis, lower=0, upper=1, grid=self.grid ).basis_vectors
self.basis_vectors = torch.from_numpy(self.basis_vectors) # def init_weights(m): # if isinstance(m, nn.Linear): # torch.nn.init.kaiming_normal_(m.weight) # m.bias.data.fill_(0.01) # self.coefs.apply(init_weights)
[docs] def interp1d(self, x, y, x_new): # 2. Find where in the original data, the values to interpolate # would be inserted. # Note: If x_new[n] == x[m], then m is returned by searchsorted. # y = torch.moveaxis(y,axis,0) # y = y.reshape((y.shape[0],-1)) x_new_indices = torch.searchsorted(x, x_new) # 3. Clip x_new_indices so that they are within the range of # self.x indices and at least 1. Removes mis-interpolation # of x_new[n] = x[0] x_new_indices = x_new_indices.clip(1, len(x) - 1) # 4. Calculate the slope of regions that each x_new value falls in. lo = x_new_indices - 1 hi = x_new_indices x_lo = x[lo] x_hi = x[hi] y_lo = y[lo] y_hi = y[hi] # Note that the following two expressions rely on the specifics of the # broadcasting semantics. slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None] # 5. Calculate the actual value for each entry in x_new. y_new = slope * (x_new - x_lo)[:, None] + y_lo return y_new
[docs] def forward(self, x, alpha): grid = self.grid.to(alpha) basis_vectors = self.basis_vectors.to(alpha) basis = self.interp1d(grid, basis_vectors, alpha) # print(basis.shape) # print(self.coefs(x).shape) # print(self.coefs(x)) weighted_basis = self.coefs(x) * basis # print(weighted_basis.shape) return weighted_basis.sum(axis=-1)
[docs] class IsplineNN(nn.Module): def __init__(self, input_dim, hidden_layers=[512, 512, 512],dropout_p=0.5, num_basis=10): super().__init__()
[docs] self.all_layers = [input_dim + 1]
[docs] self.hidden_layers = hidden_layers
self.all_layers.extend(hidden_layers)
[docs] self.num_basis = num_basis
[docs] self.dropout_p = dropout_p
[docs] self.spline_layer = ISplineLayer(in_features=self.hidden_layers[-1], num_basis=self.num_basis,dropout_p=self.dropout_p)
[docs] self.mlp_layer_list = []
for i in range(len(self.all_layers) - 1): self.mlp_layer_list.append(nn.Linear(self.all_layers[i], self.all_layers[i + 1])) self.mlp_layer_list.append(nn.PReLU()) # self.mlp_layer_list.append(nn.Dropout(p=dropout_p))
[docs] self.mlp_layers = nn.Sequential(*self.mlp_layer_list)
def init_weights(m): if isinstance(m, nn.Linear): torch.nn.init.kaiming_normal_(m.weight) m.bias.data.fill_(0.01) self.mlp_layers.apply(init_weights)
[docs] def forward(self, x): alpha = x[:, 0] res = self.mlp_layers(x) res = self.spline_layer(res, alpha) return res