neural_memristor_all_feature.py
from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple

import joblib  # noqa
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from scipy.interpolate import Akima1DInterpolator
from torch import nn
from torch.utils.data import (
    DataLoader,  # noqa
    Dataset,
)
from torchdiffeq import odeint as odeint

from normalization_predictor import NRME  # noqa


class ExponentialDecayLoss(nn.Module):
    def __init__(
        self, decay_rate=0.1, eps=1e-8, criterion=NRME("range", reduction="none")
    ):
        """
        Loss function that penalizes errors more heavily at the start of the time series.

        Args:
            decay_rate: Controls how quickly the weight decreases along the time dimension
            eps: Small constant for numerical stability
        """
        super(ExponentialDecayLoss, self).__init__()
        self.decay_rate = decay_rate
        self.eps = eps
        self.criterion = criterion

    def forward(self, y_pred, y_true):
        """
        Forward pass of the loss function.

        Args:
            y_pred: Predicted trajectory from Neural ODE, shape [batch_size, time_steps, dim]
            y_true: Ground truth trajectory, shape [batch_size, time_steps, dim]

        Returns:
            Weighted MSE loss with higher penalties for earlier time points
        """
        # Calculate squared error
        squared_error = self.criterion(y_pred, y_true)

        # Get time dimension length
        time_steps = y_pred.shape[0]

        # Create exponentially decaying weights
        time_indices = torch.arange(time_steps, device=y_pred.device).float()
        weights = torch.exp(-self.decay_rate * time_indices)
        weights = weights / (torch.sum(weights) + self.eps)  # Normalize weights

        # Reshape weights for broadcasting - PyTorch will handle batch dimension automatically
        weights = weights.view(1, time_steps, 1)

        # Apply weights to squared error
        weighted_error = squared_error * weights
        loss = torch.mean(weighted_error)

        return loss


def get_device_name(b=3, a=None):
    if torch.cuda.is_available():
        return f"cuda:{torch.cuda.current_device()}"
    elif torch.backends.mps.is_available():  # For Apple M1/M2 chips
        return "mps"
    else:
        return "cpu"


precision = torch.float64
device = get_device_name()
device = "cpu"
matplotlib.use("svg")


@dataclass
class OscParams:
    R1: float
    R2: float = 10e3
    R3: float = 3.6e3
    R4: float = 1e3
    R5: float = 1e3
    R7: float = 1e3
    R9: float = 1e3
    R6: float = 22e3
    R8: float = 33e3
    R10: float = 47e3
    R11: float = 100e3
    C1: float = 50e-6
    C3: float = 50e-6
    C2: float = 23.5e-6


