utils.py
import os
import re
from collections import Counter
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, List, Union

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.numpy as jnp
import jax.random as jr
import joblib
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax  # https://github.com/deepmind/optax
import optax.contrib as contrib
import optuna
import pandas as pd
from jax.scipy.special import expit  # JAX's sigmoid function
from loguru import logger
from optax.tree_utils import tree_norm
from optuna import Trial
from pydantic import BaseModel
from scipy.signal import savgol_filter
from typing import Tuple
matplotlib.use("svg")  # Use Agg backend for matplotlib to avoid GUI issues


os.environ["XLA_FLAGS"] = (
    "--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=20\
    --xla_cpu_enable_fast_math=true "
)



def generate_neural_stack(
    n_layers: int,
    width: int,
    activation: eqx.Module,
    final_activation: eqx.Module,
    key: jax.Array,
    layer_norm: bool = True,
    out_size: int = 1,
    in_size: int = 2,
) -> eqx.Module:
    """Generate a neural network stack using JAX Equinox.

    Args:
        trial_dict: Dictionary containing network configuration
        key: JAX random key for parameter initialization

    Returns:
        An Equinox module representing the neural network
    """
    try:
        layers = []
        keys = jax.random.split(key, 100)  # Generate enough keys
        key_idx = 0

        for i in range(n_layers):
            if i == 0:
                # First layer: 2 -> n_neurons // 2 -> n_neurons
                layers.append(eqx.nn.Linear(in_size, width // 2, key=keys[key_idx]))
                key_idx += 1
                layers.append(eqx.nn.Lambda(activation))

                layers.append(
                    eqx.nn.Linear(
                        width // 2,
                        width,
                        key=keys[key_idx],
                    )
                )
                key_idx += 1

                if layer_norm:
                    layers.append(eqx.nn.LayerNorm(width))
                layers.append(eqx.nn.Lambda(activation))
            else:
                # Hidden layers: n_neurons -> n_neurons
                layers.append(eqx.nn.Linear(width, width, key=keys[key_idx]))
                key_idx += 1

                if layer_norm:
                    layers.append(eqx.nn.LayerNorm(width))
                layers.append(eqx.nn.Lambda(activation))

        # Output layers: n_neurons -> n_neurons // 2 -> 1
        layers.extend(
            [
                eqx.nn.Linear(width, width // 2, key=keys[key_idx]),
                eqx.nn.Lambda(activation),
            ]
        )
        key_idx += 1

        layers.append(eqx.nn.Linear(width // 2, out_size, key=keys[key_idx]))
        layers.append(eqx.nn.Lambda(final_activation))

        # Create sequential model
        model = eqx.nn.Sequential(layers)

        logger.info(f"Created neural network with {len(layers)} layers \n{model}")
        # exit()
    except Exception as e:
        logger.exception(e)
    return model


class Dopant(Enum):
    Tungsten = 1
    Tin = 2
    Chromium = 3
    Carbon = 4

    @property
    def Rs(self):
        match self:
            case Dopant.Carbon:
                return 47.5
            case _:
                return 5.11


class FileMeta(BaseModel):
    dop: Dopant
    amp: float
    freq: float
    path: Path
    i_m: Union[np.ndarray, None] = None
    v_m: Union[np.ndarray, None] = None
    t: Union[np.ndarray, None] = None

    class Config:
        arbitrary_types_allowed = True


def find_periods(
    series: pd.Series,
    step: float,
):
    results = []
    t = series.index
    y = series.to_numpy()
    y = savgol_filter(y, 11, 2)

    zero_crossing_filter = y[1:] * y[:-1] < 0
    pos_slopes = y[1:] - y[:-1] > 0
    mask = zero_crossing_filter & pos_slopes
    indices = np.where(mask)[0]
    x_left = indices[0]
    indices = indices[1:]
    for ind in indices:
        y[x_left:ind]
        results.append(y[x_left:ind])
        x_left = ind

    return results


def average_data(series: pd.Series, period: float):
    temp = []
    _series = series.copy()
    size = 0
    step = np.mean(_series.index[1:] - _series.index[:-1])
    y = _series.to_numpy()
    y = savgol_filter(y, 11, 2)
    zero_crossing_filter = y[1:] * y[:-1] < 0  # | np.isclose(y[1:] *  y[:-1], 0)
    pos_slopes = np.diff(y) > 0
    mask = zero_crossing_filter & pos_slopes
    indices = np.where(mask > 0)[0]
    _series = _series.iloc[indices[0] :]
    print(step)
    _series.index = np.arange(0, len(_series) * step, step)[: len(_series)]
    while len(_series) > size:
        filtr = _series.index < period
        temp.append(_series[filtr].to_numpy())
        size = len(_series[filtr])
        _series = _series[~filtr]
        _series.index = _series.index - _series.index.min()

    most_common = Counter(map(len, temp)).most_common(1)[0][0]
    print(f"{most_common=}")
    temp = list(
        map(
            lambda x: x[:most_common]
            if most_common < len(x)
            else np.array([x[i] if i < len(x) else x[-1] for i in range(most_common)]),
            temp,
        )
    )

    return np.array(temp).mean(axis=0)


def mean_data(value: np.ndarray, ppp: int):
    sz = len(value)
    new_size = (sz // ppp) * ppp
    value = value[:new_size]
    value = value.reshape((new_size // ppp, ppp)).mean(axis=0)
    return value


def unpack_data(config: FileMeta, ppp: int = 1000):
    df = pd.read_csv(config.path, sep="\t", decimal=",", names=["u", "i", "t"])
    t_min = df.t.min()
    df.t = df.t.map(lambda t: t - t_min)
    df = df.set_index("t", drop=True)
    # print(average_data(df.u, 1/ config.freq))
    print(df)

    y = df.u.to_numpy()
    y = savgol_filter(y, 11, 2)

    zero_crossing_filter = y[1:] * y[:-1] < 0
    pos_slopes = y[1:] - y[:-1] > 0
    mask = zero_crossing_filter & pos_slopes
    indices = np.where(mask)[0]
    x_left = indices[0]
    df = df.iloc[x_left:, :]
    t_min = df.index.min()
    df.index = df.index.map(lambda t: t - t_min)
    # df = df.set_index("t", drop=True)
    v_m = df.u - df.i
    i_m = df.i / config.dop.Rs

    return (
        df.index[:ppp],
        mean_data(i_m.to_numpy(), ppp),
        mean_data(v_m.to_numpy(), ppp),
    )


def export_signals(folder: Path, req_dopant: Dopant):
    pattern = re.compile(r"mem([1-9])_sine_([0-9,]+)V_ ([0-9]+)Hz")
    results: List[FileMeta] = []
    for file in folder.glob("*.txt"):
        print(file.stem)

        if res := pattern.match(file.stem):
            dopant = res.group(1)
            amp = res.group(2)
            freq = res.group(3)
            results.append(
                FileMeta(
                    dop=Dopant(int(dopant)),
                    amp=float(amp.replace(",", ".")),
                    freq=int(freq),
                    path=file,
                )
            )
    return list(filter(lambda x: x.dop == req_dopant, results))


class MemData(object):
    def __init__(
        self, dopant: Dopant, ppp: int = 1000, folder=Path("./ac_measurements")
    ):
        self.dopant = dopant
        self.Rs = dopant.Rs
        self.file_meta = export_signals(folder=folder, req_dopant=self.dopant)
        self.period = 1.0 / min(self.file_meta, key=lambda x: x.freq).freq
        self.ppp = ppp

        for fm in self.file_meta:
            t, i_m, v_m = unpack_data(fm, self.ppp)
            fm.i_m = i_m
            fm.v_m = v_m
            fm.t = t

    def __len__(self):
        return len(self.file_meta)

    def __getitem__(self, index):
        if isinstance(index, int):
            return self.file_meta[index]
        elif isinstance(index, slice):
            return self.file_meta[index]
        else:
            raise TypeError(f"Index must be int or slice, not {type(index)}")

    def get_vms(self):
        return jnp.array(list(map(lambda x: x.v_m, self.file_meta)))

    def get_ims(self):
        return jnp.array(list(map(lambda x: x.i_m, self.file_meta)))

    def get_t(self):
        return jnp.array(list(map(lambda x: x.t, self.file_meta)))


def get_train_test_indices_from_req(
    md: MemData,
    test_req: List[Tuple[float, float]] = [(1.5, 5), (1.0, 1), (1.0, 100), (0.5, 5)],
) -> Tuple[List[int], List[int]]:
    train_ind = []
    test_ind = []
    for i, fm in enumerate(md.file_meta):
        is_test = False
        for amp, freq in test_req:
            if np.isclose(fm.amp, amp) and fm.freq == freq:
                is_test = True
                break
        if is_test:
            test_ind.append(i)
        else:
            train_ind.append(i)

    return train_ind, test_ind




class Func(eqx.Module):
    out_scale: jax.Array
    x0: jax.Array
    Rs: float
    mlp: eqx.nn.MLP
    mlgp: eqx.nn.MLP
    std_v: float

    def __init__(
        self,
        *,
        key,
        width_mplx,
        depth_mplx,
        width_mlpg,
        depth_mlpg,
        mlpx_activation,
        mlpx_final_activation,
        mlpg_activation,
        mlpg_final_activation,
        latent_states=1,
        std_v=1.0,
        with_norm=True,
        diamond_structure=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.out_scale = jnp.array([1.0], dtype=jnp.float32)
        self.mlp = (
            eqx.nn.MLP(
                in_size=latent_states + 1,  # v_m + x + t
                out_size=latent_states,
                width_size=width_mplx,
                depth=depth_mplx,
                activation=mlpx_activation,
                final_activation=mlpx_final_activation,
                key=key,
            )
            if diamond_structure
            else generate_neural_stack(
                n_layers=depth_mplx,
                width=width_mplx,
                activation=mlpx_activation,
                final_activation=mlpx_final_activation,
                in_size=latent_states + 1,
                out_size=latent_states,
                key=key,
                layer_norm=with_norm,
            )
        )
        # self.mlp = generate_neural_stack(
        #     n_layers=depth_mplx,
        #     width=width_mplx,
        #     activation=mlpx_activation,
        #     final_activation=mlpx_final_activation,
        #     in_size=latent_states + 1,
        #     out_size=latent_states,
        #     key=key,
        #     layer_norm=with_norm,
        # )

        # self.mlgp = generate_neural_stack(
        #     n_layers=depth_mlpg,
        #     width=width_mlpg,
        #     activation=mlpg_activation,
        #     final_activation=mlpg_final_activation,
        #     in_size=latent_states,
        #     out_size=1,
        #     key=jr.fold_in(key, 1),
        #     layer_norm=with_norm,
        # )
        self.mlgp = (
            eqx.nn.MLP(
                in_size=latent_states,
                out_size=1,
                width_size=width_mlpg,
                depth=depth_mlpg,
                activation=mlpg_activation,
                final_activation=mlpg_final_activation,
                key=jr.fold_in(key, 1),
            )
        )
            
        
        #     if diamond_structure
        #     else generate_neural_stack(
        #         n_layers=depth_mlpg,
        #         width=width_mlpg,
        #         activation=mlpg_activation,
        #         final_activation=mlpg_final_activation,
        #         in_size=latent_states,
        #         out_size=1,
        #         key=jr.fold_in(key, 1),
        #         layer_norm=with_norm,
        #     )
        # )
        # self.Gon = 1 / jnp.array(0.5)
        # self.Goff = 1 / jnp.array(20.0)
        self.Rs = 5.11
        self.x0 = jnp.array([0.001] * latent_states)
        self.std_v = std_v

    @property
    def Ron(self):
        return 1 / self.Gon

    @property
    def Roff(self):
        return 1 / self.Goff

    def __call__(self, t, y, args):
        amp, freq = args
        vs = jnp.sin(freq * t * 2 * jnp.pi) * amp
        X = y
        # print(f"{X.shape=}")
        G = self.mlgp(X) * self.out_scale
        # print(f"{G.shape=}, {X.shape=}, {t.shape=}")
        v_m = vs / (self.Rs * G + 1)

        # dX_input = jnp.concatenate(
        #     [jnp.atleast_1d(v_m).flatten() / self.std_v, jnp.atleast_1d(X).flatten()]
        # )
        dX_input = jnp.hstack([v_m / self.std_v, X])

        # print(f"{dX_input.shape=}, {self.mlp(dX_input).shape=}")

        # Now this will return the same shape as y

        return self.mlp(dX_input)


class NeuralODE(eqx.Module):
    func: Func
    atol: float
    rtol: float
    max_steps: int

    def __init__(
        self,
        *,
        width_mplx,
        depth_mplx,
        width_mlpg,
        depth_mlpg,
        mlpx_activation,
        mlpx_final_activation,
        mlpg_activation,
        mlpg_final_activation,
        key,
        latent_states=1,
        std_v=1.0,
        with_norm=True,
        diamond_structure=False,
        rtol=1e-3,
        atol=1e-6,
        max_steps=4096,
        
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.func = Func(
            key=key,
            width_mplx=width_mplx,
            depth_mplx=depth_mplx,
            width_mlpg=width_mlpg,
            depth_mlpg=depth_mlpg,
            mlpx_activation=mlpx_activation,
            mlpx_final_activation=mlpx_final_activation,
            mlpg_activation=mlpg_activation,
            mlpg_final_activation=mlpg_final_activation,
            latent_states=latent_states,
            std_v=std_v,
            with_norm=with_norm,
            diamond_structure=diamond_structure,
        )
        self.atol = atol
        self.rtol = rtol
        self.max_steps = max_steps

    def __call__(self, ts, amp, freq):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=self.func.x0,
            stepsize_controller=diffrax.PIDController(rtol=self.rtol, atol=self.atol),
            saveat=diffrax.SaveAt(ts=ts),
            args=(amp, freq),
            max_steps=self.max_steps,
            # adjoint=diffrax.BacksolveAdjoint(solver=diffrax.Tsit5()) # BacksolveAdjoint, InterpolationAdjoint
            adjoint=diffrax.RecursiveCheckpointAdjoint(checkpoints=2048*32),
            # max_steps=2**20,
        )

        x = solution.ys
        # print(f"{x.shape=}, {ts.shape=}")
        # print(f"{x.shape=}, {ts.shape=}")
        G = jax.vmap(self.func.mlgp)(x) * self.func.out_scale
        # print(f"{G.shape=}, {x.shape=}, {ts.shape=}")

        # G = self.func.mlgp(x) * self.func.out_scale
        # print(f"{self.func.Rs=}, {self.func.Ron=}, {self.func.Roff=}")
        v_s = jnp.reshape(jnp.sin(freq * ts * 2 * jnp.pi) * amp, (-1, 1))
        v_m = v_s / (self.func.Rs * G + 1)
        # print(f"{v_m.shape=}, {x.shape=}, {v_s.shape=}")
        # print(G)
        i_m = v_m * G

        return v_m, i_m, x, G


def plot_results(
    t: jax.Array,
    v_m: jax.Array,
    i_m: jax.Array,
    v_m_pred: jax.Array,
    i_m_pred: jax.Array,
    G_: jax.Array = None,
):
    fig, ax = plt.subplots(3, 1, figsize=(15, 8), sharex=True)
    for i in range(v_m.shape[0]):
        line = ax[0].plot(t[i, :] / t[i, :].max(), v_m[i, :], label=f"v_m_{i}")
        ax[0].plot(
            t[i, :] / t[i, :].max(),
            v_m_pred[i, :],
            label=f"v_m_pred_{i}",
            linestyle="--",
            color=line[0].get_color(),
        )
        ax[1].plot(
            t[i, :] / t[i, :].max(),
            i_m[i, :],
            label=f"i_m_{i}",
            color=line[0].get_color(),
        )
        ax[1].plot(
            t[i, :] / t[i, :].max(),
            i_m_pred[i, :],
            label=f"i_m_pred_{i}",
            linestyle="--",
            color=line[0].get_color(),
        )
        ax[2].plot(
            t[i, :] / t[i, :].max(),
            G_[i, :],
            label=f"G_{i}",
            color=line[0].get_color(),
        )

    ax[0].set_ylabel("Voltage (V)")
    ax[1].set_ylabel("Current (mA)")
    ax[2].set_ylabel("G (mS)")
    # ax[0].legend()

    # ax[1].legend()
    ax[1].set_xlabel("Time (1/T)")
    ax[0].set_xlabel("Time (1/T)")

    return fig, ax


def get_batches(
    length: int,
    batch_size: int,
    shuffle: bool = True,
    dataloader_key: jr.PRNGKey = jr.PRNGKey(0),
) -> Generator[List[int], None, None]:
    if shuffle:
        indices = jr.permutation(dataloader_key, jnp.arange(length))
    else:
        indices = jnp.arange(length)

    for i in range(0, length, batch_size):
        batch_indices = indices[i : i + batch_size]
        yield batch_indices