jax_general_optuna_3obj_test_set.py
import re
from collections import Counter
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, List, Union

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import jax.random as jr
import joblib
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax  # https://github.com/deepmind/optax
import optax.contrib as contrib
import optuna
import pandas as pd
from jax.scipy.special import expit  # JAX's sigmoid function
from loguru import logger
from optax.tree_utils import tree_norm
from optuna import Trial
from pydantic import BaseModel
from scipy.signal import savgol_filter
from typing import Tuple
import time

from utils import (generate_neural_stack, Dopant, FileMeta, find_periods, average_data, mean_data, unpack_data, export_signals, MemData, get_train_test_indices_from_req, Func, NeuralODE, plot_results, get_batches)



import os

os.environ["XLA_FLAGS"] = (
    "--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=20\
    --xla_cpu_enable_fast_math=true "
)



def get_trial_dict(trial: Trial) -> Dict[str, Any]:
    activation_dict = {
        "relu": jax.nn.relu,
        "gelu": jax.nn.gelu,
        "silu": jax.nn.silu,
        "tanh": jax.nn.tanh,
        "elu": jax.nn.elu,
        "leaky_relu": jax.nn.leaky_relu,
        "sigmoid": jax.nn.sigmoid,
        "sin": jnp.sin,
        "linear": jax.nn.identity,
    }
    optimizer_dict = {
        "adam": optax.adam,
        # "sgd": optax.sgd,
        "adamw": optax.adamw,
        "nadamw": optax.nadamw,
        "adabelief": optax.adabelief,
        # "adamw": optax.adamw,
    }

    trial_dict = {
        "mlpx_width_size": trial.suggest_int("mlpx_width_size", 16, 512, step=16),
        "mlpx_depth": trial.suggest_int("mlpx_depth", 1, 5, step=1),
        "mlpx_activation": trial.suggest_categorical(
            "mlpx_activation", list(activation_dict.keys())
        ),
        "mlpx_final_activation": trial.suggest_categorical(
            "mlpx_final_activation", ["linear", "tanh", "sigmoid"]
        ),
        "mlpg_width_size": trial.suggest_int("mlpg_width_size", 16, 512, step=16),
        "mlpg_depth": trial.suggest_int("mlpg_depth", 1, 5, step=1),
        "mlpg_activation": trial.suggest_categorical(
            "mlpg_activation", list(activation_dict.keys())
        ),
        "mlpg_final_activation": trial.suggest_categorical(
            "mlpg_final_activation", ["sigmoid"]
        ),
        "learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True),
        "clip_value": trial.suggest_float("clip_value", 0.1, 10.0, step=0.1),
        # "gfunc_learning_rate": trial.suggest_float(
        #     "gfunc_learning_rate", 1e-4, 1e-2, log=True
        # ),
        # "scaling_learning_rate": trial.suggest_float(
        #     "scaling_learning_rate", 1e-4, 1e-2, log=True
        # ),
        # "init_learning_rate": trial.suggest_float(
        #     "init_learning_rate", 1e-4, 1e-1, log=True
        # ),
        "eps": trial.suggest_float("eps", 1e-8, 1e-4, log=True),
        "with_regularization": trial.suggest_categorical(
            "with_regularization", [True, False]
        ),
        "regularization_weight": trial.suggest_float(
            "regularization_weight", 1e-6, 1e-2, log=True
        ),
        "patience": trial.suggest_int("patience", 10, 200, step=10),
        "cooldown": trial.suggest_int("cooldown", 0, 100, step=10),
        "factor": trial.suggest_float("factor", 0.1, 0.99, step=0.1),
        "rtol": trial.suggest_float("rtol", 1e-6, 1e-2, log=True),
        "accumulation_size": trial.suggest_int("accumulation_size", 1, 100, step=1),
        "length_step": trial.suggest_float("length_step", 0.2, 0.6),
        "intermediate_epochs": trial.suggest_int("intermediate_epochs", 1, 100, step=1),
        "last_epochs": trial.suggest_int("last_epochs", 1, 1000, step=10),
        "optimizer": trial.suggest_categorical(
            "optimizer", list(optimizer_dict.keys())
        ),
        "latent_states": trial.suggest_int("latent_states", 1, 6, step=1),
        "batch_size": trial.suggest_int("batch_size", 1, 18, step=1),
        "diamond_structure": trial.suggest_categorical(
            "diamond_structure", [True, False]
        ),
        "with_norm": trial.suggest_categorical("with_norm", [True, False]),
        # "use_adaptive_weights": trial.suggest_categorical(
        #     "use_adaptive_weights", [True, False]
        # ),
        # "lambda1": trial.suggest_float("lambda1", 0, 1, step=0.1),
        # "lambda2": trial.suggest_float("lambda2", 0, 1, step=0.1),
        # "lambda3": trial.suggest_float("lambda3", 0, 1, step=0.1),
    }
    print(trial_dict["learning_rate"])
    for key in trial_dict:
        if key.endswith("_activation") or key.endswith("_final_activation"):
            trial_dict[key] = activation_dict[trial_dict[key]]
        elif key == "optimizer":
            trial_dict[key] = optimizer_dict[trial_dict[key]]

    length_scale = np.arange(trial_dict["length_step"], 1.0, trial_dict["length_step"])
    trial_dict["length_scales"] = [
        (l_s, trial_dict["intermediate_epochs"]) for l_s in length_scale if l_s < 1.0
    ]
    trial_dict["length_scales"] += [(1.0, trial_dict["last_epochs"])]
    trial_dict["trial_number"] = trial.number
    return trial_dict


