neural_memristor_autoencoder_rnn.py
from pathlib import Path
from typing import List

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from pydantic import BaseModel, Field
from scipy.interpolate import Akima1DInterpolator
from torch import nn
from torch.utils.data import (
    DataLoader,  # noqa
)
from torchdiffeq import odeint as odeint
from typing import Generator, Optional
from normalization_predictor import NRME  # noqa
from copy import deepcopy
import os
import gc

os.environ["OMP_NUM_THREADS"] = "40"  # For macOS users with certain MKL versions
class NRME(nn.Module):
    def __init__(
        self,
        normalization_method="range",  # 'range' is usually safer for circuits than 'std'
        epsilon=1e-8,
        reduction="mean",  # 'mean' gives standard NMSE, 'root_mean' gives NRMSE
        norm_axis=0,
    ):
        super(NRME, self).__init__()
        self.normalization_method = normalization_method
        self.epsilon = epsilon
        self.reduction = reduction
        self.norm_axis = norm_axis

        valid_methods = ["std", "range", "mean"]
        if normalization_method not in valid_methods:
            raise ValueError(f"Method must be one of {valid_methods}")

    def get_normalization_factor(self, true_values):
        # We clone true_values to detach gradients from the normalization factor
        # (We don't want the network trying to minimize the target's variance!)
        target = true_values.detach()

        if self.normalization_method == "std":
            factor = torch.std(target, dim=self.norm_axis, keepdim=True)
        elif self.normalization_method == "range":
            # handle tuple axis for 3D tensors
            if isinstance(self.norm_axis, tuple):
                # Flatten the dimensions we want to normalize over
                # This is tricky in pure PyTorch without reshaping,
                # so usually we just map over axes if it's simple min/max
                # Simplified approach for 2D/3D:
                flattened = target.flatten(start_dim=0, end_dim=-2)
                # Assuming last dim is features.
                # If you pass specific axes, you might need a loop.
                # For now, let's trust the user input or stick to basic dim=0.
                pass

            # Standard implementation
            col_max = torch.amax(target, dim=self.norm_axis, keepdim=True)
            col_min = torch.amin(target, dim=self.norm_axis, keepdim=True)
            factor = col_max - col_min
        elif self.normalization_method == "mean":
            factor = torch.abs(torch.mean(target, dim=self.norm_axis, keepdim=True))

        return factor + self.epsilon

    def forward(self, predictions, true_values):
        if predictions.shape != true_values.shape:
            raise ValueError(
                f"Shape mismatch: {predictions.shape} vs {true_values.shape}"
            )

        norm_factors = self.get_normalization_factor(true_values)

        # Calculate error first, THEN normalize
        # (Mathematically equivalent but numerically safer)
        error = predictions - true_values
        norm_error = error / norm_factors
        squared_norm_error = norm_error**2

        if self.reduction == "mean":
            return torch.mean(squared_norm_error)
        elif self.reduction == "root_mean":
            return torch.sqrt(torch.mean(squared_norm_error))
        elif self.reduction == "sum":
            return torch.sum(squared_norm_error)

        return squared_norm_error


class ExponentialDecayLoss(nn.Module):
    def __init__(
        self, decay_rate=0.1, eps=1e-8, criterion=NRME("range", reduction="none")
    ):
        """
        Loss function that penalizes errors more heavily at the start of the time series.

        Args:
            decay_rate: Controls how quickly the weight decreases along the time dimension
            eps: Small constant for numerical stability
        """
        super(ExponentialDecayLoss, self).__init__()
        self.decay_rate = decay_rate
        self.eps = eps
        self.criterion = criterion

    def forward(self, y_pred, y_true):
        """
        Forward pass of the loss function.

        Args:
            y_pred: Predicted trajectory from Neural ODE, shape [batch_size, time_steps, dim]
            y_true: Ground truth trajectory, shape [batch_size, time_steps, dim]

        Returns:
            Weighted MSE loss with higher penalties for earlier time points
        """
        # Calculate squared error
        squared_error = self.criterion(y_pred, y_true)

        # Get time dimension length
        time_steps = y_pred.shape[0]

        # Create exponentially decaying weights
        time_indices = torch.arange(time_steps, device=y_pred.device).float()
        weights = torch.exp(-self.decay_rate * time_indices)
        weights = weights / (torch.sum(weights) + self.eps)  # Normalize weights

        # Reshape weights for broadcasting - PyTorch will handle batch dimension automatically
        weights = weights.view(1, time_steps, 1)

        # Apply weights to squared error
        weighted_error = squared_error * weights
        loss = torch.mean(weighted_error)

        return loss


def get_device_name(b=3, a=None):
    if torch.cuda.is_available():
        return f"cuda:{torch.cuda.current_device()}"
    elif torch.backends.mps.is_available():  # For Apple M1/M2 chips
        return "mps"
    else:
        return "cpu"


precision = torch.float64
device = get_device_name()
device = "cpu"
matplotlib.use("svg")


# @dataclass
class OscParams(BaseModel):
    R1: float = Field(..., ge=0.0, description="Resistor R1 value in ohms")
    R2: float = Field(default=10e3, ge=0.0, description="Resistor R2 value in ohms")
    R3: float = Field(default=3.6e3, ge=0.0, description="Resistor R3 value in ohms")
    R4: float = Field(default=1e3, ge=0.0, description="Resistor R4 value in ohms")
    R5: float = Field(default=1e3, ge=0.0, description="Resistor R5 value in ohms")
    R7: float = Field(default=1e3, ge=0.0, description="Resistor R7 value in ohms")
    R9: float = Field(default=1e3, ge=0.0, description="Resistor R9 value in ohms")
    R6: float = Field(default=22e3, ge=0.0, description="Resistor R6 value in ohms")
    R8: float = Field(default=33e3, ge=0.0, description="Resistor R8 value in ohms")
    R10: float = Field(default=47e3, ge=0.0, description="Resistor R10 value in ohms")
    R11: float = Field(default=100e3, ge=0.0, description="Resistor R11 value in ohms")
    C1: float = Field(default=50e-6, ge=0.0, description="Capacitor C1 value in farads")
    C3: float = Field(default=50e-6, ge=0.0, description="Capacitor C3 value in farads")
    C2: float = Field(
        default=23.5e-6, ge=0.0, description="Capacitor C2 value in farads"
    )


