import torch
import numpy as np
from torch.utils.data import Dataset
from prettytable import PrettyTable
[docs]
class RandomDataset(Dataset):
"""
A custom dataset class to randomly select a data point.
The data point is prepended with a random value between
0 and 1 from a Uniform distribution (coverage parameter).
The target value is 0 if Y value is less than or equal to
the coverage parameter and 1 otherwise.
The data set can be oversampled by a given factor.
Args:
X (list or array-like): The input features.
Y (list or array-like): The target values.
oversample (float, optional): The oversampling factor. Defaults to 1.
Returns:
tuple: A tuple containing the input feature and target value.
"""
def __init__(self, x_data, y_data, oversample=1):
[docs]
self.len_x = len(x_data)
[docs]
self.oversample = oversample
[docs]
def __len__(self):
return int(len(self.x_data) * self.oversample)
[docs]
def __getitem__(self, idx):
alpha = torch.rand(1)
feature = torch.hstack((alpha, torch.Tensor(self.x_data[idx % self.len_x])))
target = (self.y_data[idx % self.len_x] <= alpha).float()
return feature, target
[docs]
def count_parameters(model: torch.nn.Module) -> int:
"""
Count the number of trainable parameters in a model.
"""
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad:
continue
param = parameter.numel()
table.add_row([name, param])
total_params += param
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
[docs]
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience.
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
path (str): Path for the checkpoint to be saved to.
Default: 'checkpoint.pt'
trace_func (function): trace print function.
Default: print
"""
def __init__(self, patience=7, verbose=False, delta=0, path="checkpoint.pt", trace_func=print):
[docs]
self.patience = patience
[docs]
self.early_stop = False
[docs]
self.val_loss_min = np.inf
[docs]
self.trace_func = trace_func
[docs]
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
self.trace_func(f"EarlyStopping counter: {self.counter} out of {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
[docs]
def save_checkpoint(self, val_loss, model):
"""Saves model when validation loss decrease."""
if self.verbose:
self.trace_func(
f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
)
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
[docs]
def cde_loss(cde_estimates: torch.Tensor, y_grid: torch.Tensor, y_test: torch.Tensor) -> tuple:
"""
Calculates conditional density estimation loss on holdout data.
This is a PyTorch version of the original function.
Args:
cde_estimates (torch.Tensor): An array where each row is a density estimate on y_grid
y_grid (torch.Tensor): An array of the grid points at which cde_estimates is evaluated.
y_test (torch.Tensor): An array of the true y values corresponding to the rows of cde_estimates
Returns:
tuple: A tuple containing the loss and the standard error of the loss.
Raises:
ValueError: If the dimensions of the input tensors are not compatible.
"""
if len(z_test.shape) == 1:
z_test = z_test.reshape(-1, 1)
if len(z_grid.shape) == 1:
z_grid = z_grid.reshape(-1, 1)
n_obs, n_grid = cde_estimates.shape
n_samples, feats_samples = z_test.shape
n_grid_points, feats_grid = z_grid.shape
if n_obs != n_samples:
raise ValueError(
f"Number of samples in CDEs should be the same as in z_test.Currently {n_obs} and {n_samples}."
)
if n_grid != n_grid_points:
raise ValueError(
f"Number of grid points in CDEs should be the same as in z_grid. Currently {n_grid} and {n_grid_points}."
)
if feats_samples != feats_grid:
raise ValueError(
f"Dimensionality of test points and grid points need to coincise. Currently {feats_samples} and {feats_grid}."
)
integrals = torch.trapz(cde_estimates**2, torch.squeeze(y_grid), axis=1)
nn_ids = torch.argmin(torch.abs(y_grid - y_test.T), axis=0)
likeli = cde_estimates[(tuple(torch.arange(n_samples)), tuple(nn_ids))]
losses = integrals - 2 * likeli
loss = torch.mean(losses)
se_error = torch.std(losses, axis=0) / (n_obs**0.5)
return loss, se_error