def get_batches(
    length: int,
    batch_size: int,
    shuffle: bool = True,
    dataloader_key: jr.PRNGKey = jr.PRNGKey(0),
) -> Generator[List[int], None, None]:
    if shuffle:
        indices = jr.permutation(dataloader_key, jnp.arange(length))
    else:
        indices = jnp.arange(length)

    for i in range(0, length, batch_size):
        batch_indices = indices[i : i + batch_size]
        yield batch_indices


def train_mem(
    md: FileMeta,
    trial_dict: Dict[str, Any],
):
    key = jr.PRNGKey(5678)

    data_key, model_key, loader_key = jr.split(key, 3)

    amps = jnp.array(
        [md.file_meta[file_index].amp for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    ).reshape(-1, 1)
    freqs = jnp.array(
        [md.file_meta[file_index].freq for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    ).reshape(-1, 1)
    v_ms = jnp.array(
        [md.file_meta[file_index].v_m for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    )
    i_ms = jnp.array(
        [md.file_meta[file_index].i_m for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    )
    tss = jnp.array(
        [md.file_meta[file_index].t for file_index in range(len(md.file_meta))],
        dtype=jnp.float32,
    )

    train_ind, test_ind = get_train_test_indices_from_req(md)

    test_amps = amps[test_ind, :]
    test_freqs = freqs[test_ind, :]
    test_v_ms = v_ms[test_ind, :]
    test_i_ms = i_ms[test_ind, :]
    test_tss = tss[test_ind, :]
    amps = amps[train_ind, :]
    freqs = freqs[train_ind, :]
    v_ms = v_ms[train_ind, :]
    i_ms = i_ms[train_ind, :]
    tss = tss[train_ind, :]

    std_v = jnp.mean(jnp.std(v_ms, axis=1))

    model = NeuralODE(
        key=model_key,
        width_mplx=trial_dict["mlpx_width_size"],
        depth_mplx=trial_dict["mlpx_depth"],
        width_mlpg=trial_dict["mlpg_width_size"],
        depth_mlpg=trial_dict["mlpg_depth"],
        mlpx_activation=trial_dict["mlpx_activation"],
        mlpx_final_activation=trial_dict["mlpx_final_activation"],
        mlpg_activation=trial_dict["mlpg_activation"],
        mlpg_final_activation=trial_dict["mlpg_final_activation"],
        latent_states=trial_dict["latent_states"],
        std_v=std_v,
        with_norm=trial_dict["with_norm"],
        diamond_structure=trial_dict["diamond_structure"],
    )
    optim = optax.chain(
        optax.clip_by_global_norm(
            trial_dict["clip_value"]
        ),  # najpierw obcinanie gradientów
        trial_dict["optimizer"](
            learning_rate=trial_dict["learning_rate"],
        ),
        contrib.reduce_on_plateau(
            patience=trial_dict["patience"],
            cooldown=trial_dict["cooldown"],
            factor=trial_dict["factor"],
            rtol=trial_dict["rtol"],
            accumulation_size=trial_dict["accumulation_size"],
        ),
    )
    params = eqx.filter(model, eqx.is_inexact_array)
    # print(f"Model parameters: {params}")
    opt_state = optim.init(params)
    print(f"{amps.shape=}, {freqs.shape=}, {v_ms.shape=}, {i_ms.shape=}, {tss.shape=}")

    @eqx.filter_jit
    def get_dx_dt(
        vs: jax.Array, xs: jax.Array, ts: jax.Array, amps: jax.Array, freqs: jax.Array
    ) -> jax.Array:
        def _get_dx(v_rov, x_row, t_row, amp, freq):
            return jax.vmap(lambda v, x, t: model.func(t, x, (amp, freq)))(
                v_rov, x_row, t_row
            )

        return jax.vmap(_get_dx, in_axes=(0, 0, 0, 0, 0))(vs, xs, ts, amps, freqs)

    @eqx.filter_jit
    def smooth_range_penalty_jit(
        x, min_val=0.0, max_val=1.0, margin=0.1, penalty_weight=1.0
    ):
        """JIT-compiled version of smooth_range_penalty for faster execution."""
        lower_margin = min_val + margin
        upper_margin = max_val - margin
        below_penalty = expit(-(x - lower_margin) / (margin / 4)) * (lower_margin - x)
        above_penalty = expit((x - upper_margin) / (margin / 4)) * (x - upper_margin)
        total_penalty = below_penalty + above_penalty
        return penalty_weight * jnp.mean(total_penalty**2)

    @eqx.filter_jit
    def improved_loss_fn(
        model: NeuralODE,
        ts,
        v_m,
        i_m,
        freqs,
        amps,
        lambdas=[],
        eps=1e-8,
        adaptive_weights=True,
        params: eqx.Module = None,
        regularization_weight: float = 1e-4,
        with_regularization: bool = True,
    ):
        _v_m_pred, _i_m_pred, _x, _G = jax.vmap(model, in_axes=(0, 0, 0))(
            ts, amps, freqs
        )

        v_m_pred = _v_m_pred.squeeze() if _v_m_pred.ndim > 2 else _v_m_pred
        i_m_pred = _i_m_pred.squeeze() if _i_m_pred.ndim > 2 else _i_m_pred
        if v_m_pred.ndim < 2:
            v_m_pred = v_m_pred.reshape(1, -1)
            i_m_pred = i_m_pred.reshape(1, -1)

        x = _x.squeeze() if _x.ndim > 2 else _x
        _G = _G.squeeze() if _G.ndim > 2 else _G

        targets = jnp.vstack([v_m, i_m])  # Shape: (2 * batch_size, time_steps)
        preds = jnp.vstack([v_m_pred, i_m_pred])  # Shape: (2 * batch_size, time_steps)

        targets_d1 = jnp.diff(targets, axis=1)
        preds_d1 = jnp.diff(preds, axis=1)
        targets_d2 = jnp.diff(targets_d1, axis=1)
        preds_d2 = jnp.diff(preds_d1, axis=1)

        # Vectorized standardization
        @eqx.filter_jit
        def standardize_loss(target, pred, with_center=True):
            center = target.shape[0] // 2
            target_std = jnp.std(target, axis=1, keepdims=True)
            target_mean = jnp.mean(target, axis=1, keepdims=True)
            target_norm = (target - target_mean) / (target_std + 1e-8)
            pred_norm = (pred - target_mean) / (target_std + 1e-8)
            return (
                jnp.mean((target_norm[:center] - pred_norm[:center]) ** 2)
                + jnp.mean((target_norm[center:] - pred_norm[center:]) ** 2)
                if not with_center
                else jnp.mean((target_norm - pred_norm) ** 2)
            )

        # Multi-scale loss
        loss_0 = standardize_loss(targets, preds)
        loss_1 = standardize_loss(targets_d1, preds_d1)
        loss_2 = standardize_loss(targets_d2, preds_d2)
        # print(f"{loss_0.item()=}, {loss_1.item()=}, {loss_2.item()=}")

        if adaptive_weights:
            loses = jnp.array([loss_0, loss_1, loss_2])
            loss_mean = jnp.exp((jnp.mean(jnp.log(loses + eps))))
            lambdas = loss_mean / (loses + eps)
            lambdas = jnp.clip(lambdas, 0.1, 10.0)

        total_loss = lambdas[0] * loss_0 + lambdas[1] * loss_1 + lambdas[2] * loss_2
        total_loss += smooth_range_penalty_jit(
            _G, min_val=0.0, max_val=5.0, margin=0.1, penalty_weight=10.0
        )
        sqnorm = tree_norm(params, squared=True)
        return total_loss + regularization_weight * sqnorm * with_regularization

    @eqx.filter_jit
    def finall_loss(
        model: NeuralODE,
        ts,
        v_m,
        i_m,
        freqs,
        amps,
    ):
        _v_m_pred, _i_m_pred, _x, _G = jax.vmap(model, in_axes=(0, 0, 0))(
            ts, amps, freqs
        )


        # print(f"{_v_m_pred.shape=}, {_i_m_pred.shape=}, {_x.shape=}, {_G.shape=}")
        v_m_pred = _v_m_pred.squeeze() if _v_m_pred.ndim > 2 else _v_m_pred
        i_m_pred = _i_m_pred.squeeze() if _i_m_pred.ndim > 2 else _i_m_pred
        x = _x.squeeze() if _x.ndim > 2 else _x
        _G = _G.squeeze() if _G.ndim > 2 else _G

        targets = jnp.vstack([v_m, i_m])  # Shape: (2 * batch_size, time_steps)
        preds = jnp.vstack([v_m_pred, i_m_pred])  # Shape: (2 * batch_size, time_steps)

        # Compute all derivatives at once using jnp.diff with axis parameter
        targets_d1 = jnp.diff(targets, axis=1)
        preds_d1 = jnp.diff(preds, axis=1)
        targets_d2 = jnp.diff(targets_d1, axis=1)
        preds_d2 = jnp.diff(preds_d1, axis=1)

        def standardize_loss(target, pred, with_center=True):
            center = target.shape[0] // 2
            target_std = jnp.std(target, axis=1, keepdims=True)
            target_mean = jnp.mean(target, axis=1, keepdims=True)
            target_norm = (target - target_mean) / (target_std + 1e-8)
            pred_norm = (pred - target_mean) / (target_std + 1e-8)
            return (
                jnp.mean((target_norm[:center] - pred_norm[:center]) ** 2)
                + jnp.mean((target_norm[center:] - pred_norm[center:]) ** 2)
                if not with_center
                else jnp.mean((target_norm - pred_norm) ** 2)
            )

        # Multi-scale loss
        loss_0 = standardize_loss(targets, preds)
        loss_1 = standardize_loss(targets_d1, preds_d1)
        loss_2 = standardize_loss(targets_d2, preds_d2)

        return (
            loss_0,
            loss_1,
            loss_2,
            _v_m_pred,
            _i_m_pred,
            _G,
        )

    @eqx.filter_jit
    def make_step(model, ts, v_m, i_m, freqs, amps, opt_state, params):
        loss, grads = eqx.filter_value_and_grad(improved_loss_fn)(
            model, ts, v_m, i_m, freqs, amps
        )
        # print(f"Loss: {loss}")
        updates, opt_state = optim.update(
            grads,
            opt_state,
            params=params,
            value=loss,
            regularization_weight=trial_dict["regularization_weight"],
            with_regularization=trial_dict["with_regularization"],
        )
        # print("Updates")
        model = eqx.apply_updates(model, updates)
        # print("Updated model")
        return model, opt_state, loss

    best_loss = float("inf")
    best_model = None

    for length_scale, epochs in trial_dict["length_scales"]:
        print(f"Training with length scale: {length_scale}, epochs: {epochs}")
        model = (
            best_model if best_model is not None else model
        )  # Use the best model if available
        params = eqx.filter(model, eqx.is_inexact_array)
        opt_state = optim.init(params)
        length = int(v_ms.shape[1] * length_scale)
        v_m = v_ms[:, :length]
        i_m = i_ms[:, :length]
        ts = tss[:, :length]
        best_loss = float("inf")
        best_model = None

        for i in range(epochs):
            try:
                epoch_loss = 0.0
                for batch_indices in get_batches(
                    v_m.shape[0],
                    trial_dict["batch_size"],
                    dataloader_key=data_key,
                ):
                    batch_v_m = v_m[batch_indices, :]
                    batch_i_m = i_m[batch_indices, :]
                    batch_ts = ts[batch_indices, :]
                    batch_freqs = freqs[batch_indices, :]
                    batch_amps = amps[batch_indices, :]

                    model, opt_state, batch_loss = make_step(
                        model,
                        batch_ts,
                        batch_v_m,
                        batch_i_m,
                        batch_freqs,
                        batch_amps,
                        opt_state,
                        params,
                    )
                    epoch_loss += batch_loss / float(len(batch_indices))
                    # print(f"Step {i}, Batch Loss: {batch_loss:.4f}")
                *losses, _v_m_pred, _i_m_pred, _G = finall_loss(
                    model, test_tss, test_v_ms, test_i_ms, test_freqs, test_amps
                )

                if i % 100 == 0:
                    print(f"Step {i}, Loss: {epoch_loss:.4f}")

                epoch_loss = losses[0]  # Use primary loss for best model tracking

                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model = deepcopy(model)

                    # print(f"New best model at step {i} with loss {best_loss:.4f}")
            # except eqx.EquinoxRuntimeError:
            #     losses = (jnp.inf, jnp.inf, jnp.inf)
            #     return tuple(
            #         [loss.item() if hasattr(loss, "item") else loss for loss in losses]
            #     )
            except Exception as e:
                print(f"Error at step {i}: {e}")
                logger.exception(e)
                break

    try:
        *losses, _v_m_pred, _i_m_pred, _G = finall_loss(
            best_model, test_tss, test_v_ms, test_i_ms, test_freqs, test_amps
        )
        try:
            fig, ax = plot_results(
                test_tss,
                test_v_ms,
                test_i_ms,
                _v_m_pred,
                _i_m_pred,
                _G,
            )

            fig.savefig(
                f"multi_neural_plot_3obj/{trial_dict['trial_number']}.pdf",
                bbox_inches="tight",
            )
            plt.close()
        except Exception as e:
            print(f"Error saving figure: {e}")

    except Exception as e:
        print(f"Error in final loss computation: {e}")
        losses = (jnp.inf, jnp.inf, jnp.inf)
    return tuple([loss.item() if hasattr(loss, "item") else loss for loss in losses])



def objective(file_meta):
    def _objective(trial: Trial) -> float:
        trial_dict = get_trial_dict(trial)
        joblib.dump(sampler, SAMPLER_PATH)
        return train_mem(file_meta, trial_dict)

    return _objective


SAMPLER_PATH = Path("jax_sampler_3obj.pkl")


class MemoryCleanupCallback:
    def __call__(self, study, trial):
        # Clear JAX caches after each trial
        import gc

        jax.clear_caches()
        gc.collect()

        # Optional: Print memory usage
        import psutil

        process = psutil.Process()
        mem_mb = process.memory_info().rss / 1024 / 1024
        logger.info(f"Trial {trial.number} completed. Memory: {mem_mb:.2f} MB")

md = MemData(Dopant.Tungsten, ppp=2000)

study_name = f"NeuralODE_{md.dopant.name}_ppp{md.ppp}_3obj_test"
storage_name = f"sqlite:///optuna_studies/{study_name}.db"
if not SAMPLER_PATH.exists():
    sampler = optuna.samplers.TPESampler(
        seed=42,
        multivariate=True,
        n_startup_trials=20,
        n_ei_candidates=44,
    )
    joblib.dump(sampler, SAMPLER_PATH)
else:
    sampler = joblib.load(SAMPLER_PATH)


study = optuna.create_study(
    directions=["minimize", "minimize", "minimize"],
    sampler=sampler,
    study_name=study_name,
    storage=storage_name,
    load_if_exists=True,
)

study.optimize(
    objective(md),
    gc_after_trial=True,
    n_trials=200,
    n_jobs=1,
    callbacks=[MemoryCleanupCallback()],
)