class AutoEncoderConfig(BaseModel):
    input_size: int = Field(default=10, description="Size of the input vector")
    output_size: int = Field(default=1, description="Size of the output vector")
    hidden_size: int = Field(default=64, description="Size of the hidden layers")
    lookback: int = Field(default=100, description="Lookback window length")



class NNMemristorConfig(BaseModel):
    input_size: int = Field(default=3, description="Input size for the neural network")
    hidden_size: int = Field(
        default=64, description="Hidden layer size for the neural network"
    )
    output_size: int = Field(
        default=1, description="Output size for the neural network"
    )
    R_on: float = Field(default=1e3, description="Initial Ron value")
    R_off: float = Field(default=100e3, description="Initial Roff value")


class NeuralDiffeqConfig(BaseModel):
    osc_params: OscParams
    auto_encoder_config: AutoEncoderConfig
    nn_memristor_config: NNMemristorConfig
    vm_std: float = Field(
        default=1,
        description="Standard deviation for memristor voltage normalization",
    )
    im_std: float = Field(
        default=1, description="Standard deviation for memristor current normalization"
    )


class Sin(torch.nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return torch.sin(input)


class HiddenStateAutoEncode(nn.Module):
    def __init__(
        self,
        input_size: int = 10,
        output_size: int = 1,
        hidden_size: int = 64,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.encoder = (
            nn.Sequential(
                nn.Linear(self.input_size, self.hidden_size // 2),
                nn.SiLU(),
                nn.Linear(self.hidden_size // 2, self.hidden_size),
                nn.SiLU(),
                nn.Linear(self.hidden_size, self.hidden_size),
                nn.SiLU(),
                nn.Linear(self.hidden_size, self.hidden_size // 2),
                nn.SiLU(),
                nn.Linear(self.hidden_size // 2, self.output_size),
                nn.Sigmoid(),
            )
            .to(device)
            .to(precision)
        )
        self.decoder = (
            nn.Sequential(
                nn.Linear(self.output_size, self.hidden_size // 2),
                nn.SiLU(),
                nn.Linear(self.hidden_size // 2, self.hidden_size),
                nn.SiLU(),
                nn.Linear(self.hidden_size, self.hidden_size),
                nn.SiLU(),
                nn.Linear(self.hidden_size, self.hidden_size // 2),
                nn.SiLU(),
                nn.Linear(self.hidden_size // 2, self.input_size),
            )
            .to(device)
            .to(precision)
        )

    def encode(self, x):
        encoded = self.encoder(x)
        return encoded

    def decode(self, z):
        decoded = self.decoder(z)
        return decoded

    def forward(self, x):
        z = self.encode(x)
        reconstructed = self.decode(z)
        return reconstructed

class RecurrentAutoEncoder(nn.Module):
    def __init__(
        self,
        input_size: int = 10,
        output_size: int = 1,
        hidden_size: int = 64,
        num_layers: int = 2,
        lookback: int = 100,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.lookback = lookback        # <-- STORE IT

        self.input_size = input_size
        self.output_size = output_size
        self.latent_size = output_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Encoder: GRU that compresses a sequence into a latent vector
        self.encoder_rnn = nn.GRU(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
        )
        self.enc_fc = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size // 2),
            nn.SiLU(),
            nn.Linear(self.hidden_size // 2, self.latent_size),
            nn.Sigmoid(),
        )

        # Decoder: maps latent back to hidden, then GRU reconstructs the sequence
        self.dec_fc = nn.Sequential(
            nn.Linear(self.latent_size, self.hidden_size // 2),
            nn.SiLU(),
            nn.Linear(self.hidden_size // 2, self.hidden_size),
        )
        self.decoder_rnn = nn.GRU(
            input_size=self.hidden_size,
            hidden_size=self.hidden_size,
            num_layers=self.num_layers,
            batch_first=True,
        )
        self.output_fc = nn.Linear(self.hidden_size, self.input_size)

    def encode(self, x):
        # x: (batch, seq_len, input_size)
        _, h_n = self.encoder_rnn(x)  # h_n: (num_layers, batch, hidden)
        z = self.enc_fc(h_n[-1])      # use last layer's hidden state
        return z

    def decode(self, z, seq_len):
        # z: (batch, latent_size)
        h = self.dec_fc(z)  # (batch, hidden)
        # Repeat as input for each timestep
        dec_input = h.unsqueeze(1).repeat(1, seq_len, 1)  # (batch, seq_len, hidden)
        # Init decoder hidden state from the same projection, reshaped for all layers
        h0 = h.unsqueeze(0).repeat(self.num_layers, 1, 1)  # (num_layers, batch, hidden)
        out, _ = self.decoder_rnn(dec_input, h0)
        reconstructed = self.output_fc(out)  # (batch, seq_len, input_size)
        return reconstructed

    def forward(self, x):
        # x: (batch, seq_len, input_size)
        seq_len = x.size(1)
        z = self.encode(x)
        reconstructed = self.decode(z, seq_len)
        return reconstructed


class MemristorNN(nn.Module):
    def __init__(
        self,
        Ron: float = 1_000,
        Roff: float = 100_000,
        R: float = 1_000,
        activation: type[nn.Module] = nn.SiLU,
        hidden_size: int = 64,
        input_size: int = 3,
        output_size: int = 1,
        v_std: float = 1,
        i_std: float = 1,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.R = R
        self.log_Roff = nn.Parameter(
            torch.tensor(np.log(Roff) / self.R, dtype=precision, device=device)
        )

        self.log_Ron = nn.Parameter(
            torch.tensor(np.log(Ron) / self.R, dtype=precision, device=device)
        )

        self.out_scale = nn.Parameter(torch.tensor(1.0, dtype=precision, device=device))

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.activation = activation
        self.v_std = v_std
        self.i_std = i_std

        self.ann = (
            nn.Sequential(
                nn.Linear(self.input_size, self.hidden_size // 2),
                self.activation(),
                nn.Linear(self.hidden_size // 2, self.hidden_size),
                self.activation(),
                nn.Linear(self.hidden_size, self.hidden_size),
                self.activation(),
                nn.Linear(self.hidden_size, self.hidden_size // 2),
                self.activation(),
                nn.Linear(self.hidden_size // 2, self.output_size),
                nn.Tanh(),
            )
            .to(device)
            .to(precision)
        )

    @property
    def Ron(self):
        return torch.exp(self.log_Ron * self.R)

    @property
    def Roff(self):
        return torch.exp(self.log_Roff * self.R)

    def G(self, X):
        return X / self.Ron + (1 - X) / self.Roff

    def current(self, X, v_m):
        G = self.G(X)
        i_m = G * v_m
        return i_m

    def forward(self, X, v_m):
        nn_input = torch.stack(
            [
                X,
                v_m / self.v_std,
                self.current(X, v_m) / self.i_std, 
            ] if self.input_size == 3 else [
                X,
                v_m / self.v_std,
            ],
            dim=-1,
        )
        # print(
        #     f"MemristorNN input shape: {nn_input.shape} from X shape: {X.shape}, v_m shape: {v_m.shape}, i_m shape: {self.current(X, v_m).shape}"
        # )
        dX = self.ann(nn_input) * self.out_scale
        # print(f"MemristorNN output shape: {dX.shape}")
        return dX.squeeze(0)


class NeuralDiffEq(nn.Module):
    def __init__(
        self,
        config: OscParams,
        x0: float = 0.0,
        Ron: float = 1_000,
        Roff: float = 100_000,
        activation: type[nn.Module] = nn.SiLU,
        v_m_std: float = 0.2,
        i_m_std: float = 1e-3,
        input_size: int = 3,
        output_size: int = 1,
        hidden_size: int = 128,
    ) -> None:
        super().__init__()
        # Neural network for dX/dt
        self.dim = 3  # [X, v_m]

        self.p = config
        self.v_m_std = v_m_std
        self.i_m_std = i_m_std

        self.mem_model = MemristorNN(
            Ron=Ron,
            Roff=Roff,
            R=1e3,
            activation=activation,
            hidden_size=hidden_size,
            input_size=input_size,
            output_size=output_size,
            v_std=v_m_std,
            i_std=i_m_std,
        )
        # self.out_scale = nn.Parameter(torch.tensor(1.0, dtype=precision, device=device))

    @property
    def Ron(self):
        return self.mem_model.Ron

    @property
    def Roff(self):
        return self.mem_model.Roff

    def forward(self, t, state):
        v1, v2, v3, X = state[..., 0], state[..., 1], state[..., 2], state[..., 3]

        # Compute memristor voltage and current
        vx = -self.p.R11 / self.p.R10 * v3
        v_m = vx  # - 0.040
        i_m = self.mem_model.current(X, v_m)
        v4 = vx + self.p.R1 * i_m

        # Compute state derivatives
        dv1 = 1 / self.p.C1 * (-v1 / self.p.R4 - v2 / self.p.R9 - v4 / self.p.R3)
        dv2 = 1 / self.p.C3 * (-v2 / self.p.R8 - v3 / self.p.R7)
        dv3 = 1 / self.p.C2 * (-v1 / self.p.R5 - v3 / self.p.R6)

        dX = self.mem_model(X, v_m).squeeze(-1)
        # print(
        #     f"Shapes: dv1: {dv1.shape}, dv2: {dv2.shape}, dv3: {dv3.shape}, dX: {dX.shape}"
        # )
        # print(f"X shape: {X.shape}")
        # added constraints
        dX = torch.where(X <= 0, torch.maximum(dX, torch.zeros_like(dX)), dX)
        dX = torch.where(X >= 1, torch.minimum(dX, torch.zeros_like(dX)), dX)
        # print(
        #     f"Shapes: dv1: {dv1.shape}, dv2: {dv2.shape}, dv3: {dv3.shape}, dX: {dX.shape}"
        # )
        # print(f"state.shape: {state.shape}")
        # Combine all derivatives
        dfunc = torch.zeros_like(state)
        dfunc[..., 0] = dv1
        dfunc[..., 1] = dv2
        dfunc[..., 2] = dv3
        dfunc[..., 3] = dX

        return dfunc


class AutoencoderMemmodel(nn.Module):
    def __init__(
        self,
        config: NeuralDiffeqConfig,
    ) -> None:
        super().__init__()

        # Physics Engine
        self.neural_diffeq = NeuralDiffEq(
            config.osc_params,
            v_m_std=config.vm_std,
            i_m_std=config.im_std,
            Ron=config.nn_memristor_config.R_on,
            Roff=config.nn_memristor_config.R_off,
            input_size=config.nn_memristor_config.input_size,
        )

        # State Inference Network
        # Input: [v_m, i_m] (Size 2)
        # Output: [X] (Size 1)
        self.autoencoder = RecurrentAutoEncoder(
            input_size=config.auto_encoder_config.input_size,  # MUST be 2
            output_size=config.auto_encoder_config.output_size,  # MUST be 1
            hidden_size=config.auto_encoder_config.hidden_size,
            lookback=config.auto_encoder_config.lookback,
        )

    def forward(self, vm_backward, im_backward, exposed_states, t_forward):
        """
        vm_backward, im_backward: Observations used to infer the hidden state X.
        exposed_states: The known initial voltages [v1, v2, v3] at t=0.
        t_forward: The timesteps to simulate.
        """
        # 1. Infer the hidden memristor state X
        x_state = self.predict_states(vm_backward, im_backward)
        # print(f"Inferred X state shape: {x_state.shape}")
        # print(f"x_state sample: {x_state}")
        # print(f"Exposed states shape: {exposed_states.shape}")
        # 2. Combine Known Physics (v1..v3) with Inferred Physics (X)
        # exposed_states: [Batch, 3]
        # x_state:        [Batch, 1]
        # state:          [Batch, 4]
        state = torch.cat([exposed_states, x_state], dim=-1)

        # 3. Integrate Forward
        # Solves the system starting from the hybrid state
        sol = odeint(self.neural_diffeq, state, t_forward, method="dopri5")

        return sol

    def predict_states(self, v_m, i_m):
        # v_m: (batch, lookback), i_m: (batch, lookback)
        # Stack to (batch, lookback, 2) — sequence of [vm, im] pairs
        x = torch.stack([v_m, i_m], dim=-1)  # (batch, lookback, 2)
        results = self.autoencoder.encode(x)
        return results

    def reconstruct_states(self, z):
        return self.autoencoder.decode(z)

    def calculate_consistency_loss(self, sol, vm_backward, im_backward):
        """
        sol: [Time, Batch, 4] -> Starts at t0
        vm_backward: [Batch, Lookback] -> Ends at t0
        """
        # 1. Unpack Trajectory
        v3 = sol[..., 2]
        X_ode = sol[..., 3]

        # 2. Physics (Prediction)
        p = self.neural_diffeq.p
        v_m_pred = -p.R11 / p.R10 * v3
        i_m_pred = self.neural_diffeq.mem_model.current(X_ode, v_m_pred)

        # 3. Stitch Timeline
        # vm_backward: [Batch, Lookback] -> Transpose to [Lookback, Batch]
        v_m_hist_T = vm_backward.T
        i_m_hist_T = im_backward.T

        # CRITICAL FIX: The ODE solution 'sol' includes t0 at index 0.
        # 'vm_backward' also includes t0 at the last index.
        # We slice 'v_m_pred' to start from index 1 (t1) to avoid duplicate t0.
        full_vm = torch.cat([v_m_hist_T, v_m_pred[1:]], dim=0)
        full_im = torch.cat([i_m_hist_T, i_m_pred[1:]], dim=0)

        # 4. Create Sliding Windows
        window_size = self.autoencoder.lookback

        # Shape: [Num_Windows, Batch, Window_Size]
        vm_windows = full_vm.unfold(0, window_size, 1)
        im_windows = full_im.unfold(0, window_size, 1)

        # 5. Align with Prediction
        # We need exactly as many windows as we have time steps in 'sol'.
        # The ODE produced 'time_steps' outputs (e.g., 50).
        time_steps = X_ode.shape[0]

        # We take the last N windows.
        # Each window 'i' ends at time step 'i' of the simulation.
        vm_seq = vm_windows[-time_steps:]
        im_seq = im_windows[-time_steps:]

        # 6. Predict X
        flat_vm = vm_seq.reshape(-1, window_size)
        flat_im = im_seq.reshape(-1, window_size)

        X_encoded = self.predict_states(flat_vm, flat_im)
        X_encoded = X_encoded.view_as(X_ode)

        return X_ode, X_encoded


def smooth_range_penalty(x, min_val=0.0, max_val=1.0, margin=0.1, penalty_weight=1.0):
    # Create smooth transition zones
    lower_margin = min_val + margin
    upper_margin = max_val - margin

    # Smooth penalties using sigmoid
    below_penalty = torch.sigmoid(-(x - lower_margin) / (margin / 4)) * (
        lower_margin - x
    )
    above_penalty = torch.sigmoid((x - upper_margin) / (margin / 4)) * (
        x - upper_margin
    )

    # Combine penalties
    total_penalty = below_penalty + above_penalty

    return penalty_weight * torch.mean(total_penalty**2)


def extract_vm_im(
    path: Path,
    pp: OscParams,
):
    df = pd.read_csv(path)
    df["time"] = df["time"] - df["time"].min()

    v4 = df["v4"].to_numpy()
    v3 = df["v3"].to_numpy()
    v_m = -pp.R11 / pp.R10 * v3
    i_m = (v4 - v_m) / pp.R1
    return v_m, i_m


from typing import Optional


import torch
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Optional, Tuple
from scipy.interpolate import Akima1DInterpolator

# Assuming global device/precision are defined elsewhere, or pass them in.
# device = ...
# precision = ...
class AutoencoderMemristorData:
    def __init__(
        self,
        path: Path,
        osc_params: OscParams,
        fs: float = 1e3,
        horizon: int = 50,
        lookback: int = 10,
        train_split: float = 0.7, # 70% Training
        val_split: float = 0.2,   # 20% Validation (Leaving 10% for Test)
        device=None,
        precision=None,
        columns: List[str] = ["v1", "v2", "v3"],
    ):
        self.device = device if device else torch.device("cpu")
        self.precision = precision if precision else torch.float32
        self.columns = columns
        self.pp = osc_params
        self.fs = fs
        self.horizon = horizon
        self.lookback = lookback
        
        # Calculate boundaries
        if train_split + val_split >= 1.0:
            raise ValueError("train_split + val_split must be < 1.0 to leave room for test set.")
            
        self.train_split = train_split
        self.val_split = val_split

        # 1. Load & Process Data
        try:
            self.df = pd.read_csv(path)
            req_cols = columns + ["time"] + ["v4"]
            for col in req_cols:
                if col not in self.df.columns:
                    raise ValueError(f"Required column '{col}' missing.")

            # Normalize time & Interpolate
            self.df["time"] = self.df["time"] - self.df["time"].min()
            self.interpolate_df()

            # Physics Pre-calculation
            v3_arr = self.df["v3"].to_numpy()
            v4_arr = self.df["v4"].to_numpy()
            vm_arr = -self.pp.R11 / self.pp.R10 * v3_arr
            im_arr = (v4_arr - vm_arr) / self.pp.R1

            # 2. Create Full Tensors
            self.state_tensor = torch.tensor(
                self.df[self.columns].to_numpy(),
                device=self.device, dtype=self.precision,
            )
            self.vm_tensor = torch.tensor(
                vm_arr, device=self.device, dtype=self.precision
            )
            self.im_tensor = torch.tensor(
                im_arr, device=self.device, dtype=self.precision
            )

            dt = 1.0 / self.fs
            self.t_span = (
                torch.linspace(0, dt * (self.horizon - 1), self.horizon)
                .to(self.device).to(self.precision)
            )

            self.data_len = len(self.df)

            # 3. Calculate Split Indices
            # Train: [0 ... idx_train]
            # Val:   [idx_train ... idx_val]
            # Test:  [idx_val ... end]
            self.idx_train_end = int(self.data_len * self.train_split)
            self.idx_val_end = int(self.data_len * (self.train_split + self.val_split))
            
            # Print stats so you know what you're working with
            print(f"Data Split Loaded:")
            print(f"  Train: 0 -> {self.idx_train_end} ({self.idx_train_end} pts)")
            print(f"  Val:   {self.idx_train_end} -> {self.idx_val_end} ({self.idx_val_end - self.idx_train_end} pts)")
            print(f"  Test:  {self.idx_val_end} -> {self.data_len} ({self.data_len - self.idx_val_end} pts)")

        except Exception as e:
            raise RuntimeError(f"Error loading data: {str(e)}")

    def interpolate_df(self):
        """Interpolate to regular grid."""
        x = self.df["time"].to_numpy()
        x_new = np.arange(x[0], x[-1], 1 / self.fs)
        new_df = pd.DataFrame({"time": x_new})
        cols = self.columns + ["v4"]
        for col in cols:
            akima = Akima1DInterpolator(x, self.df[col].to_numpy())
            new_df[col] = akima(x_new)
        self.df = new_df

    def _get_indices(self, mode: str, batch_size: int, random: bool = True):
        """Internal helper to get indices for a specific split."""
        
        # Define ranges based on mode
        if mode == 'train':
            start = self.lookback
            end = self.idx_train_end - self.horizon
        elif mode == 'val':
            start = self.idx_train_end
            end = self.idx_val_end - self.horizon
        elif mode == 'test':
            start = self.idx_val_end
            end = self.data_len - self.horizon
        else:
            raise ValueError(f"Unknown mode: {mode}")

        # Safety check
        if start >= end:
            raise ValueError(f"Not enough data for {mode} split (Start: {start}, End: {end}, Horizon: {self.horizon})")

        # Select indices
        if random:
            return torch.randint(start, end, (batch_size,), device=self.device)
        else:
            # Sequential (full range)
            return torch.arange(start, end, device=self.device)

    def _gather_batch(self, indices):
        """Extracts tensors from indices."""
        # History (Backwards from t0)
        offsets = torch.arange(-self.lookback + 1, 1, device=self.device)
        lookback_indices = indices.unsqueeze(1) + offsets

        vm_hist = self.vm_tensor[lookback_indices]
        im_hist = self.im_tensor[lookback_indices]
        state_0 = self.state_tensor[indices]

        # Future (Forwards from t0)
        batch_indices = indices.unsqueeze(1) + torch.arange(
            self.horizon, device=self.device
        )
        targets = self.state_tensor[batch_indices]

        return vm_hist, im_hist, state_0, self.t_span, targets

    # --- PUBLIC INTERFACE ---

    def get_train_batch(self, batch_size: int = 32):
        """Random batch from Training set"""
        indices = self._get_indices('train', batch_size, random=True)
        return self._gather_batch(indices)

    def get_val_batch(self, batch_size: int = 32):
        """Random batch from Validation set"""
        indices = self._get_indices('val', batch_size, random=True)
        return self._gather_batch(indices)
        
    def get_test_batch(self, batch_size: int = 32):
        """Random batch from Test set"""
        indices = self._get_indices('test', batch_size, random=True)
        return self._gather_batch(indices)

    def get_sequential_batches(self, batch_size: int = 32, mode: str = 'train', stride: int = None):
        """Iterate through a specific split sequentially."""
        if stride is None: stride = self.horizon
        
        # Get all valid indices for this split
        all_indices = self._get_indices(mode, batch_size=0, random=False)
        
        # Apply stride
        # We only take every Nth index to create the sliding window effect
        all_indices = all_indices[::stride]
        
        num_batches = (len(all_indices) + batch_size - 1) // batch_size
        
        for i in range(num_batches):
            batch_start = i * batch_size
            batch_end = min((i + 1) * batch_size, len(all_indices))
            
            if batch_start >= batch_end: break
            
            yield self._gather_batch(all_indices[batch_start:batch_end])

def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def plot_data(i, t, y_pred, y_true):
    plt.plot(t, y_true[:, 2], label="true")
    plt.plot(t, y_pred[:, 2], label="pred")

    plt.legend(loc="upper right")
    plt.savefig(f"neural_plot/{i}.png", bbox_inches="tight")
    plt.close("all")


def smoth_criterion(y_pred, y_true, criterion):
    total_loss = 0
    for i in range(3):
        mean_i = torch.mean(y_true[:, i])
        std_i = torch.std(y_true[:, i])
        total_loss += criterion(
            (y_pred[:, i] - mean_i) / std_i, (y_true[:, i] - mean_i) / std_i
        )
    return total_loss / 3


def get_gausian_prediction(y_pred, y_true, criterion2: callable):
    total_loss = 0
    for i in range(3):
        total_loss += criterion2(
            torch.mean(y_pred[:, i]), y_true[:, i], torch.var(y_pred[:, i])
        )
    return total_loss / 3


def visualize_performance(model, data_loader, save_path=None, epoch=None):
    """
    Generates a 3-panel dashboard:
    1. Time Domain (Voltage vs Time)
    2. Phase Space (Attractor V1 vs V2)
    3. Memristor Physics (Pinched Hysteresis Vm vs Im)
    """
    model.eval()

    # 1. Get a nice long continuous trajectory (e.g., 500 steps)
    #    We use stride=None to get non-overlapping chunks, then stitch them
    #    or just take a large batch_size if horizon is short.
    #    Better: Just grab the first batch from the sequential loader.

    # Let's grab 3 batches to make a longer plot
    iterator = data_loader.get_sequential_batches(batch_size=1, stride=None, mode="test")

    pred_v1, pred_v2, pred_v3, pred_x = [], [], [], []
    real_v1, real_v2, real_v3 = [], [], []

    with torch.no_grad():
        for i in range(4):  # Stitch 4 horizons together (e.g. 4x50 = 200 steps)
            vm, im, s0, t, tgt = next(iterator)

            # Forward Pass
            sol = model(vm, im, s0, t)  # [Time, Batch, 4]

            # Extract Data (Batch 0)
            p = sol[:, 0, :].cpu().numpy()  # [Time, 4]
            t_vals = tgt[0, :, :].cpu().numpy()  # [Time, 3]

            pred_v1.extend(p[:, 0])
            pred_v2.extend(p[:, 1])
            pred_v3.extend(p[:, 2])
            pred_x.extend(p[:, 3])

            real_v1.extend(t_vals[:, 0])
            real_v2.extend(t_vals[:, 1])
            real_v3.extend(t_vals[:, 2])

    # Convert to numpy arrays
    pred_v1, pred_v2, pred_v3 = np.array(pred_v1), np.array(pred_v2), np.array(pred_v3)
    real_v1, real_v2, real_v3 = np.array(real_v1), np.array(real_v2), np.array(real_v3)
    pred_x = np.array(pred_x)

    # --- CALCULATE MEMRISTOR PHYSICS ---
    # We need to derive Vm and Im from the predicted voltages to plot hysteresis
    # Using the learned parameters from the model
    p = model.neural_diffeq.p  # OscParams

    # Vm = -R11/R10 * V3
    pred_vm = -p.R11 / p.R10 * pred_v3

    # We can't use Kirchhoff for Im easily without derivative of V1,
    # so let's use the Memristor Equation: I = G(x) * V
    # Get G(x) from the model's internal function
    # Note: We need to convert back to tensor for the model function
    x_tensor = torch.tensor(pred_x, device=device, dtype=precision)
    vm_tensor = torch.tensor(pred_vm, device=device, dtype=precision)

    # Calculate Current
    pred_im = (
        model.neural_diffeq.mem_model.current(x_tensor, vm_tensor)
        .detach()
        .cpu()
        .numpy()
    )

    # --- PLOTTING ---
    fig = plt.figure(figsize=(18, 10))
    plt.suptitle(
        f"Neural ODE Performance (Epoch {epoch if epoch else 'Final'})",
        fontsize=16,
        fontweight="bold",
    )

    # PANEL 1: Time Series (V1 and V2)
    ax1 = plt.subplot2grid((2, 3), (0, 0), colspan=2)
    steps = np.arange(len(real_v1))
    ax1.plot(steps, real_v1, "k-", alpha=0.3, lw=3, label="Ground Truth (V1)")
    ax1.plot(steps, pred_v1, "r--", lw=1.5, label="Prediction (V1)")
    ax1.plot(steps, pred_v2, "b--", lw=1.0, alpha=0.5, label="Prediction (V2)")
    ax1.set_title("Time Domain Reconstruction", fontsize=12)
    ax1.set_xlabel("Time Steps")
    ax1.set_ylabel("Voltage (V)")
    ax1.legend(loc="upper right")
    ax1.grid(True, alpha=0.3)

    # PANEL 2: Internal State X
    ax2 = plt.subplot2grid((2, 3), (1, 0), colspan=2)
    ax2.plot(steps, pred_x, "g-", lw=2, label="Inferred State X")
    ax2.set_title("Latent Memristor State (The 'Hidden' Variable)", fontsize=12)
    ax2.set_ylabel("State X (0=Ron, 1=Roff)")
    # ax2.set_ylim(-0.1, 1.1)
    ax2.grid(True, alpha=0.3)
    ax2.legend()

    # PANEL 3: Phase Portrait (The Chaos)
    ax3 = plt.subplot2grid((2, 3), (0, 2))
    ax3.plot(real_v1, real_v2, "k-", alpha=0.2, lw=1, label="Real")
    ax3.plot(pred_v1, pred_v2, "r-", alpha=0.8, lw=1, label="Pred")
    ax3.set_title("Phase Portrait (Attractor)", fontsize=12)
    ax3.set_xlabel("V1")
    ax3.set_ylabel("V2")
    ax3.legend()

    # PANEL 4: Hysteresis Loop (The Fingerprint)
    ax4 = plt.subplot2grid((2, 3), (1, 2))
    ax4.plot(pred_vm, pred_im * 1000, "purple", lw=1.5)  # mA
    ax4.set_title("Learned Pinched Hysteresis", fontsize=12)
    ax4.set_xlabel("Memristor Voltage (V)")
    ax4.set_ylabel("Memristor Current (mA)")
    ax4.grid(True, alpha=0.3)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])

    if save_path:
        plt.savefig(save_path, dpi=150)
        plt.close()
    else:
        plt.show()

    model.train()  # Set back to train mode after visualization


def get_mutual_characteristic(
    y_pred, y_true, i, folder: Path = Path("neural_mutual_tanh")
):
    fig, ax = plt.subplots(1, 3, layout="constrained", figsize=(15, 5))
    ax[0].plot(y_true[:, 0], y_true[:, 1], label="true")
    ax[0].plot(y_pred[:, 0], y_pred[:, 1], label="pred")
    ax[0].legend(loc="upper right")
    ax[0].set_title("V1-V2")
    ax[1].plot(y_true[:, 0], y_true[:, 2], label="true")
    ax[1].plot(y_pred[:, 0], y_pred[:, 2], label="pred")
    ax[1].legend(loc="upper right")
    ax[1].set_title("V1-V3")
    ax[2].plot(y_true[:, 1], y_true[:, 2], label="true")
    ax[2].plot(y_pred[:, 1], y_pred[:, 2], label="pred")
    ax[2].legend(loc="upper right")
    ax[2].set_title("V2-V3")
    fig.savefig(folder / f"{i}.png")



def calc_adaptive_lambdas(*losses, eps=1e-8):
    """
    Computes dynamic weights (lambdas) for multiple loss components
    using the Geometric Mean strategy.
    
    Returns:
        weighted_loss (scalar): The sum of weighted losses
        lambdas (list): The calculated weights for logging
    """
    # Stack losses into a tensor [N]
    loss_tensor = torch.stack(losses)
    
    # 1. Calculate Geometric Mean of the losses
    # log_mean = sum(log(Li)) / N
    # geo_mean = exp(log_mean)
    log_losses = torch.log(loss_tensor + eps)
    loss_mean = torch.exp(torch.mean(log_losses))
    
    # 2. Calculate Inverse Proportional Weights
    # lambda_i = GeoMean / Loss_i
    # If Loss_i is huge, lambda_i becomes small (down-weighting it)
    # If Loss_i is tiny, lambda_i becomes large (up-weighting it)
    lambdas = loss_mean / (loss_tensor + eps)
    
    # 3. Clip weights to prevent instability (e.g. 0.1 to 10.0)
    # Detach lambdas so we don't backprop through the weight calculation itself!
    lambdas = torch.clamp(lambdas, 0.1, 10.0).detach()
    
    # 4. Compute Final Weighted Sum
    weighted_loss = torch.sum(loss_tensor * lambdas)
    
    return weighted_loss, lambdas



def train_within_weights(
    weights: np.ndarray, epoch: int = 2000, ratio: float = 0.02
) -> np.ndarray:  # type: ignore

    params = OscParams(R1=68e3 / 2, R3=6456)
    v_m, i_m = extract_vm_im(path, params)
    v_ms_std = np.std(v_m)
    i_ms_std = np.std(i_m)
    from initials import find_rss

    Ron, Roff, x = find_rss(v_m, i_m,4)
    print(f"Initial Ron = {Ron}, Roff = {Roff}, x0 = {x}")
    base_loss = NRME(
        normalization_method="range",
        norm_axis=(0, 1),  # Normalize across all data points for that feature
        reduction="mean",
    )  # type: ignore
    auto_encoder_loss = nn.MSELoss()

    config = NeuralDiffeqConfig(
        osc_params=params,
        auto_encoder_config=AutoEncoderConfig(
            input_size=2, output_size=1, hidden_size=128, lookback=100
        ),
        nn_memristor_config=NNMemristorConfig(
            input_size=2,
            hidden_size=128,
            output_size=1,
            R_on=Ron,
            R_off=Roff,
        ),
        vm_std=v_ms_std,  # type: ignore
        im_std=i_ms_std,  # type: ignore
    )

    model = AutoencoderMemmodel(config).to(device).to(precision)

    optimizer = torch.optim.Adam(
        [
            {"params": model.neural_diffeq.mem_model.ann.parameters(), "lr": 1e-3},  # type: ignore
            {
                "params": [
                    model.neural_diffeq.mem_model.log_Ron,
                    model.neural_diffeq.mem_model.log_Roff,
                ],
                "lr": 1e-5,
            },  # type: ignore
            {"params": [], "lr": 1e-4},  # type: ignore
        ]
    )
    optimizer_opt = torch.optim.Adam(model.autoencoder.parameters(), lr=1e-3)
    data_loader = AutoencoderMemristorData(
        path,
        osc_params=params,
        horizon=100,
        fs=1e3,
        precision=precision,
        device=device,
        lookback=config.auto_encoder_config.lookback,  # type: ignore
        train_split=0.6, 
        val_split=0.2,   
    )

        # --- Configuration ---
    model_folder = Path("checkpoints")
    model_folder.mkdir(exist_ok=True)
    best_model_path = model_folder / "best_model.pth"

    # Define Curriculum: (Horizon Length, Epochs)
    horizons_epoch = [
        (25, 40),    # Warm-up: Learn basic short-term dynamics
        (50, 200),    # Warm-up: Learn basic short-term dynamics
        (100, 200),    # Warm-up: Learn basic short-term dynamics
        (200, 200),    # Extend to medium range
        (300, 200),
        (500, 200),
        (1000, 200),   # Long-term chaotic tracking
        (2000, 200),  # Expert: Full trajectory reconstruction
    ]

    full_epoch_counter = 0
    best_model = None
    # --- OUTER LOOP: Curriculum Phases ---
    for horizon, epochs in horizons_epoch:
        print(f"\n{'='*60}")
        print(f"STARTING PHASE: Horizon {horizon} | Training for {epochs} Epochs")
        print(f"{'='*60}")
        best_val_loss = float('inf')
        if best_model is not None:
            model.load_state_dict(best_model['model_state_dict'])
            print(f"Loaded best model from previous phase with Val Loss: {best_model['loss']:.6f}")


        # 1. Re-initialize Data Loader (Updates Horizon & Split)
        # Note: We use a fixed split (e.g., 80% Train, 20% Val)
        data_loader = AutoencoderMemristorData(
            path,
            osc_params=params,
            horizon=horizon,
            fs=1e3,
            precision=precision,
            device=device,
            lookback=config.auto_encoder_config.lookback,  # type: ignore
            train_split=0.6, 
            val_split=0.2 
        )

        # --- INNER LOOP: Epochs ---
        for i in range(epochs):
            # ==========================================
            #               TRAINING STEP
            # ==========================================
            model.train()
            
            # Get Sequential Training Batches (High Overlap for Data Augmentation)
            train_gen = data_loader.get_sequential_batches(batch_size=256, stride=50, mode='train')
            
            train_traj_loss = 0.0
            train_consis_loss = 0.0
            train_count = 0
            
            for vm, im, s0, t, tgt in train_gen:
                # Forward
                sol = model(vm, im, s0, t)
                pred = sol[..., :3].permute(1, 0, 2)
                X_ode, X_encode = model.calculate_consistency_loss(sol.detach(), vm, im)

                # Loss Calculation
                loss_traj = base_loss(pred, tgt)
                loss_consis = auto_encoder_loss(X_ode, X_encode)                # Optimization
                full_loss = loss_traj + loss_consis
                
                optimizer.zero_grad()
                optimizer_opt.zero_grad()
                full_loss.backward()
                # Gradient Clipping for Stability
                torch.nn.utils.clip_grad_norm_(model.neural_diffeq.mem_model.parameters(), max_norm=1.0)
                # Update only the Physics
                optimizer.step()
                optimizer_opt.step()
                
                
                # # ==========================================
                # loss_traj = base_loss(pred, tgt)
                
                # # CHANGE: We REMOVED the consistency term here.
                # # The ODE now ignores the Encoder. It evolves X purely to satisfy Kirchhoff's laws.
                # loss_physics = loss_traj 
                
                # optimizer.zero_grad()
                # loss_physics.backward() # No retain_graph needed (Graphs are independent now)
                
                # # Gradient Clipping
                # torch.nn.utils.clip_grad_norm_(model.neural_diffeq.mem_model.parameters(), max_norm=1.0)
                
                # optimizer.step()

                # # ==========================================
                # #      STEP 2: Train Encoder (Student)
                # # ==========================================
                # optimizer_opt.zero_grad()

                # # Recalculate with DETACHED solution
                # # The Encoder treats the ODE's state as the "Label" it must learn.
                # X_ode_fixed, X_encoded_new = model.calculate_consistency_loss(sol.detach(), vm, im)
                
                # # Loss = MSE(Target=X_ode, Input=X_encoded)
                # loss_consis = auto_encoder_loss(X_ode_fixed, X_encoded_new)
                
                # loss_consis.backward()
                # optimizer_opt.step()
                
                
                
                # Accumulate
                train_traj_loss += loss_traj.item()
                train_consis_loss += loss_consis.item()
                train_count += 1
                
            avg_train_traj = train_traj_loss / train_count
            avg_train_consis = train_consis_loss / train_count

            # ==========================================
            #             VALIDATION STEP
            # ==========================================
            model.eval()
            val_traj_loss = 0.0
            val_count = 0
            
            with torch.no_grad():
                # Use larger stride for validation to speed it up (less overlap needed)
                val_gen = data_loader.get_sequential_batches(batch_size=64, stride=horizon, mode='val')
                
                for vm_v, im_v, s0_v, t_v, tgt_v in val_gen:
                    sol_v = model(vm_v, im_v, s0_v, t_v)
                    pred_v = sol_v[..., :3].permute(1, 0, 2)
                    
                    loss_v = base_loss(pred_v, tgt_v)
                    val_traj_loss += loss_v.item()
                    val_count += 1
            
            # Avoid division by zero if val set is small
            avg_val_loss = val_traj_loss / val_count if val_count > 0 else 0.0

            # ==========================================
            #           LOGGING & CHECKPOINTING
            # ==========================================
            
            # Checkpoint: Save if this is the best model so far
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save({
                    'epoch': full_epoch_counter,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_val_loss,
                    'horizon': horizon
                }, best_model_path)
                best_model = {
                    'epoch': full_epoch_counter,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_val_loss,                    
                    'horizon': horizon
                }
                save_msg = "(*)" # Marker for saved model
            else:
                save_msg = ""

            # Print Status
            curr_Ron = model.neural_diffeq.mem_model.Ron.item() # Assuming mapped correctly
            curr_Roff = model.neural_diffeq.mem_model.Roff.item() # Assuming mapped correctly
            
            print(
                f"Epoch {full_epoch_counter:4d} (H={horizon}) | "
                f"Train Traj: {avg_train_traj:.6g} | "
                f"Train Consis: {avg_train_consis:.6g} | "
                f"Val Loss: {avg_val_loss:.6g} {save_msg} | "
                f"Ron: {curr_Ron:.1f} | Roff: {curr_Roff:.1f}"
            )

            # ==========================================
            #             VISUALIZATION
            # ==========================================
            if full_epoch_counter % 20 == 0:
                visualize_performance(
                    model,
                    data_loader,
                    save_path=model_folder / f"viz_epoch_{full_epoch_counter:04d}_H{horizon}.png",
                    epoch=full_epoch_counter,
                )
                
            full_epoch_counter += 1
    gc.collect()  # Clean up memory after training
    print(f"\nTraining Complete. Best Validation Loss: {best_val_loss:.6f}")
    print(f"Best Model Saved to: {best_model_path}")


if __name__ == "__main__":
    set_seed(42)

    path = Path("data/w6456.csv")
    batch_size = 1
    model_folder = Path("torch_autoencoder_memmodel")
    model_folder.mkdir(exist_ok=True)
    from multi_loss import CombinedLoss  # noqa

    results = []
    weights = np.array([0.8, 0.1, 0.1])
    vec = train_within_weights(
        weights=weights,
        epoch=10_000,
        ratio=0.02,
    )
    results.append(vec)