jax_generalGoff.py
import os

os.environ["XLA_FLAGS"] = (
    "--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=20"
)
# os.environ['JAX_NUM_CPU_DEVICES'] = '10'
import re
from collections import Counter
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, List, Union
import gc
from loguru import logger
import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import jax.random as jr
import joblib
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax  # https://github.com/deepmind/optax
import optax.contrib as contrib
import optuna
import pandas as pd
from jax.scipy.special import expit  # JAX's sigmoid function
from loguru import logger
from optax.tree_utils import tree_norm
from optuna import Trial
from pydantic import BaseModel
from scipy.signal import savgol_filter

from utils import (generate_neural_stack, Dopant, FileMeta, find_periods, average_data, mean_data, unpack_data, export_signals, MemData, get_train_test_indices_from_req, plot_results, get_batches)



import os

os.environ["XLA_FLAGS"] = (
    "--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=20\
    --xla_cpu_enable_fast_math=true "
)


matplotlib.use("svg")  # Use Agg backend for matplotlib to avoid GUI issues
print(f"JAX devices: {jax.devices()}")
# exit()


class MLPG(eqx.Module):
    log_Goff: jax.Array
    log_Gon: jax.Array

    def __init__(self, Goff=1 / jnp.array(60.0), Gon=1 / jnp.array(0.5), **kwargs):
        super().__init__(**kwargs)
        self.log_Goff = jnp.log(Goff)
        self.log_Gon = jnp.log(Gon)

    @property
    def Goff(self):
        return jnp.exp(self.log_Goff)

    @property
    def Gon(self):
        return jnp.exp(self.log_Gon)

    def __call__(self, x):
        return x * self.Gon + (1 - x) * self.Goff


