jax_generalGoff.py
import os
os.environ["XLA_FLAGS"] = (
"--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=20"
)
# os.environ['JAX_NUM_CPU_DEVICES'] = '10'
import re
from collections import Counter
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, List, Union
import gc
from loguru import logger
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import jax.random as jr
import joblib
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax # https://github.com/deepmind/optax
import optax.contrib as contrib
import optuna
import pandas as pd
from jax.scipy.special import expit # JAX's sigmoid function
from loguru import logger
from optax.tree_utils import tree_norm
from optuna import Trial
from pydantic import BaseModel
from scipy.signal import savgol_filter
from utils import (generate_neural_stack, Dopant, FileMeta, find_periods, average_data, mean_data, unpack_data, export_signals, MemData, get_train_test_indices_from_req, plot_results, get_batches)
import os
os.environ["XLA_FLAGS"] = (
"--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=20\
--xla_cpu_enable_fast_math=true "
)
matplotlib.use("svg") # Use Agg backend for matplotlib to avoid GUI issues
print(f"JAX devices: {jax.devices()}")
# exit()
class MLPG(eqx.Module):
log_Goff: jax.Array
log_Gon: jax.Array
def __init__(self, Goff=1 / jnp.array(60.0), Gon=1 / jnp.array(0.5), **kwargs):
super().__init__(**kwargs)
self.log_Goff = jnp.log(Goff)
self.log_Gon = jnp.log(Gon)
@property
def Goff(self):
return jnp.exp(self.log_Goff)
@property
def Gon(self):
return jnp.exp(self.log_Gon)
def __call__(self, x):
return x * self.Gon + (1 - x) * self.Goff
class Func(eqx.Module):
out_scale: jax.Array
x0: jax.Array
Rs: float
mlp: eqx.nn.MLP
std_v: float
mlpg: eqx.Module
def __init__(
self,
*,
key,
width_mplx,
depth_mplx,
mlpx_activation,
mlpx_final_activation,
latent_states=1,
std_v=1.0,
with_norm=True,
diamond_structure=True,
**kwargs,
):
super().__init__(**kwargs)
self.out_scale = jnp.array([1.0], dtype=jnp.float32)
if diamond_structure:
self.mlp = generate_neural_stack(
n_layers=depth_mplx,
width=width_mplx,
activation=mlpx_activation,
final_activation=mlpx_final_activation,
in_size=latent_states + 1,
out_size=latent_states,
key=key,
layer_norm=with_norm,
)
else:
self.mlp = eqx.nn.MLP(
in_size=latent_states + 1, # v_m + x
out_size=latent_states,
width_size=width_mplx,
depth=depth_mplx,
activation=mlpx_activation,
final_activation=mlpx_final_activation,
key=key,
)
# self.mlp = eqx.nn.MLP(
# in_size=latent_states + 1, # v_m + x + t
# out_size=latent_states,
# width_size=width_mplx,
# depth=depth_mplx,
# activation=mlpx_activation,
# final_activation=mlpx_final_activation,
# key=key,
# )
# self.mlgp = generate_neural_stack(
# n_layers=depth_mlpg,
# width=width_mlpg,
# activation=mlpg_activation,
# final_activation=mlpg_final_activation,
# in_size=latent_states,
# out_size=1,
# key=jr.fold_in(key, 1),
# layer_norm=with_norm,
# )
self.mlpg = MLPG()
self.Rs = 5.11
self.x0 = jnp.array([0.001] * latent_states)
self.std_v = std_v
@property
def Ron(self):
return 1 / self.Gon
@property
def Roff(self):
return 1 / self.mlpg.Goff
def __call__(self, t, y, args):
amp, freq = args
vs = jnp.sin(freq * t * 2 * jnp.pi) * amp
X = y
G = self.mlpg(X)
v_m = vs / (self.Rs * G + 1)
dX_input = jnp.hstack([v_m / self.std_v, X])
return self.mlp(dX_input) * self.out_scale
class NeuralODE(eqx.Module):
func: Func
def __init__(
self,
*,
width_mplx,
depth_mplx,
mlpx_activation,
mlpx_final_activation,
key,
latent_states=1,
std_v=1.0,
with_norm=True,
diamond_structure=True,
**kwargs,
):
super().__init__(**kwargs)
self.func = Func(
key=key,
width_mplx=width_mplx,
depth_mplx=depth_mplx,
mlpx_activation=mlpx_activation,
mlpx_final_activation=mlpx_final_activation,
latent_states=latent_states,
std_v=std_v,
with_norm=with_norm,
diamond_structure=diamond_structure,
)
def __call__(self, ts, amp, freq):
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=self.func.x0,
# stepsize_controller=diffrax.PIDController(rtol=1e-7, atol=1e-9),
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=diffrax.SaveAt(ts=ts),
args=(amp, freq),
max_steps=2048 * 8,
# adjoint=diffrax.BacksolveAdjoint(solver=diffrax.Tsit5()) # BacksolveAdjoint, InterpolationAdjoint
adjoint=diffrax.RecursiveCheckpointAdjoint(checkpoints=2048* 3),
)
# jax.debug.print("Function done ")
x = solution.ys
G = self.func.mlpg(x)
v_s = jnp.reshape(jnp.sin(freq * ts * 2 * jnp.pi) * amp, (-1, 1))
v_m = v_s / (self.func.Rs * G + 1)
i_m = v_m * G
return v_m, i_m, x, G
def get_trial_dict(trial: Trial) -> Dict[str, Any]:
activation_dict = {
"relu": jax.nn.relu,
"gelu": jax.nn.gelu,
"silu": jax.nn.silu,
"tanh": jax.nn.tanh,
"elu": jax.nn.elu,
"leaky_relu": jax.nn.leaky_relu,
"sigmoid": jax.nn.sigmoid,
"sin": jnp.sin,
"linear": jax.nn.identity,
}
optimizer_dict = {
"adam": optax.adam,
# "sgd": optax.sgd,
# "adamw": optax.adamw,
"nadamw": optax.nadamw,
"adabelief": optax.adabelief,
"nadam": optax.nadam,
}
trial_dict = {
"mlpx_width_size": trial.suggest_int("mlpx_width_size", 16, 512, step=16),
"mlpx_depth": trial.suggest_int("mlpx_depth", 1, 5, step=1),
"mlpx_activation": trial.suggest_categorical(
"mlpx_activation", list(activation_dict.keys())
),
"mlpx_final_activation": trial.suggest_categorical(
"mlpx_final_activation", ["linear", "tanh"]
),
"learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-1, log=True),
"mlpg_learning_rate": trial.suggest_float(
"mlpg_learning_rate", 1e-4, 1e-1, log=True
),
"clip_value": trial.suggest_float("clip_value", 0.1, 10.0, step=0.1),
"eps": trial.suggest_float("eps", 1e-8, 1e-4, log=True),
"with_regularization": trial.suggest_categorical(
"with_regularization", [True, False]
),
"mlpx_with_norm": trial.suggest_categorical("mlpx_with_norm", [True, False]),
"regularization_weight": trial.suggest_float(
"regularization_weight", 1e-6, 1e-2, log=True
),
"patience": trial.suggest_int("patience", 10, 200, step=10),
"cooldown": trial.suggest_int("cooldown", 0, 100, step=10),
"factor": trial.suggest_float("factor", 0.1, 0.99, step=0.1),
"rtol": trial.suggest_float("rtol", 1e-6, 1e-2, log=True),
"accumulation_size": trial.suggest_int("accumulation_size", 1, 100, step=1),
"length_step": trial.suggest_float("length_step", 0.2, 0.6),
"intermediate_epochs": trial.suggest_int("intermediate_epochs", 1, 100, step=1),
"last_epochs": trial.suggest_int("last_epochs", 1, 1000, step=10),
"optimizer": trial.suggest_categorical(
"optimizer", list(optimizer_dict.keys())
),
# "latent_states": trial.suggest_int("latent_states", 1, 6, step=1),
"batch_size": trial.suggest_int("batch_size", 1, 18, step=1),
"diamond_structure": trial.suggest_categorical(
"diamond_structure", [True, False]
),
}
print(trial_dict["learning_rate"])
for key in trial_dict:
if key.endswith("_activation") or key.endswith("_final_activation"):
trial_dict[key] = activation_dict[trial_dict[key]]
elif key == "optimizer":
trial_dict[key] = optimizer_dict[trial_dict[key]]
length_scale = np.arange(trial_dict["length_step"], 1.0, trial_dict["length_step"])
trial_dict["length_scales"] = [
(l_s, trial_dict["intermediate_epochs"]) for l_s in length_scale if l_s < 1.0
]
trial_dict["length_scales"] += [(1.0, trial_dict["last_epochs"])]
trial_dict["trial_number"] = trial.number
return trial_dict
def get_batches(
length: int,
batch_size: int,
shuffle: bool = True,
dataloader_key: jr.PRNGKey = jr.PRNGKey(0),
) -> Generator[List[int], None, None]:
if shuffle:
indices = jr.permutation(dataloader_key, jnp.arange(length))
else:
indices = jnp.arange(length)
for i in range(0, length, batch_size):
batch_indices = indices[i : i + batch_size]
yield batch_indices
def train_mem(
md: FileMeta,
trial_dict: Dict[str, Any],
):
key = jr.PRNGKey(5678)
data_key, model_key, loader_key = jr.split(key, 3)
amps = jnp.array(
[md.file_meta[file_index].amp for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
).reshape(-1, 1)
freqs = jnp.array(
[md.file_meta[file_index].freq for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
).reshape(-1, 1)
v_ms = jnp.array(
[md.file_meta[file_index].v_m for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
)
i_ms = jnp.array(
[md.file_meta[file_index].i_m for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
)
tss = jnp.array(
[md.file_meta[file_index].t for file_index in range(len(md.file_meta))],
dtype=jnp.float32,
)
train_ind, test_ind = get_train_test_indices_from_req(md)
test_amps = amps[test_ind, :]
test_freqs = freqs[test_ind, :]
test_v_ms = v_ms[test_ind, :]
test_i_ms = i_ms[test_ind, :]
test_tss = tss[test_ind, :]
amps = amps[train_ind, :]
freqs = freqs[train_ind, :]
v_ms = v_ms[train_ind, :]
i_ms = i_ms[train_ind, :]
tss = tss[train_ind, :]
std_v = jnp.mean(jnp.std(v_ms, axis=1))
model = NeuralODE(
key=model_key,
width_mplx=trial_dict["mlpx_width_size"],
depth_mplx=trial_dict["mlpx_depth"],
mlpx_activation=trial_dict["mlpx_activation"],
mlpx_final_activation=trial_dict["mlpx_final_activation"],
std_v=std_v,
with_norm=trial_dict["mlpx_with_norm"],
diamond_structure=trial_dict["diamond_structure"],
)
def create_param_labels(params):
"""Create labels for different parameter groups"""
def label_fn(path, param):
path_str = "/".join(str(p) for p in path)
if "mlp" in path_str:
return "mlp"
elif "mlpg" in path_str:
return "mlpg"
elif "out_scale" in path_str:
return "scaling"
elif "x0" in path_str:
return "init"
else:
return "other"
return jax.tree_util.tree_map_with_path(label_fn, params)
base_transforms = {
"mlpg": trial_dict["optimizer"](
learning_rate=trial_dict["mlpg_learning_rate"]
if "mlpg_learning_rate" in trial_dict
else trial_dict["learning_rate"],
),
"mlp": trial_dict["optimizer"](
learning_rate=trial_dict["mlp_learning_rate"]
if "mlp_learning_rate" in trial_dict
else trial_dict["learning_rate"],
),
"init": trial_dict["optimizer"](
learning_rate=trial_dict["init_learning_rate"]
if "init_learning_rate" in trial_dict
else trial_dict["learning_rate"],
), # Freeze initial condition
"other": trial_dict["optimizer"](
learning_rate=trial_dict["learning_rate"],
),
"scaling": trial_dict["optimizer"](
learning_rate=trial_dict["learning_rate"],
),
}
partitioned_optim = optax.partition(base_transforms, create_param_labels)
optim = optax.chain(
optax.clip_by_global_norm(
trial_dict["clip_value"]
), # najpierw obcinanie gradientów
partitioned_optim,
contrib.reduce_on_plateau(
patience=trial_dict["patience"],
cooldown=trial_dict["cooldown"],
factor=trial_dict["factor"],
rtol=trial_dict["rtol"],
accumulation_size=trial_dict["accumulation_size"],
),
)
params = eqx.filter(model, eqx.is_inexact_array)
# print(f"Model parameters: {params}")
opt_state = optim.init(params)
print(f"{amps.shape=}, {freqs.shape=}, {v_ms.shape=}, {i_ms.shape=}, {tss.shape=}")
@eqx.filter_jit
def get_dx_dt(
vs: jax.Array, xs: jax.Array, ts: jax.Array, amps: jax.Array, freqs: jax.Array
) -> jax.Array:
def _get_dx(v_rov, x_row, t_row, amp, freq):
return jax.vmap(lambda v, x, t: model.func(t, x, (amp, freq)))(
v_rov, x_row, t_row
)
result = jax.vmap(_get_dx, in_axes=(0, 0, 0, 0, 0))(vs, xs, ts, amps, freqs)
return result.squeeze() if result.ndim > 2 else result
@eqx.filter_jit
def dx_dt_loss(
v_m: jax.Array, i_m: jax.Array, dx_dt: jax.Array, penalty_weight: float = 1.0
) -> jax.Array:
"""
Enforce monotonicity constraint:
- When v_m > 0: require dx_dt >= 0 (x should increase)
- When v_m < 0: require dx_dt <= 0 (x should decrease)
- When v_m = 0: no constraint (dx_dt can be anything)
Penalizes violations: (v_m > 0 and dx_dt < 0) or (v_m < 0 and dx_dt > 0)
"""
# The product v_m * dx_dt should be non-negative
# When both have same sign: product > 0 (good)
# When opposite signs: product < 0 (bad, penalize)
product = v_m * dx_dt
# Penalize negative products (sign mismatch)
violation = jnp.maximum(0.0, -product)
return penalty_weight * jnp.mean(jnp.square(violation))
@eqx.filter_jit
def smooth_range_penalty_jit(
x, min_val=0.0, max_val=1.0, margin=0.1, penalty_weight=1.0
):
"""JIT-compiled version of smooth_range_penalty for faster execution."""
lower_margin = min_val + margin
upper_margin = max_val - margin
below_penalty = expit(-(x - lower_margin) / (margin / 4)) * (lower_margin - x)
above_penalty = expit((x - upper_margin) / (margin / 4)) * (x - upper_margin)
total_penalty = below_penalty + above_penalty
return penalty_weight * jnp.mean(total_penalty**2)
@eqx.filter_jit
def improved_loss_fn(
model: NeuralODE,
ts,
v_m,
i_m,
freqs,
amps,
lambdas=[],
eps=1e-8,
adaptive_weights=True,
params: eqx.Module = None,
with_regularization=True,
regularization_weight=1e-4,
):
# batched_model = eqx.filter_jit(model)
batched_forward = eqx.filter_vmap(model, in_axes=(0, 0, 0))(ts, amps, freqs)
_v_m_pred, _i_m_pred, _x, _G = batched_forward
v_m_pred = _v_m_pred.squeeze() if _v_m_pred.ndim > 2 else _v_m_pred
i_m_pred = _i_m_pred.squeeze() if _i_m_pred.ndim > 2 else _i_m_pred
x = _x.squeeze() if _x.ndim > 2 else _x
_G = _G.squeeze() if _G.ndim > 2 else _G
if v_m_pred.ndim < 2:
v_m_pred = v_m_pred.reshape(1, -1)
i_m_pred = i_m_pred.reshape(1, -1)
x = x.reshape(1, -1)
_G = _G.reshape(1, -1)
targets = jnp.vstack([v_m, i_m]) # Shape: (2 * batch_size, time_steps)
preds = jnp.vstack([v_m_pred, i_m_pred]) # Shape: (2 * batch_size, time_steps)
targets_d1 = jnp.diff(targets, axis=1)
preds_d1 = jnp.diff(preds, axis=1)
targets_d2 = jnp.diff(targets_d1, axis=1)
preds_d2 = jnp.diff(preds_d1, axis=1)
# dx_dt = get_dx_dt(v_m_pred, x, ts, amps, freqs)
# Vectorized standardization
# @eqx.filter_jit
def standardize_loss(target, pred, with_center=True):
center = target.shape[0] // 2
target_std = jnp.std(target, axis=1, keepdims=True)
target_mean = jnp.mean(target, axis=1, keepdims=True)
target_norm = (target - target_mean) / (target_std + 1e-8)
pred_norm = (pred - target_mean) / (target_std + 1e-8)
return (
jnp.mean((target_norm[:center] - pred_norm[:center]) ** 2)
+ jnp.mean((target_norm[center:] - pred_norm[center:]) ** 2)
if not with_center
else jnp.mean((target_norm - pred_norm) ** 2)
)
# Multi-scale loss
loss_0 = standardize_loss(targets, preds)
loss_1 = standardize_loss(targets_d1, preds_d1)
loss_2 = standardize_loss(targets_d2, preds_d2)
# print(f"{loss_0.item()=}, {loss_1.item()=}, {loss_2.item()=}")
if adaptive_weights:
loses = jnp.array([loss_0, loss_1, loss_2])
loss_mean = jnp.exp((jnp.mean(jnp.log(loses + eps))))
lambdas = loss_mean / (loses + eps)
lambdas = jnp.clip(lambdas, 0.1, 10.0)
total_loss = lambdas[0] * loss_0 + lambdas[1] * loss_1 + lambdas[2] * loss_2
total_loss += smooth_range_penalty_jit(
_G, min_val=0.0, max_val=5.0, margin=0.1, penalty_weight=10.0
)
total_loss += smooth_range_penalty_jit(
x, min_val=0.0, max_val=1.0, margin=0.01, penalty_weight=10.0
)
# total_loss += dx_dt_loss(v_m_pred, i_m_pred, dx_dt, penalty_weight=10.0)
sqnorm = tree_norm(params, squared=True)
# jax.debug.print("Regularization weight: {regularization_weight}, Squared norm: {sqnorm}", regularization_weight=regularization_weight, sqnorm=sqnorm)
return total_loss + regularization_weight * sqnorm * with_regularization
@eqx.filter_jit
def finall_loss(
model: NeuralODE,
ts,
v_m,
i_m,
freqs,
amps,
):
_v_m_pred, _i_m_pred, _x, _G = jax.vmap(model, in_axes=(0, 0, 0))(
ts, amps, freqs
)
# print(f"{_v_m_pred.shape=}, {_i_m_pred.shape=}, {_x.shape=}, {_G.shape=}")
v_m_pred = _v_m_pred.squeeze() if _v_m_pred.ndim > 2 else _v_m_pred
i_m_pred = _i_m_pred.squeeze() if _i_m_pred.ndim > 2 else _i_m_pred
x = _x.squeeze() if _x.ndim > 2 else _x
_G = _G.squeeze() if _G.ndim > 2 else _G
targets = jnp.vstack([v_m, i_m]) # Shape: (2 * batch_size, time_steps)
preds = jnp.vstack([v_m_pred, i_m_pred]) # Shape: (2 * batch_size, time_steps)
# Compute all derivatives at once using jnp.diff with axis parameter
targets_d1 = jnp.diff(targets, axis=1)
preds_d1 = jnp.diff(preds, axis=1)
targets_d2 = jnp.diff(targets_d1, axis=1)
preds_d2 = jnp.diff(preds_d1, axis=1)
def standardize_loss(target, pred, with_center=True):
center = target.shape[0] // 2
target_std = jnp.std(target, axis=1, keepdims=True)
target_mean = jnp.mean(target, axis=1, keepdims=True)
target_norm = (target - target_mean) / (target_std + 1e-8)
pred_norm = (pred - target_mean) / (target_std + 1e-8)
return (
jnp.mean((target_norm[:center] - pred_norm[:center]) ** 2)
+ jnp.mean((target_norm[center:] - pred_norm[center:]) ** 2)
if not with_center
else jnp.mean((target_norm - pred_norm) ** 2)
)
# Multi-scale loss
loss_0 = standardize_loss(targets, preds)
loss_1 = standardize_loss(targets_d1, preds_d1)
loss_2 = standardize_loss(targets_d2, preds_d2)
return (
loss_0,
loss_1,
loss_2,
_v_m_pred,
_i_m_pred,
_G,
)
@eqx.filter_jit
def make_step(model, ts, v_m, i_m, freqs, amps, opt_state, params):
loss, grads = eqx.filter_value_and_grad(improved_loss_fn)(
model,
ts,
v_m,
i_m,
freqs,
amps,
params=params,
with_regularization=trial_dict["with_regularization"],
regularization_weight=trial_dict["regularization_weight"],
)
updates, opt_state = optim.update(grads, opt_state, params=params, value=loss)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
best_loss = float("inf")
best_model = None
for length_scale, epochs in trial_dict["length_scales"]:
print(f"Training with length scale: {length_scale}, epochs: {epochs}")
model = (
best_model if best_model is not None else model
) # Use the best model if available
params = eqx.filter(model, eqx.is_inexact_array)
opt_state = optim.init(params)
length = int(v_ms.shape[1] * length_scale)
v_m = v_ms[:, :length]
i_m = i_ms[:, :length]
ts = tss[:, :length]
best_loss = float("inf")
best_model = None
for i in range(epochs):
try:
epoch_loss = 0.0
for batch_indices in get_batches(
v_m.shape[0],
trial_dict["batch_size"],
dataloader_key=data_key,
):
batch_v_m = v_m[batch_indices, :]
batch_i_m = i_m[batch_indices, :]
batch_ts = ts[batch_indices, :]
batch_freqs = freqs[batch_indices, :]
batch_amps = amps[batch_indices, :]
model, opt_state, batch_loss = make_step(
model,
batch_ts,
batch_v_m,
batch_i_m,
batch_freqs,
batch_amps,
opt_state,
params,
)
epoch_loss += batch_loss / float(len(batch_indices))
# print(f"Step {i}, Batch Loss: {batch_loss:.4f}")
*losses, _v_m_pred, _i_m_pred, _G = finall_loss(
model, test_tss, test_v_ms, test_i_ms, test_freqs, test_amps
)
if i % 100 == 0:
print(f"Step {i}, Loss: {epoch_loss:.4f}")
epoch_loss = losses[0] # Use primary loss for best model tracking
if epoch_loss < best_loss:
best_loss = epoch_loss
best_model = deepcopy(model)
# print(f"New best model at step {i} with loss {best_loss:.4f}")
# except eqx.EquinoxRuntimeError:
# losses = (jnp.inf, jnp.inf, jnp.inf)
# return tuple(
# [loss.item() if hasattr(loss, "item") else loss for loss in losses]
# )
# logger.debug(
# f"Epoch {i}, Loss: {epoch_loss}, Gon: {model.func.mlpg.Goff}, Goff: {model.func.mlpg.Gon}"
# )
except Exception as e:
print(f"Error at step {i}: {e}")
logger.exception(e)
break
try:
*losses, _v_m_pred, _i_m_pred, _G = finall_loss(
best_model, test_tss, test_v_ms, test_i_ms, test_freqs, test_amps
)
try:
fig, ax = plot_results(
test_tss,
test_v_ms,
test_i_ms,
_v_m_pred,
_i_m_pred,
_G,
)
fig.savefig(
f"mutli_neural_plot_fixed_mlpg_Ron/{trial_dict['trial_number']}.pdf",
bbox_inches="tight",
)
plt.close()
except Exception as e:
print(f"Error saving figure: {e}")
except Exception as e:
print(f"Error in final loss computation: {e}")
losses = (jnp.inf, jnp.inf, jnp.inf)
return tuple([loss.item() if hasattr(loss, "item") else loss for loss in losses])
md = MemData(Dopant.Tungsten, ppp=2000)
def objective(file_meta):
def _objective(trial: Trial) -> float:
trial_dict = get_trial_dict(trial)
joblib.dump(sampler, SAMPLER_PATH)
return train_mem(file_meta, trial_dict)
return _objective
SAMPLER_PATH = Path("jax_sampler_3obj_Goff.pkl")
study_name = f"NeuralODE_{md.dopant.name}_ppp{md.ppp}_Goff_scale"
storage_name = f"sqlite:///optuna_studies/{study_name}.db"
if not SAMPLER_PATH.exists():
sampler = optuna.samplers.TPESampler(
seed=42,
multivariate=True,
n_startup_trials=20,
n_ei_candidates=44,
)
joblib.dump(sampler, SAMPLER_PATH)
else:
sampler = joblib.load(SAMPLER_PATH)
class MemoryCleanupCallback:
def __call__(self, study, trial):
# Clear JAX caches after each trial
jax.clear_caches()
gc.collect()
# Optional: Print memory usage
import psutil
process = psutil.Process()
mem_mb = process.memory_info().rss / 1024 / 1024
logger.info(f"Trial {trial.number} completed. Memory: {mem_mb:.2f} MB")
study = optuna.create_study(
directions=["minimize", "minimize", "minimize"],
sampler=sampler,
study_name=study_name,
storage=storage_name,
load_if_exists=True,
)
study.optimize(
objective(md),
gc_after_trial=True,
n_trials=200,
n_jobs=1,
callbacks=[MemoryCleanupCallback()],
)