neural_memristor_autoencoder_rnn.py
from pathlib import Path
from typing import List
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from pydantic import BaseModel, Field
from scipy.interpolate import Akima1DInterpolator
from torch import nn
from torch.utils.data import (
DataLoader, # noqa
)
from torchdiffeq import odeint as odeint
from typing import Generator, Optional
from normalization_predictor import NRME # noqa
from copy import deepcopy
import os
import gc
os.environ["OMP_NUM_THREADS"] = "40" # For macOS users with certain MKL versions
class NRME(nn.Module):
def __init__(
self,
normalization_method="range", # 'range' is usually safer for circuits than 'std'
epsilon=1e-8,
reduction="mean", # 'mean' gives standard NMSE, 'root_mean' gives NRMSE
norm_axis=0,
):
super(NRME, self).__init__()
self.normalization_method = normalization_method
self.epsilon = epsilon
self.reduction = reduction
self.norm_axis = norm_axis
valid_methods = ["std", "range", "mean"]
if normalization_method not in valid_methods:
raise ValueError(f"Method must be one of {valid_methods}")
def get_normalization_factor(self, true_values):
# We clone true_values to detach gradients from the normalization factor
# (We don't want the network trying to minimize the target's variance!)
target = true_values.detach()
if self.normalization_method == "std":
factor = torch.std(target, dim=self.norm_axis, keepdim=True)
elif self.normalization_method == "range":
# handle tuple axis for 3D tensors
if isinstance(self.norm_axis, tuple):
# Flatten the dimensions we want to normalize over
# This is tricky in pure PyTorch without reshaping,
# so usually we just map over axes if it's simple min/max
# Simplified approach for 2D/3D:
flattened = target.flatten(start_dim=0, end_dim=-2)
# Assuming last dim is features.
# If you pass specific axes, you might need a loop.
# For now, let's trust the user input or stick to basic dim=0.
pass
# Standard implementation
col_max = torch.amax(target, dim=self.norm_axis, keepdim=True)
col_min = torch.amin(target, dim=self.norm_axis, keepdim=True)
factor = col_max - col_min
elif self.normalization_method == "mean":
factor = torch.abs(torch.mean(target, dim=self.norm_axis, keepdim=True))
return factor + self.epsilon
def forward(self, predictions, true_values):
if predictions.shape != true_values.shape:
raise ValueError(
f"Shape mismatch: {predictions.shape} vs {true_values.shape}"
)
norm_factors = self.get_normalization_factor(true_values)
# Calculate error first, THEN normalize
# (Mathematically equivalent but numerically safer)
error = predictions - true_values
norm_error = error / norm_factors
squared_norm_error = norm_error**2
if self.reduction == "mean":
return torch.mean(squared_norm_error)
elif self.reduction == "root_mean":
return torch.sqrt(torch.mean(squared_norm_error))
elif self.reduction == "sum":
return torch.sum(squared_norm_error)
return squared_norm_error
class ExponentialDecayLoss(nn.Module):
def __init__(
self, decay_rate=0.1, eps=1e-8, criterion=NRME("range", reduction="none")
):
"""
Loss function that penalizes errors more heavily at the start of the time series.
Args:
decay_rate: Controls how quickly the weight decreases along the time dimension
eps: Small constant for numerical stability
"""
super(ExponentialDecayLoss, self).__init__()
self.decay_rate = decay_rate
self.eps = eps
self.criterion = criterion
def forward(self, y_pred, y_true):
"""
Forward pass of the loss function.
Args:
y_pred: Predicted trajectory from Neural ODE, shape [batch_size, time_steps, dim]
y_true: Ground truth trajectory, shape [batch_size, time_steps, dim]
Returns:
Weighted MSE loss with higher penalties for earlier time points
"""
# Calculate squared error
squared_error = self.criterion(y_pred, y_true)
# Get time dimension length
time_steps = y_pred.shape[0]
# Create exponentially decaying weights
time_indices = torch.arange(time_steps, device=y_pred.device).float()
weights = torch.exp(-self.decay_rate * time_indices)
weights = weights / (torch.sum(weights) + self.eps) # Normalize weights
# Reshape weights for broadcasting - PyTorch will handle batch dimension automatically
weights = weights.view(1, time_steps, 1)
# Apply weights to squared error
weighted_error = squared_error * weights
loss = torch.mean(weighted_error)
return loss
def get_device_name(b=3, a=None):
if torch.cuda.is_available():
return f"cuda:{torch.cuda.current_device()}"
elif torch.backends.mps.is_available(): # For Apple M1/M2 chips
return "mps"
else:
return "cpu"
precision = torch.float64
device = get_device_name()
device = "cpu"
matplotlib.use("svg")
# @dataclass
class OscParams(BaseModel):
R1: float = Field(..., ge=0.0, description="Resistor R1 value in ohms")
R2: float = Field(default=10e3, ge=0.0, description="Resistor R2 value in ohms")
R3: float = Field(default=3.6e3, ge=0.0, description="Resistor R3 value in ohms")
R4: float = Field(default=1e3, ge=0.0, description="Resistor R4 value in ohms")
R5: float = Field(default=1e3, ge=0.0, description="Resistor R5 value in ohms")
R7: float = Field(default=1e3, ge=0.0, description="Resistor R7 value in ohms")
R9: float = Field(default=1e3, ge=0.0, description="Resistor R9 value in ohms")
R6: float = Field(default=22e3, ge=0.0, description="Resistor R6 value in ohms")
R8: float = Field(default=33e3, ge=0.0, description="Resistor R8 value in ohms")
R10: float = Field(default=47e3, ge=0.0, description="Resistor R10 value in ohms")
R11: float = Field(default=100e3, ge=0.0, description="Resistor R11 value in ohms")
C1: float = Field(default=50e-6, ge=0.0, description="Capacitor C1 value in farads")
C3: float = Field(default=50e-6, ge=0.0, description="Capacitor C3 value in farads")
C2: float = Field(
default=23.5e-6, ge=0.0, description="Capacitor C2 value in farads"
)
class AutoEncoderConfig(BaseModel):
input_size: int = Field(default=10, description="Size of the input vector")
output_size: int = Field(default=1, description="Size of the output vector")
hidden_size: int = Field(default=64, description="Size of the hidden layers")
lookback: int = Field(default=100, description="Lookback window length")
class NNMemristorConfig(BaseModel):
input_size: int = Field(default=3, description="Input size for the neural network")
hidden_size: int = Field(
default=64, description="Hidden layer size for the neural network"
)
output_size: int = Field(
default=1, description="Output size for the neural network"
)
R_on: float = Field(default=1e3, description="Initial Ron value")
R_off: float = Field(default=100e3, description="Initial Roff value")
class NeuralDiffeqConfig(BaseModel):
osc_params: OscParams
auto_encoder_config: AutoEncoderConfig
nn_memristor_config: NNMemristorConfig
vm_std: float = Field(
default=1,
description="Standard deviation for memristor voltage normalization",
)
im_std: float = Field(
default=1, description="Standard deviation for memristor current normalization"
)
class Sin(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.sin(input)
class HiddenStateAutoEncode(nn.Module):
def __init__(
self,
input_size: int = 10,
output_size: int = 1,
hidden_size: int = 64,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.encoder = (
nn.Sequential(
nn.Linear(self.input_size, self.hidden_size // 2),
nn.SiLU(),
nn.Linear(self.hidden_size // 2, self.hidden_size),
nn.SiLU(),
nn.Linear(self.hidden_size, self.hidden_size),
nn.SiLU(),
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.SiLU(),
nn.Linear(self.hidden_size // 2, self.output_size),
nn.Sigmoid(),
)
.to(device)
.to(precision)
)
self.decoder = (
nn.Sequential(
nn.Linear(self.output_size, self.hidden_size // 2),
nn.SiLU(),
nn.Linear(self.hidden_size // 2, self.hidden_size),
nn.SiLU(),
nn.Linear(self.hidden_size, self.hidden_size),
nn.SiLU(),
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.SiLU(),
nn.Linear(self.hidden_size // 2, self.input_size),
)
.to(device)
.to(precision)
)
def encode(self, x):
encoded = self.encoder(x)
return encoded
def decode(self, z):
decoded = self.decoder(z)
return decoded
def forward(self, x):
z = self.encode(x)
reconstructed = self.decode(z)
return reconstructed
class RecurrentAutoEncoder(nn.Module):
def __init__(
self,
input_size: int = 10,
output_size: int = 1,
hidden_size: int = 64,
num_layers: int = 2,
lookback: int = 100,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.lookback = lookback # <-- STORE IT
self.input_size = input_size
self.output_size = output_size
self.latent_size = output_size
self.hidden_size = hidden_size
self.num_layers = num_layers
# Encoder: GRU that compresses a sequence into a latent vector
self.encoder_rnn = nn.GRU(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True,
)
self.enc_fc = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size // 2),
nn.SiLU(),
nn.Linear(self.hidden_size // 2, self.latent_size),
nn.Sigmoid(),
)
# Decoder: maps latent back to hidden, then GRU reconstructs the sequence
self.dec_fc = nn.Sequential(
nn.Linear(self.latent_size, self.hidden_size // 2),
nn.SiLU(),
nn.Linear(self.hidden_size // 2, self.hidden_size),
)
self.decoder_rnn = nn.GRU(
input_size=self.hidden_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True,
)
self.output_fc = nn.Linear(self.hidden_size, self.input_size)
def encode(self, x):
# x: (batch, seq_len, input_size)
_, h_n = self.encoder_rnn(x) # h_n: (num_layers, batch, hidden)
z = self.enc_fc(h_n[-1]) # use last layer's hidden state
return z
def decode(self, z, seq_len):
# z: (batch, latent_size)
h = self.dec_fc(z) # (batch, hidden)
# Repeat as input for each timestep
dec_input = h.unsqueeze(1).repeat(1, seq_len, 1) # (batch, seq_len, hidden)
# Init decoder hidden state from the same projection, reshaped for all layers
h0 = h.unsqueeze(0).repeat(self.num_layers, 1, 1) # (num_layers, batch, hidden)
out, _ = self.decoder_rnn(dec_input, h0)
reconstructed = self.output_fc(out) # (batch, seq_len, input_size)
return reconstructed
def forward(self, x):
# x: (batch, seq_len, input_size)
seq_len = x.size(1)
z = self.encode(x)
reconstructed = self.decode(z, seq_len)
return reconstructed
class MemristorNN(nn.Module):
def __init__(
self,
Ron: float = 1_000,
Roff: float = 100_000,
R: float = 1_000,
activation: type[nn.Module] = nn.SiLU,
hidden_size: int = 64,
input_size: int = 3,
output_size: int = 1,
v_std: float = 1,
i_std: float = 1,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.R = R
self.log_Roff = nn.Parameter(
torch.tensor(np.log(Roff) / self.R, dtype=precision, device=device)
)
self.log_Ron = nn.Parameter(
torch.tensor(np.log(Ron) / self.R, dtype=precision, device=device)
)
self.out_scale = nn.Parameter(torch.tensor(1.0, dtype=precision, device=device))
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.activation = activation
self.v_std = v_std
self.i_std = i_std
self.ann = (
nn.Sequential(
nn.Linear(self.input_size, self.hidden_size // 2),
self.activation(),
nn.Linear(self.hidden_size // 2, self.hidden_size),
self.activation(),
nn.Linear(self.hidden_size, self.hidden_size),
self.activation(),
nn.Linear(self.hidden_size, self.hidden_size // 2),
self.activation(),
nn.Linear(self.hidden_size // 2, self.output_size),
nn.Tanh(),
)
.to(device)
.to(precision)
)
@property
def Ron(self):
return torch.exp(self.log_Ron * self.R)
@property
def Roff(self):
return torch.exp(self.log_Roff * self.R)
def G(self, X):
return X / self.Ron + (1 - X) / self.Roff
def current(self, X, v_m):
G = self.G(X)
i_m = G * v_m
return i_m
def forward(self, X, v_m):
nn_input = torch.stack(
[
X,
v_m / self.v_std,
self.current(X, v_m) / self.i_std,
] if self.input_size == 3 else [
X,
v_m / self.v_std,
],
dim=-1,
)
# print(
# f"MemristorNN input shape: {nn_input.shape} from X shape: {X.shape}, v_m shape: {v_m.shape}, i_m shape: {self.current(X, v_m).shape}"
# )
dX = self.ann(nn_input) * self.out_scale
# print(f"MemristorNN output shape: {dX.shape}")
return dX.squeeze(0)
class NeuralDiffEq(nn.Module):
def __init__(
self,
config: OscParams,
x0: float = 0.0,
Ron: float = 1_000,
Roff: float = 100_000,
activation: type[nn.Module] = nn.SiLU,
v_m_std: float = 0.2,
i_m_std: float = 1e-3,
input_size: int = 3,
output_size: int = 1,
hidden_size: int = 128,
) -> None:
super().__init__()
# Neural network for dX/dt
self.dim = 3 # [X, v_m]
self.p = config
self.v_m_std = v_m_std
self.i_m_std = i_m_std
self.mem_model = MemristorNN(
Ron=Ron,
Roff=Roff,
R=1e3,
activation=activation,
hidden_size=hidden_size,
input_size=input_size,
output_size=output_size,
v_std=v_m_std,
i_std=i_m_std,
)
# self.out_scale = nn.Parameter(torch.tensor(1.0, dtype=precision, device=device))
@property
def Ron(self):
return self.mem_model.Ron
@property
def Roff(self):
return self.mem_model.Roff
def forward(self, t, state):
v1, v2, v3, X = state[..., 0], state[..., 1], state[..., 2], state[..., 3]
# Compute memristor voltage and current
vx = -self.p.R11 / self.p.R10 * v3
v_m = vx # - 0.040
i_m = self.mem_model.current(X, v_m)
v4 = vx + self.p.R1 * i_m
# Compute state derivatives
dv1 = 1 / self.p.C1 * (-v1 / self.p.R4 - v2 / self.p.R9 - v4 / self.p.R3)
dv2 = 1 / self.p.C3 * (-v2 / self.p.R8 - v3 / self.p.R7)
dv3 = 1 / self.p.C2 * (-v1 / self.p.R5 - v3 / self.p.R6)
dX = self.mem_model(X, v_m).squeeze(-1)
# print(
# f"Shapes: dv1: {dv1.shape}, dv2: {dv2.shape}, dv3: {dv3.shape}, dX: {dX.shape}"
# )
# print(f"X shape: {X.shape}")
# added constraints
dX = torch.where(X <= 0, torch.maximum(dX, torch.zeros_like(dX)), dX)
dX = torch.where(X >= 1, torch.minimum(dX, torch.zeros_like(dX)), dX)
# print(
# f"Shapes: dv1: {dv1.shape}, dv2: {dv2.shape}, dv3: {dv3.shape}, dX: {dX.shape}"
# )
# print(f"state.shape: {state.shape}")
# Combine all derivatives
dfunc = torch.zeros_like(state)
dfunc[..., 0] = dv1
dfunc[..., 1] = dv2
dfunc[..., 2] = dv3
dfunc[..., 3] = dX
return dfunc
class AutoencoderMemmodel(nn.Module):
def __init__(
self,
config: NeuralDiffeqConfig,
) -> None:
super().__init__()
# Physics Engine
self.neural_diffeq = NeuralDiffEq(
config.osc_params,
v_m_std=config.vm_std,
i_m_std=config.im_std,
Ron=config.nn_memristor_config.R_on,
Roff=config.nn_memristor_config.R_off,
input_size=config.nn_memristor_config.input_size,
)
# State Inference Network
# Input: [v_m, i_m] (Size 2)
# Output: [X] (Size 1)
self.autoencoder = RecurrentAutoEncoder(
input_size=config.auto_encoder_config.input_size, # MUST be 2
output_size=config.auto_encoder_config.output_size, # MUST be 1
hidden_size=config.auto_encoder_config.hidden_size,
lookback=config.auto_encoder_config.lookback,
)
def forward(self, vm_backward, im_backward, exposed_states, t_forward):
"""
vm_backward, im_backward: Observations used to infer the hidden state X.
exposed_states: The known initial voltages [v1, v2, v3] at t=0.
t_forward: The timesteps to simulate.
"""
# 1. Infer the hidden memristor state X
x_state = self.predict_states(vm_backward, im_backward)
# print(f"Inferred X state shape: {x_state.shape}")
# print(f"x_state sample: {x_state}")
# print(f"Exposed states shape: {exposed_states.shape}")
# 2. Combine Known Physics (v1..v3) with Inferred Physics (X)
# exposed_states: [Batch, 3]
# x_state: [Batch, 1]
# state: [Batch, 4]
state = torch.cat([exposed_states, x_state], dim=-1)
# 3. Integrate Forward
# Solves the system starting from the hybrid state
sol = odeint(self.neural_diffeq, state, t_forward, method="dopri5")
return sol
def predict_states(self, v_m, i_m):
# v_m: (batch, lookback), i_m: (batch, lookback)
# Stack to (batch, lookback, 2) — sequence of [vm, im] pairs
x = torch.stack([v_m, i_m], dim=-1) # (batch, lookback, 2)
results = self.autoencoder.encode(x)
return results
def reconstruct_states(self, z):
return self.autoencoder.decode(z)
def calculate_consistency_loss(self, sol, vm_backward, im_backward):
"""
sol: [Time, Batch, 4] -> Starts at t0
vm_backward: [Batch, Lookback] -> Ends at t0
"""
# 1. Unpack Trajectory
v3 = sol[..., 2]
X_ode = sol[..., 3]
# 2. Physics (Prediction)
p = self.neural_diffeq.p
v_m_pred = -p.R11 / p.R10 * v3
i_m_pred = self.neural_diffeq.mem_model.current(X_ode, v_m_pred)
# 3. Stitch Timeline
# vm_backward: [Batch, Lookback] -> Transpose to [Lookback, Batch]
v_m_hist_T = vm_backward.T
i_m_hist_T = im_backward.T
# CRITICAL FIX: The ODE solution 'sol' includes t0 at index 0.
# 'vm_backward' also includes t0 at the last index.
# We slice 'v_m_pred' to start from index 1 (t1) to avoid duplicate t0.
full_vm = torch.cat([v_m_hist_T, v_m_pred[1:]], dim=0)
full_im = torch.cat([i_m_hist_T, i_m_pred[1:]], dim=0)
# 4. Create Sliding Windows
window_size = self.autoencoder.lookback
# Shape: [Num_Windows, Batch, Window_Size]
vm_windows = full_vm.unfold(0, window_size, 1)
im_windows = full_im.unfold(0, window_size, 1)
# 5. Align with Prediction
# We need exactly as many windows as we have time steps in 'sol'.
# The ODE produced 'time_steps' outputs (e.g., 50).
time_steps = X_ode.shape[0]
# We take the last N windows.
# Each window 'i' ends at time step 'i' of the simulation.
vm_seq = vm_windows[-time_steps:]
im_seq = im_windows[-time_steps:]
# 6. Predict X
flat_vm = vm_seq.reshape(-1, window_size)
flat_im = im_seq.reshape(-1, window_size)
X_encoded = self.predict_states(flat_vm, flat_im)
X_encoded = X_encoded.view_as(X_ode)
return X_ode, X_encoded
def smooth_range_penalty(x, min_val=0.0, max_val=1.0, margin=0.1, penalty_weight=1.0):
# Create smooth transition zones
lower_margin = min_val + margin
upper_margin = max_val - margin
# Smooth penalties using sigmoid
below_penalty = torch.sigmoid(-(x - lower_margin) / (margin / 4)) * (
lower_margin - x
)
above_penalty = torch.sigmoid((x - upper_margin) / (margin / 4)) * (
x - upper_margin
)
# Combine penalties
total_penalty = below_penalty + above_penalty
return penalty_weight * torch.mean(total_penalty**2)
def extract_vm_im(
path: Path,
pp: OscParams,
):
df = pd.read_csv(path)
df["time"] = df["time"] - df["time"].min()
v4 = df["v4"].to_numpy()
v3 = df["v3"].to_numpy()
v_m = -pp.R11 / pp.R10 * v3
i_m = (v4 - v_m) / pp.R1
return v_m, i_m
from typing import Optional
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Optional, Tuple
from scipy.interpolate import Akima1DInterpolator
# Assuming global device/precision are defined elsewhere, or pass them in.
# device = ...
# precision = ...
class AutoencoderMemristorData:
def __init__(
self,
path: Path,
osc_params: OscParams,
fs: float = 1e3,
horizon: int = 50,
lookback: int = 10,
train_split: float = 0.7, # 70% Training
val_split: float = 0.2, # 20% Validation (Leaving 10% for Test)
device=None,
precision=None,
columns: List[str] = ["v1", "v2", "v3"],
):
self.device = device if device else torch.device("cpu")
self.precision = precision if precision else torch.float32
self.columns = columns
self.pp = osc_params
self.fs = fs
self.horizon = horizon
self.lookback = lookback
# Calculate boundaries
if train_split + val_split >= 1.0:
raise ValueError("train_split + val_split must be < 1.0 to leave room for test set.")
self.train_split = train_split
self.val_split = val_split
# 1. Load & Process Data
try:
self.df = pd.read_csv(path)
req_cols = columns + ["time"] + ["v4"]
for col in req_cols:
if col not in self.df.columns:
raise ValueError(f"Required column '{col}' missing.")
# Normalize time & Interpolate
self.df["time"] = self.df["time"] - self.df["time"].min()
self.interpolate_df()
# Physics Pre-calculation
v3_arr = self.df["v3"].to_numpy()
v4_arr = self.df["v4"].to_numpy()
vm_arr = -self.pp.R11 / self.pp.R10 * v3_arr
im_arr = (v4_arr - vm_arr) / self.pp.R1
# 2. Create Full Tensors
self.state_tensor = torch.tensor(
self.df[self.columns].to_numpy(),
device=self.device, dtype=self.precision,
)
self.vm_tensor = torch.tensor(
vm_arr, device=self.device, dtype=self.precision
)
self.im_tensor = torch.tensor(
im_arr, device=self.device, dtype=self.precision
)
dt = 1.0 / self.fs
self.t_span = (
torch.linspace(0, dt * (self.horizon - 1), self.horizon)
.to(self.device).to(self.precision)
)
self.data_len = len(self.df)
# 3. Calculate Split Indices
# Train: [0 ... idx_train]
# Val: [idx_train ... idx_val]
# Test: [idx_val ... end]
self.idx_train_end = int(self.data_len * self.train_split)
self.idx_val_end = int(self.data_len * (self.train_split + self.val_split))
# Print stats so you know what you're working with
print(f"Data Split Loaded:")
print(f" Train: 0 -> {self.idx_train_end} ({self.idx_train_end} pts)")
print(f" Val: {self.idx_train_end} -> {self.idx_val_end} ({self.idx_val_end - self.idx_train_end} pts)")
print(f" Test: {self.idx_val_end} -> {self.data_len} ({self.data_len - self.idx_val_end} pts)")
except Exception as e:
raise RuntimeError(f"Error loading data: {str(e)}")
def interpolate_df(self):
"""Interpolate to regular grid."""
x = self.df["time"].to_numpy()
x_new = np.arange(x[0], x[-1], 1 / self.fs)
new_df = pd.DataFrame({"time": x_new})
cols = self.columns + ["v4"]
for col in cols:
akima = Akima1DInterpolator(x, self.df[col].to_numpy())
new_df[col] = akima(x_new)
self.df = new_df
def _get_indices(self, mode: str, batch_size: int, random: bool = True):
"""Internal helper to get indices for a specific split."""
# Define ranges based on mode
if mode == 'train':
start = self.lookback
end = self.idx_train_end - self.horizon
elif mode == 'val':
start = self.idx_train_end
end = self.idx_val_end - self.horizon
elif mode == 'test':
start = self.idx_val_end
end = self.data_len - self.horizon
else:
raise ValueError(f"Unknown mode: {mode}")
# Safety check
if start >= end:
raise ValueError(f"Not enough data for {mode} split (Start: {start}, End: {end}, Horizon: {self.horizon})")
# Select indices
if random:
return torch.randint(start, end, (batch_size,), device=self.device)
else:
# Sequential (full range)
return torch.arange(start, end, device=self.device)
def _gather_batch(self, indices):
"""Extracts tensors from indices."""
# History (Backwards from t0)
offsets = torch.arange(-self.lookback + 1, 1, device=self.device)
lookback_indices = indices.unsqueeze(1) + offsets
vm_hist = self.vm_tensor[lookback_indices]
im_hist = self.im_tensor[lookback_indices]
state_0 = self.state_tensor[indices]
# Future (Forwards from t0)
batch_indices = indices.unsqueeze(1) + torch.arange(
self.horizon, device=self.device
)
targets = self.state_tensor[batch_indices]
return vm_hist, im_hist, state_0, self.t_span, targets
# --- PUBLIC INTERFACE ---
def get_train_batch(self, batch_size: int = 32):
"""Random batch from Training set"""
indices = self._get_indices('train', batch_size, random=True)
return self._gather_batch(indices)
def get_val_batch(self, batch_size: int = 32):
"""Random batch from Validation set"""
indices = self._get_indices('val', batch_size, random=True)
return self._gather_batch(indices)
def get_test_batch(self, batch_size: int = 32):
"""Random batch from Test set"""
indices = self._get_indices('test', batch_size, random=True)
return self._gather_batch(indices)
def get_sequential_batches(self, batch_size: int = 32, mode: str = 'train', stride: int = None):
"""Iterate through a specific split sequentially."""
if stride is None: stride = self.horizon
# Get all valid indices for this split
all_indices = self._get_indices(mode, batch_size=0, random=False)
# Apply stride
# We only take every Nth index to create the sliding window effect
all_indices = all_indices[::stride]
num_batches = (len(all_indices) + batch_size - 1) // batch_size
for i in range(num_batches):
batch_start = i * batch_size
batch_end = min((i + 1) * batch_size, len(all_indices))
if batch_start >= batch_end: break
yield self._gather_batch(all_indices[batch_start:batch_end])
def set_seed(seed: int):
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def plot_data(i, t, y_pred, y_true):
plt.plot(t, y_true[:, 2], label="true")
plt.plot(t, y_pred[:, 2], label="pred")
plt.legend(loc="upper right")
plt.savefig(f"neural_plot/{i}.png", bbox_inches="tight")
plt.close("all")
def smoth_criterion(y_pred, y_true, criterion):
total_loss = 0
for i in range(3):
mean_i = torch.mean(y_true[:, i])
std_i = torch.std(y_true[:, i])
total_loss += criterion(
(y_pred[:, i] - mean_i) / std_i, (y_true[:, i] - mean_i) / std_i
)
return total_loss / 3
def get_gausian_prediction(y_pred, y_true, criterion2: callable):
total_loss = 0
for i in range(3):
total_loss += criterion2(
torch.mean(y_pred[:, i]), y_true[:, i], torch.var(y_pred[:, i])
)
return total_loss / 3
def visualize_performance(model, data_loader, save_path=None, epoch=None):
"""
Generates a 3-panel dashboard:
1. Time Domain (Voltage vs Time)
2. Phase Space (Attractor V1 vs V2)
3. Memristor Physics (Pinched Hysteresis Vm vs Im)
"""
model.eval()
# 1. Get a nice long continuous trajectory (e.g., 500 steps)
# We use stride=None to get non-overlapping chunks, then stitch them
# or just take a large batch_size if horizon is short.
# Better: Just grab the first batch from the sequential loader.
# Let's grab 3 batches to make a longer plot
iterator = data_loader.get_sequential_batches(batch_size=1, stride=None, mode="test")
pred_v1, pred_v2, pred_v3, pred_x = [], [], [], []
real_v1, real_v2, real_v3 = [], [], []
with torch.no_grad():
for i in range(4): # Stitch 4 horizons together (e.g. 4x50 = 200 steps)
vm, im, s0, t, tgt = next(iterator)
# Forward Pass
sol = model(vm, im, s0, t) # [Time, Batch, 4]
# Extract Data (Batch 0)
p = sol[:, 0, :].cpu().numpy() # [Time, 4]
t_vals = tgt[0, :, :].cpu().numpy() # [Time, 3]
pred_v1.extend(p[:, 0])
pred_v2.extend(p[:, 1])
pred_v3.extend(p[:, 2])
pred_x.extend(p[:, 3])
real_v1.extend(t_vals[:, 0])
real_v2.extend(t_vals[:, 1])
real_v3.extend(t_vals[:, 2])
# Convert to numpy arrays
pred_v1, pred_v2, pred_v3 = np.array(pred_v1), np.array(pred_v2), np.array(pred_v3)
real_v1, real_v2, real_v3 = np.array(real_v1), np.array(real_v2), np.array(real_v3)
pred_x = np.array(pred_x)
# --- CALCULATE MEMRISTOR PHYSICS ---
# We need to derive Vm and Im from the predicted voltages to plot hysteresis
# Using the learned parameters from the model
p = model.neural_diffeq.p # OscParams
# Vm = -R11/R10 * V3
pred_vm = -p.R11 / p.R10 * pred_v3
# We can't use Kirchhoff for Im easily without derivative of V1,
# so let's use the Memristor Equation: I = G(x) * V
# Get G(x) from the model's internal function
# Note: We need to convert back to tensor for the model function
x_tensor = torch.tensor(pred_x, device=device, dtype=precision)
vm_tensor = torch.tensor(pred_vm, device=device, dtype=precision)
# Calculate Current
pred_im = (
model.neural_diffeq.mem_model.current(x_tensor, vm_tensor)
.detach()
.cpu()
.numpy()
)
# --- PLOTTING ---
fig = plt.figure(figsize=(18, 10))
plt.suptitle(
f"Neural ODE Performance (Epoch {epoch if epoch else 'Final'})",
fontsize=16,
fontweight="bold",
)
# PANEL 1: Time Series (V1 and V2)
ax1 = plt.subplot2grid((2, 3), (0, 0), colspan=2)
steps = np.arange(len(real_v1))
ax1.plot(steps, real_v1, "k-", alpha=0.3, lw=3, label="Ground Truth (V1)")
ax1.plot(steps, pred_v1, "r--", lw=1.5, label="Prediction (V1)")
ax1.plot(steps, pred_v2, "b--", lw=1.0, alpha=0.5, label="Prediction (V2)")
ax1.set_title("Time Domain Reconstruction", fontsize=12)
ax1.set_xlabel("Time Steps")
ax1.set_ylabel("Voltage (V)")
ax1.legend(loc="upper right")
ax1.grid(True, alpha=0.3)
# PANEL 2: Internal State X
ax2 = plt.subplot2grid((2, 3), (1, 0), colspan=2)
ax2.plot(steps, pred_x, "g-", lw=2, label="Inferred State X")
ax2.set_title("Latent Memristor State (The 'Hidden' Variable)", fontsize=12)
ax2.set_ylabel("State X (0=Ron, 1=Roff)")
# ax2.set_ylim(-0.1, 1.1)
ax2.grid(True, alpha=0.3)
ax2.legend()
# PANEL 3: Phase Portrait (The Chaos)
ax3 = plt.subplot2grid((2, 3), (0, 2))
ax3.plot(real_v1, real_v2, "k-", alpha=0.2, lw=1, label="Real")
ax3.plot(pred_v1, pred_v2, "r-", alpha=0.8, lw=1, label="Pred")
ax3.set_title("Phase Portrait (Attractor)", fontsize=12)
ax3.set_xlabel("V1")
ax3.set_ylabel("V2")
ax3.legend()
# PANEL 4: Hysteresis Loop (The Fingerprint)
ax4 = plt.subplot2grid((2, 3), (1, 2))
ax4.plot(pred_vm, pred_im * 1000, "purple", lw=1.5) # mA
ax4.set_title("Learned Pinched Hysteresis", fontsize=12)
ax4.set_xlabel("Memristor Voltage (V)")
ax4.set_ylabel("Memristor Current (mA)")
ax4.grid(True, alpha=0.3)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
if save_path:
plt.savefig(save_path, dpi=150)
plt.close()
else:
plt.show()
model.train() # Set back to train mode after visualization
def get_mutual_characteristic(
y_pred, y_true, i, folder: Path = Path("neural_mutual_tanh")
):
fig, ax = plt.subplots(1, 3, layout="constrained", figsize=(15, 5))
ax[0].plot(y_true[:, 0], y_true[:, 1], label="true")
ax[0].plot(y_pred[:, 0], y_pred[:, 1], label="pred")
ax[0].legend(loc="upper right")
ax[0].set_title("V1-V2")
ax[1].plot(y_true[:, 0], y_true[:, 2], label="true")
ax[1].plot(y_pred[:, 0], y_pred[:, 2], label="pred")
ax[1].legend(loc="upper right")
ax[1].set_title("V1-V3")
ax[2].plot(y_true[:, 1], y_true[:, 2], label="true")
ax[2].plot(y_pred[:, 1], y_pred[:, 2], label="pred")
ax[2].legend(loc="upper right")
ax[2].set_title("V2-V3")
fig.savefig(folder / f"{i}.png")
def calc_adaptive_lambdas(*losses, eps=1e-8):
"""
Computes dynamic weights (lambdas) for multiple loss components
using the Geometric Mean strategy.
Returns:
weighted_loss (scalar): The sum of weighted losses
lambdas (list): The calculated weights for logging
"""
# Stack losses into a tensor [N]
loss_tensor = torch.stack(losses)
# 1. Calculate Geometric Mean of the losses
# log_mean = sum(log(Li)) / N
# geo_mean = exp(log_mean)
log_losses = torch.log(loss_tensor + eps)
loss_mean = torch.exp(torch.mean(log_losses))
# 2. Calculate Inverse Proportional Weights
# lambda_i = GeoMean / Loss_i
# If Loss_i is huge, lambda_i becomes small (down-weighting it)
# If Loss_i is tiny, lambda_i becomes large (up-weighting it)
lambdas = loss_mean / (loss_tensor + eps)
# 3. Clip weights to prevent instability (e.g. 0.1 to 10.0)
# Detach lambdas so we don't backprop through the weight calculation itself!
lambdas = torch.clamp(lambdas, 0.1, 10.0).detach()
# 4. Compute Final Weighted Sum
weighted_loss = torch.sum(loss_tensor * lambdas)
return weighted_loss, lambdas
def train_within_weights(
weights: np.ndarray, epoch: int = 2000, ratio: float = 0.02
) -> np.ndarray: # type: ignore
params = OscParams(R1=68e3 / 2, R3=6456)
v_m, i_m = extract_vm_im(path, params)
v_ms_std = np.std(v_m)
i_ms_std = np.std(i_m)
from initials import find_rss
Ron, Roff, x = find_rss(v_m, i_m,4)
print(f"Initial Ron = {Ron}, Roff = {Roff}, x0 = {x}")
base_loss = NRME(
normalization_method="range",
norm_axis=(0, 1), # Normalize across all data points for that feature
reduction="mean",
) # type: ignore
auto_encoder_loss = nn.MSELoss()
config = NeuralDiffeqConfig(
osc_params=params,
auto_encoder_config=AutoEncoderConfig(
input_size=2, output_size=1, hidden_size=128, lookback=100
),
nn_memristor_config=NNMemristorConfig(
input_size=2,
hidden_size=128,
output_size=1,
R_on=Ron,
R_off=Roff,
),
vm_std=v_ms_std, # type: ignore
im_std=i_ms_std, # type: ignore
)
model = AutoencoderMemmodel(config).to(device).to(precision)
optimizer = torch.optim.Adam(
[
{"params": model.neural_diffeq.mem_model.ann.parameters(), "lr": 1e-3}, # type: ignore
{
"params": [
model.neural_diffeq.mem_model.log_Ron,
model.neural_diffeq.mem_model.log_Roff,
],
"lr": 1e-5,
}, # type: ignore
{"params": [], "lr": 1e-4}, # type: ignore
]
)
optimizer_opt = torch.optim.Adam(model.autoencoder.parameters(), lr=1e-3)
data_loader = AutoencoderMemristorData(
path,
osc_params=params,
horizon=100,
fs=1e3,
precision=precision,
device=device,
lookback=config.auto_encoder_config.lookback, # type: ignore
train_split=0.6,
val_split=0.2,
)
# --- Configuration ---
model_folder = Path("checkpoints")
model_folder.mkdir(exist_ok=True)
best_model_path = model_folder / "best_model.pth"
# Define Curriculum: (Horizon Length, Epochs)
horizons_epoch = [
(25, 40), # Warm-up: Learn basic short-term dynamics
(50, 200), # Warm-up: Learn basic short-term dynamics
(100, 200), # Warm-up: Learn basic short-term dynamics
(200, 200), # Extend to medium range
(300, 200),
(500, 200),
(1000, 200), # Long-term chaotic tracking
(2000, 200), # Expert: Full trajectory reconstruction
]
full_epoch_counter = 0
best_model = None
# --- OUTER LOOP: Curriculum Phases ---
for horizon, epochs in horizons_epoch:
print(f"\n{'='*60}")
print(f"STARTING PHASE: Horizon {horizon} | Training for {epochs} Epochs")
print(f"{'='*60}")
best_val_loss = float('inf')
if best_model is not None:
model.load_state_dict(best_model['model_state_dict'])
print(f"Loaded best model from previous phase with Val Loss: {best_model['loss']:.6f}")
# 1. Re-initialize Data Loader (Updates Horizon & Split)
# Note: We use a fixed split (e.g., 80% Train, 20% Val)
data_loader = AutoencoderMemristorData(
path,
osc_params=params,
horizon=horizon,
fs=1e3,
precision=precision,
device=device,
lookback=config.auto_encoder_config.lookback, # type: ignore
train_split=0.6,
val_split=0.2
)
# --- INNER LOOP: Epochs ---
for i in range(epochs):
# ==========================================
# TRAINING STEP
# ==========================================
model.train()
# Get Sequential Training Batches (High Overlap for Data Augmentation)
train_gen = data_loader.get_sequential_batches(batch_size=256, stride=50, mode='train')
train_traj_loss = 0.0
train_consis_loss = 0.0
train_count = 0
for vm, im, s0, t, tgt in train_gen:
# Forward
sol = model(vm, im, s0, t)
pred = sol[..., :3].permute(1, 0, 2)
X_ode, X_encode = model.calculate_consistency_loss(sol.detach(), vm, im)
# Loss Calculation
loss_traj = base_loss(pred, tgt)
loss_consis = auto_encoder_loss(X_ode, X_encode) # Optimization
full_loss = loss_traj + loss_consis
optimizer.zero_grad()
optimizer_opt.zero_grad()
full_loss.backward()
# Gradient Clipping for Stability
torch.nn.utils.clip_grad_norm_(model.neural_diffeq.mem_model.parameters(), max_norm=1.0)
# Update only the Physics
optimizer.step()
optimizer_opt.step()
# # ==========================================
# loss_traj = base_loss(pred, tgt)
# # CHANGE: We REMOVED the consistency term here.
# # The ODE now ignores the Encoder. It evolves X purely to satisfy Kirchhoff's laws.
# loss_physics = loss_traj
# optimizer.zero_grad()
# loss_physics.backward() # No retain_graph needed (Graphs are independent now)
# # Gradient Clipping
# torch.nn.utils.clip_grad_norm_(model.neural_diffeq.mem_model.parameters(), max_norm=1.0)
# optimizer.step()
# # ==========================================
# # STEP 2: Train Encoder (Student)
# # ==========================================
# optimizer_opt.zero_grad()
# # Recalculate with DETACHED solution
# # The Encoder treats the ODE's state as the "Label" it must learn.
# X_ode_fixed, X_encoded_new = model.calculate_consistency_loss(sol.detach(), vm, im)
# # Loss = MSE(Target=X_ode, Input=X_encoded)
# loss_consis = auto_encoder_loss(X_ode_fixed, X_encoded_new)
# loss_consis.backward()
# optimizer_opt.step()
# Accumulate
train_traj_loss += loss_traj.item()
train_consis_loss += loss_consis.item()
train_count += 1
avg_train_traj = train_traj_loss / train_count
avg_train_consis = train_consis_loss / train_count
# ==========================================
# VALIDATION STEP
# ==========================================
model.eval()
val_traj_loss = 0.0
val_count = 0
with torch.no_grad():
# Use larger stride for validation to speed it up (less overlap needed)
val_gen = data_loader.get_sequential_batches(batch_size=64, stride=horizon, mode='val')
for vm_v, im_v, s0_v, t_v, tgt_v in val_gen:
sol_v = model(vm_v, im_v, s0_v, t_v)
pred_v = sol_v[..., :3].permute(1, 0, 2)
loss_v = base_loss(pred_v, tgt_v)
val_traj_loss += loss_v.item()
val_count += 1
# Avoid division by zero if val set is small
avg_val_loss = val_traj_loss / val_count if val_count > 0 else 0.0
# ==========================================
# LOGGING & CHECKPOINTING
# ==========================================
# Checkpoint: Save if this is the best model so far
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
torch.save({
'epoch': full_epoch_counter,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': best_val_loss,
'horizon': horizon
}, best_model_path)
best_model = {
'epoch': full_epoch_counter,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': best_val_loss,
'horizon': horizon
}
save_msg = "(*)" # Marker for saved model
else:
save_msg = ""
# Print Status
curr_Ron = model.neural_diffeq.mem_model.Ron.item() # Assuming mapped correctly
curr_Roff = model.neural_diffeq.mem_model.Roff.item() # Assuming mapped correctly
print(
f"Epoch {full_epoch_counter:4d} (H={horizon}) | "
f"Train Traj: {avg_train_traj:.6g} | "
f"Train Consis: {avg_train_consis:.6g} | "
f"Val Loss: {avg_val_loss:.6g} {save_msg} | "
f"Ron: {curr_Ron:.1f} | Roff: {curr_Roff:.1f}"
)
# ==========================================
# VISUALIZATION
# ==========================================
if full_epoch_counter % 20 == 0:
visualize_performance(
model,
data_loader,
save_path=model_folder / f"viz_epoch_{full_epoch_counter:04d}_H{horizon}.png",
epoch=full_epoch_counter,
)
full_epoch_counter += 1
gc.collect() # Clean up memory after training
print(f"\nTraining Complete. Best Validation Loss: {best_val_loss:.6f}")
print(f"Best Model Saved to: {best_model_path}")
if __name__ == "__main__":
set_seed(42)
path = Path("data/w6456.csv")
batch_size = 1
model_folder = Path("torch_autoencoder_memmodel")
model_folder.mkdir(exist_ok=True)
from multi_loss import CombinedLoss # noqa
results = []
weights = np.array([0.8, 0.1, 0.1])
vec = train_within_weights(
weights=weights,
epoch=10_000,
ratio=0.02,
)
results.append(vec)