class Func(eqx.Module):
    out_scale: jax.Array
    x0: jax.Array
    Rs: float
    mlp: eqx.nn.MLP
    std_v: float
    mlpg: eqx.Module

    def __init__(
        self,
        *,
        key,
        width_mplx,
        depth_mplx,
        mlpx_activation,
        mlpx_final_activation,
        latent_states=1,
        std_v=1.0,
        with_norm=True,
        diamond_structure=True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.out_scale = jnp.array([1.0], dtype=jnp.float32)
        if diamond_structure:
            self.mlp = generate_neural_stack(
                n_layers=depth_mplx,
                width=width_mplx,
                activation=mlpx_activation,
                final_activation=mlpx_final_activation,
                in_size=latent_states + 1,
                out_size=latent_states,
                key=key,
                layer_norm=with_norm,
            )
        else:
            self.mlp = eqx.nn.MLP(
                in_size=latent_states + 1,  # v_m + x
                out_size=latent_states,
                width_size=width_mplx,
                depth=depth_mplx,
                activation=mlpx_activation,
                final_activation=mlpx_final_activation,
                key=key,
            )
        # self.mlp = eqx.nn.MLP(
        #     in_size=latent_states + 1,  # v_m + x + t
        #     out_size=latent_states,
        #     width_size=width_mplx,
        #     depth=depth_mplx,
        #     activation=mlpx_activation,
        #     final_activation=mlpx_final_activation,
        #     key=key,
        # )
       

        # self.mlgp = generate_neural_stack(
        #     n_layers=depth_mlpg,
        #     width=width_mlpg,
        #     activation=mlpg_activation,
        #     final_activation=mlpg_final_activation,
        #     in_size=latent_states,
        #     out_size=1,
        #     key=jr.fold_in(key, 1),
        #     layer_norm=with_norm,
        # )

        self.mlpg = MLPG()
        self.Rs = 5.11
        self.x0 = jnp.array([0.001] * latent_states)
        self.std_v = std_v

    @property
    def Ron(self):
        return 1 / self.Gon

    @property
    def Roff(self):
        return 1 / self.mlpg.Goff

    def __call__(self, t, y, args):
        amp, freq = args
        vs = jnp.sin(freq * t * 2 * jnp.pi) * amp
        X = y

        G = self.mlpg(X)
        v_m = vs / (self.Rs * G + 1)

        dX_input = jnp.hstack([v_m / self.std_v, X])
        return self.mlp(dX_input) * self.out_scale


class NeuralODE(eqx.Module):
    func: Func

    def __init__(
        self,
        *,
        width_mplx,
        depth_mplx,
        mlpx_activation,
        mlpx_final_activation,
        key,
        latent_states=1,
        std_v=1.0,
        with_norm=True,
        diamond_structure=True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.func = Func(
            key=key,
            width_mplx=width_mplx,
            depth_mplx=depth_mplx,
            mlpx_activation=mlpx_activation,
            mlpx_final_activation=mlpx_final_activation,
            latent_states=latent_states,
            std_v=std_v,
            with_norm=with_norm,
            diamond_structure=diamond_structure,
        )

    def __call__(self, ts, amp, freq):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=self.func.x0,
            # stepsize_controller=diffrax.PIDController(rtol=1e-7, atol=1e-9),
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
            args=(amp, freq),
            max_steps=2048 * 8,
            # adjoint=diffrax.BacksolveAdjoint(solver=diffrax.Tsit5()) # BacksolveAdjoint, InterpolationAdjoint
            adjoint=diffrax.RecursiveCheckpointAdjoint(checkpoints=2048* 3),
        )
        # jax.debug.print("Function done ")
        x = solution.ys
        G = self.func.mlpg(x)
        v_s = jnp.reshape(jnp.sin(freq * ts * 2 * jnp.pi) * amp, (-1, 1))
        v_m = v_s / (self.func.Rs * G + 1)
        i_m = v_m * G
        return v_m, i_m, x, G


def get_trial_dict(trial: Trial) -> Dict[str, Any]:
    activation_dict = {
        "relu": jax.nn.relu,
        "gelu": jax.nn.gelu,
        "silu": jax.nn.silu,
        "tanh": jax.nn.tanh,
        "elu": jax.nn.elu,
        "leaky_relu": jax.nn.leaky_relu,
        "sigmoid": jax.nn.sigmoid,
        "sin": jnp.sin,
        "linear": jax.nn.identity,
    }
    optimizer_dict = {
        "adam": optax.adam,
        # "sgd": optax.sgd,
        # "adamw": optax.adamw,
        "nadamw": optax.nadamw,
        "adabelief": optax.adabelief,
        "nadam": optax.nadam,
    }

    trial_dict = {
        "mlpx_width_size": trial.suggest_int("mlpx_width_size", 16, 512, step=16),
        "mlpx_depth": trial.suggest_int("mlpx_depth", 1, 5, step=1),
        "mlpx_activation": trial.suggest_categorical(
            "mlpx_activation", list(activation_dict.keys())
        ),
        "mlpx_final_activation": trial.suggest_categorical(
            "mlpx_final_activation", ["linear", "tanh"]
        ),
        "learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True),
        "mlpg_learning_rate": trial.suggest_float(
            "mlpg_learning_rate", 1e-4, 1e-1, log=True
        ),
        "clip_value": trial.suggest_float("clip_value", 0.1, 10.0, step=0.1),
        "eps": trial.suggest_float("eps", 1e-8, 1e-4, log=True),
        "with_regularization": trial.suggest_categorical(
            "with_regularization", [True, False]
        ),
        "mlpx_with_norm": trial.suggest_categorical("mlpx_with_norm", [True, False]),
        "regularization_weight": trial.suggest_float(
            "regularization_weight", 1e-6, 1e-2, log=True
        ),
        "patience": trial.suggest_int("patience", 10, 200, step=10),
        "cooldown": trial.suggest_int("cooldown", 0, 100, step=10),
        "factor": trial.suggest_float("factor", 0.1, 0.99, step=0.1),
        "rtol": trial.suggest_float("rtol", 1e-6, 1e-2, log=True),
        "accumulation_size": trial.suggest_int("accumulation_size", 1, 100, step=1),
        "length_step": trial.suggest_float("length_step", 0.2, 0.6),
        "intermediate_epochs": trial.suggest_int("intermediate_epochs", 1, 100, step=1),
        "last_epochs": trial.suggest_int("last_epochs", 1, 1000, step=10),
        "optimizer": trial.suggest_categorical(
            "optimizer", list(optimizer_dict.keys())
        ),
        # "latent_states": trial.suggest_int("latent_states", 1, 6, step=1),
        "batch_size": trial.suggest_int("batch_size", 1, 18, step=1),
        "diamond_structure": trial.suggest_categorical(
            "diamond_structure", [True, False]
        ),
    }
    print(trial_dict["learning_rate"])
    for key in trial_dict:
        if key.endswith("_activation") or key.endswith("_final_activation"):
            trial_dict[key] = activation_dict[trial_dict[key]]
        elif key == "optimizer":
            trial_dict[key] = optimizer_dict[trial_dict[key]]

    length_scale = np.arange(trial_dict["length_step"], 1.0, trial_dict["length_step"])
    trial_dict["length_scales"] = [
        (l_s, trial_dict["intermediate_epochs"]) for l_s in length_scale if l_s < 1.0
    ]
    trial_dict["length_scales"] += [(1.0, trial_dict["last_epochs"])]
    trial_dict["trial_number"] = trial.number
    return trial_dict




def get_batches(
    length: int,
    batch_size: int,
    shuffle: bool = True,
    dataloader_key: jr.PRNGKey = jr.PRNGKey(0),
) -> Generator[List[int], None, None]:
    if shuffle:
        indices = jr.permutation(dataloader_key, jnp.arange(length))
    else:
        indices = jnp.arange(length)

    for i in range(0, length, batch_size):
        batch_indices = indices[i : i + batch_size]
        yield batch_indices


def train_mem(
    md: FileMeta,
    trial_dict: Dict[str, Any],
):
    key = jr.PRNGKey(5678)

    data_key, model_key, loader_key = jr.split(key, 3)

    amps = jnp.array(
        [md.file_meta[file_index].amp for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    ).reshape(-1, 1)
    freqs = jnp.array(
        [md.file_meta[file_index].freq for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    ).reshape(-1, 1)
    v_ms = jnp.array(
        [md.file_meta[file_index].v_m for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    )
    i_ms = jnp.array(
        [md.file_meta[file_index].i_m for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    )
    tss = jnp.array(
        [md.file_meta[file_index].t for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    )

    train_ind, test_ind = get_train_test_indices_from_req(md)

    test_amps = amps[test_ind, :]
    test_freqs = freqs[test_ind, :]
    test_v_ms = v_ms[test_ind, :]
    test_i_ms = i_ms[test_ind, :]
    test_tss = tss[test_ind, :]
    amps = amps[train_ind, :]
    freqs = freqs[train_ind, :]
    v_ms = v_ms[train_ind, :]
    i_ms = i_ms[train_ind, :]
    tss = tss[train_ind, :]

    std_v = jnp.mean(jnp.std(v_ms, axis=1))

    model = NeuralODE(
        key=model_key,
        width_mplx=trial_dict["mlpx_width_size"],
        depth_mplx=trial_dict["mlpx_depth"],
        mlpx_activation=trial_dict["mlpx_activation"],
        mlpx_final_activation=trial_dict["mlpx_final_activation"],
        std_v=std_v,
        with_norm=trial_dict["mlpx_with_norm"],
        diamond_structure=trial_dict["diamond_structure"],
    )

    def create_param_labels(params):
        """Create labels for different parameter groups"""

        def label_fn(path, param):
            path_str = "/".join(str(p) for p in path)

            if "mlp" in path_str:
                return "mlp"
            elif "mlpg" in path_str:
                return "mlpg"
            elif "out_scale" in path_str:
                return "scaling"
            elif "x0" in path_str:
                return "init"
            else:
                return "other"

        return jax.tree_util.tree_map_with_path(label_fn, params)

    base_transforms = {
        "mlpg": trial_dict["optimizer"](
            learning_rate=trial_dict["mlpg_learning_rate"]
            if "mlpg_learning_rate" in trial_dict
            else trial_dict["learning_rate"],
        ),
        "mlp": trial_dict["optimizer"](
            learning_rate=trial_dict["mlp_learning_rate"]
            if "mlp_learning_rate" in trial_dict
            else trial_dict["learning_rate"],
        ),
        "init": trial_dict["optimizer"](
            learning_rate=trial_dict["init_learning_rate"]
            if "init_learning_rate" in trial_dict
            else trial_dict["learning_rate"],
        ),  # Freeze initial condition
        "other": trial_dict["optimizer"](
            learning_rate=trial_dict["learning_rate"],
        ),
        "scaling": trial_dict["optimizer"](
            learning_rate=trial_dict["learning_rate"],
        ),
    }

    partitioned_optim = optax.partition(base_transforms, create_param_labels)

    optim = optax.chain(
        optax.clip_by_global_norm(
            trial_dict["clip_value"]
        ),  # najpierw obcinanie gradientów
        partitioned_optim,
        contrib.reduce_on_plateau(
            patience=trial_dict["patience"],
            cooldown=trial_dict["cooldown"],
            factor=trial_dict["factor"],
            rtol=trial_dict["rtol"],
            accumulation_size=trial_dict["accumulation_size"],
        ),
    )
    params = eqx.filter(model, eqx.is_inexact_array)
    # print(f"Model parameters: {params}")
    opt_state = optim.init(params)

    print(f"{amps.shape=}, {freqs.shape=}, {v_ms.shape=}, {i_ms.shape=}, {tss.shape=}")

    @eqx.filter_jit
    def get_dx_dt(
        vs: jax.Array, xs: jax.Array, ts: jax.Array, amps: jax.Array, freqs: jax.Array
    ) -> jax.Array:
        def _get_dx(v_rov, x_row, t_row, amp, freq):
            return jax.vmap(lambda v, x, t: model.func(t, x, (amp, freq)))(
                v_rov, x_row, t_row
            )

        result = jax.vmap(_get_dx, in_axes=(0, 0, 0, 0, 0))(vs, xs, ts, amps, freqs)
        return result.squeeze() if result.ndim > 2 else result

    @eqx.filter_jit
    def dx_dt_loss(
        v_m: jax.Array, i_m: jax.Array, dx_dt: jax.Array, penalty_weight: float = 1.0
    ) -> jax.Array:
        """
        Enforce monotonicity constraint:
        - When v_m > 0: require dx_dt >= 0 (x should increase)
        - When v_m < 0: require dx_dt <= 0 (x should decrease)
        - When v_m = 0: no constraint (dx_dt can be anything)

        Penalizes violations: (v_m > 0 and dx_dt < 0) or (v_m < 0 and dx_dt > 0)
        """
        # The product v_m * dx_dt should be non-negative
        # When both have same sign: product > 0 (good)
        # When opposite signs: product < 0 (bad, penalize)
        product = v_m * dx_dt

        # Penalize negative products (sign mismatch)
        violation = jnp.maximum(0.0, -product)

        return penalty_weight * jnp.mean(jnp.square(violation))

    @eqx.filter_jit
    def smooth_range_penalty_jit(
        x, min_val=0.0, max_val=1.0, margin=0.1, penalty_weight=1.0
    ):
        """JIT-compiled version of smooth_range_penalty for faster execution."""
        lower_margin = min_val + margin
        upper_margin = max_val - margin

        below_penalty = expit(-(x - lower_margin) / (margin / 4)) * (lower_margin - x)
        above_penalty = expit((x - upper_margin) / (margin / 4)) * (x - upper_margin)

        total_penalty = below_penalty + above_penalty

        return penalty_weight * jnp.mean(total_penalty**2)

    @eqx.filter_jit
    def improved_loss_fn(
        model: NeuralODE,
        ts,
        v_m,
        i_m,
        freqs,
        amps,
        lambdas=[],
        eps=1e-8,
        adaptive_weights=True,
        params: eqx.Module = None,
        with_regularization=True,
        regularization_weight=1e-4,
    ):
        # batched_model = eqx.filter_jit(model)
        batched_forward = eqx.filter_vmap(model, in_axes=(0, 0, 0))(ts, amps, freqs)
        _v_m_pred, _i_m_pred, _x, _G = batched_forward

        v_m_pred = _v_m_pred.squeeze() if _v_m_pred.ndim > 2 else _v_m_pred
        i_m_pred = _i_m_pred.squeeze() if _i_m_pred.ndim > 2 else _i_m_pred
        x = _x.squeeze() if _x.ndim > 2 else _x
        _G = _G.squeeze() if _G.ndim > 2 else _G

        if v_m_pred.ndim < 2:
            v_m_pred = v_m_pred.reshape(1, -1)
            i_m_pred = i_m_pred.reshape(1, -1)
            x = x.reshape(1, -1)
            _G = _G.reshape(1, -1)

        targets = jnp.vstack([v_m, i_m])  # Shape: (2 * batch_size, time_steps)
        preds = jnp.vstack([v_m_pred, i_m_pred])  # Shape: (2 * batch_size, time_steps)

        targets_d1 = jnp.diff(targets, axis=1)
        preds_d1 = jnp.diff(preds, axis=1)
        targets_d2 = jnp.diff(targets_d1, axis=1)
        preds_d2 = jnp.diff(preds_d1, axis=1)

        # dx_dt = get_dx_dt(v_m_pred, x, ts, amps, freqs)

        # Vectorized standardization
        # @eqx.filter_jit
        def standardize_loss(target, pred, with_center=True):
            center = target.shape[0] // 2
            target_std = jnp.std(target, axis=1, keepdims=True)
            target_mean = jnp.mean(target, axis=1, keepdims=True)
            target_norm = (target - target_mean) / (target_std + 1e-8)
            pred_norm = (pred - target_mean) / (target_std + 1e-8)
            return (
                jnp.mean((target_norm[:center] - pred_norm[:center]) ** 2)
                + jnp.mean((target_norm[center:] - pred_norm[center:]) ** 2)
                if not with_center
                else jnp.mean((target_norm - pred_norm) ** 2)
            )

        # Multi-scale loss
        loss_0 = standardize_loss(targets, preds)
        loss_1 = standardize_loss(targets_d1, preds_d1)
        loss_2 = standardize_loss(targets_d2, preds_d2)
        # print(f"{loss_0.item()=}, {loss_1.item()=}, {loss_2.item()=}")

        if adaptive_weights:
            loses = jnp.array([loss_0, loss_1, loss_2])
            loss_mean = jnp.exp((jnp.mean(jnp.log(loses + eps))))
            lambdas = loss_mean / (loses + eps)
            lambdas = jnp.clip(lambdas, 0.1, 10.0)

        total_loss = lambdas[0] * loss_0 + lambdas[1] * loss_1 + lambdas[2] * loss_2
        total_loss += smooth_range_penalty_jit(
            _G, min_val=0.0, max_val=5.0, margin=0.1, penalty_weight=10.0
        )
        total_loss += smooth_range_penalty_jit(
            x, min_val=0.0, max_val=1.0, margin=0.01, penalty_weight=10.0
        )
        # total_loss += dx_dt_loss(v_m_pred, i_m_pred, dx_dt, penalty_weight=10.0)
        sqnorm = tree_norm(params, squared=True)
        # jax.debug.print("Regularization weight: {regularization_weight}, Squared norm: {sqnorm}", regularization_weight=regularization_weight, sqnorm=sqnorm)
        return total_loss + regularization_weight * sqnorm * with_regularization

    @eqx.filter_jit
    def finall_loss(
        model: NeuralODE,
        ts,
        v_m,
        i_m,
        freqs,
        amps,
    ):
        _v_m_pred, _i_m_pred, _x, _G = jax.vmap(model, in_axes=(0, 0, 0))(
            ts, amps, freqs
        )

        # print(f"{_v_m_pred.shape=}, {_i_m_pred.shape=}, {_x.shape=}, {_G.shape=}")
        v_m_pred = _v_m_pred.squeeze() if _v_m_pred.ndim > 2 else _v_m_pred
        i_m_pred = _i_m_pred.squeeze() if _i_m_pred.ndim > 2 else _i_m_pred
        x = _x.squeeze() if _x.ndim > 2 else _x
        _G = _G.squeeze() if _G.ndim > 2 else _G

        targets = jnp.vstack([v_m, i_m])  # Shape: (2 * batch_size, time_steps)
        preds = jnp.vstack([v_m_pred, i_m_pred])  # Shape: (2 * batch_size, time_steps)

        # Compute all derivatives at once using jnp.diff with axis parameter
        targets_d1 = jnp.diff(targets, axis=1)
        preds_d1 = jnp.diff(preds, axis=1)
        targets_d2 = jnp.diff(targets_d1, axis=1)
        preds_d2 = jnp.diff(preds_d1, axis=1)

        def standardize_loss(target, pred, with_center=True):
            center = target.shape[0] // 2
            target_std = jnp.std(target, axis=1, keepdims=True)
            target_mean = jnp.mean(target, axis=1, keepdims=True)
            target_norm = (target - target_mean) / (target_std + 1e-8)
            pred_norm = (pred - target_mean) / (target_std + 1e-8)
            return (
                jnp.mean((target_norm[:center] - pred_norm[:center]) ** 2)
                + jnp.mean((target_norm[center:] - pred_norm[center:]) ** 2)
                if not with_center
                else jnp.mean((target_norm - pred_norm) ** 2)
            )

        # Multi-scale loss
        loss_0 = standardize_loss(targets, preds)
        loss_1 = standardize_loss(targets_d1, preds_d1)
        loss_2 = standardize_loss(targets_d2, preds_d2)

        return (
            loss_0,
            loss_1,
            loss_2,
            _v_m_pred,
            _i_m_pred,
            _G,
        )

    @eqx.filter_jit
    def make_step(model, ts, v_m, i_m, freqs, amps, opt_state, params):
        loss, grads = eqx.filter_value_and_grad(improved_loss_fn)(
            model,
            ts,
            v_m,
            i_m,
            freqs,
            amps,
            params=params,
            with_regularization=trial_dict["with_regularization"],
            regularization_weight=trial_dict["regularization_weight"],
        )
        updates, opt_state = optim.update(grads, opt_state, params=params, value=loss)
        model = eqx.apply_updates(model, updates)
        return model, opt_state, loss

    best_loss = float("inf")
    best_model = None

    for length_scale, epochs in trial_dict["length_scales"]:
        print(f"Training with length scale: {length_scale}, epochs: {epochs}")
        model = (
            best_model if best_model is not None else model
        )  # Use the best model if available
        params = eqx.filter(model, eqx.is_inexact_array)
        opt_state = optim.init(params)
        length = int(v_ms.shape[1] * length_scale)
        v_m = v_ms[:, :length]
        i_m = i_ms[:, :length]
        ts = tss[:, :length]
        best_loss = float("inf")
        best_model = None

        for i in range(epochs):
            try:
                epoch_loss = 0.0
                for batch_indices in get_batches(
                    v_m.shape[0],
                    trial_dict["batch_size"],
                    dataloader_key=data_key,
                ):
                    batch_v_m = v_m[batch_indices, :]
                    batch_i_m = i_m[batch_indices, :]
                    batch_ts = ts[batch_indices, :]
                    batch_freqs = freqs[batch_indices, :]
                    batch_amps = amps[batch_indices, :]

                    model, opt_state, batch_loss = make_step(
                        model,
                        batch_ts,
                        batch_v_m,
                        batch_i_m,
                        batch_freqs,
                        batch_amps,
                        opt_state,
                        params,
                    )
                    epoch_loss += batch_loss / float(len(batch_indices))
                    # print(f"Step {i}, Batch Loss: {batch_loss:.4f}")
                    
                *losses, _v_m_pred, _i_m_pred, _G = finall_loss(
                    model, test_tss, test_v_ms, test_i_ms, test_freqs, test_amps
                )

                if i % 100 == 0:
                    print(f"Step {i}, Loss: {epoch_loss:.4f}")

                epoch_loss = losses[0]  # Use primary loss for best model tracking

                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model = deepcopy(model)
                    # print(f"New best model at step {i} with loss {best_loss:.4f}")
                # except eqx.EquinoxRuntimeError:
                #     losses = (jnp.inf, jnp.inf, jnp.inf)
                #     return tuple(
                #         [loss.item() if hasattr(loss, "item") else loss for loss in losses]
                #     )
                # logger.debug(
                #     f"Epoch {i}, Loss: {epoch_loss}, Gon: {model.func.mlpg.Goff}, Goff: {model.func.mlpg.Gon}"
                # )
            except Exception as e:
                print(f"Error at step {i}: {e}")
                logger.exception(e)
                break

    try:
        *losses, _v_m_pred, _i_m_pred, _G = finall_loss(
            best_model, test_tss, test_v_ms, test_i_ms, test_freqs, test_amps
        )
        try:
            fig, ax = plot_results(
                test_tss,
                test_v_ms,
                test_i_ms,
                _v_m_pred,
                _i_m_pred,
                _G,
            )

            fig.savefig(
                f"mutli_neural_plot_fixed_mlpg_Ron/{trial_dict['trial_number']}.pdf",
                bbox_inches="tight",
            )
            plt.close()
        except Exception as e:
            print(f"Error saving figure: {e}")

    except Exception as e:
        print(f"Error in final loss computation: {e}")
        losses = (jnp.inf, jnp.inf, jnp.inf)
    return tuple([loss.item() if hasattr(loss, "item") else loss for loss in losses])


md = MemData(Dopant.Tungsten, ppp=2000)


def objective(file_meta):
    def _objective(trial: Trial) -> float:
        trial_dict = get_trial_dict(trial)
        joblib.dump(sampler, SAMPLER_PATH)
        return train_mem(file_meta, trial_dict)

    return _objective


SAMPLER_PATH = Path("jax_sampler_3obj_Goff.pkl")


study_name = f"NeuralODE_{md.dopant.name}_ppp{md.ppp}_Goff_scale"
storage_name = f"sqlite:///optuna_studies/{study_name}.db"
if not SAMPLER_PATH.exists():
    sampler = optuna.samplers.TPESampler(
        seed=42,
        multivariate=True,
        n_startup_trials=20,
        n_ei_candidates=44,
    )
    joblib.dump(sampler, SAMPLER_PATH)
else:
    sampler = joblib.load(SAMPLER_PATH)


class MemoryCleanupCallback:
    def __call__(self, study, trial):
        # Clear JAX caches after each trial
        jax.clear_caches()
        gc.collect()

        # Optional: Print memory usage
        import psutil

        process = psutil.Process()
        mem_mb = process.memory_info().rss / 1024 / 1024
        logger.info(f"Trial {trial.number} completed. Memory: {mem_mb:.2f} MB")


study = optuna.create_study(
    directions=["minimize", "minimize", "minimize"],
    sampler=sampler,
    study_name=study_name,
    storage=storage_name,
    load_if_exists=True,
)

study.optimize(
    objective(md),
    gc_after_trial=True,
    n_trials=200,
    n_jobs=1,
    callbacks=[MemoryCleanupCallback()],
)