calpit.nn.utils
Classes
A custom dataset class to randomly select a data point. |
|
Early stops the training if validation loss doesn't improve after a given patience. |
Functions
|
Count the number of trainable parameters in a model. |
|
Calculates conditional density estimation loss on holdout data. |
Module Contents
- class RandomDataset(x_data, y_data, oversample=1)[source]
Bases:
torch.utils.data.DatasetA custom dataset class to randomly select a data point. The data point is prepended with a random value between 0 and 1 from a Uniform distribution (coverage parameter). The target value is 0 if Y value is less than or equal to the coverage parameter and 1 otherwise. The data set can be oversampled by a given factor.
- Parameters:
X (list or array-like) – The input features.
Y (list or array-like) – The target values.
oversample (float, optional) – The oversampling factor. Defaults to 1.
- Returns:
A tuple containing the input feature and target value.
- Return type:
tuple
- count_parameters(model: torch.nn.Module) int[source]
Count the number of trainable parameters in a model.
- class EarlyStopping(patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print)[source]
Early stops the training if validation loss doesn’t improve after a given patience. :param patience: How long to wait after last time validation loss improved.
Default: 7
- Parameters:
verbose (bool) – If True, prints a message for each validation loss improvement. Default: False
delta (float) – Minimum change in the monitored quantity to qualify as an improvement. Default: 0
path (str) – Path for the checkpoint to be saved to. Default: ‘checkpoint.pt’
trace_func (function) – trace print function. Default: print
- cde_loss(cde_estimates: torch.Tensor, y_grid: torch.Tensor, y_test: torch.Tensor) tuple[source]
Calculates conditional density estimation loss on holdout data. This is a PyTorch version of the original function.
- Parameters:
cde_estimates (torch.Tensor) – An array where each row is a density estimate on y_grid
y_grid (torch.Tensor) – An array of the grid points at which cde_estimates is evaluated.
y_test (torch.Tensor) – An array of the true y values corresponding to the rows of cde_estimates
- Returns:
A tuple containing the loss and the standard error of the loss.
- Return type:
tuple
- Raises:
ValueError – If the dimensions of the input tensors are not compatible.