class Sin(torch.nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return torch.sin(input)


class NeuralDiffEq(nn.Module):
    def __init__(
        self,
        config: OscParams,
        x0: float = 0.0,
        Ron: float = 1_000,
        Roff: float = 100_000,
        activation: nn.Module = nn.Sigmoid(),
        v_m_std: float = 0.2,
        i_m_std: float = 1e-3,
    ) -> None:
        super().__init__()
        # Neural network for dX/dt
        self.dim = 3  # [X, v_m]
        self.ann = (
            nn.Sequential(
                nn.Linear(self.dim, 64),
                activation,
                nn.Linear(64, 128),
                activation,
                nn.Linear(128, 128),
                activation,
                nn.Linear(128, 64),
                activation,
                nn.Linear(64, 1),
                # nn.Tanh(),
            )
            .to(device)
            .to(precision)
        )

        # self.ann2 = self.ann
        self.R = 1000
        # Parameters
        self.x0 = nn.Parameter(torch.tensor(x0, dtype=precision, device=device))
        # Use log space for resistance parameters
        self.log_Ron = nn.Parameter(
            torch.tensor(np.log(Ron) / self.R, dtype=precision, device=device)
        )

        self.log_Roff = nn.Parameter(
            torch.tensor(np.log(Roff) / self.R, dtype=precision, device=device)
        )
        self.p = config
        self.v_m_std = v_m_std
        self.i_m_std = i_m_std
        # self.out_scale = nn.Parameter(torch.tensor(1.0, dtype=precision, device=device))

    @property
    def Ron(self):
        return torch.exp(self.log_Ron * self.R)

    @property
    def Roff(self):
        return torch.exp(self.log_Roff * self.R)

    def compute_memristor_current(self, X, vx):
        v_m = vx  # - 0.040
        G = X / self.Ron + (1 - X) / self.Roff
        i_m = G * v_m
        return v_m, i_m

    def forward(self, t, state):
        v1, v2, v3, X = state[..., 0], state[..., 1], state[..., 2], state[..., 3]

        # Compute memristor voltage and current
        vx = -self.p.R11 / self.p.R10 * v3
        v_m, i_m = self.compute_memristor_current(X, vx)
        v4 = vx + self.p.R1 * i_m

        # Compute state derivatives
        dv1 = 1 / self.p.C1 * (-v1 / self.p.R4 - v2 / self.p.R9 - v4 / self.p.R3)
        dv2 = 1 / self.p.C3 * (-v2 / self.p.R8 - v3 / self.p.R7)
        dv3 = 1 / self.p.C2 * (-v1 / self.p.R5 - v3 / self.p.R6)

        # Neural network input for dX/dt
        dX_input = torch.cat(
            [
                X.unsqueeze(-1),
                v_m.unsqueeze(-1) / self.v_m_std,
                i_m.unsqueeze(-1) / self.i_m_std,
            ],
            dim=-1,
        ).unsqueeze(0)
        dX = self.ann(dX_input)  # * self.out_scale
        # added constraints
        dX = torch.where(X <= 0, torch.maximum(dX, torch.zeros_like(dX)), dX)
        dX = torch.where(X >= 1, torch.minimum(dX, torch.zeros_like(dX)), dX)

        # Combine all derivatives
        dfunc = torch.zeros_like(state)
        dfunc[..., 0] = dv1
        dfunc[..., 1] = dv2
        dfunc[..., 2] = dv3
        dfunc[..., 3] = dX.squeeze(0)

        return dfunc


def smooth_range_penalty(x, min_val=0.0, max_val=1.0, margin=0.1, penalty_weight=1.0):
    # Create smooth transition zones
    lower_margin = min_val + margin
    upper_margin = max_val - margin

    # Smooth penalties using sigmoid
    below_penalty = torch.sigmoid(-(x - lower_margin) / (margin / 4)) * (
        lower_margin - x
    )
    above_penalty = torch.sigmoid((x - upper_margin) / (margin / 4)) * (
        x - upper_margin
    )

    # Combine penalties
    total_penalty = below_penalty + above_penalty

    return penalty_weight * torch.mean(total_penalty**2)


class MemristorDataset(Dataset):
    def __init__(
        self,
        file: Path,
        ds: int = 100,
        chunk_size: float = 0.05,
    ):
        super().__init__()
        df = pd.read_csv(
            path,
        )
        df["time"] = df["time"] - df["time"].min()
        self.df = df
        self.l_t = len(df["time"])
        self.df = df.iloc[0 : ds : self.l_t, :]
        self.ds = ds
        self.chunk_size = chunk_size

    def __getitem__(self, idx):
        cur_slice = slice(idx * self.ds, idx * self.ds + self.ds)
        tmp_df = self.df.iloc[cur_slice, :]

        return (
            torch.tensor(tmp_df["time"].to_numpy(), device=device, dtype=precision),
            torch.tensor(
                tmp_df[["v1", "v2", "v3"]].to_numpy(),
                device=device,
                dtype=precision,
            ),
        )

    def __len__(self):
        return int(len(self.df) / self.chunk_size)


# device = "cpu"


# device = "cpu"


class MemData:
    """
    Memory Data handler for neural ODE training with progressive data inclusion.

    This class implements a curriculum learning approach for training neural ODEs:
    1. Start with small time chunks to learn local dynamics
    2. Gradually increase time window as the model improves
    3. Smooth transitions between data lengths to maintain training stability
    """

    def __init__(
        self,
        path: Path,
        fs: float = 1e3,
        dt: float = 0.05,
        patience: int = 100,
        dry_run_epoch: int = 1000,
        device=device,
        precision=precision,
        columns: List[str] = ["v1", "v2", "v3"],
        transition_steps: int = 20,  # More steps for smoother transitions
        min_improvement: float = 1e-3,  # Minimum relative improvement to reset patience
    ):
        """
        Initialize the MemData object for neural ODE training.

        Args:
            path: Path to the CSV file containing measurement time series data
            fs: Sampling frequency in Hz
            dt: Time step for neural ODE simulation
            patience: Number of epochs to wait before increasing data length
            dry_run_epoch: Number of epochs to run with minimal data chunk
            device: Torch device to use (defaults to cuda if available)
            precision: Torch data type precision
            columns: Data columns to use (must exist in the CSV)
            transition_steps: Number of steps for smooth transitions between data lengths
            min_improvement: Minimum relative improvement to consider progress
        """
        # Set default device if not provided
        self.device = (
            device
            if device is not None
            else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        self.precision = precision
        self.columns = columns

        # Training parameters
        self.dt = dt
        self.fs = fs
        self.di = dt / (1 / fs)  # Increment size based on dt and fs
        self.counter = 0
        self.base_mult = 1.0
        self.target_mult = 1.0
        self.current_mult = 1.0
        self.patience_counter = 0
        self.curr_loss = 1e6
        self.best_loss = 1e6
        self.dry_run_epoch = dry_run_epoch
        self.patience = patience
        self.transition_steps = transition_steps
        self.transition_counter = 0
        self.in_transition = False
        self.min_improvement = min_improvement
        self.epoch = 0

        # Load and preprocess data
        try:
            self.df = pd.read_csv(path)
            if "time" not in self.df.columns:
                raise ValueError("CSV must contain a 'time' column")

            # Validate all required columns exist
            for col in self.columns:
                if col not in self.df.columns:
                    raise ValueError(f"Required column '{col}' not found in the CSV")

            # Normalize time to start at 0
            self.df["time"] = self.df["time"] - self.df["time"].min()

            # Interpolate to regular intervals
            self.interpolate_df()

            # Create tensors
            self.l_t = len(self.df["time"])
            self.T = torch.tensor(
                self.df["time"].to_numpy(), device=self.device, dtype=self.precision
            )
            self.X = torch.tensor(
                self.df[self.columns].to_numpy(),
                device=self.device,
                dtype=self.precision,
            )

            # Calculate max multiplier value based on dataset size
            self.max_multiplier = max(1.0, self.l_t / self.di)

            print(f"Loaded dataset with {self.l_t} time points")
            print(f"Initial chunk: {int(self.di)} points")
            print(f"Maximum multiplier: {self.max_multiplier:.2f}")

        except Exception as e:
            raise RuntimeError(f"Error loading data: {str(e)}")

    def interpolate_df(self):
        """
        Interpolate the dataframe to regular time intervals using Akima interpolation.
        """
        x = self.df["time"].to_numpy()
        x_new = np.arange(x[0], x[-1], 1 / self.fs)
        new_df = pd.DataFrame({"time": x_new})

        # Interpolate all data columns
        data_columns = (
            self.columns + ["v4"] if "v4" in self.df.columns else self.columns
        )

        for col in data_columns:
            if col in self.df.columns:
                y = self.df[col].to_numpy()
                akima_interp = Akima1DInterpolator(x, y)
                y_new = akima_interp(x_new)
                new_df[col] = y_new

        # Normalize time to start at 0
        new_df["time"] = new_df["time"] - new_df["time"].min()
        self.df = new_df

    def load_data(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Load a subset of data based on the current multiplier.

        Returns:
            Tuple of (time_tensor, values_tensor)
        """
        # Calculate the number of points to include
        end_idx = min(int(self.di * self.current_mult), self.l_t)

        # Ensure end_idx is at least 1
        end_idx = max(1, end_idx)

        # Return the time series subset
        return (
            self.T[:end_idx],
            self.X[:end_idx, :],
        )

    def update_multiplier(self):
        """
        Update the multiplier with smooth transition.
        Uses cosine annealing for smoother transitions than linear interpolation.
        """
        if self.in_transition:
            # Calculate smooth transition using cosine annealing (smoother than linear)
            progress = self.transition_counter / self.transition_steps
            # Cosine annealing formula: cos(pi * t) * (1-t) + t
            # This gives a smoother S-curve than linear interpolation
            t = progress
            smooth_t = np.cos(np.pi * t) * (1 - t) + t

            self.current_mult = (
                self.base_mult + (self.target_mult - self.base_mult) * smooth_t
            )

            self.transition_counter += 1
            if self.transition_counter >= self.transition_steps:
                # Transition complete
                self.in_transition = False
                self.current_mult = self.target_mult
                self.base_mult = self.target_mult
                print(
                    f"Transition complete. Current data length: {int(self.di * self.current_mult)} points"
                )

    def start_transition(self, new_target_mult):
        """
        Start a smooth transition to a new multiplier target.

        Args:
            new_target_mult: The target multiplier to transition to
        """
        # Cap the target multiplier to the maximum possible
        new_target_mult = min(new_target_mult, self.max_multiplier)

        # Only start transition if the target is different
        if abs(new_target_mult - self.current_mult) > 1e-4:
            self.base_mult = self.current_mult
            self.target_mult = new_target_mult
            self.transition_counter = 0
            self.in_transition = True
            print(
                f"Starting transition from {int(self.di * self.base_mult)} to {int(self.di * self.target_mult)} points"
            )

    def get_data(self, loss: float) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Get data with adaptive strategy based on neural ODE training progress.

        This method implements curriculum learning for neural ODEs:
        1. Initial dry run with small time window to learn basic dynamics
        2. Gradual increase of time window as model improves
        3. Uses patience mechanism with smooth transitions

        Args:
            loss: Current training loss

        Returns:
            Tuple of (time_tensor, values_tensor)
        """
        self.epoch += 1

        # Update multiplier if in transition mode
        if self.in_transition:
            self.update_multiplier()
            return self.load_data()

        # Track best loss
        if loss < self.best_loss:
            improvement = (self.best_loss - loss) / self.best_loss
            self.best_loss = loss

            # Log significant improvements
            if improvement > 0.05:  # 5% improvement
                print(
                    f"Epoch {self.epoch}: Loss improved by {improvement * 100:.2f}% to {loss:.6f}"
                )

        # Initial dry run with minimal data
        if self.base_mult == 1.0 and self.counter < self.dry_run_epoch:
            self.counter += 1
            if self.counter % 100 == 0:
                print(
                    f"Dry run epoch {self.counter}/{self.dry_run_epoch}, loss: {loss:.6f}"
                )
            return self.load_data()
        elif self.base_mult == 1.0 and self.counter >= self.dry_run_epoch:
            self.counter = 0
            print(f"Dry run complete. Initial loss: {loss:.6f}")
            # Start transition to multiplier 2
            self.start_transition(2.0)
            return self.load_data()

        # Patience-based strategy for subsequent epochs
        relative_improvement = (self.curr_loss - loss) / self.curr_loss

        if relative_improvement > self.min_improvement:
            # Sufficient progress - reset patience counter
            prev_loss = self.curr_loss
            self.curr_loss = loss
            self.patience_counter = 0
            if self.epoch % 50 == 0:
                print(
                    f"Epoch {self.epoch}: Loss improved from {prev_loss:.6f} to {loss:.6f} ({relative_improvement * 100:.2f}%)"
                )
        else:
            # Insufficient or no progress - increment patience counter
            self.patience_counter += 1

            # Log patience status periodically
            if self.patience_counter % 10 == 0:
                print(
                    f"Patience: {self.patience_counter}/{self.patience}, current loss: {loss:.6f}"
                )

            if self.patience_counter >= self.patience:
                # Patience exceeded - start transition to next multiplier
                # Increase by smaller steps at higher multipliers for stability
                if self.base_mult < 5:
                    next_mult = self.base_mult + 1.0
                elif self.base_mult < 10:
                    next_mult = self.base_mult + 0.5
                else:
                    next_mult = self.base_mult * 1.2

                self.start_transition(next_mult)
                self.curr_loss = loss
                self.patience_counter = 0

        return self.load_data()

    def reset(self):
        """Reset the data loader state."""
        self.counter = 0
        self.base_mult = 1.0
        self.target_mult = 1.0
        self.current_mult = 1.0
        self.patience_counter = 0
        self.curr_loss = 1e6
        self.best_loss = 1e6
        self.in_transition = False
        self.transition_counter = 0
        self.epoch = 0
        print("Data loader reset to initial state")

    def get_current_data_length(self):
        """Get the current number of data points being used."""
        return int(self.di * self.current_mult)

    def get_max_data_length(self):
        """Get the maximum number of data points available."""
        return self.l_t

    def get_full_data(self):
        """Get the full dataset."""
        return self.T, self.X


def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def plot_data(i, t, y_pred, y_true):
    plt.plot(t, y_true[:, 2], label="true")
    plt.plot(t, y_pred[:, 2], label="pred")

    plt.legend(loc="upper right")
    plt.savefig(f"neural_plot/{i}.png", bbox_inches="tight")
    plt.close("all")


def smoth_criterion(y_pred, y_true, criterion):
    total_loss = 0
    for i in range(3):
        mean_i = torch.mean(y_true[:, i])
        std_i = torch.std(y_true[:, i])
        total_loss += criterion(
            (y_pred[:, i] - mean_i) / std_i, (y_true[:, i] - mean_i) / std_i
        )
    return total_loss / 3


def get_gausian_prediction(y_pred, y_true, criterion2: callable):
    total_loss = 0
    for i in range(3):
        total_loss += criterion2(
            torch.mean(y_pred[:, i]), y_true[:, i], torch.var(y_pred[:, i])
        )
    return total_loss / 3


def extract_vm_im(
    path: Path,
    pp: OscParams,
):
    df = pd.read_csv(path)
    df["time"] = df["time"] - df["time"].min()

    v4 = df["v4"].to_numpy()
    v3 = df["v3"].to_numpy()
    v_m = -pp.R11 / pp.R10 * v3
    i_m = (v4 - v_m) / pp.R1
    return v_m, i_m


def get_mutual_characteristic(
    y_pred, y_true, i, folder: Path = Path("neural_mutual_tanh")
):
    fig, ax = plt.subplots(1, 3, layout="constrained", figsize=(15, 5))
    ax[0].plot(y_true[:, 0], y_true[:, 1], label="true")
    ax[0].plot(y_pred[:, 0], y_pred[:, 1], label="pred")
    ax[0].legend(loc="upper right")
    ax[0].set_title("V1-V2")
    ax[1].plot(y_true[:, 0], y_true[:, 2], label="true")
    ax[1].plot(y_pred[:, 0], y_pred[:, 2], label="pred")
    ax[1].legend(loc="upper right")
    ax[1].set_title("V1-V3")
    ax[2].plot(y_true[:, 1], y_true[:, 2], label="true")
    ax[2].plot(y_pred[:, 1], y_pred[:, 2], label="pred")
    ax[2].legend(loc="upper right")
    ax[2].set_title("V2-V3")
    fig.savefig(folder / f"{i}.png")


def train_within_weights(
    weights: np.ndarray, epoch: int = 2000, ratio: float = 0.02
) -> np.ndarray:
    criterion1 = NRME("range", reduction="mean", norm_axis=0)

    criterion2 = MultiscaleChaoticPeakLoss(
        main_peak_weight=0.05,  # Primary emphasis on main peak
        multi_peak_weight=0.05,  # Secondary emphasis on harmonic structure
        spectral_shape_weight=0.99,  # General frequency distribution
        time_domain_weight=0,  # Small weight for time-domain matching
        n_peaks=1,  # Consider 3 main peaks
        log_magnitude=True,  # Use log scale for better dynamic range
        divergence_type="mse",  # Use L2 divergence for spectral shape
        spectral_norm="l1",
    )
    # criterion2 = torch.nn.MSELoss()
    criterion3 = CombinedAttractorLoss(
        histogram_weight=0.8,
        moment_weight=0.2,
        histogram_kwargs=dict(distance_metric="l2"),
    )
    loss_fn = ExponentialDecayLoss(decay_rate=0.2)

    #

    main_criterion = CombinedLoss(
        losses=[criterion1, criterion3, loss_fn], weights=[0.6, 0.2, 0.2]
    )
    params = OscParams(R1=68e3 / 2, R3=6456)
    v_m, i_m = extract_vm_im(path, params)
    v_ms_std = np.std(v_m)
    i_ms_std = np.std(i_m)
    from initials import find_rss

    Ron, Roff, x = find_rss(v_m, i_m, 4)
    print(f"Initial Ron = {Ron}, Roff = {Roff}, x0 = {x}")
    osc = NeuralDiffEq(
        params,
        Ron=Ron,
        Roff=Roff,
        x0=x,
        activation=nn.SiLU(),
        v_m_std=v_ms_std,
        i_m_std=i_ms_std,
    )
    # osc = NeuralDiffEq(OscParams(R1=68e3), x0=0.2, activation=nn.SiLU())

    param_groups = [
        {"params": osc.ann.parameters(), "lr": 1e-3},
        # {"params": osc.ann2.parameters(), "lr": 1e-3},
        {"params": [osc.log_Roff, osc.log_Ron], "lr": 1e-4},
        {"params": [osc.x0], "lr": 1e-3},
    ]
    print(osc.parameters())
    optimizer = torch.optim.Adam(param_groups)

    # print(t)
    osc.train()
    data_loader = MemData(
        path,
        fs=1e2,
        dt=0.2,  # 1 / 7 / 2,
        # min_improvement=0.01,
    )

    # Define closure for L-BFGS

    def closure():
        optimizer.zero_grad()
        state = torch.cat([v[0, 0:3], osc.x0.unsqueeze(-1)])

        sol = odeint(osc, state, t, method="rk4")

        loss = main_criterion(sol[:, [0, 1, 2]], v)
        # loss += smoth_criterion(sol[:, [0, 1, 2]], v, criterion)
        loss += smooth_range_penalty(sol[:, -1], 0.0, 1.0, 0.1, 1.0)
        loss += smooth_range_penalty(osc.Ron, 0.0, osc.Roff, 0.2, 1.0)
        loss += smooth_range_penalty(osc.Roff, osc.Ron, 1e6, 0.2, 1.0)
        loss += smooth_range_penalty(osc.x0, 0.0, 1, 0.1, 1.0)
        loss.backward()
        return loss

    loss = 1e6
    for i in range(epoch):
        t, v = data_loader.get_data(
            loss=loss if not hasattr(loss, "item") else loss.item()
        )

        loss = closure()
        if np.isnan(loss.item()):
            print("NaN loss, stopping training.")
            break
        optimizer.step()
        print(f"iteration = {i}\tloss = {loss.item():.2g}")

        # optimizer.step()
        if i % 50 == 0:
            if i in iteration2save:
                torch.save(osc, model_folder / f"model_{i}.pth")

            print(f"Ron= {osc.Ron.item()}")
            print(f"Roff= {osc.Roff.item()}")
            print(f"loss = {loss.item()}")
            print(f"x0 = {osc.x0.item()}")

            # For plotting, we need to run the forward pass again
            with torch.no_grad():
                # t, v = data_loader.get_full_data()
                state = torch.cat([v[0, 0:3], osc.x0.unsqueeze(-1)])
                sol = odeint(osc, state, t, method="dopri5")

            fig, ax = plt.subplots(4, 1, sharex=True)
            # osc.export()
            for j in range(3):
                ax[j].plot(
                    t.cpu().detach().numpy(),
                    v.cpu().detach().numpy()[:, j],
                    label="true",
                )
                ax[j].plot(
                    t.cpu().detach().numpy(),
                    sol.cpu().detach().numpy()[:, j],
                    label="pred",
                )
                ax[j].legend(loc="upper right")
            ax[3].plot(
                t.cpu().detach().numpy(), sol.cpu().detach().numpy()[:, 3], label="X"
            )
            fig.savefig(f"neural_plot_tanh/{i}.png", bbox_inches="tight")
            plt.close("all")
            get_mutual_characteristic(
                sol.cpu().detach().numpy(), v.cpu().detach().numpy(), i
            )

    # Final plot
    with torch.no_grad():
        state = torch.cat([v[0, 0:3], osc.x0.unsqueeze(-1)])
        sol = odeint(osc, state, t, method="rk4")
        return main_criterion.get_vector(sol[:, [0, 1, 2]], v)


if __name__ == "__main__":
    set_seed(42)

    path = Path(
        "/Users/karol/Documents/Studia Doktoranckie/mineti_knowm_memristor/preprocessed_11_05/w6456.csv"
    )
    batch_size = 1
    model_folder = Path("torch_models_vi_all_features")
    model_folder.mkdir(exist_ok=True)
    # iteration2save = list(range(2400, 3500, 50)) + list(range(6000, 8500, 50))
    # iteration2save = list(
    #     range(6050, 6140, 10),
    # )  # list(range(100, 2000, 100)) + list(range(2000, 10000, 200))
    iteration2save = [6100]

    from multi_loss import CombinedLoss  # noqa
    from multi_loss import CombinedAttractorLoss, MultiscaleChaoticPeakLoss

    results = []
    weights = np.array([0.8, 0.1, 0.1])
    vec = train_within_weights(
        weights=weights,
        epoch=10_000,
        ratio=0.02,
    )
    results.append(vec)