utils.py
import os
import re
from collections import Counter
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, List, Union
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import jax.random as jr
import joblib
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax # https://github.com/deepmind/optax
import optax.contrib as contrib
import optuna
import pandas as pd
from jax.scipy.special import expit # JAX's sigmoid function
from loguru import logger
from optax.tree_utils import tree_norm
from optuna import Trial
from pydantic import BaseModel
from scipy.signal import savgol_filter
from typing import Tuple
matplotlib.use("svg") # Use Agg backend for matplotlib to avoid GUI issues
os.environ["XLA_FLAGS"] = (
"--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=20\
--xla_cpu_enable_fast_math=true "
)
def generate_neural_stack(
n_layers: int,
width: int,
activation: eqx.Module,
final_activation: eqx.Module,
key: jax.Array,
layer_norm: bool = True,
out_size: int = 1,
in_size: int = 2,
) -> eqx.Module:
"""Generate a neural network stack using JAX Equinox.
Args:
trial_dict: Dictionary containing network configuration
key: JAX random key for parameter initialization
Returns:
An Equinox module representing the neural network
"""
try:
layers = []
keys = jax.random.split(key, 100) # Generate enough keys
key_idx = 0
for i in range(n_layers):
if i == 0:
# First layer: 2 -> n_neurons // 2 -> n_neurons
layers.append(eqx.nn.Linear(in_size, width // 2, key=keys[key_idx]))
key_idx += 1
layers.append(eqx.nn.Lambda(activation))
layers.append(
eqx.nn.Linear(
width // 2,
width,
key=keys[key_idx],
)
)
key_idx += 1
if layer_norm:
layers.append(eqx.nn.LayerNorm(width))
layers.append(eqx.nn.Lambda(activation))
else:
# Hidden layers: n_neurons -> n_neurons
layers.append(eqx.nn.Linear(width, width, key=keys[key_idx]))
key_idx += 1
if layer_norm:
layers.append(eqx.nn.LayerNorm(width))
layers.append(eqx.nn.Lambda(activation))
# Output layers: n_neurons -> n_neurons // 2 -> 1
layers.extend(
[
eqx.nn.Linear(width, width // 2, key=keys[key_idx]),
eqx.nn.Lambda(activation),
]
)
key_idx += 1
layers.append(eqx.nn.Linear(width // 2, out_size, key=keys[key_idx]))
layers.append(eqx.nn.Lambda(final_activation))
# Create sequential model
model = eqx.nn.Sequential(layers)
logger.info(f"Created neural network with {len(layers)} layers \n{model}")
# exit()
except Exception as e:
logger.exception(e)
return model
class Dopant(Enum):
Tungsten = 1
Tin = 2
Chromium = 3
Carbon = 4
@property
def Rs(self):
match self:
case Dopant.Carbon:
return 47.5
case _:
return 5.11
class FileMeta(BaseModel):
dop: Dopant
amp: float
freq: float
path: Path
i_m: Union[np.ndarray, None] = None
v_m: Union[np.ndarray, None] = None
t: Union[np.ndarray, None] = None
class Config:
arbitrary_types_allowed = True
def find_periods(
series: pd.Series,
step: float,
):
results = []
t = series.index
y = series.to_numpy()
y = savgol_filter(y, 11, 2)
zero_crossing_filter = y[1:] * y[:-1] < 0
pos_slopes = y[1:] - y[:-1] > 0
mask = zero_crossing_filter & pos_slopes
indices = np.where(mask)[0]
x_left = indices[0]
indices = indices[1:]
for ind in indices:
y[x_left:ind]
results.append(y[x_left:ind])
x_left = ind
return results
def average_data(series: pd.Series, period: float):
temp = []
_series = series.copy()
size = 0
step = np.mean(_series.index[1:] - _series.index[:-1])
y = _series.to_numpy()
y = savgol_filter(y, 11, 2)
zero_crossing_filter = y[1:] * y[:-1] < 0 # | np.isclose(y[1:] * y[:-1], 0)
pos_slopes = np.diff(y) > 0
mask = zero_crossing_filter & pos_slopes
indices = np.where(mask > 0)[0]
_series = _series.iloc[indices[0] :]
print(step)
_series.index = np.arange(0, len(_series) * step, step)[: len(_series)]
while len(_series) > size:
filtr = _series.index < period
temp.append(_series[filtr].to_numpy())
size = len(_series[filtr])
_series = _series[~filtr]
_series.index = _series.index - _series.index.min()
most_common = Counter(map(len, temp)).most_common(1)[0][0]
print(f"{most_common=}")
temp = list(
map(
lambda x: x[:most_common]
if most_common < len(x)
else np.array([x[i] if i < len(x) else x[-1] for i in range(most_common)]),
temp,
)
)
return np.array(temp).mean(axis=0)
def mean_data(value: np.ndarray, ppp: int):
sz = len(value)
new_size = (sz // ppp) * ppp
value = value[:new_size]
value = value.reshape((new_size // ppp, ppp)).mean(axis=0)
return value
def unpack_data(config: FileMeta, ppp: int = 1000):
df = pd.read_csv(config.path, sep="\t", decimal=",", names=["u", "i", "t"])
t_min = df.t.min()
df.t = df.t.map(lambda t: t - t_min)
df = df.set_index("t", drop=True)
# print(average_data(df.u, 1/ config.freq))
print(df)
y = df.u.to_numpy()
y = savgol_filter(y, 11, 2)
zero_crossing_filter = y[1:] * y[:-1] < 0
pos_slopes = y[1:] - y[:-1] > 0
mask = zero_crossing_filter & pos_slopes
indices = np.where(mask)[0]
x_left = indices[0]
df = df.iloc[x_left:, :]
t_min = df.index.min()
df.index = df.index.map(lambda t: t - t_min)
# df = df.set_index("t", drop=True)
v_m = df.u - df.i
i_m = df.i / config.dop.Rs
return (
df.index[:ppp],
mean_data(i_m.to_numpy(), ppp),
mean_data(v_m.to_numpy(), ppp),
)
def export_signals(folder: Path, req_dopant: Dopant):
pattern = re.compile(r"mem([1-9])_sine_([0-9,]+)V_ ([0-9]+)Hz")
results: List[FileMeta] = []
for file in folder.glob("*.txt"):
print(file.stem)
if res := pattern.match(file.stem):
dopant = res.group(1)
amp = res.group(2)
freq = res.group(3)
results.append(
FileMeta(
dop=Dopant(int(dopant)),
amp=float(amp.replace(",", ".")),
freq=int(freq),
path=file,
)
)
return list(filter(lambda x: x.dop == req_dopant, results))
class MemData(object):
def __init__(
self, dopant: Dopant, ppp: int = 1000, folder=Path("./ac_measurements")
):
self.dopant = dopant
self.Rs = dopant.Rs
self.file_meta = export_signals(folder=folder, req_dopant=self.dopant)
self.period = 1.0 / min(self.file_meta, key=lambda x: x.freq).freq
self.ppp = ppp
for fm in self.file_meta:
t, i_m, v_m = unpack_data(fm, self.ppp)
fm.i_m = i_m
fm.v_m = v_m
fm.t = t
def __len__(self):
return len(self.file_meta)
def __getitem__(self, index):
if isinstance(index, int):
return self.file_meta[index]
elif isinstance(index, slice):
return self.file_meta[index]
else:
raise TypeError(f"Index must be int or slice, not {type(index)}")
def get_vms(self):
return jnp.array(list(map(lambda x: x.v_m, self.file_meta)))
def get_ims(self):
return jnp.array(list(map(lambda x: x.i_m, self.file_meta)))
def get_t(self):
return jnp.array(list(map(lambda x: x.t, self.file_meta)))
def get_train_test_indices_from_req(
md: MemData,
test_req: List[Tuple[float, float]] = [(1.5, 5), (1.0, 1), (1.0, 100), (0.5, 5)],
) -> Tuple[List[int], List[int]]:
train_ind = []
test_ind = []
for i, fm in enumerate(md.file_meta):
is_test = False
for amp, freq in test_req:
if np.isclose(fm.amp, amp) and fm.freq == freq:
is_test = True
break
if is_test:
test_ind.append(i)
else:
train_ind.append(i)
return train_ind, test_ind
class Func(eqx.Module):
out_scale: jax.Array
x0: jax.Array
Rs: float
mlp: eqx.nn.MLP
mlgp: eqx.nn.MLP
std_v: float
def __init__(
self,
*,
key,
width_mplx,
depth_mplx,
width_mlpg,
depth_mlpg,
mlpx_activation,
mlpx_final_activation,
mlpg_activation,
mlpg_final_activation,
latent_states=1,
std_v=1.0,
with_norm=True,
diamond_structure=False,
**kwargs,
):
super().__init__(**kwargs)
self.out_scale = jnp.array([1.0], dtype=jnp.float32)
self.mlp = (
eqx.nn.MLP(
in_size=latent_states + 1, # v_m + x + t
out_size=latent_states,
width_size=width_mplx,
depth=depth_mplx,
activation=mlpx_activation,
final_activation=mlpx_final_activation,
key=key,
)
if diamond_structure
else generate_neural_stack(
n_layers=depth_mplx,
width=width_mplx,
activation=mlpx_activation,
final_activation=mlpx_final_activation,
in_size=latent_states + 1,
out_size=latent_states,
key=key,
layer_norm=with_norm,
)
)
# self.mlp = generate_neural_stack(
# n_layers=depth_mplx,
# width=width_mplx,
# activation=mlpx_activation,
# final_activation=mlpx_final_activation,
# in_size=latent_states + 1,
# out_size=latent_states,
# key=key,
# layer_norm=with_norm,
# )
# self.mlgp = generate_neural_stack(
# n_layers=depth_mlpg,
# width=width_mlpg,
# activation=mlpg_activation,
# final_activation=mlpg_final_activation,
# in_size=latent_states,
# out_size=1,
# key=jr.fold_in(key, 1),
# layer_norm=with_norm,
# )
self.mlgp = (
eqx.nn.MLP(
in_size=latent_states,
out_size=1,
width_size=width_mlpg,
depth=depth_mlpg,
activation=mlpg_activation,
final_activation=mlpg_final_activation,
key=jr.fold_in(key, 1),
)
)
# if diamond_structure
# else generate_neural_stack(
# n_layers=depth_mlpg,
# width=width_mlpg,
# activation=mlpg_activation,
# final_activation=mlpg_final_activation,
# in_size=latent_states,
# out_size=1,
# key=jr.fold_in(key, 1),
# layer_norm=with_norm,
# )
# )
# self.Gon = 1 / jnp.array(0.5)
# self.Goff = 1 / jnp.array(20.0)
self.Rs = 5.11
self.x0 = jnp.array([0.001] * latent_states)
self.std_v = std_v
@property
def Ron(self):
return 1 / self.Gon
@property
def Roff(self):
return 1 / self.Goff
def __call__(self, t, y, args):
amp, freq = args
vs = jnp.sin(freq * t * 2 * jnp.pi) * amp
X = y
# print(f"{X.shape=}")
G = self.mlgp(X) * self.out_scale
# print(f"{G.shape=}, {X.shape=}, {t.shape=}")
v_m = vs / (self.Rs * G + 1)
# dX_input = jnp.concatenate(
# [jnp.atleast_1d(v_m).flatten() / self.std_v, jnp.atleast_1d(X).flatten()]
# )
dX_input = jnp.hstack([v_m / self.std_v, X])
# print(f"{dX_input.shape=}, {self.mlp(dX_input).shape=}")
# Now this will return the same shape as y
return self.mlp(dX_input)
class NeuralODE(eqx.Module):
func: Func
atol: float
rtol: float
max_steps: int
def __init__(
self,
*,
width_mplx,
depth_mplx,
width_mlpg,
depth_mlpg,
mlpx_activation,
mlpx_final_activation,
mlpg_activation,
mlpg_final_activation,
key,
latent_states=1,
std_v=1.0,
with_norm=True,
diamond_structure=False,
rtol=1e-3,
atol=1e-6,
max_steps=4096,
**kwargs,
):
super().__init__(**kwargs)
self.func = Func(
key=key,
width_mplx=width_mplx,
depth_mplx=depth_mplx,
width_mlpg=width_mlpg,
depth_mlpg=depth_mlpg,
mlpx_activation=mlpx_activation,
mlpx_final_activation=mlpx_final_activation,
mlpg_activation=mlpg_activation,
mlpg_final_activation=mlpg_final_activation,
latent_states=latent_states,
std_v=std_v,
with_norm=with_norm,
diamond_structure=diamond_structure,
)
self.atol = atol
self.rtol = rtol
self.max_steps = max_steps
def __call__(self, ts, amp, freq):
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=self.func.x0,
stepsize_controller=diffrax.PIDController(rtol=self.rtol, atol=self.atol),
saveat=diffrax.SaveAt(ts=ts),
args=(amp, freq),
max_steps=self.max_steps,
# adjoint=diffrax.BacksolveAdjoint(solver=diffrax.Tsit5()) # BacksolveAdjoint, InterpolationAdjoint
adjoint=diffrax.RecursiveCheckpointAdjoint(checkpoints=2048*32),
# max_steps=2**20,
)
x = solution.ys
# print(f"{x.shape=}, {ts.shape=}")
# print(f"{x.shape=}, {ts.shape=}")
G = jax.vmap(self.func.mlgp)(x) * self.func.out_scale
# print(f"{G.shape=}, {x.shape=}, {ts.shape=}")
# G = self.func.mlgp(x) * self.func.out_scale
# print(f"{self.func.Rs=}, {self.func.Ron=}, {self.func.Roff=}")
v_s = jnp.reshape(jnp.sin(freq * ts * 2 * jnp.pi) * amp, (-1, 1))
v_m = v_s / (self.func.Rs * G + 1)
# print(f"{v_m.shape=}, {x.shape=}, {v_s.shape=}")
# print(G)
i_m = v_m * G
return v_m, i_m, x, G
def plot_results(
t: jax.Array,
v_m: jax.Array,
i_m: jax.Array,
v_m_pred: jax.Array,
i_m_pred: jax.Array,
G_: jax.Array = None,
):
fig, ax = plt.subplots(3, 1, figsize=(15, 8), sharex=True)
for i in range(v_m.shape[0]):
line = ax[0].plot(t[i, :] / t[i, :].max(), v_m[i, :], label=f"v_m_{i}")
ax[0].plot(
t[i, :] / t[i, :].max(),
v_m_pred[i, :],
label=f"v_m_pred_{i}",
linestyle="--",
color=line[0].get_color(),
)
ax[1].plot(
t[i, :] / t[i, :].max(),
i_m[i, :],
label=f"i_m_{i}",
color=line[0].get_color(),
)
ax[1].plot(
t[i, :] / t[i, :].max(),
i_m_pred[i, :],
label=f"i_m_pred_{i}",
linestyle="--",
color=line[0].get_color(),
)
ax[2].plot(
t[i, :] / t[i, :].max(),
G_[i, :],
label=f"G_{i}",
color=line[0].get_color(),
)
ax[0].set_ylabel("Voltage (V)")
ax[1].set_ylabel("Current (mA)")
ax[2].set_ylabel("G (mS)")
# ax[0].legend()
# ax[1].legend()
ax[1].set_xlabel("Time (1/T)")
ax[0].set_xlabel("Time (1/T)")
return fig, ax
def get_batches(
length: int,
batch_size: int,
shuffle: bool = True,
dataloader_key: jr.PRNGKey = jr.PRNGKey(0),
) -> Generator[List[int], None, None]:
if shuffle:
indices = jr.permutation(dataloader_key, jnp.arange(length))
else:
indices = jnp.arange(length)
for i in range(0, length, batch_size):
batch_indices = indices[i : i + batch_size]
yield batch_indices