multi_loss.py
from itertools import combinations
from typing import List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChaoticFFTLoss(nn.Module):
"""
A specialized loss function for comparing frequency spectra of chaotic signals
generated by neural ODEs.
This loss function is designed to handle the unique characteristics of chaotic systems:
1. High sensitivity to initial conditions
2. Potentially non-stationary behavior
3. Complex frequency structures
Parameters:
-----------
window_size : int, optional
Size of the sliding window for short-time Fourier transform.
Set to None to use the full signal length.
Default: None
window_overlap : float, optional
Overlap between consecutive windows as a fraction (0.0-0.9).
Default: 0.5
spectral_norm : str, optional
Type of normalization to apply to spectra:
'none': no normalization
'l1': L1 normalization (sum to 1)
'l2': L2 normalization (sum of squares to 1)
'max': normalize by maximum value
Default: 'none'
focus_range : tuple or None, optional
Tuple of (min_freq_idx, max_freq_idx) to focus on specific frequency bands.
Default: None (use all frequencies)
magnitude_weight : float, optional
Weight for comparing magnitude spectra.
Default: 1.0
phase_weight : float, optional
Weight for comparing phase spectra.
Default: 0.0
log_magnitude : bool, optional
Whether to use log magnitude spectra.
Default: True
divergence_type : str, optional
Type of divergence measure to use:
'mse': Mean squared error (L2 distance)
'mae': Mean absolute error (L1 distance)
'kl': Kullback-Leibler divergence (for normalized spectra only)
'js': Jensen-Shannon divergence (for normalized spectra only)
'wasserstein': Wasserstein distance (for normalized spectra only)
Default: 'mse'
epsilon : float, optional
Small value to add before taking log.
Default: 1e-8
"""
def __init__(
self,
window_size=None,
window_overlap=0.5,
spectral_norm="none",
focus_range=None,
magnitude_weight=1.0,
phase_weight=0.0,
log_magnitude=True,
divergence_type="mse",
epsilon=1e-8,
):
super(ChaoticFFTLoss, self).__init__()
self.window_size = window_size
self.window_overlap = window_overlap
self.spectral_norm = spectral_norm
self.focus_range = focus_range
self.magnitude_weight = magnitude_weight
self.phase_weight = phase_weight
self.log_magnitude = log_magnitude
self.divergence_type = divergence_type
self.epsilon = epsilon
# Validate inputs
if spectral_norm not in ["none", "l1", "l2", "max"]:
raise ValueError(f"Invalid spectral_norm: {spectral_norm}")
if divergence_type not in ["mse", "mae", "kl", "js", "wasserstein"]:
raise ValueError(f"Invalid divergence_type: {divergence_type}")
if divergence_type in ["kl", "js", "wasserstein"] and spectral_norm == "none":
raise ValueError(
f"{divergence_type} divergence requires normalized spectra"
)
def _normalize_spectrum(self, spectrum):
"""Apply the selected normalization to a spectrum."""
if self.spectral_norm == "none":
return spectrum
elif self.spectral_norm == "l1":
# L1 normalization (sum to 1)
norm = torch.sum(spectrum, dim=-1, keepdim=True) + self.epsilon
return spectrum / norm
elif self.spectral_norm == "l2":
# L2 normalization (sum of squares to 1)
norm = (
torch.sqrt(torch.sum(spectrum**2, dim=-1, keepdim=True)) + self.epsilon
)
return spectrum / norm
elif self.spectral_norm == "max":
# Normalize by maximum value
norm = torch.max(spectrum, dim=-1, keepdim=True)[0] + self.epsilon
return spectrum / norm
def _compute_divergence(self, input_spectrum, target_spectrum):
"""Compute the specified divergence between spectra."""
if self.divergence_type == "mse":
return F.mse_loss(input_spectrum, target_spectrum)
elif self.divergence_type == "mae":
return F.l1_loss(input_spectrum, target_spectrum)
elif self.divergence_type == "kl":
# KL divergence for normalized spectra (treats them as distributions)
# Add epsilon to avoid log(0)
return F.kl_div(
torch.log(input_spectrum + self.epsilon),
target_spectrum,
reduction="batchmean",
)
elif self.divergence_type == "js":
# Jensen-Shannon divergence
m = 0.5 * (input_spectrum + target_spectrum)
kl1 = F.kl_div(
torch.log(input_spectrum + self.epsilon), m, reduction="batchmean"
)
kl2 = F.kl_div(
torch.log(target_spectrum + self.epsilon), m, reduction="batchmean"
)
return 0.5 * (kl1 + kl2)
elif self.divergence_type == "wasserstein":
# Approximate 1D Wasserstein distance for normalized spectra
# Using cumulative distribution approximation
input_cdf = torch.cumsum(input_spectrum, dim=-1)
target_cdf = torch.cumsum(target_spectrum, dim=-1)
return torch.mean(torch.abs(input_cdf - target_cdf))
def compute_fft(self, signal):
"""
Compute FFT of the signal, with windowing if specified.
Returns magnitude and phase.
"""
batch_size = signal.shape[0]
if self.window_size is None:
# Full signal FFT
fft = torch.fft.fft(signal)
# Get only the positive frequencies (first half)
n = signal.shape[-1]
fft = fft[..., : n // 2 + 1]
magnitude = torch.abs(fft)
phase = torch.angle(fft) if self.phase_weight > 0 else None
# Apply log if needed
if self.log_magnitude:
magnitude = torch.log(magnitude + self.epsilon)
# Apply normalization
magnitude = self._normalize_spectrum(magnitude)
# Apply focus range if specified
if self.focus_range is not None:
min_idx, max_idx = self.focus_range
magnitude = magnitude[..., min_idx:max_idx]
if phase is not None:
phase = phase[..., min_idx:max_idx]
return magnitude, phase
else:
# Short-time Fourier transform with overlapping windows
step_size = int(self.window_size * (1 - self.window_overlap))
signal_length = signal.shape[-1]
# Calculate number of windows
num_windows = max(1, (signal_length - self.window_size) // step_size + 1)
# Initialize tensors to store results
window_magnitudes = []
window_phases = []
# Apply windowing
window_func = torch.hann_window(self.window_size, device=signal.device)
for i in range(num_windows):
start_idx = i * step_size
end_idx = start_idx + self.window_size
if end_idx > signal_length:
# Pad the last window if needed
padded = F.pad(
signal[..., start_idx:], (0, end_idx - signal_length)
)
windowed = padded * window_func
else:
windowed = signal[..., start_idx:end_idx] * window_func
# Compute FFT for this window
fft = torch.fft.fft(windowed)
# Get only the positive frequencies
n = windowed.shape[-1]
fft = fft[..., : n // 2 + 1]
window_mag = torch.abs(fft)
# Apply log if needed
if self.log_magnitude:
window_mag = torch.log(window_mag + self.epsilon)
# Apply normalization
window_mag = self._normalize_spectrum(window_mag)
# Apply focus range if specified
if self.focus_range is not None:
min_idx, max_idx = self.focus_range
window_mag = window_mag[..., min_idx:max_idx]
window_magnitudes.append(window_mag)
if self.phase_weight > 0:
window_phase = torch.angle(fft)
if self.focus_range is not None:
window_phase = window_phase[..., min_idx:max_idx]
window_phases.append(window_phase)
# Stack all windows
magnitude = torch.stack(window_magnitudes, dim=1) # [batch, windows, freq]
phase = torch.stack(window_phases, dim=1) if self.phase_weight > 0 else None
return magnitude, phase
def forward(self, input_signal, target_signal):
"""
Compute the loss between input and target chaotic signals.
Parameters:
-----------
input_signal : torch.Tensor
Predicted signal from neural ODE, shape [batch_size, signal_length]
target_signal : torch.Tensor
Target chaotic signal, shape [batch_size, signal_length]
Returns:
--------
loss : torch.Tensor
The computed spectral loss
"""
# Handle multi-dimensional signals
original_shape = input_signal.shape
if len(original_shape) > 2:
# For multi-dimensional signals, flatten all but batch dimension
input_signal = input_signal.reshape(original_shape[0], -1)
target_signal = target_signal.reshape(target_signal.shape[0], -1)
# Ensure signals have the same shape
if input_signal.shape != target_signal.shape:
raise ValueError(
f"Input shape {input_signal.shape} and target shape {target_signal.shape} must match"
)
# Compute FFT spectra
input_magnitude, input_phase = self.compute_fft(input_signal)
target_magnitude, target_phase = self.compute_fft(target_signal)
# Compute magnitude loss
magnitude_loss = self._compute_divergence(input_magnitude, target_magnitude)
# Initialize with magnitude loss
loss = self.magnitude_weight * magnitude_loss
# Add phase loss if needed
if self.phase_weight > 0:
# Handle phase wrapping by computing the minimum distance in the circle
phase_diff = torch.abs(input_phase - target_phase)
phase_diff = torch.min(phase_diff, 2 * torch.pi - phase_diff)
phase_loss = torch.mean(phase_diff**2)
loss = loss + self.phase_weight * phase_loss
return loss
class MainPeakLoss(nn.Module):
"""
A focused loss function that specifically targets the main peak in a spectrum.
This loss is ideal when the primary frequency component carries the most important
information about the chaotic system.
Parameters:
-----------
main_peak_width_factor : float
Width around the main peak to consider as a factor of the total spectrum width.
For example, 0.1 means consider 10% of the spectrum width around the main peak.
Default: 0.1
adaptive_width : bool
If True, adapts the width based on the actual peak width.
Default: True
position_weight : float
Weight for the peak position component of the loss.
Default: 0.6
shape_weight : float
Weight for the peak shape component of the loss.
Default: 0.4
log_magnitude : bool
Whether to use log magnitudes for comparison.
Default: True
epsilon : float
Small value added to avoid log(0).
Default: 1e-8
"""
def __init__(
self,
main_peak_width_factor=0.1,
adaptive_width=True,
position_weight=0.6,
shape_weight=0.4,
log_magnitude=True,
epsilon=1e-8,
*args,
**kwargs,
):
super(MainPeakLoss, self).__init__()
self.main_peak_width_factor = main_peak_width_factor
self.adaptive_width = adaptive_width
self.position_weight = position_weight
self.shape_weight = shape_weight
self.log_magnitude = log_magnitude
self.epsilon = epsilon
def find_main_peak(self, spectrum):
"""
Find the main peak in a spectrum.
Parameters:
-----------
spectrum : torch.Tensor
Spectrum with shape [batch_size, freq_bins]
Returns:
--------
peak_indices : torch.Tensor
Indices of main peaks with shape [batch_size]
peak_widths : torch.Tensor
Estimated widths of main peaks with shape [batch_size]
"""
batch_size, freq_bins = spectrum.shape
device = spectrum.device
# Initialize outputs
peak_indices = torch.zeros(batch_size, dtype=torch.long, device=device)
peak_widths = torch.zeros(batch_size, device=device)
# Find main peak for each batch item
for b in range(batch_size):
spec = spectrum[b]
# Skip the DC component (0 Hz) if we have enough bins
start_idx = 1 if freq_bins > 10 else 0
# Find the maximum value (main peak)
if start_idx < freq_bins:
max_idx = torch.argmax(spec[start_idx:]) + start_idx
peak_indices[b] = max_idx
# Estimate peak width
if self.adaptive_width:
max_val = spec[max_idx]
half_height = max_val / 2
# Find left bound (where value drops below half height)
left_bound = max_idx
for i in range(max_idx - 1, -1, -1):
if spec[i] < half_height:
left_bound = i
break
# Find right bound
right_bound = max_idx
for i in range(max_idx + 1, freq_bins):
if spec[i] < half_height:
right_bound = i
break
# Calculate width
peak_widths[b] = right_bound - left_bound
else:
# Use fixed width based on factor
peak_widths[b] = max(
1, int(freq_bins * self.main_peak_width_factor)
)
return peak_indices, peak_widths
def extract_peak_regions(self, spectrum, peak_indices, peak_widths):
"""
Extract regions around peaks for comparison.
Parameters:
-----------
spectrum : torch.Tensor
Spectrum with shape [batch_size, freq_bins]
peak_indices : torch.Tensor
Indices of peaks with shape [batch_size]
peak_widths : torch.Tensor
Widths of peaks with shape [batch_size]
Returns:
--------
peak_regions : list of torch.Tensor
List of tensors containing peak regions
region_indices : list of tuples
List of (start, end) indices for each region
"""
batch_size, freq_bins = spectrum.shape
device = spectrum.device
peak_regions = []
region_indices = []
for b in range(batch_size):
peak_idx = peak_indices[b].item()
# Determine region width
if self.adaptive_width:
width = max(1, int(peak_widths[b].item()))
else:
width = max(1, int(freq_bins * self.main_peak_width_factor))
# Calculate region bounds
half_width = width // 2
start_idx = max(0, peak_idx - half_width)
end_idx = min(freq_bins, peak_idx + half_width + 1)
# Extract region
region = spectrum[b, start_idx:end_idx]
peak_regions.append(region)
region_indices.append((start_idx, end_idx))
return peak_regions, region_indices
def compute_peak_loss(
self, input_regions, input_indices, target_regions, target_indices
):
"""
Compute loss between peak regions.
Parameters:
-----------
input_regions : list of torch.Tensor
Peak regions from input signals
input_indices : list of tuples
(start, end) indices for input regions
target_regions : list of torch.Tensor
Peak regions from target signals
target_indices : list of tuples
(start, end) indices for target regions
Returns:
--------
position_loss : torch.Tensor
Loss component for peak positions
shape_loss : torch.Tensor
Loss component for peak shapes
"""
batch_size = len(input_regions)
device = input_regions[0].device if batch_size > 0 else torch.device("cpu")
position_losses = torch.zeros(batch_size, device=device)
shape_losses = torch.zeros(batch_size, device=device)
for b in range(batch_size):
# Calculate position loss (normalized by spectrum size)
in_start, in_end = input_indices[b]
# Convert to tensor
in_center = torch.tensor((in_start + in_end) / 2, device=device)
tgt_start, tgt_end = target_indices[b]
# Convert to tensor
tgt_center = torch.tensor((tgt_start + tgt_end) / 2, device=device)
# Normalize by spectrum size for relative position
# Use input_regions[b] length as a tensor value
region_length = torch.tensor(
len(input_regions[b]), dtype=torch.float, device=device
)
position_losses[b] = torch.abs(in_center - tgt_center) / region_length
# Calculate shape loss
# We need to handle different region sizes
in_region = input_regions[b]
tgt_region = target_regions[b]
# Interpolate to match sizes if different
if len(in_region) != len(tgt_region):
# Use the larger size for interpolation
target_size = max(len(in_region), len(tgt_region))
# Interpolate both to the target size
in_region_interp = (
F.interpolate(
in_region.unsqueeze(0).unsqueeze(0),
size=target_size,
mode="linear",
align_corners=False,
)
.squeeze(0)
.squeeze(0)
)
tgt_region_interp = (
F.interpolate(
tgt_region.unsqueeze(0).unsqueeze(0),
size=target_size,
mode="linear",
align_corners=False,
)
.squeeze(0)
.squeeze(0)
)
# Compute MSE between interpolated regions
shape_losses[b] = F.mse_loss(in_region_interp, tgt_region_interp)
else:
# Compute MSE directly
shape_losses[b] = F.mse_loss(in_region, tgt_region)
# Compute mean losses
position_loss = torch.mean(position_losses)
shape_loss = torch.mean(shape_losses)
return position_loss, shape_loss
def forward(self, input_signal, target_signal):
"""
Compute the main peak focused loss between input and target signals.
Parameters:
-----------
input_signal : torch.Tensor
Input signal of shape [batch_size, signal_length]
target_signal : torch.Tensor
Target signal of shape [batch_size, signal_length]
Returns:
--------
loss : torch.Tensor
Computed loss focused on main spectral peak
"""
# Handle multi-dimensional signals
original_shape = input_signal.shape
if len(original_shape) > 2:
# Reshape to [batch_size, -1]
input_signal = input_signal.reshape(original_shape[0], -1)
target_signal = target_signal.reshape(target_signal.shape[0], -1)
# Check shape match
if input_signal.shape != target_signal.shape:
raise ValueError(
f"Input shape {input_signal.shape} must match target shape {target_signal.shape}"
)
batch_size, signal_length = input_signal.shape
# Compute FFT
input_fft = torch.fft.fft(input_signal)
target_fft = torch.fft.fft(target_signal)
# Extract positive frequencies
n = signal_length
input_fft = input_fft[:, : n // 2 + 1]
target_fft = target_fft[:, : n // 2 + 1]
# Calculate magnitude spectra
input_mag = torch.abs(input_fft)
target_mag = torch.abs(target_fft)
# Apply log if requested
if self.log_magnitude:
input_mag = torch.log(input_mag + self.epsilon)
target_mag = torch.log(target_mag + self.epsilon)
# Find main peaks
input_peak_indices, input_peak_widths = self.find_main_peak(input_mag)
target_peak_indices, target_peak_widths = self.find_main_peak(target_mag)
# Extract regions around peaks
input_regions, input_region_indices = self.extract_peak_regions(
input_mag, input_peak_indices, input_peak_widths
)
target_regions, target_region_indices = self.extract_peak_regions(
target_mag, target_peak_indices, target_peak_widths
)
# Compute loss components
position_loss, shape_loss = self.compute_peak_loss(
input_regions, input_region_indices, target_regions, target_region_indices
)
# Combine losses
total_loss = (
self.position_weight * position_loss + self.shape_weight * shape_loss
)
return total_loss
class MultiscaleChaoticPeakLoss(nn.Module):
"""
A comprehensive loss function that combines:
1. Main peak analysis
2. Multiple peak analysis
3. Overall spectral shape
4. Time-domain comparison
This is ideal for chaotic neural ODEs where both the dominant modes and
overall dynamical behavior need to be matched.
Parameters:
-----------
main_peak_weight : float
Weight for the main peak component.
Default: 0.4
multi_peak_weight : float
Weight for the multiple peaks component.
Default: 0.3
spectral_shape_weight : float
Weight for the overall spectral shape component.
Default: 0.2
time_domain_weight : float
Weight for the time-domain comparison component.
Default: 0.1
n_peaks : int
Number of peaks to consider for multi-peak analysis.
Default: 3
"""
def __init__(
self,
main_peak_weight=0.4,
multi_peak_weight=0.3,
spectral_shape_weight=0.2,
time_domain_weight=0.1,
n_peaks=3,
**kwargs,
):
super(MultiscaleChaoticPeakLoss, self).__init__()
self.main_peak_weight = main_peak_weight
self.multi_peak_weight = multi_peak_weight
self.spectral_shape_weight = spectral_shape_weight
self.time_domain_weight = time_domain_weight
# Create individual loss components
self.main_peak_loss = MainPeakLoss(**kwargs)
self.multi_peak_loss = PeakSpectraLoss(n_peaks=n_peaks, **kwargs)
self.spectral_loss = ChaoticFFTLoss(**kwargs)
def forward(self, input_signal, target_signal):
"""
Compute the multi-component loss between input and target chaotic signals.
Parameters:
-----------
input_signal : torch.Tensor
Predicted signal from neural ODE, shape [batch_size, signal_length]
target_signal : torch.Tensor
Target chaotic signal, shape [batch_size, signal_length]
Returns:
--------
loss : torch.Tensor
The computed multi-component loss
"""
# Compute individual loss components
main_loss = self.main_peak_loss(input_signal, target_signal)
multi_loss = self.multi_peak_loss(input_signal, target_signal)
spectral_loss = self.spectral_loss(input_signal, target_signal)
# Time-domain loss (basic MSE)
time_loss = F.mse_loss(input_signal, target_signal)
# Combine all loss components
total_loss = (
self.main_peak_weight * main_loss
+ self.multi_peak_weight * multi_loss
+ self.spectral_shape_weight * spectral_loss
+ self.time_domain_weight * time_loss
)
return total_loss
# Example usage with a neural ODE chaotic circuit
def example_with_neural_ode_chaotic_circuit():
"""
Example showing how to use the peak-focused loss functions with a neural ODE
model that simulates a chaotic circuit.
"""
import torch
import torch.nn as nn
from torchdiffeq import odeint
class ChaoticCircuitODE(nn.Module):
"""
Neural ODE model of a chaotic circuit (e.g., Chua's circuit).
"""
def __init__(self, hidden_dim=32):
super(ChaoticCircuitODE, self).__init__()
# Neural network part for learning chaotic dynamics
self.net = nn.Sequential(
nn.Linear(3, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 3),
)
# Parameters similar to Chua's circuit
self.alpha = nn.Parameter(torch.tensor(9.0))
self.beta = nn.Parameter(torch.tensor(14.0))
self.gamma = nn.Parameter(torch.tensor(0.1))
def chua_nonlinearity(self, x):
"""Nonlinear function similar to Chua's diode"""
return torch.tanh(x) - 0.2 * torch.tanh(3 * x)
def forward(self, t, state):
"""ODE function for the chaotic circuit"""
x, y, z = state[:, 0], state[:, 1], state[:, 2]
# Chua's circuit-like dynamics
dx = self.alpha * (y - x - self.chua_nonlinearity(x))
dy = x - y + z
dz = -self.beta * y - self.gamma * z
# Base dynamics
dstate = torch.stack([dx, dy, dz], dim=1)
# Neural network contribution (can learn additional dynamics)
nn_contrib = self.net(state)
return dstate + 0.01 * nn_contrib
# Create the model
model = ChaoticCircuitODE()
# Generate some example data
batch_size = 8
initial_state = torch.randn(batch_size, 3) * 0.1
t = torch.linspace(0, 10, 1000)
# Generate target trajectories
with torch.no_grad():
trajectories = odeint(model, initial_state, t)
# Shape: [time_steps, batch_size, 3]
trajectories = trajectories.permute(1, 0, 2)
# Now shape: [batch_size, time_steps, 3]
# Use the first component as the signal
target_signals = trajectories[:, :, 0]
# Create slightly perturbed initial conditions
perturbed_initial = initial_state + torch.randn_like(initial_state) * 0.01
# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Create the loss function - focusing on peaks
peak_loss = MultiscaleChaoticPeakLoss(
main_peak_weight=0.4,
multi_peak_weight=0.3,
spectral_shape_weight=0.2,
time_domain_weight=0.1,
n_peaks=3,
log_magnitude=True,
)
# Training loop (simplified example)
for epoch in range(10):
optimizer.zero_grad()
# Forward pass with perturbed initial conditions
pred_trajectories = odeint(model, perturbed_initial, t)
pred_trajectories = pred_trajectories.permute(1, 0, 2)
pred_signals = pred_trajectories[:, :, 0]
# Compute loss
loss = peak_loss(pred_signals, target_signals)
# Backward pass
loss.backward()
optimizer.step()
print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
return model, peak_loss
class PeakSpectraLoss(nn.Module):
"""
A specialized loss function that focuses on comparing multiple peaks
in frequency spectra of signals from chaotic systems.
Parameters:
-----------
n_peaks : int
Number of dominant peaks to consider in the comparison.
Default: 3
peak_match_type : str
How to match peaks between input and target:
'position': Compare peaks at the same frequency positions
'magnitude': Match peaks by magnitude (largest to largest, etc.)
'nearest': Match each peak to the nearest peak in the other spectrum
Default: 'nearest'
position_weight : float
Weight for comparing peak positions.
Default: 0.7
magnitude_weight : float
Weight for comparing peak magnitudes.
Default: 0.3
log_magnitude : bool
Whether to use log magnitudes for comparison.
Default: True
epsilon : float
Small value added to avoid log(0).
Default: 1e-8
"""
def __init__(
self,
n_peaks=3,
peak_match_type="nearest",
position_weight=0.7,
magnitude_weight=0.3,
log_magnitude=True,
epsilon=1e-8,
*args,
**kwargs,
):
super(PeakSpectraLoss, self).__init__()
self.n_peaks = n_peaks
self.peak_match_type = peak_match_type
self.position_weight = position_weight
self.magnitude_weight = magnitude_weight
self.log_magnitude = log_magnitude
self.epsilon = epsilon
def find_peaks(self, spectrum):
"""
Find the dominant peaks in a spectrum.
Parameters:
-----------
spectrum : torch.Tensor
Spectrum with shape [batch_size, freq_bins]
Returns:
--------
peak_indices : torch.Tensor
Indices of peaks with shape [batch_size, n_peaks]
peak_values : torch.Tensor
Magnitude of peaks with shape [batch_size, n_peaks]
"""
batch_size, freq_bins = spectrum.shape
device = spectrum.device
# Initialize outputs with default values
peak_indices = torch.zeros(
(batch_size, self.n_peaks), dtype=torch.long, device=device
)
peak_values = torch.zeros((batch_size, self.n_peaks), device=device)
# Process each batch item separately
for b in range(batch_size):
# Get this spectrum
spec = spectrum[b]
# Need at least 3 points to detect a peak
if freq_bins < 3:
continue
# Create a mask for local maxima
is_peak = torch.zeros(freq_bins, dtype=torch.bool, device=device)
for i in range(1, freq_bins - 1):
if spec[i] > spec[i - 1] and spec[i] > spec[i + 1]:
is_peak[i] = True
# Get peak indices and values
peak_idx = torch.where(is_peak)[0]
if len(peak_idx) == 0:
continue
peak_vals = spec[peak_idx]
# Filter by threshold relative to maximum
max_val = torch.max(spec)
threshold = max_val * 0.1 # Keep peaks at least 10% of max
mask = peak_vals >= threshold
peak_idx = peak_idx[mask]
peak_vals = peak_vals[mask]
if len(peak_idx) == 0:
continue
# Sort by magnitude (descending)
sorted_indices = torch.argsort(peak_vals, descending=True)
peak_idx = peak_idx[sorted_indices]
peak_vals = peak_vals[sorted_indices]
# Keep top n_peaks
n_found = min(len(peak_idx), self.n_peaks)
peak_indices[b, :n_found] = peak_idx[:n_found]
peak_values[b, :n_found] = peak_vals[:n_found]
return peak_indices, peak_values
def compute_peak_loss(
self, input_indices, input_values, target_indices, target_values, max_freq_bin
):
"""
Compute loss between peaks based on matching strategy.
Parameters:
-----------
input_indices, target_indices : torch.Tensor
Indices of peaks, shape [batch_size, n_peaks]
input_values, target_values : torch.Tensor
Values of peaks, shape [batch_size, n_peaks]
max_freq_bin : int
Maximum frequency bin index for normalization
Returns:
--------
loss : torch.Tensor
Computed loss between peak features
"""
batch_size = input_indices.shape[0]
device = input_indices.device
# Initialize loss components
position_loss = torch.zeros(batch_size, device=device)
magnitude_loss = torch.zeros(batch_size, device=device)
for b in range(batch_size):
# Get peaks for this batch item
in_idx = input_indices[b]
in_val = input_values[b]
tgt_idx = target_indices[b]
tgt_val = target_values[b]
# Normalize peak indices by max frequency for position loss
norm_in_idx = in_idx.float() / max_freq_bin
norm_tgt_idx = tgt_idx.float() / max_freq_bin
# Check if we have valid peaks to compare
valid_in = torch.any(in_val > 0)
valid_tgt = torch.any(tgt_val > 0)
if not (valid_in and valid_tgt):
continue
# Match peaks according to strategy
if self.peak_match_type == "position":
# Match by position in sorted list (1st to 1st, etc.)
idx_diff = torch.abs(norm_in_idx - norm_tgt_idx)
val_diff = torch.abs(in_val - tgt_val)
elif self.peak_match_type == "magnitude":
# Match by magnitude order (already sorted)
idx_diff = torch.abs(norm_in_idx - norm_tgt_idx)
val_diff = torch.abs(in_val - tgt_val)
elif self.peak_match_type == "nearest":
# Match each input peak to nearest target peak
idx_diff = torch.zeros_like(in_idx, dtype=torch.float)
val_diff = torch.zeros_like(in_val)
for i in range(len(in_idx)):
if in_val[i] > 0: # Only consider valid peaks
# Find nearest target peak by frequency
distances = torch.abs(norm_in_idx[i] - norm_tgt_idx)
nearest_idx = torch.argmin(distances)
idx_diff[i] = distances[nearest_idx]
val_diff[i] = torch.abs(in_val[i] - tgt_val[nearest_idx])
# Compute mean losses for this batch item
valid_count = torch.sum(in_val > 0).float()
if valid_count > 0:
position_loss[b] = torch.sum(idx_diff) / valid_count
magnitude_loss[b] = torch.sum(val_diff) / valid_count
# Compute weighted average loss across batch
total_loss = self.position_weight * torch.mean(
position_loss
) + self.magnitude_weight * torch.mean(magnitude_loss)
return total_loss
def forward(self, input_signal, target_signal):
"""
Compute the peak-focused loss between input and target signals.
Parameters:
-----------
input_signal : torch.Tensor
Input signal of shape [batch_size, signal_length]
target_signal : torch.Tensor
Target signal of shape [batch_size, signal_length]
Returns:
--------
loss : torch.Tensor
Computed loss focused on spectral peaks
"""
# Handle multi-dimensional signals
original_shape = input_signal.shape
if len(original_shape) > 2:
# Reshape to [batch_size, -1]
input_signal = input_signal.reshape(original_shape[0], -1)
target_signal = target_signal.reshape(target_signal.shape[0], -1)
# Check shape match
if input_signal.shape != target_signal.shape:
raise ValueError(
f"Input shape {input_signal.shape} must match target shape {target_signal.shape}"
)
batch_size, signal_length = input_signal.shape
# Compute FFT
input_fft = torch.fft.fft(input_signal)
target_fft = torch.fft.fft(target_signal)
# Extract positive frequencies
n = signal_length
input_fft = input_fft[:, : n // 2 + 1]
target_fft = target_fft[:, : n // 2 + 1]
# Calculate magnitude spectra
input_mag = torch.abs(input_fft)
target_mag = torch.abs(target_fft)
# Apply log if requested
if self.log_magnitude:
input_mag = torch.log(input_mag + self.epsilon)
target_mag = torch.log(target_mag + self.epsilon)
# Find peaks in both spectra
input_peak_indices, input_peak_values = self.find_peaks(input_mag)
target_peak_indices, target_peak_values = self.find_peaks(target_mag)
# Compute peak-based loss
max_freq_bin = n // 2
total_loss = self.compute_peak_loss(
input_peak_indices,
input_peak_values,
target_peak_indices,
target_peak_values,
max_freq_bin,
)
return total_loss
class AttractorHistogramLoss(nn.Module):
"""
A loss function that compares chaotic attractors by generating and comparing
2D histograms of their state space distributions.
This approach is especially effective for chaotic systems where the exact
trajectories may diverge due to sensitivity to initial conditions, but the
overall attractor shape and density distribution remain similar.
Parameters:
-----------
bins : int or tuple(int, int)
Number of bins for the histogram in each dimension.
If int: same number of bins used for both dimensions
If tuple: (bins_x, bins_y) for different binning in each dimension
Default: 50
range_mode : str
How to determine the histogram range:
'adaptive': Automatically determined from data (each batch may have different ranges)
'fixed': Use fixed_range parameter
'combined': Use the union of adaptive ranges from both input and target
Default: 'combined'
fixed_range : tuple of tuple of float
Only used when range_mode='fixed'
Format: ((x_min, x_max), (y_min, y_max))
Default: ((-2.0, 2.0), (-2.0, 2.0))
normalize : bool
Whether to normalize histograms to sum to 1 (probability distribution)
Default: True
smooth : bool
Whether to apply Gaussian smoothing to histograms
Default: True
smooth_sigma : float
Standard deviation for Gaussian smoothing kernel
Default: 1.0
epsilon : float
Small value added to histograms to avoid log(0) and division by zero
Default: 1e-8
distance_metric : str
Metric to compare histograms:
'kl': Kullback-Leibler divergence
'js': Jensen-Shannon divergence
'wasserstein': Wasserstein distance (Earth Mover's Distance)
'l2': L2 distance (Mean Squared Error)
'l1': L1 distance (Mean Absolute Error)
'cosine': Cosine similarity
'bhattacharyya': Bhattacharyya distance
Default: 'js'
"""
def __init__(
self,
bins=50,
range_mode="combined",
fixed_range=((-2.0, 2.0), (-2.0, 2.0)),
normalize=True,
smooth=True,
smooth_sigma=1.0,
epsilon=1e-8,
distance_metric="l2",
):
super(AttractorHistogramLoss, self).__init__()
# Set parameters
if isinstance(bins, int):
self.bins = (bins, bins)
else:
self.bins = bins
self.range_mode = range_mode
self.fixed_range = fixed_range
self.normalize = normalize
self.smooth = smooth
self.smooth_sigma = smooth_sigma
self.epsilon = epsilon
self.distance_metric = distance_metric
# Validate parameters
if self.range_mode not in ["adaptive", "fixed", "combined"]:
raise ValueError(f"Invalid range_mode: {range_mode}")
if self.distance_metric not in [
"kl",
"js",
"wasserstein",
"l2",
"l1",
"cosine",
"bhattacharyya",
]:
raise ValueError(f"Invalid distance_metric: {distance_metric}")
def compute_histogram_range(self, x, y=None):
"""
Compute appropriate range for histogram based on the specified mode.
Parameters:
-----------
x : torch.Tensor
Input tensor of shape [batch_size, n_steps, 2]
y : torch.Tensor, optional
Target tensor of shape [batch_size, n_steps, 2]
Only used when range_mode='combined'
Returns:
--------
ranges : tuple of tuples
((x_min, x_max), (y_min, y_max))
"""
if self.range_mode == "fixed":
return self.fixed_range
# Compute range from x
x_min_0, _ = torch.min(x[..., 0], dim=1)
x_max_0, _ = torch.max(x[..., 0], dim=1)
x_min_1, _ = torch.min(x[..., 1], dim=1)
x_max_1, _ = torch.max(x[..., 1], dim=1)
# Take min/max across batch
range_x0 = (torch.min(x_min_0), torch.max(x_max_0))
range_x1 = (torch.min(x_min_1), torch.max(x_max_1))
if self.range_mode == "adaptive":
return (range_x0, range_x1)
# If mode is 'combined', combine with y's range
if y is not None:
y_min_0, _ = torch.min(y[..., 0], dim=1)
y_max_0, _ = torch.max(y[..., 0], dim=1)
y_min_1, _ = torch.min(y[..., 1], dim=1)
y_max_1, _ = torch.max(y[..., 1], dim=1)
# Take overall min/max
range_0 = (
min(torch.min(x_min_0), torch.min(y_min_0)),
max(torch.max(x_max_0), torch.max(y_max_0)),
)
range_1 = (
min(torch.min(x_min_1), torch.min(y_min_1)),
max(torch.max(x_max_1), torch.max(y_max_1)),
)
return (range_0, range_1)
# Fallback
return (range_x0, range_x1)
def compute_2d_histogram(self, x, hist_range):
"""
Compute 2D histogram of the attractor points.
Parameters:
-----------
x : torch.Tensor
Input tensor of shape [batch_size, n_steps, 2]
hist_range : tuple of tuples
((x_min, x_max), (y_min, y_max))
Returns:
--------
histograms : torch.Tensor
Batch of histograms with shape [batch_size, bins[0], bins[1]]
"""
batch_size, n_steps, _ = x.shape
device = x.device
# Create meshgrid for binning
range_x, range_y = hist_range
x_min, x_max = range_x
y_min, y_max = range_y
# Initialize histograms
histograms = torch.zeros(
(batch_size, self.bins[0], self.bins[1]), device=device
)
# Process each batch item separately
for b in range(batch_size):
# Extract 2D points for this batch
points = x[b] # [n_steps, 2]
# Skip if no points (zeros histogram will be returned)
if n_steps == 0:
continue
# Compute bin indices for each point
x_indices = torch.clamp(
((points[:, 0] - x_min) / (x_max - x_min) * self.bins[0]).long(),
0,
self.bins[0] - 1,
)
y_indices = torch.clamp(
((points[:, 1] - y_min) / (y_max - y_min) * self.bins[1]).long(),
0,
self.bins[1] - 1,
)
# Accumulate counts in histogram
for i in range(n_steps):
histograms[b, x_indices[i], y_indices[i]] += 1
# Normalize if requested
if self.normalize:
# Add epsilon to avoid division by zero
sums = (
torch.sum(histograms.view(batch_size, -1), dim=1).view(batch_size, 1, 1)
+ self.epsilon
)
histograms = histograms / sums
# Apply smoothing if requested
if self.smooth:
# Create Gaussian kernel
kernel_size = min(13, min(self.bins) // 2 * 2 + 1) # Odd kernel size
sigma = torch.tensor([self.smooth_sigma], device=device)
# Apply 2D Gaussian filter
padding = kernel_size // 2
# Reshape for conv2d [batch, channels, height, width]
histograms_reshaped = histograms.unsqueeze(
1
) # [batch, 1, bins[0], bins[1]]
# Create 1D Gaussian kernels and apply separably for efficiency
coords = torch.arange(kernel_size, device=device) - padding
kernel_1d = torch.exp(-(coords**2) / (2 * sigma**2))
kernel_1d = kernel_1d / kernel_1d.sum()
# Apply horizontal and vertical Gaussian blur
kernel_h = kernel_1d.view(1, 1, 1, kernel_size)
kernel_v = kernel_1d.view(1, 1, kernel_size, 1)
# Apply horizontal blu
smoothed = F.conv2d(
F.pad(histograms_reshaped, (padding, padding, 0, 0), mode="reflect"),
kernel_h,
padding=(0, 0),
)
# Apply vertical blur
smoothed = F.conv2d(
F.pad(smoothed, (0, 0, padding, padding), mode="reflect"),
kernel_v,
padding=(0, 0),
)
# Return to original shape
histograms = smoothed.squeeze(1)
return histograms
def compute_histogram_distance(self, hist1, hist2):
"""
Compute distance between two histograms based on the selected metric.
Parameters:
-----------
hist1, hist2 : torch.Tensor
Histograms to compare with shape [batch_size, bins[0], bins[1]]
Returns:
--------
distances : torch.Tensor
Computed distances with shape [batch_size]
"""
batch_size = hist1.shape[0]
device = hist1.device
# Reshape for certain metrics
flat_hist1 = hist1.reshape(batch_size, -1)
flat_hist2 = hist2.reshape(batch_size, -1)
# Initialize distances
distances = torch.zeros(batch_size, device=device)
# Compute selected distance metric
if self.distance_metric == "kl":
# KL divergence (add epsilon to avoid log(0))
kl_div = flat_hist1 * (
torch.log(flat_hist1 + self.epsilon)
- torch.log(flat_hist2 + self.epsilon)
)
distances = torch.sum(kl_div, dim=1)
elif self.distance_metric == "js":
# Jensen-Shannon divergence
m = 0.5 * (flat_hist1 + flat_hist2)
js_div = 0.5 * (
flat_hist1
* (torch.log(flat_hist1 + self.epsilon) - torch.log(m + self.epsilon))
+ flat_hist2
* (torch.log(flat_hist2 + self.epsilon) - torch.log(m + self.epsilon))
)
distances = torch.sum(js_div, dim=1)
elif self.distance_metric == "wasserstein":
# Simple approximation of 2D Wasserstein distance using cumulative histograms
# This is a simplified version and not the true 2D EMD
# Compute marginal cumulative histograms
cum_x1 = torch.cumsum(torch.sum(hist1, dim=2), dim=1)
cum_x2 = torch.cumsum(torch.sum(hist2, dim=2), dim=1)
cum_y1 = torch.cumsum(torch.sum(hist1, dim=1), dim=1)
cum_y2 = torch.cumsum(torch.sum(hist2, dim=1), dim=1)
# Normalize
cum_x1 = cum_x1 / (torch.max(cum_x1, dim=1, keepdim=True)[0] + self.epsilon)
cum_x2 = cum_x2 / (torch.max(cum_x2, dim=1, keepdim=True)[0] + self.epsilon)
cum_y1 = cum_y1 / (torch.max(cum_y1, dim=1, keepdim=True)[0] + self.epsilon)
cum_y2 = cum_y2 / (torch.max(cum_y2, dim=1, keepdim=True)[0] + self.epsilon)
# Compute L1 distances between cumulative histograms
x_dist = torch.sum(torch.abs(cum_x1 - cum_x2), dim=1)
y_dist = torch.sum(torch.abs(cum_y1 - cum_y2), dim=1)
# Combine X and Y distances
distances = x_dist + y_dist
elif self.distance_metric == "l2":
# L2 distance (MSE)
distances = torch.sum((flat_hist1 - flat_hist2) ** 2, dim=1)
elif self.distance_metric == "l1":
# L1 distance (MAE)
distances = torch.sum(torch.abs(flat_hist1 - flat_hist2), dim=1)
elif self.distance_metric == "cosine":
# Cosine distance (1 - cosine similarity)
dot_product = torch.sum(flat_hist1 * flat_hist2, dim=1)
norm1 = torch.sqrt(torch.sum(flat_hist1**2, dim=1) + self.epsilon)
norm2 = torch.sqrt(torch.sum(flat_hist2**2, dim=1) + self.epsilon)
cosine_sim = dot_product / (norm1 * norm2)
distances = 1.0 - cosine_sim
elif self.distance_metric == "bhattacharyya":
# Bhattacharyya distance
bc = torch.sqrt(flat_hist1 * flat_hist2)
bc_coeff = torch.sum(bc, dim=1)
distances = -torch.log(bc_coeff + self.epsilon)
return distances
def forward(self, input_traj, target_traj):
"""
Compute the loss between input and target attractors.
Parameters:
-----------
input_traj : torch.Tensor
Input trajectory of shape [batch_size, n_steps, 2]
Represents the 2D attractor from the model
target_traj : torch.Tensor
Target trajectory of shape [batch_size, n_steps, 2]
Represents the 2D attractor to match
Returns:
--------
loss : torch.Tensor
Computed histogram-based loss
"""
# Validate input shapes
if input_traj.dim() != 3 or target_traj.dim() != 3:
raise ValueError(
"Input and target trajectories must be 3D tensors [batch, steps, 2]"
)
if input_traj.shape[0] != target_traj.shape[0]:
raise ValueError(
f"Batch sizes don't match: {input_traj.shape[0]} vs {target_traj.shape[0]}"
)
if input_traj.shape[2] != 2 or target_traj.shape[2] != 2:
raise ValueError("Last dimension must be 2 for 2D attractors")
# Compute histogram range
hist_range = self.compute_histogram_range(
input_traj, target_traj if self.range_mode == "combined" else None
)
# Compute histograms
input_hist = self.compute_2d_histogram(input_traj, hist_range)
target_hist = self.compute_2d_histogram(target_traj, hist_range)
# Compute distances between histograms
distances = self.compute_histogram_distance(input_hist, target_hist)
# Return mean distance as the loss
return torch.mean(distances)
class AttractorMomentLoss(nn.Module):
"""
A loss function that compares chaotic attractors using multivariate statistical moments
and distribution properties.
This approach captures the overall shape, spread, and orientation of the attractor
without being sensitive to exact point positions. Uses multivariate skewness and
kurtosis (Mardia's measures) for rotation-invariant comparison.
Parameters:
-----------
mean_weight : float
Weight for the mean position component
Default: 0.1
cov_weight : float
Weight for the covariance matrix component
Default: 0.4
skewness_weight : float
Weight for the multivariate skewness component
Default: 0.3
kurtosis_weight : float
Weight for the multivariate kurtosis component
Default: 0.2
normalize_components : bool
Whether to normalize each loss component by target magnitude
Recommended: True for scale-invariant comparison
Default: True
epsilon : float
Small value for numerical stability
Default: 1e-6
max_samples_for_skewness : int
Maximum number of samples to use for skewness computation (memory optimization)
If trajectory has more steps, random sampling is used
Default: 1000
regularization_strength : float
Strength of regularization added to covariance matrix for numerical stability
Default: 1e-4
"""
def __init__(
self,
mean_weight=0.1,
cov_weight=0.4,
skewness_weight=0.3,
kurtosis_weight=0.2,
normalize_components=True,
epsilon=1e-6,
max_samples_for_skewness=1000,
regularization_strength=1e-4,
):
super(AttractorMomentLoss, self).__init__()
self.mean_weight = mean_weight
self.cov_weight = cov_weight
self.skewness_weight = skewness_weight
self.kurtosis_weight = kurtosis_weight
self.normalize_components = normalize_components
self.epsilon = epsilon
self.max_samples_for_skewness = max_samples_for_skewness
self.regularization_strength = regularization_strength
# Validate weights
total_weight = mean_weight + cov_weight + skewness_weight + kurtosis_weight
if abs(total_weight - 1.0) > 0.01:
print(
f"Warning: Weights sum to {total_weight:.3f}, not 1.0. "
f"Consider normalizing for better interpretability."
)
def compute_moments(self, traj):
"""
Compute multivariate statistical moments of the attractor.
Parameters:
-----------
traj : torch.Tensor
Trajectory tensor of shape [batch_size, n_steps, dim]
Returns:
--------
means : torch.Tensor
Mean positions with shape [batch_size, dim]
covs : torch.Tensor
Covariance matrices with shape [batch_size, dim, dim]
skews : torch.Tensor
Mardia's multivariate skewness with shape [batch_size]
kurts : torch.Tensor
Mardia's multivariate kurtosis with shape [batch_size]
"""
batch_size, n_steps, dim = traj.shape
device = traj.device
# Validate minimum steps
if n_steps < 2:
raise ValueError(
f"Need at least 2 time steps to compute moments, got {n_steps}"
)
# Check for NaN or Inf
if torch.isnan(traj).any() or torch.isinf(traj).any():
raise ValueError("Input trajectory contains NaN or Inf values")
# Vectorized mean computation
means = torch.mean(traj, dim=1) # [batch_size, dim]
# Center the data
centered = traj - means.unsqueeze(1) # [batch_size, n_steps, dim]
# Vectorized covariance computation
# covs[b] = (1/(n-1)) * centered[b].T @ centered[b]
covs = torch.bmm(centered.transpose(1, 2), centered) / (
n_steps - 1
) # [batch_size, dim, dim]
# Add regularization for numerical stability
eye = torch.eye(dim, device=device).unsqueeze(0) # [1, dim, dim]
cov_reg = covs + self.regularization_strength * eye # [batch_size, dim, dim]
# Compute inverse covariance (with error handling)
try:
cov_inv = torch.linalg.inv(cov_reg)
except RuntimeError as e:
# Fallback to pseudo-inverse if singular
print(
f"Warning: Covariance matrix inversion failed, using pseudo-inverse. Error: {e}"
)
cov_inv = torch.linalg.pinv(cov_reg)
# Compute Mahalanobis-transformed centered data
# centered_transformed[b] = centered[b] @ cov_inv[b]
centered_transformed = torch.bmm(
centered, cov_inv
) # [batch_size, n_steps, dim]
# Mardia's multivariate kurtosis (4th moment)
# β₂,d = (1/n) Σᵢ [(xᵢ - μ)ᵀ Σ⁻¹ (xᵢ - μ)]²
mahal_sq = torch.sum(
centered * centered_transformed, dim=2
) # [batch_size, n_steps]
kurts = torch.mean(mahal_sq**2, dim=1) # [batch_size]
# Mardia's multivariate skewness (3rd moment)
# β₁,d = (1/n²) Σᵢ Σⱼ [(xᵢ - μ)ᵀ Σ⁻¹ (xⱼ - μ)]³
# This requires an n×n matrix, so we sample for memory efficiency
if n_steps > self.max_samples_for_skewness:
# Random sampling for efficiency
indices = torch.randperm(n_steps, device=device)[
: self.max_samples_for_skewness
]
centered_sample = centered[:, indices, :] # [batch_size, n_sample, dim]
# Use full transformed data for the second part
# [batch_size, n_sample, dim] @ [batch_size, dim, n_steps]
mahal_matrix = torch.bmm(
centered_sample, centered_transformed.transpose(1, 2)
)
else:
# Use all data
# [batch_size, n_steps, dim] @ [batch_size, dim, n_steps]
mahal_matrix = torch.bmm(
centered, centered_transformed.transpose(1, 2)
) # [batch_size, n_steps, n_steps]
# Average over both dimensions
skews = torch.mean(mahal_matrix**3, dim=(1, 2)) # [batch_size]
return means, covs, skews, kurts
def forward(self, input_traj, target_traj):
"""
Compute the moment-based loss between input and target attractors.
Parameters:
-----------
input_traj : torch.Tensor
Input trajectory of shape [batch_size, n_steps, dim]
target_traj : torch.Tensor
Target trajectory of shape [batch_size, n_steps, dim]
Returns:
--------
loss : torch.Tensor
Computed moment-based loss
loss_dict : dict (optional, for debugging)
Dictionary containing individual loss components
"""
# Validate input shapes
if input_traj.dim() != 3 or target_traj.dim() != 3:
raise ValueError(
"Input and target trajectories must be 3D tensors [batch, steps, dim]"
)
if input_traj.shape[0] != target_traj.shape[0]:
raise ValueError(
f"Batch sizes don't match: {input_traj.shape[0]} vs {target_traj.shape[0]}"
)
if input_traj.shape[2] != target_traj.shape[2]:
raise ValueError(
f"Dimension mismatch: {input_traj.shape[2]} vs {target_traj.shape[2]}"
)
# Compute moments
input_means, input_covs, input_skews, input_kurts = self.compute_moments(
input_traj
)
target_means, target_covs, target_skews, target_kurts = self.compute_moments(
target_traj
)
# Compute loss components
# Mean loss (Euclidean distance)
mean_loss = F.mse_loss(input_means, target_means, reduction="mean")
# Covariance loss (Frobenius norm via MSE)
cov_loss = F.mse_loss(input_covs, target_covs, reduction="mean")
# Multivariate skewness loss
skew_loss = F.mse_loss(input_skews, target_skews, reduction="mean")
# Multivariate kurtosis loss
kurt_loss = F.mse_loss(input_kurts, target_kurts, reduction="mean")
# Normalize components if requested (makes loss scale-invariant)
if self.normalize_components:
# Normalize by target magnitude to make scale-invariant
mean_scale = torch.mean(target_means**2) + self.epsilon
cov_scale = torch.mean(target_covs**2) + self.epsilon
skew_scale = torch.mean(target_skews**2) + self.epsilon
kurt_scale = torch.mean(target_kurts**2) + self.epsilon
mean_loss = mean_loss / mean_scale
cov_loss = cov_loss / cov_scale
skew_loss = skew_loss / skew_scale
kurt_loss = kurt_loss / kurt_scale
# Combine components with weights
total_loss = (
self.mean_weight * mean_loss
+ self.cov_weight * cov_loss
+ self.skewness_weight * skew_loss
+ self.kurtosis_weight * kurt_loss
)
return total_loss
def get_component_losses(self, input_traj, target_traj):
"""
Compute and return individual loss components for analysis.
Useful for debugging and understanding which moments are contributing most to the loss.
Parameters:
-----------
input_traj : torch.Tensor
Input trajectory of shape [batch_size, n_steps, dim]
target_traj : torch.Tensor
Target trajectory of shape [batch_size, n_steps, dim]
Returns:
--------
loss_dict : dict
Dictionary containing:
- 'total': total weighted loss
- 'mean': mean loss component
- 'cov': covariance loss component
- 'skewness': skewness loss component
- 'kurtosis': kurtosis loss component
- 'mean_raw': unweighted mean loss
- 'cov_raw': unweighted covariance loss
- 'skewness_raw': unweighted skewness loss
- 'kurtosis_raw': unweighted kurtosis loss
"""
# Compute moments
input_means, input_covs, input_skews, input_kurts = self.compute_moments(
input_traj
)
target_means, target_covs, target_skews, target_kurts = self.compute_moments(
target_traj
)
# Compute raw loss components
mean_loss = F.mse_loss(input_means, target_means, reduction="mean")
cov_loss = F.mse_loss(input_covs, target_covs, reduction="mean")
skew_loss = F.mse_loss(input_skews, target_skews, reduction="mean")
kurt_loss = F.mse_loss(input_kurts, target_kurts, reduction="mean")
# Store raw losses
loss_dict = {
"mean_raw": mean_loss.item(),
"cov_raw": cov_loss.item(),
"skewness_raw": skew_loss.item(),
"kurtosis_raw": kurt_loss.item(),
}
# Normalize if requested
if self.normalize_components:
mean_scale = torch.mean(target_means**2) + self.epsilon
cov_scale = torch.mean(target_covs**2) + self.epsilon
skew_scale = torch.mean(target_skews**2) + self.epsilon
kurt_scale = torch.mean(target_kurts**2) + self.epsilon
mean_loss = mean_loss / mean_scale
cov_loss = cov_loss / cov_scale
skew_loss = skew_loss / skew_scale
kurt_loss = kurt_loss / kurt_scale
# Weighted losses
loss_dict.update(
{
"mean": (self.mean_weight * mean_loss).item(),
"cov": (self.cov_weight * cov_loss).item(),
"skewness": (self.skewness_weight * skew_loss).item(),
"kurtosis": (self.kurtosis_weight * kurt_loss).item(),
}
)
# Total loss
total_loss = (
self.mean_weight * mean_loss
+ self.cov_weight * cov_loss
+ self.skewness_weight * skew_loss
+ self.kurtosis_weight * kurt_loss
)
loss_dict["total"] = total_loss.item()
return loss_dict
# class AttractorMomentLoss(nn.Module):
# """
# A loss function that compares chaotic attractors using multivariate statistical moments
# and distribution properties.
# This approach captures the overall shape, spread, and orientation of the attractor
# without being sensitive to exact point positions. Uses multivariate skewness and
# kurtosis (Mardia's measures) for rotation-invariant comparison.
# Parameters:
# -----------
# mean_weight : float
# Weight for the mean position component
# Default: 0.1
# cov_weight : float
# Weight for the covariance matrix component
# Default: 0.4
# skewness_weight : float
# Weight for the multivariate skewness component
# Default: 0.3
# kurtosis_weight : float
# Weight for the multivariate kurtosis component
# Default: 0.2
# epsilon : float
# Small value for numerical stability
# Default: 1e-8
# """
# def __init__(
# self,
# mean_weight=0.1,
# cov_weight=0.4,
# skewness_weight=0.3,
# kurtosis_weight=0.2,
# epsilon=1e-8,
# ):
# super(AttractorMomentLoss, self).__init__()
# self.mean_weight = mean_weight
# self.cov_weight = cov_weight
# self.skewness_weight = skewness_weight
# self.kurtosis_weight = kurtosis_weight
# self.epsilon = epsilon
# def compute_moments(self, traj):
# """
# Compute multivariate statistical moments of the attractor.
# Parameters:
# -----------
# traj : torch.Tensor
# Trajectory tensor of shape [batch_size, n_steps, 2]
# Returns:
# --------
# means : torch.Tensor
# Mean positions with shape [batch_size, 2]
# covs : torch.Tensor
# Covariance matrices with shape [batch_size, 2, 2]
# skews : torch.Tensor
# Mardia's multivariate skewness with shape [batch_size]
# kurts : torch.Tensor
# Mardia's multivariate kurtosis with shape [batch_size]
# """
# batch_size, n_steps, dim = traj.shape
# device = traj.device
# # Initialize outputs
# means = torch.zeros((batch_size, dim), device=device)
# covs = torch.zeros((batch_size, dim, dim), device=device)
# skews = torch.zeros(batch_size, device=device)
# kurts = torch.zeros(batch_size, device=device)
# # Compute moments for each batch
# for b in range(batch_size):
# # Extract points for this batch
# points = traj[b] # [n_steps, 2]
# # Mean (1st moment)
# means[b] = torch.mean(points, dim=0)
# # Center the data
# centered = points - means[b].unsqueeze(0)
# # Covariance (2nd moment)
# cov = torch.matmul(centered.transpose(0, 1), centered) / (n_steps - 1)
# covs[b] = cov
# # Add regularization for numerical stability
# cov_reg = cov + self.epsilon * torch.eye(dim, device=device)
# # Compute inverse covariance for higher moments
# try:
# cov_inv = torch.linalg.inv(cov_reg)
# except:
# # Fallback to pseudo-inverse if singular
# cov_inv = torch.linalg.pinv(cov_reg)
# # Mardia's multivariate skewness (3rd moment)
# # β₁,d = (1/n²) Σᵢ Σⱼ [(xᵢ - μ)ᵀ Σ⁻¹ (xⱼ - μ)]³
# mahal_matrix = torch.matmul(
# torch.matmul(centered, cov_inv), centered.transpose(0, 1)
# )
# skews[b] = torch.mean(mahal_matrix**3)
# # Mardia's multivariate kurtosis (4th moment)
# # β₂,d = (1/n) Σᵢ [(xᵢ - μ)ᵀ Σ⁻¹ (xᵢ - μ)]²
# mahal_distances_sq = torch.sum(
# centered * torch.matmul(centered, cov_inv), dim=1
# )
# kurts[b] = torch.mean(mahal_distances_sq**2)
# return means, covs, skews, kurts
# def forward(self, input_traj, target_traj):
# """
# Compute the moment-based loss between input and target attractors.
# Parameters:
# -----------
# input_traj : torch.Tensor
# Input trajectory of shape [batch_size, n_steps, 2]
# target_traj : torch.Tensor
# Target trajectory of shape [batch_size, n_steps, 2]
# Returns:
# --------
# loss : torch.Tensor
# Computed moment-based loss
# """
# # Validate input shapes
# if input_traj.dim() != 3 or target_traj.dim() != 3:
# raise ValueError(
# "Input and target trajectories must be 3D tensors [batch, steps, 2]"
# )
# if input_traj.shape[0] != target_traj.shape[0]:
# raise ValueError(
# f"Batch sizes don't match: {input_traj.shape[0]} vs {target_traj.shape[0]}"
# )
# if input_traj.shape[2] != 2 or target_traj.shape[2] != 2:
# raise ValueError("Last dimension must be 2 for 2D attractors")
# # Compute moments
# input_means, input_covs, input_skews, input_kurts = self.compute_moments(
# input_traj
# )
# target_means, target_covs, target_skews, target_kurts = self.compute_moments(
# target_traj
# )
# # Compute loss components
# # Mean loss (Euclidean distance)
# mean_loss = F.mse_loss(input_means, target_means)
# # Covariance loss (Frobenius norm via MSE)
# cov_loss = F.mse_loss(input_covs, target_covs)
# # Multivariate skewness loss
# # Note: skews is now [batch_size] instead of [batch_size, 2]
# skew_loss = F.mse_loss(input_skews, target_skews)
# # Multivariate kurtosis loss
# # Note: kurts is now [batch_size] instead of [batch_size, 2]
# kurt_loss = F.mse_loss(input_kurts, target_kurts)
# # Combine components
# total_loss = (
# self.mean_weight * mean_loss
# + self.cov_weight * cov_loss
# + self.skewness_weight * skew_loss
# + self.kurtosis_weight * kurt_loss
# )
# return total_loss
# class AttractorMomentLoss(nn.Module):
# """
# A loss function that compares chaotic attractors using statistical moments
# and distribution properties.
# This approach captures the overall shape, spread, and orientation of the attractor
# without being sensitive to exact point positions.
# Parameters:
# -----------
# mean_weight : float
# Weight for the mean position component
# Default: 0.1
# cov_weight : float
# Weight for the covariance matrix component
# Default: 0.4
# skewness_weight : float
# Weight for the skewness component
# Default: 0.3
# kurtosis_weight : float
# Weight for the kurtosis component
# Default: 0.2
# epsilon : float
# Small value for numerical stability
# Default: 1e-8
# """
# def __init__(
# self,
# mean_weight=0.1,
# cov_weight=0.4,
# skewness_weight=0.3,
# kurtosis_weight=0.2,
# epsilon=1e-8,
# ):
# super(AttractorMomentLoss, self).__init__()
# self.mean_weight = mean_weight
# self.cov_weight = cov_weight
# self.skewness_weight = skewness_weight
# self.kurtosis_weight = kurtosis_weight
# self.epsilon = epsilon
# def compute_moments(self, traj):
# """
# Compute statistical moments of the attractor.
# Parameters:
# -----------
# traj : torch.Tensor
# Trajectory tensor of shape [batch_size, n_steps, 2]
# Returns:
# --------
# means : torch.Tensor
# Mean positions with shape [batch_size, 2]
# covs : torch.Tensor
# Covariance matrices with shape [batch_size, 2, 2]
# skews : torch.Tensor
# Skewness values with shape [batch_size, 2]
# kurts : torch.Tensor
# Kurtosis values with shape [batch_size, 2]
# """
# batch_size, n_steps, dim = traj.shape
# device = traj.device
# # Initialize outputs
# means = torch.zeros((batch_size, dim), device=device)
# covs = torch.zeros((batch_size, dim, dim), device=device)
# skews = torch.zeros((batch_size, dim), device=device)
# kurts = torch.zeros((batch_size, dim), device=device)
# # Compute moments for each batch
# for b in range(batch_size):
# # Extract points for this batch
# points = traj[b] # [n_steps, 2]
# # Mean (1st moment)
# means[b] = torch.mean(points, dim=0)
# # Center the data
# centered = points - means[b].unsqueeze(0)
# # Covariance (2nd moment)
# cov = torch.matmul(centered.transpose(0, 1), centered) / (
# n_steps - 1 + self.epsilon
# )
# covs[b] = cov
# # Compute standard deviations for normalization
# std = torch.sqrt(torch.diag(cov) + self.epsilon)
# # Compute higher moments for each dimension
# for d in range(dim):
# # Normalized centered data
# norm_centered_d = centered[:, d] / (std[d] + self.epsilon)
# # Skewness (3rd moment)
# skews[b, d] = torch.mean(norm_centered_d**3)
# # Kurtosis (4th moment)
# kurts[b, d] = torch.mean(norm_centered_d**4) - 3.0 # Excess kurtosis
# return means, covs, skews, kurts
# def forward(self, input_traj, target_traj):
# """
# Compute the moment-based loss between input and target attractors.
# Parameters:
# -----------
# input_traj : torch.Tensor
# Input trajectory of shape [batch_size, n_steps, 2]
# target_traj : torch.Tensor
# Target trajectory of shape [batch_size, n_steps, 2]
# Returns:
# --------
# loss : torch.Tensor
# Computed moment-based loss
# """
# # Validate input shapes
# if input_traj.dim() != 3 or target_traj.dim() != 3:
# raise ValueError(
# "Input and target trajectories must be 3D tensors [batch, steps, 2]"
# )
# if input_traj.shape[0] != target_traj.shape[0]:
# raise ValueError(
# f"Batch sizes don't match: {input_traj.shape[0]} vs {target_traj.shape[0]}"
# )
# if input_traj.shape[2] != 2 or target_traj.shape[2] != 2:
# raise ValueError("Last dimension must be 2 for 2D attractors")
# # Compute moments
# input_means, input_covs, input_skews, input_kurts = self.compute_moments(
# input_traj
# )
# target_means, target_covs, target_skews, target_kurts = self.compute_moments(
# target_traj
# )
# # Compute loss components
# # Mean loss
# mean_loss = F.mse_loss(input_means, target_means)
# # Covariance loss
# cov_loss = F.mse_loss(input_covs, target_covs)
# # Skewness loss
# skew_loss = F.mse_loss(input_skews, target_skews)
# # Kurtosis loss
# kurt_loss = F.mse_loss(input_kurts, target_kurts)
# # Combine components
# total_loss = (
# self.mean_weight * mean_loss
# + self.cov_weight * cov_loss
# + self.skewness_weight * skew_loss
# + self.kurtosis_weight * kurt_loss
# )
# return total_loss
class CombinedAttractorLoss(nn.Module):
"""
A comprehensive loss function that combines histogram-based and moment-based
approaches for comparing chaotic attractors.
Parameters:
-----------
histogram_weight : float
Weight for the histogram-based component
Default: 0.6
moment_weight : float
Weight for the moment-based component
Default: 0.4
histogram_kwargs : dict
Arguments for the histogram loss component
moment_kwargs : dict
Arguments for the moment loss component
"""
def __init__(
self,
histogram_weight=0.6,
moment_weight=0.4,
histogram_kwargs=None,
moment_kwargs=None,
):
super(CombinedAttractorLoss, self).__init__()
self.histogram_weight = histogram_weight
self.moment_weight = moment_weight
# Initialize components with default parameters if none provided
histogram_kwargs = histogram_kwargs or {}
moment_kwargs = moment_kwargs or {}
self.histogram_loss = AttractorHistogramLoss(**histogram_kwargs)
self.moment_loss = AttractorMomentLoss(**moment_kwargs)
def compte_single(self, input_traj, target_traj):
"""
Compute the single loss between input and target attractors.
Parameters:
-----------
input_traj : torch.Tensor
Input trajectory of shape [batch_size, n_steps, 2]
target_traj : torch.Tensor
Target trajectory of shape [batch_size, n_steps, 2]
Returns:
--------
loss : torch.Tensor
Computed single loss
"""
# Compute individual loss components
hist_loss = self.histogram_loss(input_traj, target_traj)
moment_loss = self.moment_loss(input_traj, target_traj)
# Combine components
total_loss = (
self.histogram_weight * hist_loss + self.moment_weight * moment_loss
)
return total_loss
def forward(self, input_traj: torch.Tensor, target_traj: torch.Tensor):
"""
Compute the combined loss between input and target attractors.
Parameters:
-----------
input_traj : torch.Tensor
Input trajectory of shape [batch_size, n_steps, 2]
target_traj : torch.Tensor
Target trajectory of shape [batch_size, n_steps, 2]
Returns:
--------
loss : torch.Tensor
Computed combined loss
"""
total_loss = 0.0
for i1, i2 in combinations(range(input_traj.shape[1]), 2):
if i1 == i2:
continue
in_tr = torch.unsqueeze(input_traj[:, [i1, i2]], dim=0)
sol_tr = torch.unsqueeze(target_traj[:, [i1, i2]], dim=0)
total_loss += self.compte_single(in_tr, sol_tr)
return total_loss
class CombinedLoss(nn.Module):
def __init__(self, losses: List[nn.Module], weights: np.ndarray):
assert len(losses) == len(weights), (
"Number of losses must match number of weights"
)
super(CombinedLoss, self).__init__()
self.losses = nn.ModuleList(losses)
self.weights = torch.Tensor(weights, device="cpu")
def forward(self, true_vals, pred_vals):
"""
Compute the combined loss based on the individual losses and their weights.
Parameters:
-----------
true_vals : torch.Tensor
True values of shape [batch_size, n_steps, n_features]
pred_vals : torch.Tensor
Predicted values of shape [batch_size, n_steps, n_features]
Returns:
--------
total_loss : torch.Tensor
Combined loss value
"""
total_loss = 0.0
for loss, weight in zip(self.losses, self.weights):
total_loss += weight * loss(true_vals, pred_vals)
return total_loss
def get_vector(self, true_vals, pred_vals) -> np.ndarray:
"""
Compute the individual losses and their weights.
Parameters:
-----------
true_vals : torch.Tensor
True values of shape [batch_size, n_steps, n_features]
pred_vals : torch.Tensor
Predicted values of shape [batch_size, n_steps, n_features]
Returns:
--------
loss_vector : torch.Tensor
Vector of individual loss values
"""
loss_vector = []
for loss, weight in zip(self.losses, self.weights):
loss_vector.append(loss(true_vals, pred_vals))
return torch.stack(loss_vector).detach().cpu().numpy()