jax_general_optuna_3obj_test_set.py
import re
from collections import Counter
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, List, Union
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import jax.random as jr
import joblib
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax # https://github.com/deepmind/optax
import optax.contrib as contrib
import optuna
import pandas as pd
from jax.scipy.special import expit # JAX's sigmoid function
from loguru import logger
from optax.tree_utils import tree_norm
from optuna import Trial
from pydantic import BaseModel
from scipy.signal import savgol_filter
from typing import Tuple
import time
from utils import (generate_neural_stack, Dopant, FileMeta, find_periods, average_data, mean_data, unpack_data, export_signals, MemData, get_train_test_indices_from_req, Func, NeuralODE, plot_results, get_batches)
import os
os.environ["XLA_FLAGS"] = (
"--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=20\
--xla_cpu_enable_fast_math=true "
)
def get_trial_dict(trial: Trial) -> Dict[str, Any]:
activation_dict = {
"relu": jax.nn.relu,
"gelu": jax.nn.gelu,
"silu": jax.nn.silu,
"tanh": jax.nn.tanh,
"elu": jax.nn.elu,
"leaky_relu": jax.nn.leaky_relu,
"sigmoid": jax.nn.sigmoid,
"sin": jnp.sin,
"linear": jax.nn.identity,
}
optimizer_dict = {
"adam": optax.adam,
# "sgd": optax.sgd,
"adamw": optax.adamw,
"nadamw": optax.nadamw,
"adabelief": optax.adabelief,
# "adamw": optax.adamw,
}
trial_dict = {
"mlpx_width_size": trial.suggest_int("mlpx_width_size", 16, 512, step=16),
"mlpx_depth": trial.suggest_int("mlpx_depth", 1, 5, step=1),
"mlpx_activation": trial.suggest_categorical(
"mlpx_activation", list(activation_dict.keys())
),
"mlpx_final_activation": trial.suggest_categorical(
"mlpx_final_activation", ["linear", "tanh", "sigmoid"]
),
"mlpg_width_size": trial.suggest_int("mlpg_width_size", 16, 512, step=16),
"mlpg_depth": trial.suggest_int("mlpg_depth", 1, 5, step=1),
"mlpg_activation": trial.suggest_categorical(
"mlpg_activation", list(activation_dict.keys())
),
"mlpg_final_activation": trial.suggest_categorical(
"mlpg_final_activation", ["sigmoid"]
),
"learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True),
"clip_value": trial.suggest_float("clip_value", 0.1, 10.0, step=0.1),
# "gfunc_learning_rate": trial.suggest_float(
# "gfunc_learning_rate", 1e-4, 1e-2, log=True
# ),
# "scaling_learning_rate": trial.suggest_float(
# "scaling_learning_rate", 1e-4, 1e-2, log=True
# ),
# "init_learning_rate": trial.suggest_float(
# "init_learning_rate", 1e-4, 1e-1, log=True
# ),
"eps": trial.suggest_float("eps", 1e-8, 1e-4, log=True),
"with_regularization": trial.suggest_categorical(
"with_regularization", [True, False]
),
"regularization_weight": trial.suggest_float(
"regularization_weight", 1e-6, 1e-2, log=True
),
"patience": trial.suggest_int("patience", 10, 200, step=10),
"cooldown": trial.suggest_int("cooldown", 0, 100, step=10),
"factor": trial.suggest_float("factor", 0.1, 0.99, step=0.1),
"rtol": trial.suggest_float("rtol", 1e-6, 1e-2, log=True),
"accumulation_size": trial.suggest_int("accumulation_size", 1, 100, step=1),
"length_step": trial.suggest_float("length_step", 0.2, 0.6),
"intermediate_epochs": trial.suggest_int("intermediate_epochs", 1, 100, step=1),
"last_epochs": trial.suggest_int("last_epochs", 1, 1000, step=10),
"optimizer": trial.suggest_categorical(
"optimizer", list(optimizer_dict.keys())
),
"latent_states": trial.suggest_int("latent_states", 1, 6, step=1),
"batch_size": trial.suggest_int("batch_size", 1, 18, step=1),
"diamond_structure": trial.suggest_categorical(
"diamond_structure", [True, False]
),
"with_norm": trial.suggest_categorical("with_norm", [True, False]),
# "use_adaptive_weights": trial.suggest_categorical(
# "use_adaptive_weights", [True, False]
# ),
# "lambda1": trial.suggest_float("lambda1", 0, 1, step=0.1),
# "lambda2": trial.suggest_float("lambda2", 0, 1, step=0.1),
# "lambda3": trial.suggest_float("lambda3", 0, 1, step=0.1),
}
print(trial_dict["learning_rate"])
for key in trial_dict:
if key.endswith("_activation") or key.endswith("_final_activation"):
trial_dict[key] = activation_dict[trial_dict[key]]
elif key == "optimizer":
trial_dict[key] = optimizer_dict[trial_dict[key]]
length_scale = np.arange(trial_dict["length_step"], 1.0, trial_dict["length_step"])
trial_dict["length_scales"] = [
(l_s, trial_dict["intermediate_epochs"]) for l_s in length_scale if l_s < 1.0
]
trial_dict["length_scales"] += [(1.0, trial_dict["last_epochs"])]
trial_dict["trial_number"] = trial.number
return trial_dict
def get_batches(
length: int,
batch_size: int,
shuffle: bool = True,
dataloader_key: jr.PRNGKey = jr.PRNGKey(0),
) -> Generator[List[int], None, None]:
if shuffle:
indices = jr.permutation(dataloader_key, jnp.arange(length))
else:
indices = jnp.arange(length)
for i in range(0, length, batch_size):
batch_indices = indices[i : i + batch_size]
yield batch_indices
def train_mem(
md: FileMeta,
trial_dict: Dict[str, Any],
):
key = jr.PRNGKey(5678)
data_key, model_key, loader_key = jr.split(key, 3)
amps = jnp.array(
[md.file_meta[file_index].amp for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
).reshape(-1, 1)
freqs = jnp.array(
[md.file_meta[file_index].freq for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
).reshape(-1, 1)
v_ms = jnp.array(
[md.file_meta[file_index].v_m for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
)
i_ms = jnp.array(
[md.file_meta[file_index].i_m for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
)
tss = jnp.array(
[md.file_meta[file_index].t for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
)
train_ind, test_ind = get_train_test_indices_from_req(md)
test_amps = amps[test_ind, :]
test_freqs = freqs[test_ind, :]
test_v_ms = v_ms[test_ind, :]
test_i_ms = i_ms[test_ind, :]
test_tss = tss[test_ind, :]
amps = amps[train_ind, :]
freqs = freqs[train_ind, :]
v_ms = v_ms[train_ind, :]
i_ms = i_ms[train_ind, :]
tss = tss[train_ind, :]
std_v = jnp.mean(jnp.std(v_ms, axis=1))
model = NeuralODE(
key=model_key,
width_mplx=trial_dict["mlpx_width_size"],
depth_mplx=trial_dict["mlpx_depth"],
width_mlpg=trial_dict["mlpg_width_size"],
depth_mlpg=trial_dict["mlpg_depth"],
mlpx_activation=trial_dict["mlpx_activation"],
mlpx_final_activation=trial_dict["mlpx_final_activation"],
mlpg_activation=trial_dict["mlpg_activation"],
mlpg_final_activation=trial_dict["mlpg_final_activation"],
latent_states=trial_dict["latent_states"],
std_v=std_v,
with_norm=trial_dict["with_norm"],
diamond_structure=trial_dict["diamond_structure"],
)
optim = optax.chain(
optax.clip_by_global_norm(
trial_dict["clip_value"]
), # najpierw obcinanie gradientów
trial_dict["optimizer"](
learning_rate=trial_dict["learning_rate"],
),
contrib.reduce_on_plateau(
patience=trial_dict["patience"],
cooldown=trial_dict["cooldown"],
factor=trial_dict["factor"],
rtol=trial_dict["rtol"],
accumulation_size=trial_dict["accumulation_size"],
),
)
params = eqx.filter(model, eqx.is_inexact_array)
# print(f"Model parameters: {params}")
opt_state = optim.init(params)
print(f"{amps.shape=}, {freqs.shape=}, {v_ms.shape=}, {i_ms.shape=}, {tss.shape=}")
@eqx.filter_jit
def get_dx_dt(
vs: jax.Array, xs: jax.Array, ts: jax.Array, amps: jax.Array, freqs: jax.Array
) -> jax.Array:
def _get_dx(v_rov, x_row, t_row, amp, freq):
return jax.vmap(lambda v, x, t: model.func(t, x, (amp, freq)))(
v_rov, x_row, t_row
)
return jax.vmap(_get_dx, in_axes=(0, 0, 0, 0, 0))(vs, xs, ts, amps, freqs)
@eqx.filter_jit
def smooth_range_penalty_jit(
x, min_val=0.0, max_val=1.0, margin=0.1, penalty_weight=1.0
):
"""JIT-compiled version of smooth_range_penalty for faster execution."""
lower_margin = min_val + margin
upper_margin = max_val - margin
below_penalty = expit(-(x - lower_margin) / (margin / 4)) * (lower_margin - x)
above_penalty = expit((x - upper_margin) / (margin / 4)) * (x - upper_margin)
total_penalty = below_penalty + above_penalty
return penalty_weight * jnp.mean(total_penalty**2)
@eqx.filter_jit
def improved_loss_fn(
model: NeuralODE,
ts,
v_m,
i_m,
freqs,
amps,
lambdas=[],
eps=1e-8,
adaptive_weights=True,
params: eqx.Module = None,
regularization_weight: float = 1e-4,
with_regularization: bool = True,
):
_v_m_pred, _i_m_pred, _x, _G = jax.vmap(model, in_axes=(0, 0, 0))(
ts, amps, freqs
)
v_m_pred = _v_m_pred.squeeze() if _v_m_pred.ndim > 2 else _v_m_pred
i_m_pred = _i_m_pred.squeeze() if _i_m_pred.ndim > 2 else _i_m_pred
if v_m_pred.ndim < 2:
v_m_pred = v_m_pred.reshape(1, -1)
i_m_pred = i_m_pred.reshape(1, -1)
x = _x.squeeze() if _x.ndim > 2 else _x
_G = _G.squeeze() if _G.ndim > 2 else _G
targets = jnp.vstack([v_m, i_m]) # Shape: (2 * batch_size, time_steps)
preds = jnp.vstack([v_m_pred, i_m_pred]) # Shape: (2 * batch_size, time_steps)
targets_d1 = jnp.diff(targets, axis=1)
preds_d1 = jnp.diff(preds, axis=1)
targets_d2 = jnp.diff(targets_d1, axis=1)
preds_d2 = jnp.diff(preds_d1, axis=1)
# Vectorized standardization
@eqx.filter_jit
def standardize_loss(target, pred, with_center=True):
center = target.shape[0] // 2
target_std = jnp.std(target, axis=1, keepdims=True)
target_mean = jnp.mean(target, axis=1, keepdims=True)
target_norm = (target - target_mean) / (target_std + 1e-8)
pred_norm = (pred - target_mean) / (target_std + 1e-8)
return (
jnp.mean((target_norm[:center] - pred_norm[:center]) ** 2)
+ jnp.mean((target_norm[center:] - pred_norm[center:]) ** 2)
if not with_center
else jnp.mean((target_norm - pred_norm) ** 2)
)
# Multi-scale loss
loss_0 = standardize_loss(targets, preds)
loss_1 = standardize_loss(targets_d1, preds_d1)
loss_2 = standardize_loss(targets_d2, preds_d2)
# print(f"{loss_0.item()=}, {loss_1.item()=}, {loss_2.item()=}")
if adaptive_weights:
loses = jnp.array([loss_0, loss_1, loss_2])
loss_mean = jnp.exp((jnp.mean(jnp.log(loses + eps))))
lambdas = loss_mean / (loses + eps)
lambdas = jnp.clip(lambdas, 0.1, 10.0)
total_loss = lambdas[0] * loss_0 + lambdas[1] * loss_1 + lambdas[2] * loss_2
total_loss += smooth_range_penalty_jit(
_G, min_val=0.0, max_val=5.0, margin=0.1, penalty_weight=10.0
)
sqnorm = tree_norm(params, squared=True)
return total_loss + regularization_weight * sqnorm * with_regularization
@eqx.filter_jit
def finall_loss(
model: NeuralODE,
ts,
v_m,
i_m,
freqs,
amps,
):
_v_m_pred, _i_m_pred, _x, _G = jax.vmap(model, in_axes=(0, 0, 0))(
ts, amps, freqs
)
# print(f"{_v_m_pred.shape=}, {_i_m_pred.shape=}, {_x.shape=}, {_G.shape=}")
v_m_pred = _v_m_pred.squeeze() if _v_m_pred.ndim > 2 else _v_m_pred
i_m_pred = _i_m_pred.squeeze() if _i_m_pred.ndim > 2 else _i_m_pred
x = _x.squeeze() if _x.ndim > 2 else _x
_G = _G.squeeze() if _G.ndim > 2 else _G
targets = jnp.vstack([v_m, i_m]) # Shape: (2 * batch_size, time_steps)
preds = jnp.vstack([v_m_pred, i_m_pred]) # Shape: (2 * batch_size, time_steps)
# Compute all derivatives at once using jnp.diff with axis parameter
targets_d1 = jnp.diff(targets, axis=1)
preds_d1 = jnp.diff(preds, axis=1)
targets_d2 = jnp.diff(targets_d1, axis=1)
preds_d2 = jnp.diff(preds_d1, axis=1)
def standardize_loss(target, pred, with_center=True):
center = target.shape[0] // 2
target_std = jnp.std(target, axis=1, keepdims=True)
target_mean = jnp.mean(target, axis=1, keepdims=True)
target_norm = (target - target_mean) / (target_std + 1e-8)
pred_norm = (pred - target_mean) / (target_std + 1e-8)
return (
jnp.mean((target_norm[:center] - pred_norm[:center]) ** 2)
+ jnp.mean((target_norm[center:] - pred_norm[center:]) ** 2)
if not with_center
else jnp.mean((target_norm - pred_norm) ** 2)
)
# Multi-scale loss
loss_0 = standardize_loss(targets, preds)
loss_1 = standardize_loss(targets_d1, preds_d1)
loss_2 = standardize_loss(targets_d2, preds_d2)
return (
loss_0,
loss_1,
loss_2,
_v_m_pred,
_i_m_pred,
_G,
)
@eqx.filter_jit
def make_step(model, ts, v_m, i_m, freqs, amps, opt_state, params):
loss, grads = eqx.filter_value_and_grad(improved_loss_fn)(
model, ts, v_m, i_m, freqs, amps
)
# print(f"Loss: {loss}")
updates, opt_state = optim.update(
grads,
opt_state,
params=params,
value=loss,
regularization_weight=trial_dict["regularization_weight"],
with_regularization=trial_dict["with_regularization"],
)
# print("Updates")
model = eqx.apply_updates(model, updates)
# print("Updated model")
return model, opt_state, loss
best_loss = float("inf")
best_model = None
for length_scale, epochs in trial_dict["length_scales"]:
print(f"Training with length scale: {length_scale}, epochs: {epochs}")
model = (
best_model if best_model is not None else model
) # Use the best model if available
params = eqx.filter(model, eqx.is_inexact_array)
opt_state = optim.init(params)
length = int(v_ms.shape[1] * length_scale)
v_m = v_ms[:, :length]
i_m = i_ms[:, :length]
ts = tss[:, :length]
best_loss = float("inf")
best_model = None
for i in range(epochs):
try:
epoch_loss = 0.0
for batch_indices in get_batches(
v_m.shape[0],
trial_dict["batch_size"],
dataloader_key=data_key,
):
batch_v_m = v_m[batch_indices, :]
batch_i_m = i_m[batch_indices, :]
batch_ts = ts[batch_indices, :]
batch_freqs = freqs[batch_indices, :]
batch_amps = amps[batch_indices, :]
model, opt_state, batch_loss = make_step(
model,
batch_ts,
batch_v_m,
batch_i_m,
batch_freqs,
batch_amps,
opt_state,
params,
)
epoch_loss += batch_loss / float(len(batch_indices))
# print(f"Step {i}, Batch Loss: {batch_loss:.4f}")
*losses, _v_m_pred, _i_m_pred, _G = finall_loss(
model, test_tss, test_v_ms, test_i_ms, test_freqs, test_amps
)
if i % 100 == 0:
print(f"Step {i}, Loss: {epoch_loss:.4f}")
epoch_loss = losses[0] # Use primary loss for best model tracking
if epoch_loss < best_loss:
best_loss = epoch_loss
best_model = deepcopy(model)
# print(f"New best model at step {i} with loss {best_loss:.4f}")
# except eqx.EquinoxRuntimeError:
# losses = (jnp.inf, jnp.inf, jnp.inf)
# return tuple(
# [loss.item() if hasattr(loss, "item") else loss for loss in losses]
# )
except Exception as e:
print(f"Error at step {i}: {e}")
logger.exception(e)
break
try:
*losses, _v_m_pred, _i_m_pred, _G = finall_loss(
best_model, test_tss, test_v_ms, test_i_ms, test_freqs, test_amps
)
try:
fig, ax = plot_results(
test_tss,
test_v_ms,
test_i_ms,
_v_m_pred,
_i_m_pred,
_G,
)
fig.savefig(
f"multi_neural_plot_3obj/{trial_dict['trial_number']}.pdf",
bbox_inches="tight",
)
plt.close()
except Exception as e:
print(f"Error saving figure: {e}")
except Exception as e:
print(f"Error in final loss computation: {e}")
losses = (jnp.inf, jnp.inf, jnp.inf)
return tuple([loss.item() if hasattr(loss, "item") else loss for loss in losses])
def objective(file_meta):
def _objective(trial: Trial) -> float:
trial_dict = get_trial_dict(trial)
joblib.dump(sampler, SAMPLER_PATH)
return train_mem(file_meta, trial_dict)
return _objective
SAMPLER_PATH = Path("jax_sampler_3obj.pkl")
class MemoryCleanupCallback:
def __call__(self, study, trial):
# Clear JAX caches after each trial
import gc
jax.clear_caches()
gc.collect()
# Optional: Print memory usage
import psutil
process = psutil.Process()
mem_mb = process.memory_info().rss / 1024 / 1024
logger.info(f"Trial {trial.number} completed. Memory: {mem_mb:.2f} MB")
md = MemData(Dopant.Tungsten, ppp=2000)
study_name = f"NeuralODE_{md.dopant.name}_ppp{md.ppp}_3obj_test"
storage_name = f"sqlite:///optuna_studies/{study_name}.db"
if not SAMPLER_PATH.exists():
sampler = optuna.samplers.TPESampler(
seed=42,
multivariate=True,
n_startup_trials=20,
n_ei_candidates=44,
)
joblib.dump(sampler, SAMPLER_PATH)
else:
sampler = joblib.load(SAMPLER_PATH)
study = optuna.create_study(
directions=["minimize", "minimize", "minimize"],
sampler=sampler,
study_name=study_name,
storage=storage_name,
load_if_exists=True,
)
study.optimize(
objective(md),
gc_after_trial=True,
n_trials=200,
n_jobs=1,
callbacks=[MemoryCleanupCallback()],
)