# maps.py
# This file defines transformation maps used for importance sampling in Monte Carlo integration.
# It includes the Configuration class for storing samples and various Map implementations.
import numpy as np
import torch
from torch import nn
from MCintegration.base import Uniform
from MCintegration.utils import get_device
import sys
TINY = 10 ** (sys.float_info.min_10_exp + 50) # Small but safe non-zero value
[docs]
class Configuration:
"""
Configuration class to store samples and integration results.
This class holds batches of points in both the original and transformed space,
along with function values and weights.
"""
def __init__(self, batch_size, dim, f_dim, device=None, dtype=torch.float32):
"""
Initialize a Configuration object.
Args:
batch_size (int): Number of samples in a batch
dim (int): Dimensionality of the integration space
f_dim (int): Dimensionality of the function output
device (str or torch.device): Device to use for computation
dtype (torch.dtype): Data type for computations
"""
if device is None:
self.device = get_device()
else:
self.device = device
self.dim = dim
self.f_dim = f_dim
self.batch_size = batch_size
# Initialize tensors for storing samples and results
self.u = torch.empty(
(batch_size, dim), dtype=dtype, device=self.device)
self.x = torch.empty(
(batch_size, dim), dtype=dtype, device=self.device)
self.fx = torch.empty((batch_size, f_dim),
dtype=dtype, device=self.device)
self.weight = torch.empty(
(batch_size,), dtype=dtype, device=self.device)
self.detJ = torch.empty((batch_size,), dtype=dtype, device=self.device)
[docs]
class Map(nn.Module):
"""
Base class for all transformation maps.
Maps transform points from the unit hypercube to the target integration domain
with potentially better sampling distribution.
"""
def __init__(self, device=None, dtype=torch.float32):
"""
Initialize a Map object.
Args:
device (str or torch.device): Device to use for computation
dtype (torch.dtype): Data type for computations
"""
super().__init__()
if device is None:
self.device = get_device()
else:
self.device = device
self.dtype = dtype
[docs]
def forward(self, u):
"""
Transform points from the unit hypercube to the target domain.
Args:
u (torch.Tensor): Points in the unit hypercube
Returns:
tuple: (transformed points, log_det_jacobian)
Raises:
NotImplementedError: This is an abstract method
"""
raise NotImplementedError("Subclasses must implement this method")
[docs]
def forward_with_detJ(self, u):
"""
Transform points with Jacobian determinant (not log).
Args:
u (torch.Tensor): Points in the unit hypercube
Returns:
tuple: (transformed points, det_jacobian)
"""
u, detJ = self.forward(u)
detJ.exp_() # Convert log_det to det
return u, detJ
[docs]
def inverse(self, x):
"""
Transform points from the target domain back to the unit hypercube.
Args:
x (torch.Tensor): Points in the target domain
Returns:
tuple: (transformed points, log_det_jacobian)
Raises:
NotImplementedError: This is an abstract method
"""
raise NotImplementedError("Subclasses must implement this method")
[docs]
class CompositeMap(Map):
"""
Composite transformation map that applies multiple maps sequentially.
Allows for complex transformations by composing simpler ones.
"""
def __init__(self, maps, device=None, dtype=None):
"""
Initialize a CompositeMap with a list of maps.
Args:
maps (list): List of Map objects to apply in sequence
device (str or torch.device): Device to use for computation
dtype (torch.dtype): Data type for computations
"""
if not maps:
raise ValueError("Maps can not be empty.")
if dtype is None:
dtype = maps[-1].dtype
if device is None:
device = maps[-1].device
elif device != maps[-1].device:
for map in maps:
map.to(device)
super().__init__(device, dtype)
self.maps = maps
[docs]
def forward(self, u):
"""
Apply all maps in sequence to transform points.
Args:
u (torch.Tensor): Points in the unit hypercube
Returns:
tuple: (transformed points, log_det_jacobian)
"""
log_detJ = torch.zeros(len(u), device=u.device, dtype=self.dtype)
for map in self.maps:
u, log_detj = map.forward(u)
log_detJ += log_detj
return u, log_detJ
[docs]
def inverse(self, x):
"""
Apply all maps in reverse sequence to transform points back.
Args:
x (torch.Tensor): Points in the target domain
Returns:
tuple: (transformed points, log_det_jacobian)
"""
log_detJ = torch.zeros(len(x), device=x.device, dtype=self.dtype)
for i in range(len(self.maps) - 1, -1, -1):
x, log_detj = self.maps[i].inverse(x)
log_detJ += log_detj
return x, log_detJ
[docs]
class Vegas(Map):
"""
VEGAS algorithm implementation as a transformation map.
VEGAS adapts to the integrand by adjusting the sampling density
based on the magnitude of the integrand in different regions.
"""
def __init__(self, dim, ninc=1000, device=None, dtype=torch.float32):
"""
Initialize a VEGAS map.
Args:
dim (int): Dimensionality of the integration space
ninc (int): Number of increments for the grid
device (str or torch.device): Device to use for computation
dtype (torch.dtype): Data type for computations
"""
super().__init__(device, dtype)
self.dim = dim
# Ensure ninc is a tensor of appropriate shape and type
if isinstance(ninc, int):
self.ninc = torch.full(
(self.dim,), ninc, dtype=torch.int32, device=self.device
)
elif isinstance(ninc, (list, np.ndarray)):
self.ninc = torch.tensor(
ninc, dtype=torch.int32, device=self.device)
elif isinstance(ninc, torch.Tensor):
self.ninc = ninc.to(dtype=torch.int32, device=self.device)
else:
raise ValueError(
"'ninc' must be an int, list, numpy array, or torch tensor."
)
# Ensure ninc has the correct shape
if self.ninc.shape != (self.dim,):
raise ValueError(
f"'ninc' must be a scalar or a 1D array of length {self.dim}."
)
# Preallocate tensors to minimize memory allocations
self.max_ninc = self.ninc.max().item()
# Preallocate temporary tensors for adapt
self.sum_f = torch.zeros(
self.dim, self.max_ninc, dtype=self.dtype, device=self.device
)
self.n_f = torch.zeros(
self.dim, self.max_ninc, dtype=self.dtype, device=self.device
)
self.avg_f = torch.ones(
(self.dim, self.max_ninc), dtype=self.dtype, device=self.device
)
self.tmp_f = torch.zeros(
(self.dim, self.max_ninc), dtype=self.dtype, device=self.device
)
self.make_uniform()
[docs]
def adaptive_training(
self,
batch_size,
f,
f_dim=1,
epoch=10,
alpha=0.5,
):
"""
Perform adaptive training to adjust the grid based on the training function.
Args:
batch_size (int): Number of samples per batch.
f (callable): Training function that takes x and fx as inputs.
f_dim (int, optional): Dimension of the function f. Defaults to 1.
epoch (int, optional): Number of training epochs. Defaults to 10.
alpha (float, optional): Adaptation rate. Defaults to 0.5.
"""
q0 = Uniform(self.dim, device=self.device, dtype=self.dtype)
sample = Configuration(
batch_size, self.dim, f_dim, device=self.device, dtype=self.dtype
)
for _ in range(epoch):
sample.u, log_detJ0 = q0.sample(batch_size)
sample.x[:], log_detJ = self.forward(sample.u)
sample.weight = f(sample.x, sample.fx)
sample.detJ = torch.exp(log_detJ0 + log_detJ)
self.add_training_data(sample)
self.adapt(alpha)
[docs]
@torch.no_grad()
def add_training_data(self, sample):
"""Add training data ``f`` for ``u``-space points ``u``.
Accumulates training data for later use by ``self.adapt()``.
Grid increments will be made smaller in regions where
``f`` is larger than average, and larger where ``f``
is smaller than average. The grid is unchanged (converged?)
when ``f`` is constant across the grid.
Args:
u (tensor): ``u`` values corresponding to the training data.
``u`` is a contiguous 2-d tensor, where ``u[j, d]``
is for points along direction ``d``.
f (tensor): Training function values. ``f[j]`` corresponds to
point ``u[j, d]`` in ``u``-space.
"""
fval = (sample.detJ * sample.weight) ** 2
iu = torch.floor(sample.u * self.ninc).long()
for d in range(self.dim):
indices = iu[:, d]
self.sum_f[d].scatter_add_(0, indices, fval.abs())
self.n_f[d].scatter_add_(0, indices, torch.ones_like(fval))
[docs]
@torch.no_grad()
def adapt(self, alpha=0.5):
"""
Adapt the grid based on accumulated training data.
Shrinks grid increments in regions where the accumulated f is large,
and grows them where f is small. The adaptation speed is controlled by alpha.
Args:
alpha (float, optional): Determines the speed with which the grid
adapts to training data. Large (positive) values imply
rapid evolution; small values (much less than one) imply
slow evolution. Typical values are of order one. Choosing
``alpha<0`` causes adaptation to the unmodified training
data (usually not a good idea).
"""
# Aggregate training data across distributed processes if applicable
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
self.sum_f, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(
self.n_f, op=torch.distributed.ReduceOp.SUM)
# Initialize a new grid tensor
new_grid = torch.empty(
(self.dim, self.max_ninc + 1), dtype=self.dtype, device=self.device
)
if alpha > 0:
tmp_f = torch.empty(
self.max_ninc, dtype=self.dtype, device=self.device)
# avg_f = torch.ones(self.inc.shape[1], dtype=self.dtype, device=self.device)
for d in range(self.dim):
ninc = self.ninc[d].item()
if alpha != 0:
# Compute average f for current dimension where n_f > 0
mask = self.n_f[d, :ninc] > 0 # Shape: (ninc,)
avg_f = torch.where(
mask,
self.sum_f[d, :ninc] / self.n_f[d, :ninc],
torch.zeros_like(self.sum_f[d, :ninc]),
) # Shape: (ninc,)
if alpha > 0:
# Smooth avg_f
tmp_f[0] = (7.0 * avg_f[0] + avg_f[1]
).abs() / 8.0 # Shape: ()
tmp_f[ninc - 1] = (
7.0 * avg_f[ninc - 1] + avg_f[ninc - 2]
).abs() / 8.0 # Shape: ()
tmp_f[1: ninc - 1] = (
6.0 * avg_f[1: ninc - 1] +
avg_f[: ninc - 2] + avg_f[2:ninc]
).abs() / 8.0
# Normalize tmp_f to ensure the sum is 1
sum_f = torch.sum(tmp_f[:ninc]).clamp_min_(TINY) # Scalar
avg_f = tmp_f[:ninc] / sum_f + TINY # Shape: (ninc,)
# Apply non-linear transformation controlled by alpha
avg_f = (-(1 - avg_f) / torch.log(avg_f)).pow_(
alpha
) # Shape: (ninc,)
# Compute the target accumulated f per increment
f_ninc = avg_f.sum() / ninc # Scalar
new_grid[d, 0] = self.grid[d, 0]
new_grid[d, ninc] = self.grid[d, ninc]
target_cumulative_weights = (
torch.arange(1, ninc, device=self.device) * f_ninc
) # Calculate the target cumulative weights for each new grid point
cumulative_avg_f = torch.cat(
(
torch.tensor([0.0], device=self.device),
torch.cumsum(avg_f, dim=0),
)
) # Calculate the cumulative sum of avg_f
interval_indices = (
torch.searchsorted(
cumulative_avg_f, target_cumulative_weights, right=True
)
- 1
) # Find the intervals in the original grid where the target weights fall
# Extract the necessary values using the interval indices
grid_left = self.grid[d, interval_indices]
inc_relevant = self.inc[d, interval_indices]
avg_f_relevant = avg_f[interval_indices]
cumulative_avg_f_relevant = cumulative_avg_f[interval_indices]
# Calculate the fractional position within each interval
fractional_positions = (
target_cumulative_weights - cumulative_avg_f_relevant
) / avg_f_relevant
# Calculate the new grid points using vectorized operations
new_grid[d, 1:ninc] = grid_left + \
fractional_positions * inc_relevant
else:
# If alpha == 0 or no training data, retain the existing grid
new_grid[d, :] = self.grid[d, :]
# Assign the newly computed grid
self.grid = new_grid
# Update increments based on the new grid
# Compute the difference between consecutive grid points
self.inc.zero_() # Reset increments to zero
for d in range(self.dim):
self.inc[d, : self.ninc[d]] = (
self.grid[d, 1: self.ninc[d] + 1] -
self.grid[d, : self.ninc[d]]
)
# Clear accumulated training data for the next adaptation cycle
self.clear()
[docs]
@torch.no_grad()
def clear(self):
"Clear information accumulated by :meth:`AdaptiveMap.add_training_data`."
self.sum_f.zero_()
self.n_f.zero_()
[docs]
@torch.no_grad()
def forward(self, u):
u_ninc = u * self.ninc
iu = torch.floor(u_ninc).long()
du_ninc = u_ninc - iu
batch_size = u.size(0)
# Clamp iu to [0, ninc-1] to handle out-of-bounds indices
min_tensor = torch.zeros(
(1, self.dim), dtype=iu.dtype, device=self.device)
# Shape: (1, dim)
max_tensor = (self.ninc - 1).unsqueeze(0).to(iu.dtype)
iu_clamped = torch.clamp(iu, min=min_tensor, max=max_tensor)
grid_expanded = self.grid.unsqueeze(0).expand(batch_size, -1, -1)
inc_expanded = self.inc.unsqueeze(0).expand(batch_size, -1, -1)
grid_gather = torch.gather(grid_expanded, 2, iu_clamped.unsqueeze(2)).squeeze(
2
) # Shape: (batch_size, dim)
inc_gather = torch.gather(
inc_expanded, 2, iu_clamped.unsqueeze(2)).squeeze(2)
x = grid_gather + inc_gather * du_ninc
log_detJ = (inc_gather * self.ninc).log_().sum(dim=1)
# Handle out-of-bounds by setting x to grid boundary and adjusting detJ
out_of_bounds = iu >= self.ninc
if out_of_bounds.any():
# Create indices for out-of-bounds
# For each sample and dimension, set x to grid[d, ninc[d]]
# and log_detJ += log(inc[d, ninc[d]-1] * ninc[d])
boundary_grid = (
self.grid[torch.arange(
self.dim, device=self.device), self.ninc]
.unsqueeze(0)
.expand(batch_size, -1)
)
# x = torch.where(out_of_bounds, boundary_grid, x)
x[out_of_bounds] = boundary_grid[out_of_bounds]
boundary_inc = (
self.inc[torch.arange(
self.dim, device=self.device), self.ninc - 1]
.unsqueeze(0)
.expand(batch_size, -1)
)
adj_log_detJ = ((boundary_inc * self.ninc).log_() * out_of_bounds).sum(
dim=1
)
log_detJ += adj_log_detJ
return x, log_detJ
[docs]
@torch.no_grad()
def inverse(self, x):
"""
Inverse map from x-space to u-space.
Args:
x (torch.Tensor): Tensor of shape (batch_size, dim) representing points in x-space.
Returns:
u (torch.Tensor): Tensor of shape (batch_size, dim) representing points in u-space.
log_detJ (torch.Tensor): Tensor of shape (batch_size,) representing the log determinant of the Jacobian.
"""
x.to(self.device)
batch_size, dim = x.shape
# Initialize output tensors
u = torch.empty_like(x)
log_detJ = torch.zeros(
batch_size, device=self.device, dtype=self.dtype)
# Loop over each dimension to perform inverse mapping
for d in range(dim):
# Extract the grid and increment for dimension d
grid_d = self.grid[d] # Shape: (max_ninc + 1,)
inc_d = self.inc[d] # Shape: (max_ninc,)
# ninc_d = self.ninc[d].float() # Scalar tensor
ninc_d = self.ninc[d] # Scalar tensor
# Perform searchsorted to find indices where x should be inserted to maintain order
# torch.searchsorted returns indices in [0, max_ninc +1]
iu = (
torch.searchsorted(
grid_d, x[:, d].contiguous(), right=True) - 1
) # Shape: (batch_size,)
# Clamp indices to [0, ninc_d - 1] to ensure they are within valid range
# Shape: (batch_size,)
iu_clamped = torch.clamp(iu, min=0, max=ninc_d - 1)
# Gather grid and increment values based on iu_clamped
# grid_gather and inc_gather have shape (batch_size,)
grid_gather = grid_d[iu_clamped] # Shape: (batch_size,)
inc_gather = inc_d[iu_clamped] # Shape: (batch_size,)
# Compute du: fractional part within the increment
du = (x[:, d] - grid_gather) / \
(inc_gather + TINY) # Shape: (batch_size,)
# Compute u for dimension d
u[:, d] = (du + iu_clamped) / ninc_d # Shape: (batch_size,)
# Compute log determinant contribution for dimension d
# Shape: (batch_size,)
log_detJ += (inc_gather * ninc_d + TINY).log_()
# Handle out-of-bounds cases
# Lower bound: x <= grid[d, 0]
lower_mask = x[:, d] <= grid_d[0] # Shape: (batch_size,)
if lower_mask.any():
u[:, d].masked_fill_(lower_mask, 0.0)
log_detJ += (inc_d[0] * ninc_d + TINY).log_()
# Upper bound: x >= grid[d, ninc_d]
upper_mask = x[:, d] >= grid_d[ninc_d] # Shape: (batch_size,)
if upper_mask.any():
u[:, d].masked_fill_(upper_mask, 1.0)
log_detJ += (inc_d[ninc_d - 1] * ninc_d + TINY).log_()
return u, log_detJ
# class NormalizingFlow(Map):
# def __init__(self, dim, flow_model, device="cpu"):
# super().__init__(dim, device)
# self.flow_model = flow_model.to(device)
# def forward(self, u):
# return self.flow_model.forward(u)
# def inverse(self, x):
# return self.flow_model.inverse(x)