multi_loss.py
from itertools import combinations
from typing import List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class ChaoticFFTLoss(nn.Module):
    """
    A specialized loss function for comparing frequency spectra of chaotic signals
    generated by neural ODEs.

    This loss function is designed to handle the unique characteristics of chaotic systems:
    1. High sensitivity to initial conditions
    2. Potentially non-stationary behavior
    3. Complex frequency structures

    Parameters:
    -----------
    window_size : int, optional
        Size of the sliding window for short-time Fourier transform.
        Set to None to use the full signal length.
        Default: None

    window_overlap : float, optional
        Overlap between consecutive windows as a fraction (0.0-0.9).
        Default: 0.5

    spectral_norm : str, optional
        Type of normalization to apply to spectra:
        'none': no normalization
        'l1': L1 normalization (sum to 1)
        'l2': L2 normalization (sum of squares to 1)
        'max': normalize by maximum value
        Default: 'none'

    focus_range : tuple or None, optional
        Tuple of (min_freq_idx, max_freq_idx) to focus on specific frequency bands.
        Default: None (use all frequencies)

    magnitude_weight : float, optional
        Weight for comparing magnitude spectra.
        Default: 1.0

    phase_weight : float, optional
        Weight for comparing phase spectra.
        Default: 0.0

    log_magnitude : bool, optional
        Whether to use log magnitude spectra.
        Default: True

    divergence_type : str, optional
        Type of divergence measure to use:
        'mse': Mean squared error (L2 distance)
        'mae': Mean absolute error (L1 distance)
        'kl': Kullback-Leibler divergence (for normalized spectra only)
        'js': Jensen-Shannon divergence (for normalized spectra only)
        'wasserstein': Wasserstein distance (for normalized spectra only)
        Default: 'mse'

    epsilon : float, optional
        Small value to add before taking log.
        Default: 1e-8
    """

    def __init__(
        self,
        window_size=None,
        window_overlap=0.5,
        spectral_norm="none",
        focus_range=None,
        magnitude_weight=1.0,
        phase_weight=0.0,
        log_magnitude=True,
        divergence_type="mse",
        epsilon=1e-8,
    ):
        super(ChaoticFFTLoss, self).__init__()
        self.window_size = window_size
        self.window_overlap = window_overlap
        self.spectral_norm = spectral_norm
        self.focus_range = focus_range
        self.magnitude_weight = magnitude_weight
        self.phase_weight = phase_weight
        self.log_magnitude = log_magnitude
        self.divergence_type = divergence_type
        self.epsilon = epsilon

        # Validate inputs
        if spectral_norm not in ["none", "l1", "l2", "max"]:
            raise ValueError(f"Invalid spectral_norm: {spectral_norm}")
        if divergence_type not in ["mse", "mae", "kl", "js", "wasserstein"]:
            raise ValueError(f"Invalid divergence_type: {divergence_type}")
        if divergence_type in ["kl", "js", "wasserstein"] and spectral_norm == "none":
            raise ValueError(
                f"{divergence_type} divergence requires normalized spectra"
            )

    def _normalize_spectrum(self, spectrum):
        """Apply the selected normalization to a spectrum."""
        if self.spectral_norm == "none":
            return spectrum
        elif self.spectral_norm == "l1":
            # L1 normalization (sum to 1)
            norm = torch.sum(spectrum, dim=-1, keepdim=True) + self.epsilon
            return spectrum / norm
        elif self.spectral_norm == "l2":
            # L2 normalization (sum of squares to 1)
            norm = (
                torch.sqrt(torch.sum(spectrum**2, dim=-1, keepdim=True)) + self.epsilon
            )
            return spectrum / norm
        elif self.spectral_norm == "max":
            # Normalize by maximum value
            norm = torch.max(spectrum, dim=-1, keepdim=True)[0] + self.epsilon
            return spectrum / norm

    def _compute_divergence(self, input_spectrum, target_spectrum):
        """Compute the specified divergence between spectra."""
        if self.divergence_type == "mse":
            return F.mse_loss(input_spectrum, target_spectrum)
        elif self.divergence_type == "mae":
            return F.l1_loss(input_spectrum, target_spectrum)
        elif self.divergence_type == "kl":
            # KL divergence for normalized spectra (treats them as distributions)
            # Add epsilon to avoid log(0)
            return F.kl_div(
                torch.log(input_spectrum + self.epsilon),
                target_spectrum,
                reduction="batchmean",
            )
        elif self.divergence_type == "js":
            # Jensen-Shannon divergence
            m = 0.5 * (input_spectrum + target_spectrum)
            kl1 = F.kl_div(
                torch.log(input_spectrum + self.epsilon), m, reduction="batchmean"
            )
            kl2 = F.kl_div(
                torch.log(target_spectrum + self.epsilon), m, reduction="batchmean"
            )
            return 0.5 * (kl1 + kl2)
        elif self.divergence_type == "wasserstein":
            # Approximate 1D Wasserstein distance for normalized spectra
            # Using cumulative distribution approximation
            input_cdf = torch.cumsum(input_spectrum, dim=-1)
            target_cdf = torch.cumsum(target_spectrum, dim=-1)
            return torch.mean(torch.abs(input_cdf - target_cdf))

    def compute_fft(self, signal):
        """
        Compute FFT of the signal, with windowing if specified.
        Returns magnitude and phase.
        """
        batch_size = signal.shape[0]

        if self.window_size is None:
            # Full signal FFT
            fft = torch.fft.fft(signal)
            # Get only the positive frequencies (first half)
            n = signal.shape[-1]
            fft = fft[..., : n // 2 + 1]

            magnitude = torch.abs(fft)
            phase = torch.angle(fft) if self.phase_weight > 0 else None

            # Apply log if needed
            if self.log_magnitude:
                magnitude = torch.log(magnitude + self.epsilon)

            # Apply normalization
            magnitude = self._normalize_spectrum(magnitude)

            # Apply focus range if specified
            if self.focus_range is not None:
                min_idx, max_idx = self.focus_range
                magnitude = magnitude[..., min_idx:max_idx]
                if phase is not None:
                    phase = phase[..., min_idx:max_idx]

            return magnitude, phase
        else:
            # Short-time Fourier transform with overlapping windows
            step_size = int(self.window_size * (1 - self.window_overlap))
            signal_length = signal.shape[-1]

            # Calculate number of windows
            num_windows = max(1, (signal_length - self.window_size) // step_size + 1)

            # Initialize tensors to store results
            window_magnitudes = []
            window_phases = []

            # Apply windowing
            window_func = torch.hann_window(self.window_size, device=signal.device)

            for i in range(num_windows):
                start_idx = i * step_size
                end_idx = start_idx + self.window_size

                if end_idx > signal_length:
                    # Pad the last window if needed
                    padded = F.pad(
                        signal[..., start_idx:], (0, end_idx - signal_length)
                    )
                    windowed = padded * window_func
                else:
                    windowed = signal[..., start_idx:end_idx] * window_func

                # Compute FFT for this window
                fft = torch.fft.fft(windowed)

                # Get only the positive frequencies
                n = windowed.shape[-1]
                fft = fft[..., : n // 2 + 1]

                window_mag = torch.abs(fft)

                # Apply log if needed
                if self.log_magnitude:
                    window_mag = torch.log(window_mag + self.epsilon)

                # Apply normalization
                window_mag = self._normalize_spectrum(window_mag)

                # Apply focus range if specified
                if self.focus_range is not None:
                    min_idx, max_idx = self.focus_range
                    window_mag = window_mag[..., min_idx:max_idx]

                window_magnitudes.append(window_mag)

                if self.phase_weight > 0:
                    window_phase = torch.angle(fft)
                    if self.focus_range is not None:
                        window_phase = window_phase[..., min_idx:max_idx]
                    window_phases.append(window_phase)

            # Stack all windows
            magnitude = torch.stack(window_magnitudes, dim=1)  # [batch, windows, freq]
            phase = torch.stack(window_phases, dim=1) if self.phase_weight > 0 else None

            return magnitude, phase

    def forward(self, input_signal, target_signal):
        """
        Compute the loss between input and target chaotic signals.

        Parameters:
        -----------
        input_signal : torch.Tensor
            Predicted signal from neural ODE, shape [batch_size, signal_length]

        target_signal : torch.Tensor
            Target chaotic signal, shape [batch_size, signal_length]

        Returns:
        --------
        loss : torch.Tensor
            The computed spectral loss
        """
        # Handle multi-dimensional signals
        original_shape = input_signal.shape
        if len(original_shape) > 2:
            # For multi-dimensional signals, flatten all but batch dimension
            input_signal = input_signal.reshape(original_shape[0], -1)
            target_signal = target_signal.reshape(target_signal.shape[0], -1)

        # Ensure signals have the same shape
        if input_signal.shape != target_signal.shape:
            raise ValueError(
                f"Input shape {input_signal.shape} and target shape {target_signal.shape} must match"
            )

        # Compute FFT spectra
        input_magnitude, input_phase = self.compute_fft(input_signal)
        target_magnitude, target_phase = self.compute_fft(target_signal)

        # Compute magnitude loss
        magnitude_loss = self._compute_divergence(input_magnitude, target_magnitude)

        # Initialize with magnitude loss
        loss = self.magnitude_weight * magnitude_loss

        # Add phase loss if needed
        if self.phase_weight > 0:
            # Handle phase wrapping by computing the minimum distance in the circle
            phase_diff = torch.abs(input_phase - target_phase)
            phase_diff = torch.min(phase_diff, 2 * torch.pi - phase_diff)

            phase_loss = torch.mean(phase_diff**2)
            loss = loss + self.phase_weight * phase_loss

        return loss


class MainPeakLoss(nn.Module):
    """
    A focused loss function that specifically targets the main peak in a spectrum.

    This loss is ideal when the primary frequency component carries the most important
    information about the chaotic system.

    Parameters:
    -----------
    main_peak_width_factor : float
        Width around the main peak to consider as a factor of the total spectrum width.
        For example, 0.1 means consider 10% of the spectrum width around the main peak.
        Default: 0.1

    adaptive_width : bool
        If True, adapts the width based on the actual peak width.
        Default: True

    position_weight : float
        Weight for the peak position component of the loss.
        Default: 0.6

    shape_weight : float
        Weight for the peak shape component of the loss.
        Default: 0.4

    log_magnitude : bool
        Whether to use log magnitudes for comparison.
        Default: True

    epsilon : float
        Small value added to avoid log(0).
        Default: 1e-8
    """

    def __init__(
        self,
        main_peak_width_factor=0.1,
        adaptive_width=True,
        position_weight=0.6,
        shape_weight=0.4,
        log_magnitude=True,
        epsilon=1e-8,
        *args,
        **kwargs,
    ):
        super(MainPeakLoss, self).__init__()
        self.main_peak_width_factor = main_peak_width_factor
        self.adaptive_width = adaptive_width
        self.position_weight = position_weight
        self.shape_weight = shape_weight
        self.log_magnitude = log_magnitude
        self.epsilon = epsilon

    def find_main_peak(self, spectrum):
        """
        Find the main peak in a spectrum.

        Parameters:
        -----------
        spectrum : torch.Tensor
            Spectrum with shape [batch_size, freq_bins]

        Returns:
        --------
        peak_indices : torch.Tensor
            Indices of main peaks with shape [batch_size]

        peak_widths : torch.Tensor
            Estimated widths of main peaks with shape [batch_size]
        """
        batch_size, freq_bins = spectrum.shape
        device = spectrum.device

        # Initialize outputs
        peak_indices = torch.zeros(batch_size, dtype=torch.long, device=device)
        peak_widths = torch.zeros(batch_size, device=device)

        # Find main peak for each batch item
        for b in range(batch_size):
            spec = spectrum[b]

            # Skip the DC component (0 Hz) if we have enough bins
            start_idx = 1 if freq_bins > 10 else 0

            # Find the maximum value (main peak)
            if start_idx < freq_bins:
                max_idx = torch.argmax(spec[start_idx:]) + start_idx
                peak_indices[b] = max_idx

                # Estimate peak width
                if self.adaptive_width:
                    max_val = spec[max_idx]
                    half_height = max_val / 2

                    # Find left bound (where value drops below half height)
                    left_bound = max_idx
                    for i in range(max_idx - 1, -1, -1):
                        if spec[i] < half_height:
                            left_bound = i
                            break

                    # Find right bound
                    right_bound = max_idx
                    for i in range(max_idx + 1, freq_bins):
                        if spec[i] < half_height:
                            right_bound = i
                            break

                    # Calculate width
                    peak_widths[b] = right_bound - left_bound
                else:
                    # Use fixed width based on factor
                    peak_widths[b] = max(
                        1, int(freq_bins * self.main_peak_width_factor)
                    )

        return peak_indices, peak_widths

    def extract_peak_regions(self, spectrum, peak_indices, peak_widths):
        """
        Extract regions around peaks for comparison.

        Parameters:
        -----------
        spectrum : torch.Tensor
            Spectrum with shape [batch_size, freq_bins]

        peak_indices : torch.Tensor
            Indices of peaks with shape [batch_size]

        peak_widths : torch.Tensor
            Widths of peaks with shape [batch_size]

        Returns:
        --------
        peak_regions : list of torch.Tensor
            List of tensors containing peak regions

        region_indices : list of tuples
            List of (start, end) indices for each region
        """
        batch_size, freq_bins = spectrum.shape
        device = spectrum.device

        peak_regions = []
        region_indices = []

        for b in range(batch_size):
            peak_idx = peak_indices[b].item()

            # Determine region width
            if self.adaptive_width:
                width = max(1, int(peak_widths[b].item()))
            else:
                width = max(1, int(freq_bins * self.main_peak_width_factor))

            # Calculate region bounds
            half_width = width // 2
            start_idx = max(0, peak_idx - half_width)
            end_idx = min(freq_bins, peak_idx + half_width + 1)

            # Extract region
            region = spectrum[b, start_idx:end_idx]
            peak_regions.append(region)
            region_indices.append((start_idx, end_idx))

        return peak_regions, region_indices

    def compute_peak_loss(
        self, input_regions, input_indices, target_regions, target_indices
    ):
        """
        Compute loss between peak regions.

        Parameters:
        -----------
        input_regions : list of torch.Tensor
            Peak regions from input signals

        input_indices : list of tuples
            (start, end) indices for input regions

        target_regions : list of torch.Tensor
            Peak regions from target signals

        target_indices : list of tuples
            (start, end) indices for target regions

        Returns:
        --------
        position_loss : torch.Tensor
            Loss component for peak positions

        shape_loss : torch.Tensor
            Loss component for peak shapes
        """
        batch_size = len(input_regions)
        device = input_regions[0].device if batch_size > 0 else torch.device("cpu")

        position_losses = torch.zeros(batch_size, device=device)
        shape_losses = torch.zeros(batch_size, device=device)

        for b in range(batch_size):
            # Calculate position loss (normalized by spectrum size)
            in_start, in_end = input_indices[b]
            # Convert to tensor
            in_center = torch.tensor((in_start + in_end) / 2, device=device)

            tgt_start, tgt_end = target_indices[b]
            # Convert to tensor
            tgt_center = torch.tensor((tgt_start + tgt_end) / 2, device=device)

            # Normalize by spectrum size for relative position
            # Use input_regions[b] length as a tensor value
            region_length = torch.tensor(
                len(input_regions[b]), dtype=torch.float, device=device
            )
            position_losses[b] = torch.abs(in_center - tgt_center) / region_length

            # Calculate shape loss
            # We need to handle different region sizes
            in_region = input_regions[b]
            tgt_region = target_regions[b]

            # Interpolate to match sizes if different
            if len(in_region) != len(tgt_region):
                # Use the larger size for interpolation
                target_size = max(len(in_region), len(tgt_region))

                # Interpolate both to the target size
                in_region_interp = (
                    F.interpolate(
                        in_region.unsqueeze(0).unsqueeze(0),
                        size=target_size,
                        mode="linear",
                        align_corners=False,
                    )
                    .squeeze(0)
                    .squeeze(0)
                )

                tgt_region_interp = (
                    F.interpolate(
                        tgt_region.unsqueeze(0).unsqueeze(0),
                        size=target_size,
                        mode="linear",
                        align_corners=False,
                    )
                    .squeeze(0)
                    .squeeze(0)
                )

                # Compute MSE between interpolated regions
                shape_losses[b] = F.mse_loss(in_region_interp, tgt_region_interp)
            else:
                # Compute MSE directly
                shape_losses[b] = F.mse_loss(in_region, tgt_region)

        # Compute mean losses
        position_loss = torch.mean(position_losses)
        shape_loss = torch.mean(shape_losses)

        return position_loss, shape_loss

    def forward(self, input_signal, target_signal):
        """
        Compute the main peak focused loss between input and target signals.

        Parameters:
        -----------
        input_signal : torch.Tensor
            Input signal of shape [batch_size, signal_length]

        target_signal : torch.Tensor
            Target signal of shape [batch_size, signal_length]

        Returns:
        --------
        loss : torch.Tensor
            Computed loss focused on main spectral peak
        """
        # Handle multi-dimensional signals
        original_shape = input_signal.shape
        if len(original_shape) > 2:
            # Reshape to [batch_size, -1]
            input_signal = input_signal.reshape(original_shape[0], -1)
            target_signal = target_signal.reshape(target_signal.shape[0], -1)

        # Check shape match
        if input_signal.shape != target_signal.shape:
            raise ValueError(
                f"Input shape {input_signal.shape} must match target shape {target_signal.shape}"
            )

        batch_size, signal_length = input_signal.shape

        # Compute FFT
        input_fft = torch.fft.fft(input_signal)
        target_fft = torch.fft.fft(target_signal)

        # Extract positive frequencies
        n = signal_length
        input_fft = input_fft[:, : n // 2 + 1]
        target_fft = target_fft[:, : n // 2 + 1]

        # Calculate magnitude spectra
        input_mag = torch.abs(input_fft)
        target_mag = torch.abs(target_fft)

        # Apply log if requested
        if self.log_magnitude:
            input_mag = torch.log(input_mag + self.epsilon)
            target_mag = torch.log(target_mag + self.epsilon)

        # Find main peaks
        input_peak_indices, input_peak_widths = self.find_main_peak(input_mag)
        target_peak_indices, target_peak_widths = self.find_main_peak(target_mag)

        # Extract regions around peaks
        input_regions, input_region_indices = self.extract_peak_regions(
            input_mag, input_peak_indices, input_peak_widths
        )
        target_regions, target_region_indices = self.extract_peak_regions(
            target_mag, target_peak_indices, target_peak_widths
        )

        # Compute loss components
        position_loss, shape_loss = self.compute_peak_loss(
            input_regions, input_region_indices, target_regions, target_region_indices
        )

        # Combine losses
        total_loss = (
            self.position_weight * position_loss + self.shape_weight * shape_loss
        )

        return total_loss


class MultiscaleChaoticPeakLoss(nn.Module):
    """
    A comprehensive loss function that combines:
    1. Main peak analysis
    2. Multiple peak analysis
    3. Overall spectral shape
    4. Time-domain comparison

    This is ideal for chaotic neural ODEs where both the dominant modes and
    overall dynamical behavior need to be matched.

    Parameters:
    -----------
    main_peak_weight : float
        Weight for the main peak component.
        Default: 0.4

    multi_peak_weight : float
        Weight for the multiple peaks component.
        Default: 0.3

    spectral_shape_weight : float
        Weight for the overall spectral shape component.
        Default: 0.2

    time_domain_weight : float
        Weight for the time-domain comparison component.
        Default: 0.1

    n_peaks : int
        Number of peaks to consider for multi-peak analysis.
        Default: 3
    """

    def __init__(
        self,
        main_peak_weight=0.4,
        multi_peak_weight=0.3,
        spectral_shape_weight=0.2,
        time_domain_weight=0.1,
        n_peaks=3,
        **kwargs,
    ):
        super(MultiscaleChaoticPeakLoss, self).__init__()

        self.main_peak_weight = main_peak_weight
        self.multi_peak_weight = multi_peak_weight
        self.spectral_shape_weight = spectral_shape_weight
        self.time_domain_weight = time_domain_weight

        # Create individual loss components
        self.main_peak_loss = MainPeakLoss(**kwargs)
        self.multi_peak_loss = PeakSpectraLoss(n_peaks=n_peaks, **kwargs)
        self.spectral_loss = ChaoticFFTLoss(**kwargs)

    def forward(self, input_signal, target_signal):
        """
        Compute the multi-component loss between input and target chaotic signals.

        Parameters:
        -----------
        input_signal : torch.Tensor
            Predicted signal from neural ODE, shape [batch_size, signal_length]

        target_signal : torch.Tensor
            Target chaotic signal, shape [batch_size, signal_length]

        Returns:
        --------
        loss : torch.Tensor
            The computed multi-component loss
        """
        # Compute individual loss components
        main_loss = self.main_peak_loss(input_signal, target_signal)
        multi_loss = self.multi_peak_loss(input_signal, target_signal)
        spectral_loss = self.spectral_loss(input_signal, target_signal)

        # Time-domain loss (basic MSE)
        time_loss = F.mse_loss(input_signal, target_signal)

        # Combine all loss components
        total_loss = (
            self.main_peak_weight * main_loss
            + self.multi_peak_weight * multi_loss
            + self.spectral_shape_weight * spectral_loss
            + self.time_domain_weight * time_loss
        )

        return total_loss


# Example usage with a neural ODE chaotic circuit
def example_with_neural_ode_chaotic_circuit():
    """
    Example showing how to use the peak-focused loss functions with a neural ODE
    model that simulates a chaotic circuit.
    """
    import torch
    import torch.nn as nn
    from torchdiffeq import odeint

    class ChaoticCircuitODE(nn.Module):
        """
        Neural ODE model of a chaotic circuit (e.g., Chua's circuit).
        """

        def __init__(self, hidden_dim=32):
            super(ChaoticCircuitODE, self).__init__()

            # Neural network part for learning chaotic dynamics
            self.net = nn.Sequential(
                nn.Linear(3, hidden_dim),
                nn.Tanh(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.Tanh(),
                nn.Linear(hidden_dim, 3),
            )

            # Parameters similar to Chua's circuit
            self.alpha = nn.Parameter(torch.tensor(9.0))
            self.beta = nn.Parameter(torch.tensor(14.0))
            self.gamma = nn.Parameter(torch.tensor(0.1))

        def chua_nonlinearity(self, x):
            """Nonlinear function similar to Chua's diode"""
            return torch.tanh(x) - 0.2 * torch.tanh(3 * x)

        def forward(self, t, state):
            """ODE function for the chaotic circuit"""
            x, y, z = state[:, 0], state[:, 1], state[:, 2]

            # Chua's circuit-like dynamics
            dx = self.alpha * (y - x - self.chua_nonlinearity(x))
            dy = x - y + z
            dz = -self.beta * y - self.gamma * z

            # Base dynamics
            dstate = torch.stack([dx, dy, dz], dim=1)

            # Neural network contribution (can learn additional dynamics)
            nn_contrib = self.net(state)

            return dstate + 0.01 * nn_contrib

    # Create the model
    model = ChaoticCircuitODE()

    # Generate some example data
    batch_size = 8
    initial_state = torch.randn(batch_size, 3) * 0.1
    t = torch.linspace(0, 10, 1000)

    # Generate target trajectories
    with torch.no_grad():
        trajectories = odeint(model, initial_state, t)
        # Shape: [time_steps, batch_size, 3]
        trajectories = trajectories.permute(1, 0, 2)
        # Now shape: [batch_size, time_steps, 3]

    # Use the first component as the signal
    target_signals = trajectories[:, :, 0]

    # Create slightly perturbed initial conditions
    perturbed_initial = initial_state + torch.randn_like(initial_state) * 0.01

    # Create optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Create the loss function - focusing on peaks
    peak_loss = MultiscaleChaoticPeakLoss(
        main_peak_weight=0.4,
        multi_peak_weight=0.3,
        spectral_shape_weight=0.2,
        time_domain_weight=0.1,
        n_peaks=3,
        log_magnitude=True,
    )

    # Training loop (simplified example)
    for epoch in range(10):
        optimizer.zero_grad()

        # Forward pass with perturbed initial conditions
        pred_trajectories = odeint(model, perturbed_initial, t)
        pred_trajectories = pred_trajectories.permute(1, 0, 2)
        pred_signals = pred_trajectories[:, :, 0]

        # Compute loss
        loss = peak_loss(pred_signals, target_signals)

        # Backward pass
        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

    return model, peak_loss


class PeakSpectraLoss(nn.Module):
    """
    A specialized loss function that focuses on comparing multiple peaks
    in frequency spectra of signals from chaotic systems.

    Parameters:
    -----------
    n_peaks : int
        Number of dominant peaks to consider in the comparison.
        Default: 3

    peak_match_type : str
        How to match peaks between input and target:
        'position': Compare peaks at the same frequency positions
        'magnitude': Match peaks by magnitude (largest to largest, etc.)
        'nearest': Match each peak to the nearest peak in the other spectrum
        Default: 'nearest'

    position_weight : float
        Weight for comparing peak positions.
        Default: 0.7

    magnitude_weight : float
        Weight for comparing peak magnitudes.
        Default: 0.3

    log_magnitude : bool
        Whether to use log magnitudes for comparison.
        Default: True

    epsilon : float
        Small value added to avoid log(0).
        Default: 1e-8
    """

    def __init__(
        self,
        n_peaks=3,
        peak_match_type="nearest",
        position_weight=0.7,
        magnitude_weight=0.3,
        log_magnitude=True,
        epsilon=1e-8,
        *args,
        **kwargs,
    ):
        super(PeakSpectraLoss, self).__init__()
        self.n_peaks = n_peaks
        self.peak_match_type = peak_match_type
        self.position_weight = position_weight
        self.magnitude_weight = magnitude_weight
        self.log_magnitude = log_magnitude
        self.epsilon = epsilon

    def find_peaks(self, spectrum):
        """
        Find the dominant peaks in a spectrum.

        Parameters:
        -----------
        spectrum : torch.Tensor
            Spectrum with shape [batch_size, freq_bins]

        Returns:
        --------
        peak_indices : torch.Tensor
            Indices of peaks with shape [batch_size, n_peaks]

        peak_values : torch.Tensor
            Magnitude of peaks with shape [batch_size, n_peaks]
        """
        batch_size, freq_bins = spectrum.shape
        device = spectrum.device

        # Initialize outputs with default values
        peak_indices = torch.zeros(
            (batch_size, self.n_peaks), dtype=torch.long, device=device
        )
        peak_values = torch.zeros((batch_size, self.n_peaks), device=device)

        # Process each batch item separately
        for b in range(batch_size):
            # Get this spectrum
            spec = spectrum[b]

            # Need at least 3 points to detect a peak
            if freq_bins < 3:
                continue

            # Create a mask for local maxima
            is_peak = torch.zeros(freq_bins, dtype=torch.bool, device=device)
            for i in range(1, freq_bins - 1):
                if spec[i] > spec[i - 1] and spec[i] > spec[i + 1]:
                    is_peak[i] = True

            # Get peak indices and values
            peak_idx = torch.where(is_peak)[0]
            if len(peak_idx) == 0:
                continue

            peak_vals = spec[peak_idx]

            # Filter by threshold relative to maximum
            max_val = torch.max(spec)
            threshold = max_val * 0.1  # Keep peaks at least 10% of max
            mask = peak_vals >= threshold
            peak_idx = peak_idx[mask]
            peak_vals = peak_vals[mask]

            if len(peak_idx) == 0:
                continue

            # Sort by magnitude (descending)
            sorted_indices = torch.argsort(peak_vals, descending=True)
            peak_idx = peak_idx[sorted_indices]
            peak_vals = peak_vals[sorted_indices]

            # Keep top n_peaks
            n_found = min(len(peak_idx), self.n_peaks)
            peak_indices[b, :n_found] = peak_idx[:n_found]
            peak_values[b, :n_found] = peak_vals[:n_found]

        return peak_indices, peak_values

    def compute_peak_loss(
        self, input_indices, input_values, target_indices, target_values, max_freq_bin
    ):
        """
        Compute loss between peaks based on matching strategy.

        Parameters:
        -----------
        input_indices, target_indices : torch.Tensor
            Indices of peaks, shape [batch_size, n_peaks]

        input_values, target_values : torch.Tensor
            Values of peaks, shape [batch_size, n_peaks]

        max_freq_bin : int
            Maximum frequency bin index for normalization

        Returns:
        --------
        loss : torch.Tensor
            Computed loss between peak features
        """
        batch_size = input_indices.shape[0]
        device = input_indices.device

        # Initialize loss components
        position_loss = torch.zeros(batch_size, device=device)
        magnitude_loss = torch.zeros(batch_size, device=device)

        for b in range(batch_size):
            # Get peaks for this batch item
            in_idx = input_indices[b]
            in_val = input_values[b]

            tgt_idx = target_indices[b]
            tgt_val = target_values[b]

            # Normalize peak indices by max frequency for position loss
            norm_in_idx = in_idx.float() / max_freq_bin
            norm_tgt_idx = tgt_idx.float() / max_freq_bin

            # Check if we have valid peaks to compare
            valid_in = torch.any(in_val > 0)
            valid_tgt = torch.any(tgt_val > 0)

            if not (valid_in and valid_tgt):
                continue

            # Match peaks according to strategy
            if self.peak_match_type == "position":
                # Match by position in sorted list (1st to 1st, etc.)
                idx_diff = torch.abs(norm_in_idx - norm_tgt_idx)
                val_diff = torch.abs(in_val - tgt_val)

            elif self.peak_match_type == "magnitude":
                # Match by magnitude order (already sorted)
                idx_diff = torch.abs(norm_in_idx - norm_tgt_idx)
                val_diff = torch.abs(in_val - tgt_val)

            elif self.peak_match_type == "nearest":
                # Match each input peak to nearest target peak
                idx_diff = torch.zeros_like(in_idx, dtype=torch.float)
                val_diff = torch.zeros_like(in_val)

                for i in range(len(in_idx)):
                    if in_val[i] > 0:  # Only consider valid peaks
                        # Find nearest target peak by frequency
                        distances = torch.abs(norm_in_idx[i] - norm_tgt_idx)
                        nearest_idx = torch.argmin(distances)

                        idx_diff[i] = distances[nearest_idx]
                        val_diff[i] = torch.abs(in_val[i] - tgt_val[nearest_idx])

            # Compute mean losses for this batch item
            valid_count = torch.sum(in_val > 0).float()
            if valid_count > 0:
                position_loss[b] = torch.sum(idx_diff) / valid_count
                magnitude_loss[b] = torch.sum(val_diff) / valid_count

        # Compute weighted average loss across batch
        total_loss = self.position_weight * torch.mean(
            position_loss
        ) + self.magnitude_weight * torch.mean(magnitude_loss)

        return total_loss

    def forward(self, input_signal, target_signal):
        """
        Compute the peak-focused loss between input and target signals.

        Parameters:
        -----------
        input_signal : torch.Tensor
            Input signal of shape [batch_size, signal_length]

        target_signal : torch.Tensor
            Target signal of shape [batch_size, signal_length]

        Returns:
        --------
        loss : torch.Tensor
            Computed loss focused on spectral peaks
        """
        # Handle multi-dimensional signals
        original_shape = input_signal.shape
        if len(original_shape) > 2:
            # Reshape to [batch_size, -1]
            input_signal = input_signal.reshape(original_shape[0], -1)
            target_signal = target_signal.reshape(target_signal.shape[0], -1)

        # Check shape match
        if input_signal.shape != target_signal.shape:
            raise ValueError(
                f"Input shape {input_signal.shape} must match target shape {target_signal.shape}"
            )

        batch_size, signal_length = input_signal.shape

        # Compute FFT
        input_fft = torch.fft.fft(input_signal)
        target_fft = torch.fft.fft(target_signal)

        # Extract positive frequencies
        n = signal_length
        input_fft = input_fft[:, : n // 2 + 1]
        target_fft = target_fft[:, : n // 2 + 1]

        # Calculate magnitude spectra
        input_mag = torch.abs(input_fft)
        target_mag = torch.abs(target_fft)

        # Apply log if requested
        if self.log_magnitude:
            input_mag = torch.log(input_mag + self.epsilon)
            target_mag = torch.log(target_mag + self.epsilon)

        # Find peaks in both spectra
        input_peak_indices, input_peak_values = self.find_peaks(input_mag)
        target_peak_indices, target_peak_values = self.find_peaks(target_mag)

        # Compute peak-based loss
        max_freq_bin = n // 2
        total_loss = self.compute_peak_loss(
            input_peak_indices,
            input_peak_values,
            target_peak_indices,
            target_peak_values,
            max_freq_bin,
        )

        return total_loss


class AttractorHistogramLoss(nn.Module):
    """
    A loss function that compares chaotic attractors by generating and comparing
    2D histograms of their state space distributions.

    This approach is especially effective for chaotic systems where the exact
    trajectories may diverge due to sensitivity to initial conditions, but the
    overall attractor shape and density distribution remain similar.

    Parameters:
    -----------
    bins : int or tuple(int, int)
        Number of bins for the histogram in each dimension.
        If int: same number of bins used for both dimensions
        If tuple: (bins_x, bins_y) for different binning in each dimension
        Default: 50

    range_mode : str
        How to determine the histogram range:
        'adaptive': Automatically determined from data (each batch may have different ranges)
        'fixed': Use fixed_range parameter
        'combined': Use the union of adaptive ranges from both input and target
        Default: 'combined'

    fixed_range : tuple of tuple of float
        Only used when range_mode='fixed'
        Format: ((x_min, x_max), (y_min, y_max))
        Default: ((-2.0, 2.0), (-2.0, 2.0))

    normalize : bool
        Whether to normalize histograms to sum to 1 (probability distribution)
        Default: True

    smooth : bool
        Whether to apply Gaussian smoothing to histograms
        Default: True

    smooth_sigma : float
        Standard deviation for Gaussian smoothing kernel
        Default: 1.0

    epsilon : float
        Small value added to histograms to avoid log(0) and division by zero
        Default: 1e-8

    distance_metric : str
        Metric to compare histograms:
        'kl': Kullback-Leibler divergence
        'js': Jensen-Shannon divergence
        'wasserstein': Wasserstein distance (Earth Mover's Distance)
        'l2': L2 distance (Mean Squared Error)
        'l1': L1 distance (Mean Absolute Error)
        'cosine': Cosine similarity
        'bhattacharyya': Bhattacharyya distance
        Default: 'js'
    """

    def __init__(
        self,
        bins=50,
        range_mode="combined",
        fixed_range=((-2.0, 2.0), (-2.0, 2.0)),
        normalize=True,
        smooth=True,
        smooth_sigma=1.0,
        epsilon=1e-8,
        distance_metric="l2",
    ):
        super(AttractorHistogramLoss, self).__init__()

        # Set parameters
        if isinstance(bins, int):
            self.bins = (bins, bins)
        else:
            self.bins = bins

        self.range_mode = range_mode
        self.fixed_range = fixed_range
        self.normalize = normalize
        self.smooth = smooth
        self.smooth_sigma = smooth_sigma
        self.epsilon = epsilon
        self.distance_metric = distance_metric

        # Validate parameters
        if self.range_mode not in ["adaptive", "fixed", "combined"]:
            raise ValueError(f"Invalid range_mode: {range_mode}")

        if self.distance_metric not in [
            "kl",
            "js",
            "wasserstein",
            "l2",
            "l1",
            "cosine",
            "bhattacharyya",
        ]:
            raise ValueError(f"Invalid distance_metric: {distance_metric}")

    def compute_histogram_range(self, x, y=None):
        """
        Compute appropriate range for histogram based on the specified mode.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape [batch_size, n_steps, 2]

        y : torch.Tensor, optional
            Target tensor of shape [batch_size, n_steps, 2]
            Only used when range_mode='combined'

        Returns:
        --------
        ranges : tuple of tuples
            ((x_min, x_max), (y_min, y_max))
        """
        if self.range_mode == "fixed":
            return self.fixed_range

        # Compute range from x
        x_min_0, _ = torch.min(x[..., 0], dim=1)
        x_max_0, _ = torch.max(x[..., 0], dim=1)
        x_min_1, _ = torch.min(x[..., 1], dim=1)
        x_max_1, _ = torch.max(x[..., 1], dim=1)

        # Take min/max across batch
        range_x0 = (torch.min(x_min_0), torch.max(x_max_0))
        range_x1 = (torch.min(x_min_1), torch.max(x_max_1))

        if self.range_mode == "adaptive":
            return (range_x0, range_x1)

        # If mode is 'combined', combine with y's range
        if y is not None:
            y_min_0, _ = torch.min(y[..., 0], dim=1)
            y_max_0, _ = torch.max(y[..., 0], dim=1)
            y_min_1, _ = torch.min(y[..., 1], dim=1)
            y_max_1, _ = torch.max(y[..., 1], dim=1)

            # Take overall min/max
            range_0 = (
                min(torch.min(x_min_0), torch.min(y_min_0)),
                max(torch.max(x_max_0), torch.max(y_max_0)),
            )
            range_1 = (
                min(torch.min(x_min_1), torch.min(y_min_1)),
                max(torch.max(x_max_1), torch.max(y_max_1)),
            )

            return (range_0, range_1)

        # Fallback
        return (range_x0, range_x1)

    def compute_2d_histogram(self, x, hist_range):
        """
        Compute 2D histogram of the attractor points.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape [batch_size, n_steps, 2]

        hist_range : tuple of tuples
            ((x_min, x_max), (y_min, y_max))

        Returns:
        --------
        histograms : torch.Tensor
            Batch of histograms with shape [batch_size, bins[0], bins[1]]
        """
        batch_size, n_steps, _ = x.shape
        device = x.device

        # Create meshgrid for binning
        range_x, range_y = hist_range
        x_min, x_max = range_x
        y_min, y_max = range_y

        # Initialize histograms
        histograms = torch.zeros(
            (batch_size, self.bins[0], self.bins[1]), device=device
        )

        # Process each batch item separately
        for b in range(batch_size):
            # Extract 2D points for this batch
            points = x[b]  # [n_steps, 2]

            # Skip if no points (zeros histogram will be returned)
            if n_steps == 0:
                continue

            # Compute bin indices for each point
            x_indices = torch.clamp(
                ((points[:, 0] - x_min) / (x_max - x_min) * self.bins[0]).long(),
                0,
                self.bins[0] - 1,
            )
            y_indices = torch.clamp(
                ((points[:, 1] - y_min) / (y_max - y_min) * self.bins[1]).long(),
                0,
                self.bins[1] - 1,
            )

            # Accumulate counts in histogram
            for i in range(n_steps):
                histograms[b, x_indices[i], y_indices[i]] += 1

        # Normalize if requested
        if self.normalize:
            # Add epsilon to avoid division by zero
            sums = (
                torch.sum(histograms.view(batch_size, -1), dim=1).view(batch_size, 1, 1)
                + self.epsilon
            )
            histograms = histograms / sums

        # Apply smoothing if requested
        if self.smooth:
            # Create Gaussian kernel
            kernel_size = min(13, min(self.bins) // 2 * 2 + 1)  # Odd kernel size
            sigma = torch.tensor([self.smooth_sigma], device=device)

            # Apply 2D Gaussian filter
            padding = kernel_size // 2

            # Reshape for conv2d [batch, channels, height, width]
            histograms_reshaped = histograms.unsqueeze(
                1
            )  # [batch, 1, bins[0], bins[1]]

            # Create 1D Gaussian kernels and apply separably for efficiency
            coords = torch.arange(kernel_size, device=device) - padding
            kernel_1d = torch.exp(-(coords**2) / (2 * sigma**2))
            kernel_1d = kernel_1d / kernel_1d.sum()

            # Apply horizontal and vertical Gaussian blur
            kernel_h = kernel_1d.view(1, 1, 1, kernel_size)
            kernel_v = kernel_1d.view(1, 1, kernel_size, 1)

            # Apply horizontal blu
            smoothed = F.conv2d(
                F.pad(histograms_reshaped, (padding, padding, 0, 0), mode="reflect"),
                kernel_h,
                padding=(0, 0),
            )

            # Apply vertical blur
            smoothed = F.conv2d(
                F.pad(smoothed, (0, 0, padding, padding), mode="reflect"),
                kernel_v,
                padding=(0, 0),
            )

            # Return to original shape
            histograms = smoothed.squeeze(1)

        return histograms

    def compute_histogram_distance(self, hist1, hist2):
        """
        Compute distance between two histograms based on the selected metric.

        Parameters:
        -----------
        hist1, hist2 : torch.Tensor
            Histograms to compare with shape [batch_size, bins[0], bins[1]]

        Returns:
        --------
        distances : torch.Tensor
            Computed distances with shape [batch_size]
        """
        batch_size = hist1.shape[0]
        device = hist1.device

        # Reshape for certain metrics
        flat_hist1 = hist1.reshape(batch_size, -1)
        flat_hist2 = hist2.reshape(batch_size, -1)

        # Initialize distances
        distances = torch.zeros(batch_size, device=device)

        # Compute selected distance metric
        if self.distance_metric == "kl":
            # KL divergence (add epsilon to avoid log(0))
            kl_div = flat_hist1 * (
                torch.log(flat_hist1 + self.epsilon)
                - torch.log(flat_hist2 + self.epsilon)
            )
            distances = torch.sum(kl_div, dim=1)

        elif self.distance_metric == "js":
            # Jensen-Shannon divergence
            m = 0.5 * (flat_hist1 + flat_hist2)
            js_div = 0.5 * (
                flat_hist1
                * (torch.log(flat_hist1 + self.epsilon) - torch.log(m + self.epsilon))
                + flat_hist2
                * (torch.log(flat_hist2 + self.epsilon) - torch.log(m + self.epsilon))
            )
            distances = torch.sum(js_div, dim=1)

        elif self.distance_metric == "wasserstein":
            # Simple approximation of 2D Wasserstein distance using cumulative histograms
            # This is a simplified version and not the true 2D EMD

            # Compute marginal cumulative histograms
            cum_x1 = torch.cumsum(torch.sum(hist1, dim=2), dim=1)
            cum_x2 = torch.cumsum(torch.sum(hist2, dim=2), dim=1)
            cum_y1 = torch.cumsum(torch.sum(hist1, dim=1), dim=1)
            cum_y2 = torch.cumsum(torch.sum(hist2, dim=1), dim=1)

            # Normalize
            cum_x1 = cum_x1 / (torch.max(cum_x1, dim=1, keepdim=True)[0] + self.epsilon)
            cum_x2 = cum_x2 / (torch.max(cum_x2, dim=1, keepdim=True)[0] + self.epsilon)
            cum_y1 = cum_y1 / (torch.max(cum_y1, dim=1, keepdim=True)[0] + self.epsilon)
            cum_y2 = cum_y2 / (torch.max(cum_y2, dim=1, keepdim=True)[0] + self.epsilon)

            # Compute L1 distances between cumulative histograms
            x_dist = torch.sum(torch.abs(cum_x1 - cum_x2), dim=1)
            y_dist = torch.sum(torch.abs(cum_y1 - cum_y2), dim=1)

            # Combine X and Y distances
            distances = x_dist + y_dist

        elif self.distance_metric == "l2":
            # L2 distance (MSE)
            distances = torch.sum((flat_hist1 - flat_hist2) ** 2, dim=1)

        elif self.distance_metric == "l1":
            # L1 distance (MAE)
            distances = torch.sum(torch.abs(flat_hist1 - flat_hist2), dim=1)

        elif self.distance_metric == "cosine":
            # Cosine distance (1 - cosine similarity)
            dot_product = torch.sum(flat_hist1 * flat_hist2, dim=1)
            norm1 = torch.sqrt(torch.sum(flat_hist1**2, dim=1) + self.epsilon)
            norm2 = torch.sqrt(torch.sum(flat_hist2**2, dim=1) + self.epsilon)
            cosine_sim = dot_product / (norm1 * norm2)
            distances = 1.0 - cosine_sim

        elif self.distance_metric == "bhattacharyya":
            # Bhattacharyya distance
            bc = torch.sqrt(flat_hist1 * flat_hist2)
            bc_coeff = torch.sum(bc, dim=1)
            distances = -torch.log(bc_coeff + self.epsilon)

        return distances

    def forward(self, input_traj, target_traj):
        """
        Compute the loss between input and target attractors.

        Parameters:
        -----------
        input_traj : torch.Tensor
            Input trajectory of shape [batch_size, n_steps, 2]
            Represents the 2D attractor from the model

        target_traj : torch.Tensor
            Target trajectory of shape [batch_size, n_steps, 2]
            Represents the 2D attractor to match

        Returns:
        --------
        loss : torch.Tensor
            Computed histogram-based loss
        """
        # Validate input shapes
        if input_traj.dim() != 3 or target_traj.dim() != 3:
            raise ValueError(
                "Input and target trajectories must be 3D tensors [batch, steps, 2]"
            )

        if input_traj.shape[0] != target_traj.shape[0]:
            raise ValueError(
                f"Batch sizes don't match: {input_traj.shape[0]} vs {target_traj.shape[0]}"
            )

        if input_traj.shape[2] != 2 or target_traj.shape[2] != 2:
            raise ValueError("Last dimension must be 2 for 2D attractors")

        # Compute histogram range
        hist_range = self.compute_histogram_range(
            input_traj, target_traj if self.range_mode == "combined" else None
        )

        # Compute histograms
        input_hist = self.compute_2d_histogram(input_traj, hist_range)
        target_hist = self.compute_2d_histogram(target_traj, hist_range)

        # Compute distances between histograms
        distances = self.compute_histogram_distance(input_hist, target_hist)

        # Return mean distance as the loss
        return torch.mean(distances)


class AttractorMomentLoss(nn.Module):
    """
    A loss function that compares chaotic attractors using multivariate statistical moments
    and distribution properties.

    This approach captures the overall shape, spread, and orientation of the attractor
    without being sensitive to exact point positions. Uses multivariate skewness and
    kurtosis (Mardia's measures) for rotation-invariant comparison.

    Parameters:
    -----------
    mean_weight : float
        Weight for the mean position component
        Default: 0.1

    cov_weight : float
        Weight for the covariance matrix component
        Default: 0.4

    skewness_weight : float
        Weight for the multivariate skewness component
        Default: 0.3

    kurtosis_weight : float
        Weight for the multivariate kurtosis component
        Default: 0.2

    normalize_components : bool
        Whether to normalize each loss component by target magnitude
        Recommended: True for scale-invariant comparison
        Default: True

    epsilon : float
        Small value for numerical stability
        Default: 1e-6

    max_samples_for_skewness : int
        Maximum number of samples to use for skewness computation (memory optimization)
        If trajectory has more steps, random sampling is used
        Default: 1000

    regularization_strength : float
        Strength of regularization added to covariance matrix for numerical stability
        Default: 1e-4
    """

    def __init__(
        self,
        mean_weight=0.1,
        cov_weight=0.4,
        skewness_weight=0.3,
        kurtosis_weight=0.2,
        normalize_components=True,
        epsilon=1e-6,
        max_samples_for_skewness=1000,
        regularization_strength=1e-4,
    ):
        super(AttractorMomentLoss, self).__init__()
        self.mean_weight = mean_weight
        self.cov_weight = cov_weight
        self.skewness_weight = skewness_weight
        self.kurtosis_weight = kurtosis_weight
        self.normalize_components = normalize_components
        self.epsilon = epsilon
        self.max_samples_for_skewness = max_samples_for_skewness
        self.regularization_strength = regularization_strength

        # Validate weights
        total_weight = mean_weight + cov_weight + skewness_weight + kurtosis_weight
        if abs(total_weight - 1.0) > 0.01:
            print(
                f"Warning: Weights sum to {total_weight:.3f}, not 1.0. "
                f"Consider normalizing for better interpretability."
            )

    def compute_moments(self, traj):
        """
        Compute multivariate statistical moments of the attractor.

        Parameters:
        -----------
        traj : torch.Tensor
            Trajectory tensor of shape [batch_size, n_steps, dim]

        Returns:
        --------
        means : torch.Tensor
            Mean positions with shape [batch_size, dim]

        covs : torch.Tensor
            Covariance matrices with shape [batch_size, dim, dim]

        skews : torch.Tensor
            Mardia's multivariate skewness with shape [batch_size]

        kurts : torch.Tensor
            Mardia's multivariate kurtosis with shape [batch_size]
        """
        batch_size, n_steps, dim = traj.shape
        device = traj.device

        # Validate minimum steps
        if n_steps < 2:
            raise ValueError(
                f"Need at least 2 time steps to compute moments, got {n_steps}"
            )

        # Check for NaN or Inf
        if torch.isnan(traj).any() or torch.isinf(traj).any():
            raise ValueError("Input trajectory contains NaN or Inf values")

        # Vectorized mean computation
        means = torch.mean(traj, dim=1)  # [batch_size, dim]

        # Center the data
        centered = traj - means.unsqueeze(1)  # [batch_size, n_steps, dim]

        # Vectorized covariance computation
        # covs[b] = (1/(n-1)) * centered[b].T @ centered[b]
        covs = torch.bmm(centered.transpose(1, 2), centered) / (
            n_steps - 1
        )  # [batch_size, dim, dim]

        # Add regularization for numerical stability
        eye = torch.eye(dim, device=device).unsqueeze(0)  # [1, dim, dim]
        cov_reg = covs + self.regularization_strength * eye  # [batch_size, dim, dim]

        # Compute inverse covariance (with error handling)
        try:
            cov_inv = torch.linalg.inv(cov_reg)
        except RuntimeError as e:
            # Fallback to pseudo-inverse if singular
            print(
                f"Warning: Covariance matrix inversion failed, using pseudo-inverse. Error: {e}"
            )
            cov_inv = torch.linalg.pinv(cov_reg)

        # Compute Mahalanobis-transformed centered data
        # centered_transformed[b] = centered[b] @ cov_inv[b]
        centered_transformed = torch.bmm(
            centered, cov_inv
        )  # [batch_size, n_steps, dim]

        # Mardia's multivariate kurtosis (4th moment)
        # β₂,d = (1/n) Σᵢ [(xᵢ - μ)ᵀ Σ⁻¹ (xᵢ - μ)]²
        mahal_sq = torch.sum(
            centered * centered_transformed, dim=2
        )  # [batch_size, n_steps]
        kurts = torch.mean(mahal_sq**2, dim=1)  # [batch_size]

        # Mardia's multivariate skewness (3rd moment)
        # β₁,d = (1/n²) Σᵢ Σⱼ [(xᵢ - μ)ᵀ Σ⁻¹ (xⱼ - μ)]³
        # This requires an n×n matrix, so we sample for memory efficiency

        if n_steps > self.max_samples_for_skewness:
            # Random sampling for efficiency
            indices = torch.randperm(n_steps, device=device)[
                : self.max_samples_for_skewness
            ]
            centered_sample = centered[:, indices, :]  # [batch_size, n_sample, dim]
            # Use full transformed data for the second part
            # [batch_size, n_sample, dim] @ [batch_size, dim, n_steps]
            mahal_matrix = torch.bmm(
                centered_sample, centered_transformed.transpose(1, 2)
            )
        else:
            # Use all data
            # [batch_size, n_steps, dim] @ [batch_size, dim, n_steps]
            mahal_matrix = torch.bmm(
                centered, centered_transformed.transpose(1, 2)
            )  # [batch_size, n_steps, n_steps]

        # Average over both dimensions
        skews = torch.mean(mahal_matrix**3, dim=(1, 2))  # [batch_size]

        return means, covs, skews, kurts

    def forward(self, input_traj, target_traj):
        """
        Compute the moment-based loss between input and target attractors.

        Parameters:
        -----------
        input_traj : torch.Tensor
            Input trajectory of shape [batch_size, n_steps, dim]

        target_traj : torch.Tensor
            Target trajectory of shape [batch_size, n_steps, dim]

        Returns:
        --------
        loss : torch.Tensor
            Computed moment-based loss

        loss_dict : dict (optional, for debugging)
            Dictionary containing individual loss components
        """
        # Validate input shapes
        if input_traj.dim() != 3 or target_traj.dim() != 3:
            raise ValueError(
                "Input and target trajectories must be 3D tensors [batch, steps, dim]"
            )

        if input_traj.shape[0] != target_traj.shape[0]:
            raise ValueError(
                f"Batch sizes don't match: {input_traj.shape[0]} vs {target_traj.shape[0]}"
            )

        if input_traj.shape[2] != target_traj.shape[2]:
            raise ValueError(
                f"Dimension mismatch: {input_traj.shape[2]} vs {target_traj.shape[2]}"
            )

        # Compute moments
        input_means, input_covs, input_skews, input_kurts = self.compute_moments(
            input_traj
        )
        target_means, target_covs, target_skews, target_kurts = self.compute_moments(
            target_traj
        )

        # Compute loss components
        # Mean loss (Euclidean distance)
        mean_loss = F.mse_loss(input_means, target_means, reduction="mean")

        # Covariance loss (Frobenius norm via MSE)
        cov_loss = F.mse_loss(input_covs, target_covs, reduction="mean")

        # Multivariate skewness loss
        skew_loss = F.mse_loss(input_skews, target_skews, reduction="mean")

        # Multivariate kurtosis loss
        kurt_loss = F.mse_loss(input_kurts, target_kurts, reduction="mean")

        # Normalize components if requested (makes loss scale-invariant)
        if self.normalize_components:
            # Normalize by target magnitude to make scale-invariant
            mean_scale = torch.mean(target_means**2) + self.epsilon
            cov_scale = torch.mean(target_covs**2) + self.epsilon
            skew_scale = torch.mean(target_skews**2) + self.epsilon
            kurt_scale = torch.mean(target_kurts**2) + self.epsilon

            mean_loss = mean_loss / mean_scale
            cov_loss = cov_loss / cov_scale
            skew_loss = skew_loss / skew_scale
            kurt_loss = kurt_loss / kurt_scale

        # Combine components with weights
        total_loss = (
            self.mean_weight * mean_loss
            + self.cov_weight * cov_loss
            + self.skewness_weight * skew_loss
            + self.kurtosis_weight * kurt_loss
        )

        return total_loss

    def get_component_losses(self, input_traj, target_traj):
        """
        Compute and return individual loss components for analysis.

        Useful for debugging and understanding which moments are contributing most to the loss.

        Parameters:
        -----------
        input_traj : torch.Tensor
            Input trajectory of shape [batch_size, n_steps, dim]

        target_traj : torch.Tensor
            Target trajectory of shape [batch_size, n_steps, dim]

        Returns:
        --------
        loss_dict : dict
            Dictionary containing:
            - 'total': total weighted loss
            - 'mean': mean loss component
            - 'cov': covariance loss component
            - 'skewness': skewness loss component
            - 'kurtosis': kurtosis loss component
            - 'mean_raw': unweighted mean loss
            - 'cov_raw': unweighted covariance loss
            - 'skewness_raw': unweighted skewness loss
            - 'kurtosis_raw': unweighted kurtosis loss
        """
        # Compute moments
        input_means, input_covs, input_skews, input_kurts = self.compute_moments(
            input_traj
        )
        target_means, target_covs, target_skews, target_kurts = self.compute_moments(
            target_traj
        )

        # Compute raw loss components
        mean_loss = F.mse_loss(input_means, target_means, reduction="mean")
        cov_loss = F.mse_loss(input_covs, target_covs, reduction="mean")
        skew_loss = F.mse_loss(input_skews, target_skews, reduction="mean")
        kurt_loss = F.mse_loss(input_kurts, target_kurts, reduction="mean")

        # Store raw losses
        loss_dict = {
            "mean_raw": mean_loss.item(),
            "cov_raw": cov_loss.item(),
            "skewness_raw": skew_loss.item(),
            "kurtosis_raw": kurt_loss.item(),
        }

        # Normalize if requested
        if self.normalize_components:
            mean_scale = torch.mean(target_means**2) + self.epsilon
            cov_scale = torch.mean(target_covs**2) + self.epsilon
            skew_scale = torch.mean(target_skews**2) + self.epsilon
            kurt_scale = torch.mean(target_kurts**2) + self.epsilon

            mean_loss = mean_loss / mean_scale
            cov_loss = cov_loss / cov_scale
            skew_loss = skew_loss / skew_scale
            kurt_loss = kurt_loss / kurt_scale

        # Weighted losses
        loss_dict.update(
            {
                "mean": (self.mean_weight * mean_loss).item(),
                "cov": (self.cov_weight * cov_loss).item(),
                "skewness": (self.skewness_weight * skew_loss).item(),
                "kurtosis": (self.kurtosis_weight * kurt_loss).item(),
            }
        )

        # Total loss
        total_loss = (
            self.mean_weight * mean_loss
            + self.cov_weight * cov_loss
            + self.skewness_weight * skew_loss
            + self.kurtosis_weight * kurt_loss
        )
        loss_dict["total"] = total_loss.item()

        return loss_dict


# class AttractorMomentLoss(nn.Module):
#     """
#     A loss function that compares chaotic attractors using multivariate statistical moments
#     and distribution properties.

#     This approach captures the overall shape, spread, and orientation of the attractor
#     without being sensitive to exact point positions. Uses multivariate skewness and
#     kurtosis (Mardia's measures) for rotation-invariant comparison.

#     Parameters:
#     -----------
#     mean_weight : float
#         Weight for the mean position component
#         Default: 0.1

#     cov_weight : float
#         Weight for the covariance matrix component
#         Default: 0.4

#     skewness_weight : float
#         Weight for the multivariate skewness component
#         Default: 0.3

#     kurtosis_weight : float
#         Weight for the multivariate kurtosis component
#         Default: 0.2

#     epsilon : float
#         Small value for numerical stability
#         Default: 1e-8
#     """

#     def __init__(
#         self,
#         mean_weight=0.1,
#         cov_weight=0.4,
#         skewness_weight=0.3,
#         kurtosis_weight=0.2,
#         epsilon=1e-8,
#     ):
#         super(AttractorMomentLoss, self).__init__()
#         self.mean_weight = mean_weight
#         self.cov_weight = cov_weight
#         self.skewness_weight = skewness_weight
#         self.kurtosis_weight = kurtosis_weight
#         self.epsilon = epsilon

#     def compute_moments(self, traj):
#         """
#         Compute multivariate statistical moments of the attractor.

#         Parameters:
#         -----------
#         traj : torch.Tensor
#             Trajectory tensor of shape [batch_size, n_steps, 2]

#         Returns:
#         --------
#         means : torch.Tensor
#             Mean positions with shape [batch_size, 2]

#         covs : torch.Tensor
#             Covariance matrices with shape [batch_size, 2, 2]

#         skews : torch.Tensor
#             Mardia's multivariate skewness with shape [batch_size]

#         kurts : torch.Tensor
#             Mardia's multivariate kurtosis with shape [batch_size]
#         """
#         batch_size, n_steps, dim = traj.shape
#         device = traj.device

#         # Initialize outputs
#         means = torch.zeros((batch_size, dim), device=device)
#         covs = torch.zeros((batch_size, dim, dim), device=device)
#         skews = torch.zeros(batch_size, device=device)
#         kurts = torch.zeros(batch_size, device=device)

#         # Compute moments for each batch
#         for b in range(batch_size):
#             # Extract points for this batch
#             points = traj[b]  # [n_steps, 2]

#             # Mean (1st moment)
#             means[b] = torch.mean(points, dim=0)

#             # Center the data
#             centered = points - means[b].unsqueeze(0)

#             # Covariance (2nd moment)
#             cov = torch.matmul(centered.transpose(0, 1), centered) / (n_steps - 1)
#             covs[b] = cov

#             # Add regularization for numerical stability
#             cov_reg = cov + self.epsilon * torch.eye(dim, device=device)

#             # Compute inverse covariance for higher moments
#             try:
#                 cov_inv = torch.linalg.inv(cov_reg)
#             except:
#                 # Fallback to pseudo-inverse if singular
#                 cov_inv = torch.linalg.pinv(cov_reg)

#             # Mardia's multivariate skewness (3rd moment)
#             # β₁,d = (1/n²) Σᵢ Σⱼ [(xᵢ - μ)ᵀ Σ⁻¹ (xⱼ - μ)]³
#             mahal_matrix = torch.matmul(
#                 torch.matmul(centered, cov_inv), centered.transpose(0, 1)
#             )
#             skews[b] = torch.mean(mahal_matrix**3)

#             # Mardia's multivariate kurtosis (4th moment)
#             # β₂,d = (1/n) Σᵢ [(xᵢ - μ)ᵀ Σ⁻¹ (xᵢ - μ)]²
#             mahal_distances_sq = torch.sum(
#                 centered * torch.matmul(centered, cov_inv), dim=1
#             )
#             kurts[b] = torch.mean(mahal_distances_sq**2)

#         return means, covs, skews, kurts

#     def forward(self, input_traj, target_traj):
#         """
#         Compute the moment-based loss between input and target attractors.

#         Parameters:
#         -----------
#         input_traj : torch.Tensor
#             Input trajectory of shape [batch_size, n_steps, 2]

#         target_traj : torch.Tensor
#             Target trajectory of shape [batch_size, n_steps, 2]

#         Returns:
#         --------
#         loss : torch.Tensor
#             Computed moment-based loss
#         """
#         # Validate input shapes
#         if input_traj.dim() != 3 or target_traj.dim() != 3:
#             raise ValueError(
#                 "Input and target trajectories must be 3D tensors [batch, steps, 2]"
#             )

#         if input_traj.shape[0] != target_traj.shape[0]:
#             raise ValueError(
#                 f"Batch sizes don't match: {input_traj.shape[0]} vs {target_traj.shape[0]}"
#             )

#         if input_traj.shape[2] != 2 or target_traj.shape[2] != 2:
#             raise ValueError("Last dimension must be 2 for 2D attractors")

#         # Compute moments
#         input_means, input_covs, input_skews, input_kurts = self.compute_moments(
#             input_traj
#         )
#         target_means, target_covs, target_skews, target_kurts = self.compute_moments(
#             target_traj
#         )

#         # Compute loss components
#         # Mean loss (Euclidean distance)
#         mean_loss = F.mse_loss(input_means, target_means)

#         # Covariance loss (Frobenius norm via MSE)
#         cov_loss = F.mse_loss(input_covs, target_covs)

#         # Multivariate skewness loss
#         # Note: skews is now [batch_size] instead of [batch_size, 2]
#         skew_loss = F.mse_loss(input_skews, target_skews)

#         # Multivariate kurtosis loss
#         # Note: kurts is now [batch_size] instead of [batch_size, 2]
#         kurt_loss = F.mse_loss(input_kurts, target_kurts)

#         # Combine components
#         total_loss = (
#             self.mean_weight * mean_loss
#             + self.cov_weight * cov_loss
#             + self.skewness_weight * skew_loss
#             + self.kurtosis_weight * kurt_loss
#         )

#         return total_loss


# class AttractorMomentLoss(nn.Module):
#     """
#     A loss function that compares chaotic attractors using statistical moments
#     and distribution properties.

#     This approach captures the overall shape, spread, and orientation of the attractor
#     without being sensitive to exact point positions.

#     Parameters:
#     -----------
#     mean_weight : float
#         Weight for the mean position component
#         Default: 0.1

#     cov_weight : float
#         Weight for the covariance matrix component
#         Default: 0.4

#     skewness_weight : float
#         Weight for the skewness component
#         Default: 0.3

#     kurtosis_weight : float
#         Weight for the kurtosis component
#         Default: 0.2

#     epsilon : float
#         Small value for numerical stability
#         Default: 1e-8
#     """

#     def __init__(
#         self,
#         mean_weight=0.1,
#         cov_weight=0.4,
#         skewness_weight=0.3,
#         kurtosis_weight=0.2,
#         epsilon=1e-8,
#     ):
#         super(AttractorMomentLoss, self).__init__()
#         self.mean_weight = mean_weight
#         self.cov_weight = cov_weight
#         self.skewness_weight = skewness_weight
#         self.kurtosis_weight = kurtosis_weight
#         self.epsilon = epsilon

#     def compute_moments(self, traj):
#         """
#         Compute statistical moments of the attractor.

#         Parameters:
#         -----------
#         traj : torch.Tensor
#             Trajectory tensor of shape [batch_size, n_steps, 2]

#         Returns:
#         --------
#         means : torch.Tensor
#             Mean positions with shape [batch_size, 2]

#         covs : torch.Tensor
#             Covariance matrices with shape [batch_size, 2, 2]

#         skews : torch.Tensor
#             Skewness values with shape [batch_size, 2]

#         kurts : torch.Tensor
#             Kurtosis values with shape [batch_size, 2]
#         """
#         batch_size, n_steps, dim = traj.shape
#         device = traj.device

#         # Initialize outputs
#         means = torch.zeros((batch_size, dim), device=device)
#         covs = torch.zeros((batch_size, dim, dim), device=device)
#         skews = torch.zeros((batch_size, dim), device=device)
#         kurts = torch.zeros((batch_size, dim), device=device)

#         # Compute moments for each batch
#         for b in range(batch_size):
#             # Extract points for this batch
#             points = traj[b]  # [n_steps, 2]

#             # Mean (1st moment)
#             means[b] = torch.mean(points, dim=0)

#             # Center the data
#             centered = points - means[b].unsqueeze(0)

#             # Covariance (2nd moment)
#             cov = torch.matmul(centered.transpose(0, 1), centered) / (
#                 n_steps - 1 + self.epsilon
#             )
#             covs[b] = cov

#             # Compute standard deviations for normalization
#             std = torch.sqrt(torch.diag(cov) + self.epsilon)

#             # Compute higher moments for each dimension
#             for d in range(dim):
#                 # Normalized centered data
#                 norm_centered_d = centered[:, d] / (std[d] + self.epsilon)

#                 # Skewness (3rd moment)
#                 skews[b, d] = torch.mean(norm_centered_d**3)

#                 # Kurtosis (4th moment)
#                 kurts[b, d] = torch.mean(norm_centered_d**4) - 3.0  # Excess kurtosis

#         return means, covs, skews, kurts

#     def forward(self, input_traj, target_traj):
#         """
#         Compute the moment-based loss between input and target attractors.

#         Parameters:
#         -----------
#         input_traj : torch.Tensor
#             Input trajectory of shape [batch_size, n_steps, 2]

#         target_traj : torch.Tensor
#             Target trajectory of shape [batch_size, n_steps, 2]

#         Returns:
#         --------
#         loss : torch.Tensor
#             Computed moment-based loss
#         """
#         # Validate input shapes
#         if input_traj.dim() != 3 or target_traj.dim() != 3:
#             raise ValueError(
#                 "Input and target trajectories must be 3D tensors [batch, steps, 2]"
#             )

#         if input_traj.shape[0] != target_traj.shape[0]:
#             raise ValueError(
#                 f"Batch sizes don't match: {input_traj.shape[0]} vs {target_traj.shape[0]}"
#             )

#         if input_traj.shape[2] != 2 or target_traj.shape[2] != 2:
#             raise ValueError("Last dimension must be 2 for 2D attractors")

#         # Compute moments
#         input_means, input_covs, input_skews, input_kurts = self.compute_moments(
#             input_traj
#         )
#         target_means, target_covs, target_skews, target_kurts = self.compute_moments(
#             target_traj
#         )

#         # Compute loss components
#         # Mean loss
#         mean_loss = F.mse_loss(input_means, target_means)

#         # Covariance loss
#         cov_loss = F.mse_loss(input_covs, target_covs)

#         # Skewness loss
#         skew_loss = F.mse_loss(input_skews, target_skews)

#         # Kurtosis loss
#         kurt_loss = F.mse_loss(input_kurts, target_kurts)

#         # Combine components
#         total_loss = (
#             self.mean_weight * mean_loss
#             + self.cov_weight * cov_loss
#             + self.skewness_weight * skew_loss
#             + self.kurtosis_weight * kurt_loss
#         )

#         return total_loss


class CombinedAttractorLoss(nn.Module):
    """
    A comprehensive loss function that combines histogram-based and moment-based
    approaches for comparing chaotic attractors.

    Parameters:
    -----------
    histogram_weight : float
        Weight for the histogram-based component
        Default: 0.6

    moment_weight : float
        Weight for the moment-based component
        Default: 0.4

    histogram_kwargs : dict
        Arguments for the histogram loss component

    moment_kwargs : dict
        Arguments for the moment loss component
    """

    def __init__(
        self,
        histogram_weight=0.6,
        moment_weight=0.4,
        histogram_kwargs=None,
        moment_kwargs=None,
    ):
        super(CombinedAttractorLoss, self).__init__()

        self.histogram_weight = histogram_weight
        self.moment_weight = moment_weight

        # Initialize components with default parameters if none provided
        histogram_kwargs = histogram_kwargs or {}
        moment_kwargs = moment_kwargs or {}

        self.histogram_loss = AttractorHistogramLoss(**histogram_kwargs)
        self.moment_loss = AttractorMomentLoss(**moment_kwargs)

    def compte_single(self, input_traj, target_traj):
        """
        Compute the single loss between input and target attractors.

        Parameters:
        -----------
        input_traj : torch.Tensor
            Input trajectory of shape [batch_size, n_steps, 2]

        target_traj : torch.Tensor
            Target trajectory of shape [batch_size, n_steps, 2]

        Returns:
        --------
        loss : torch.Tensor
            Computed single loss
        """

        # Compute individual loss components
        hist_loss = self.histogram_loss(input_traj, target_traj)
        moment_loss = self.moment_loss(input_traj, target_traj)

        # Combine components
        total_loss = (
            self.histogram_weight * hist_loss + self.moment_weight * moment_loss
        )

        return total_loss

    def forward(self, input_traj: torch.Tensor, target_traj: torch.Tensor):
        """
        Compute the combined loss between input and target attractors.

        Parameters:
        -----------
        input_traj : torch.Tensor
            Input trajectory of shape [batch_size, n_steps, 2]

        target_traj : torch.Tensor
            Target trajectory of shape [batch_size, n_steps, 2]

        Returns:
        --------
        loss : torch.Tensor
            Computed combined loss
        """

        total_loss = 0.0

        for i1, i2 in combinations(range(input_traj.shape[1]), 2):
            if i1 == i2:
                continue
            in_tr = torch.unsqueeze(input_traj[:, [i1, i2]], dim=0)
            sol_tr = torch.unsqueeze(target_traj[:, [i1, i2]], dim=0)
            total_loss += self.compte_single(in_tr, sol_tr)

        return total_loss


class CombinedLoss(nn.Module):
    def __init__(self, losses: List[nn.Module], weights: np.ndarray):
        assert len(losses) == len(weights), (
            "Number of losses must match number of weights"
        )
        super(CombinedLoss, self).__init__()
        self.losses = nn.ModuleList(losses)
        self.weights = torch.Tensor(weights, device="cpu")

    def forward(self, true_vals, pred_vals):
        """
        Compute the combined loss based on the individual losses and their weights.

        Parameters:
        -----------
        true_vals : torch.Tensor
            True values of shape [batch_size, n_steps, n_features]

        pred_vals : torch.Tensor
            Predicted values of shape [batch_size, n_steps, n_features]

        Returns:
        --------
        total_loss : torch.Tensor
            Combined loss value
        """
        total_loss = 0.0
        for loss, weight in zip(self.losses, self.weights):
            total_loss += weight * loss(true_vals, pred_vals)

        return total_loss

    def get_vector(self, true_vals, pred_vals) -> np.ndarray:
        """
        Compute the individual losses and their weights.

        Parameters:
        -----------
        true_vals : torch.Tensor
            True values of shape [batch_size, n_steps, n_features]

        pred_vals : torch.Tensor
            Predicted values of shape [batch_size, n_steps, n_features]

        Returns:
        --------
        loss_vector : torch.Tensor
            Vector of individual loss values
        """
        loss_vector = []
        for loss, weight in zip(self.losses, self.weights):
            loss_vector.append(loss(true_vals, pred_vals))

        return torch.stack(loss_vector).detach().cpu().numpy()