diff --git a/src/single-core-regen/regen.py b/src/single-core-regen/regen.py new file mode 100644 index 0000000..c334e51 --- /dev/null +++ b/src/single-core-regen/regen.py @@ -0,0 +1,287 @@ +from dataclasses import dataclass +from datetime import datetime +import time +import matplotlib.pyplot as plt + +import numpy as np +import optuna +import warnings + +import torch +import torch.nn as nn +import torch.functional as F +import torch.optim as optim +import torch.utils.data + +import multiprocessing + +from util.datasets import FiberRegenerationDataset +import util + + +# global settings +@dataclass +class GlobalSettings: + seed: int = 42 + + +# data settings +@dataclass +class DataSettings: + config_path: str = "data/*-128-16384-10000-0-0-17-0-PAM4-0.ini" + symbols_range: tuple = (1, 100) + data_size_range: tuple = (1, 20) + target_delay: float = 0 + xy_delay_range: tuple = (0, 1) + drop_first: int = 1000 + train_split: float = 0.8 + + +# pytorch settings +@dataclass +class PytorchSettings: + device: str = "cuda" + batchsize: int = 128 + epochs: int = 100 + + +# model settings +@dataclass +class ModelSettings: + output_size: int = 2 + n_layer_range: tuple = (1, 3) + n_units_range: tuple = (4, 128) + activation_func_range: tuple = ("ReLU",) + + +@dataclass +class OptimizerSettings: + optimizer_range: tuple = ("Adam", "RMSprop", "SGD") + lr_range: tuple = (1e-5, 1e-1) + + +# optuna settings +@dataclass +class OptunaSettings: + n_trials: int = 128 + n_threads: int = 16 + timeout: int = 600 + directions: tuple = ("maximize",) + + limit_examples: bool = True + n_train_examples: int = PytorchSettings.batchsize * 30 + n_valid_examples: int = PytorchSettings.batchsize * 10 + storage: str = "sqlite:///optuna_single_core_regen.db" + study_name: str = ( + f"single_core_regen_{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" + ) + metrics_names: tuple = ("accuracy",) + +class HyperTraining: + def __init__(self): + self.global_settings = GlobalSettings() + self.data_settings = DataSettings() + self.pytorch_settings = PytorchSettings() + self.model_settings = ModelSettings() + self.optimizer_settings = OptimizerSettings() + self.optuna_settings = OptunaSettings() + + # set some extra settings to make the code more readable + self._extra_optuna_settings() + + def setup_study(self): + self.study = optuna.create_study( + study_name=self.optuna_settings.study_name, + storage=self.optuna_settings.storage, + load_if_exists=True, + direction=self.optuna_settings.direction, + directions=self.optuna_settings.directions, + ) + + with warnings.catch_warnings(action="ignore"): + self.study.set_metric_names(self.optuna_settings.metrics_names) + + self.n_threads = min(self.optuna_settings.n_trials, self.optuna_settings.n_threads) + self.processes = [] + for _ in range(self.n_threads): + p = multiprocessing.Process(target=self._run_optimize) + self.processes.append(p) + + def run_study(self): + for p in self.processes: + p.start() + + for p in self.processes: + p.join() + + remaining_trials = self.optuna_settings.n_trials - self.optuna_settings.n_trials % self.optuna_settings.n_threads + if remaining_trials: + self._run_optimize(remaining_trials) + + def _run_optimize(self, n_trials): + self.study.optimize(self.objective, n_trials=n_trials, timeout=self.optuna_settings.timeout) + + def eye(self, show=True): + util.plot.eye(self.data_settings.config_path, show=show) + + def _extra_optuna_settings(self): + self.optuna_settings.multi_objective = len(self.optuna_settings.directions) > 1 + if self.optuna_settings.multi_objective: + self.optuna_settings.direction = None + else: + self.optuna_settings.direction = self.optuna_settings.directions[0] + self.optuna_settings.directions = None + + self.optuna_settings.n_train_examples = ( + self.optuna_settings.n_train_examples + if self.optuna_settings.limit_examples + else float("inf") + ) + self.optuna_settings.n_valid_examples = ( + self.optuna_settings.n_valid_examples + if self.optuna_settings.limit_examples + else float("inf") + ) + + def define_model(self, trial: optuna.Trial): + n_layers = trial.suggest_int( + "model_n_layers", *self.model_settings.n_layer_range + ) + layers = [] + + # REVIEW does that work? + in_features = trial.params["dataset_data_size"] * 2 + for i in range(n_layers): + out_features = trial.suggest_int( + f"model_n_units_l{i}", *self.model_settings.n_units_range + ) + activation_func = trial.suggest_categorical( + f"model_activation_func_l{i}", self.model_settings.activation_func_range + ) + + layers.append(nn.Linear(in_features, out_features)) + layers.append(getattr(nn, activation_func)) + in_features = out_features + + layers.append(nn.Linear(in_features, self.model_settings.output_size)) + + return nn.Sequential(*layers) + + def get_sliced_data(self, trial: optuna.Trial): + assert ModelSettings.input_size % 2 == 0, "input_dim must be even" + symbols = trial.suggest_float( + "dataset_symbols", *self.data_settings.symbols_range, log=True + ) + xy_delay = trial.suggest_float( + "dataset_xy_delay", *self.data_settings.xy_delay_range + ) + data_size = trial.suggest_int( + "dataset_data_size", *self.data_settings.data_size_range + ) + # get dataset + dataset = FiberRegenerationDataset( + file_path=self.data_settings.config_path, + symbols=symbols, + data_size=data_size, # two channels (x,y) + target_delay=self.data_settings.target_delay, + xy_delay=xy_delay, + drop_first=self.data_settings.drop_first, + ) + + dataset_size = len(dataset) + indices = list(range(dataset_size)) + split = int(np.floor(self.data_settings.train_split * dataset_size)) + np.random.seed(self.global_settings.seed) + np.random.shuffle(indices) + train_indices, valid_indices = indices[:split], indices[split:] + + train_sampler = torch.utils.data.SubsetRandomSampler(train_indices) + valid_sampler = torch.utils.data.SubsetRandomSampler(valid_indices) + + train_loader = torch.utils.data.DataLoader( + dataset, batch_size=self.pytorch_settings.batchsize, sampler=train_sampler + ) + valid_loader = torch.utils.data.DataLoader( + dataset, batch_size=self.pytorch_settings.batchsize, sampler=valid_sampler + ) + + return train_loader, valid_loader + + + def train_model(self, model, optimizer, train_loader): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + if (batch_idx * train_loader.batchsize + >= self.optuna_settings.n_train_examples): + break + optimizer.zero_grad() + data, target = ( + data.to(self.pytorch_settings.device), + target.to(self.pytorch_settings.device), + ) + target_pred = model(data) + loss = F.mean_squared_error(target_pred, target) + loss.backward() + optimizer.step() + + + def eval_model(self, model, valid_loader): + model.eval() + correct = 0 + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(valid_loader): + if ( + batch_idx * valid_loader.batchsize + >= self.optuna_settings.n_valid_examples + ): + break + data, target = ( + data.to(self.pytorch_settings.device), + target.to(self.pytorch_settings.device), + ) + target_pred = model(data) + pred = target_pred.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + accuracy = correct / len(valid_loader.dataset) + # num_params = sum(p.numel() for p in model.parameters()) + return accuracy + + def objective(self, trial: optuna.Trial): + model = self.define_model(trial).to(self.pytorch_settings.device) + + optimizer_name = trial.suggest_categorical( + "optimizer", self.optimizer_settings.optimizer_range + ) + lr = trial.suggest_float("lr", *self.optimizer_settings.lr_range, log=True) + optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) + + train_loader, valid_loader = self.get_sliced_data(trial) + + for epoch in range(self.pytorch_settings.epochs): + self.train_model(model, optimizer, train_loader) + accuracy = self.eval_model(model, valid_loader) + + if len(self.optuna_settings.directions) == 1: + trial.report(accuracy, epoch) + if trial.should_prune(): + raise optuna.exceptions.TrialPruned() + + return accuracy + + +if __name__ == "__main__": + # plt.ion() + hyper_training = HyperTraining() + hyper_training.eye() + + # hyper_training.setup_study() + # hyper_training.run_study() + for i in range(10): + #simulate some work + print(i) + time.sleep(0.2) + + plt.show() + ... + diff --git a/src/single-core-regen/testing/learn_optuna.py b/src/single-core-regen/testing/learn_optuna.py new file mode 100644 index 0000000..b700cc0 --- /dev/null +++ b/src/single-core-regen/testing/learn_optuna.py @@ -0,0 +1,194 @@ +from datetime import datetime +from pathlib import Path + +import optuna +import warnings +from util.optuna_vis import show_figures + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data + +from torchvision import datasets +from torchvision import transforms +import multiprocessing + +# from util.dataset import SlicedDataset + +DEVICE = torch.device("cuda") +BATCHSIZE = 128 +CLASSES = 10 +DIR = Path(__file__).parent +EPOCHS = 100 +N_TRAIN_EXAMPLES = BATCHSIZE * 30 +N_VALID_EXAMPLES = BATCHSIZE * 10 + +n_trials = 128 +n_threads = 16 + + +def define_model(trial): + n_layers = trial.suggest_int("n_layers", 1, 3) + layers = [] + + in_features = 28 * 28 + for i in range(n_layers): + out_features = trial.suggest_int(f"n_units_l{i}", 4, 128) + layers.append(nn.Linear(in_features, out_features)) + layers.append(nn.ReLU()) + p = trial.suggest_float(f"dropout_l{i}", 0.2, 0.5) + layers.append(nn.Dropout(p)) + + in_features = out_features + + layers.append(nn.Linear(in_features, CLASSES)) + layers.append(nn.LogSoftmax(dim=1)) + + return nn.Sequential(*layers) + + +def get_mnist(): + # Load FashionMNIST dataset. + train_loader = torch.utils.data.DataLoader( + datasets.FashionMNIST( + DIR / ".data", train=True, download=True, transform=transforms.ToTensor() + ), + batch_size=BATCHSIZE, + shuffle=True, + ) + valid_loader = torch.utils.data.DataLoader( + datasets.FashionMNIST( + DIR / ".data", train=False, transform=transforms.ToTensor() + ), + batch_size=BATCHSIZE, + shuffle=True, + ) + + return train_loader, valid_loader + + +def objective(trial): + model = define_model(trial).to(DEVICE) + + optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"]) + lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True) + optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) + + train_loader, valid_loader = get_mnist() + + for epoch in range(EPOCHS): + train_model(model, optimizer, train_loader) + accuracy, num_params = eval_model(model, valid_loader) + + return accuracy, num_params + + +def eval_model(model, valid_loader): + model.eval() + correct = 0 + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(valid_loader): + if batch_idx * BATCHSIZE >= N_VALID_EXAMPLES: + break + + data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE) + output = model(data) + + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + + accuracy = correct / min(len(valid_loader.dataset), N_VALID_EXAMPLES) + num_params = sum(p.numel() for p in model.parameters()) + return accuracy, num_params + + +def train_model(model, optimizer, train_loader): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + if batch_idx * BATCHSIZE >= N_TRAIN_EXAMPLES: + break + + data, target = data.view(data.size(0), -1).to(DEVICE), target.to(DEVICE) + + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + +def run_optimize(n_trials, study): + study.optimize(objective, n_trials=n_trials, timeout=600) + + +if __name__ == "__main__": + study_name = f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} mnist example" + storage = "sqlite:///db.sqlite3" + directions = ["maximize", "minimize"] + + study = optuna.create_study( + directions=directions, + storage=storage, + study_name=study_name, + ) + + with warnings.catch_warnings(action="ignore"): + study.set_metric_names(["accuracy", "num params"]) + + n_threads = min(n_trials, n_threads) + + processes = [] + for _ in range(n_threads): + p = multiprocessing.Process( + target=run_optimize, args=(n_trials // n_threads, study) + ) + p.start() + processes.append(p) + + for p in processes: + p.join() + + remaining_trials = n_trials - ((n_trials // n_threads) * n_threads) + if remaining_trials: + print( + f"\nRunning last {remaining_trials} trial{'s' if remaining_trials > 1 else ''}:" + ) + run_optimize(directions, remaining_trials, study_name, storage) + + print(f"Number of trials on the Pareto front: {len(study.best_trials)}") + + trial_with_highest_accuracy = max(study.best_trials, key=lambda t: t.values[1]) + print("Trial with highest accuracy: ") + print(f"\tnumber: {trial_with_highest_accuracy.number}") + print(f"\tparams: {trial_with_highest_accuracy.params}") + print(f"\tvalues: {trial_with_highest_accuracy.values}") + + # for trial in trials: + # print(f"Trial {trial.number}") + # print(f" Accuracy: {trial.values[0]}") + # print(f" n_params: {int(trial.values[1])}") + # print( " Params: ") + # for key, value in trial.params.items(): + # print(" {}: {}".format(key, value)) + # print() + + # print(" Value: ", trial.value) + + # print(" Params: ") + # for key, value in trial.params.items(): + # print(" {}: {}".format(key, value)) + + figures = [] + figures.append( + optuna.visualization.plot_pareto_front( + study, target_names=["accuracy", "num_params"] + ) + ) + figures.append(optuna.visualization.plot_timeline(study)) + + plt = show_figures(*figures) + + print() + # plt.show() diff --git a/src/single-core-regen/util/__init__.py b/src/single-core-regen/util/__init__.py new file mode 100644 index 0000000..9651fca --- /dev/null +++ b/src/single-core-regen/util/__init__.py @@ -0,0 +1,3 @@ +from .datasets import FiberRegenerationDataset # noqa: F401 +from .optuna_vis import show_figures # noqa: F401 +from .plot import eye # noqa: F401 \ No newline at end of file diff --git a/src/single-core-regen/util/dataset.py b/src/single-core-regen/util/dataset.py deleted file mode 100644 index 9537b80..0000000 --- a/src/single-core-regen/util/dataset.py +++ /dev/null @@ -1,53 +0,0 @@ -from pathlib import Path -import torch -from torch.utils.data import Dataset -import numpy as np -import configparser - -class SlicedDataset(Dataset): - def __init__(self, config_path, symbols, drop_first=0): - """ - Initialize the dataset. - - :param config_path: Path to the configuration file - :type config_path: str - :param out_size: Output size in symbols - :type out_size: int - :param reduce: Reduce the dataset size by taking every reduce-th sample - :type reduce: int - """ - - self.config = configparser.ConfigParser() - self.config.read(Path(config_path)) - - self.data_path = (Path(self.config['data']['dir'].strip('"')) / (self.config['data']['npy_dir'].strip('"')) / self.config['data']['file'].strip('"')) - - self.symbols_per_slice = symbols - self.samples_per_symbol = int(self.config['glova']['sps']) - self.samples_per_slice = self.symbols_per_slice * self.samples_per_symbol - - data_raw = torch.tensor(np.load(self.data_path))[drop_first*self.samples_per_symbol:] - data_raw = data_raw.transpose(0,1) - data_raw = data_raw.view(2,2,-1) - # [no_samples, 4] -> [4, no_samples] -> [2, 2, no_samples] - - self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) - self.data = self.data.movedim(-2, 0) - # -> [no_slices, 2, 2, samples_per_slice] - ... - - - def __len__(self): - return self.data.shape[0] - - def __getitem__(self, idx): - if isinstance(idx, slice): - return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] - else: - return (self.data[idx,1].squeeze(), self.data[idx,0].squeeze()) - - -if __name__ == "__main__": - - - pass \ No newline at end of file diff --git a/src/single-core-regen/util/datasets.py b/src/single-core-regen/util/datasets.py new file mode 100644 index 0000000..7bbdec5 --- /dev/null +++ b/src/single-core-regen/util/datasets.py @@ -0,0 +1,256 @@ +from pathlib import Path +import torch +from torch.utils.data import Dataset +import numpy as np +import configparser + + +def load_data(config_path, skipfirst=0, num_symbols=None): + filepath = Path(config_path) + filepath = filepath.parent.glob(filepath.name) + config = configparser.ConfigParser() + config.read(filepath) + path_elements = ( + config["data"]["dir"], + config["data"]["npy_dir"], + config["data"]["file"], + ) + datapath = Path("/".join(path_elements).replace('"', "")) + sps = int(config["glova"]["sps"]) + + if num_symbols is None: + num_symbols = int(config["glova"]["nos"]) - skipfirst + + data = np.load(datapath)[skipfirst * sps : num_symbols * sps + skipfirst * sps] + config["glova"]["nos"] = str(num_symbols) + + return data, config + + +def roll_along(arr, shifts, dim): + # https://stackoverflow.com/a/76920720 + # (c) Mateen Ulhaq, 2023 + # CC BY-SA 4.0 + shifts = torch.tensor(shifts) + assert arr.ndim - 1 == shifts.ndim + dim %= arr.ndim + shape = (1,) * dim + (-1,) + (1,) * (arr.ndim - dim - 1) + dim_indices = torch.arange(arr.shape[dim]).reshape(shape) + indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim] + return torch.gather(arr, dim, indices) + + +class FiberRegenerationDataset(Dataset): + """ + Dataset for fiber regeneration training. + + The dataset is loaded from a configuration file, which must contain (at least) the following sections: + ``` + [data] + dir = + npy_dir = + file = + + [glova] + sps = + ``` + The data is loaded from the file `//` and is assumed to be in the following format: + ``` + [ E_in_x, + E_in_y, + E_out_x, + E_out_y ] + ``` + + The dataset is sliced into slices, where each slice consists of a (fractional) number of symbols. + The target can be delayed relative to the input data by a (fractional) number of symbols. + The x and y channels can be delayed relative to each other by a (fractional) number of symbols. + """ + + def __init__( + self, + file_path: str | Path, + symbols: int | float, + *, + data_size: int = None, + target_delay: float | int = 0, + xy_delay: float | int = 0, + drop_first: float | int = 0, + **kwargs, + ): + """ + Initialize the dataset. + + :param file_path: Path to the data file. Can contain wildcards (*). The first + :type file_path: str | pathlib.Path + :param symbols: Number of symbols in each slice. Can be a float to specify a fraction of a symbol. + :type symbols: float | int + :param data_size: Number of samples in each slice. The data is reduced by taking equally spaced samples. If unset, each slice will contain symbols*samples_per_symbol samples. + :type data_size: int, optional + :param target_delay: Delay (in fractional symbols) between data and target. A positive delay means the target is delayed relative to the data. Default is 0. + :type target_delay: float | int, optional + :param xy_delay: Delay (in fractional symbols) between the x and y channels. A positive delay means the y channel is delayed relative to the x channel. Default is 0. + :type xy_delay: float | int, optional + :param drop_first: Number of (fractional) symbols to drop from the beginning + :type drop_first: float | int + """ + + # check types + assert isinstance(file_path, str), "file_path must be a string" + assert isinstance(symbols, (float, int)), ( + "symbols must be a float or an integer" + ) + assert data_size is None or isinstance(data_size, int), ( + "output_len must be an integer" + ) + assert isinstance(target_delay, (float, int)), ( + "target_delay must be a float or an integer" + ) + assert isinstance(xy_delay, (float, int)), ( + "xy_delay must be a float or an integer" + ) + assert isinstance(drop_first, int), "drop_first must be an integer" + + # check values + assert symbols > 0, "symbols must be positive" + assert data_size is None or data_size > 0, "output_len must be positive or None" + assert drop_first >= 0, "drop_first must be non-negative" + + faux = kwargs.pop("faux", False) + + if faux: + data_raw = np.array( + [[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)], + dtype=np.complex128, + ) + self.config = { + "data": {"dir": '"."', "npy_dir": '"."', "file": "faux"}, + "glova": {"sps": 128}, + } + else: + data_raw, self.config = load_data(file_path) + + self.samples_per_symbol = int(self.config["glova"]["sps"]) + self.samples_per_slice = int(symbols * self.samples_per_symbol) + self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol + + self.data_size = data_size or self.samples_per_slice + self.target_delay = target_delay or 0 + self.xy_delay = xy_delay or 0 + + ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None) + ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None) + ovrd_drop_first_samples = kwargs.pop("ovrd_drop_first_samples", None) + + self.target_delay_samples = ( + ovrd_target_delay_samples + if ovrd_target_delay_samples is not None + else int(self.target_delay * self.samples_per_symbol) + ) + self.xy_delay_samples = ( + ovrd_xy_delay_samples + if ovrd_xy_delay_samples is not None + else int(self.xy_delay * self.samples_per_symbol) + ) + drop_first_samples = ( + ovrd_drop_first_samples + if ovrd_drop_first_samples is not None + else int(drop_first * self.samples_per_symbol) + ) + + # drop samples from the beginning + data_raw = data_raw[drop_first_samples:] + + # data layout + # [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0], + # [E_in_x1, E_in_y1, E_out_x1, E_out_y1], + # ... + # [E_in_xN, E_in_yN, E_out_xN, E_out_yN] ] + + data_raw = data_raw.transpose(0, 1) + + # data layout + # [ E_in_x[0:N], + # E_in_y[0:N], + # E_out_x[0:N], + # E_out_y[0:N] ] + + # shift x data by xy_delay_samples relative to the y data (example value: 3) + # [ E_in_x [0:N], [ E_in_x [ 0:N ], [ E_in_x [3:N ], + # E_in_y [0:N], -> E_in_y [-3:N-3], -> E_in_y [0:N-3], + # E_out_x[0:N], E_out_x[ 0:N ], E_out_x[3:N ], + # E_out_y[0:N] ] E_out_y[-3:N-3] ] E_out_y[0:N-3] ] + + if self.xy_delay_samples != 0: + data_raw = roll_along( + data_raw, [0, self.xy_delay_samples, 0, self.xy_delay_samples], dim=1 + ) + if self.xy_delay_samples > 0: + data_raw = data_raw[:, self.xy_delay_samples :] + elif self.xy_delay_samples < 0: + data_raw = data_raw[:, : self.xy_delay_samples] + + # shift fiber input data (target) by target_delay_samples relative to the fiber output data (input) + # (example value: 5) + # [ E_in_x [0:N], [ E_in_x [-5:N-5], [ E_in_x [0:N-5], + # E_in_y [0:N], -> E_in_y [-5:N-5], -> E_in_y [0:N-5], + # E_out_x[0:N], E_out_x[ 0:N ], E_out_x[5:N ], + # E_out_y[0:N] ] E_out_y[ 0:N ] ] E_out_y[5:N ] ] + + if self.target_delay_samples != 0: + data_raw = roll_along( + data_raw, + [self.target_delay_samples, self.target_delay_samples, 0, 0], + dim=1, + ) + if self.target_delay_samples > 0: + data_raw = data_raw[:, self.target_delay_samples :] + elif self.target_delay_samples < 0: + data_raw = data_raw[:, : self.target_delay_samples] + + data_raw = data_raw.view(2, 2, -1) + # data layout + # [ [E_in_x, E_in_y], + # [E_out_x, E_out_y] ] + + self.data = data_raw.unfold(dimension=-1, size=self.samples_per_slice, step=1) + self.data = self.data.movedim(-2, 0) + # -> [no_slices, 2, 2, samples_per_slice] + + # data layout + # [ + # [ [E_in_x[0:N+0], E_in_y[0:N+0] ], [ E_out_x[0:N+0], E_out_y[0:N+0] ] ], + # [ [E_in_x[1:N+1], E_in_y[1:N+1] ], [ E_out_x[1:N+1], E_out_y[1:N+1] ] ], + # ... + # ] -> [no_slices, 2, 2, samples_per_slice] + + ... + + def __len__(self): + return self.data.shape[0] + + def __getitem__(self, idx): + if isinstance(idx, slice): + return [self.__getitem__(i) for i in range(*idx.indices(len(self)))] + else: + data, target = self.data[idx, 1].squeeze(), self.data[idx, 0].squeeze() + + # reduce by by taking self.output_dim equally spaced samples + data = data[:, : data.shape[1] // self.data_size * self.data_size] + data = data.view(data.shape[0], self.data_size, -1) + data = data[:, :, 0] + + # target is corresponding to the latest data point -> try to regenerate that + target = target[:, : target.shape[1] // self.data_size * self.data_size] + target = target.view(target.shape[0], self.data_size, -1) + target = target[:, 0, 0] + + data = data.transpose(0, 1).flatten().squeeze() + target = target.flatten().squeeze() + + # data layout: + # [sample_x0, sample_y0, sample_x1, sample_y1, ...] + # target layout: + # [sample_x0, sample_y0] + + return data, target diff --git a/src/single-core-regen/util/plot.py b/src/single-core-regen/util/plot.py new file mode 100644 index 0000000..8a3a196 --- /dev/null +++ b/src/single-core-regen/util/plot.py @@ -0,0 +1,34 @@ +import matplotlib.pyplot as plt +import numpy as np +from .datasets import load_data + +def eye(path, title=None, head=1000, skipfirst=1000, show=True): + """Plot an eye diagram for the data given by filepath. + + Args: + path (str): Path to the data description file. + title (str): Title of the plot. + head (int): Number of symbols to plot. + skipfirst (int): Number of symbols to skip. + show (bool): Whether to call plt.show(). + """ + data, config = load_data(path, skipfirst, head) + sps = int(config["glova"]["sps"]) + + xaxis = np.linspace(0, 2, 2*sps, endpoint=False) + fig, axs = plt.subplots(2, 2, figsize=(10, 10), sharex=True, sharey=True) + for i in range(head-1): + inx, iny, outx, outy = data[i*sps:(i+2)*sps].T + axs[0, 0].plot(xaxis, np.abs(inx)**2, color="C0", alpha=0.1) + axs[0, 1].plot(xaxis, np.abs(outx)**2, color="C0", alpha=0.1) + axs[1, 0].plot(xaxis, np.abs(iny)**2, color="C0", alpha=0.1) + axs[1, 1].plot(xaxis, np.abs(outy)**2, color="C0", alpha=0.1) + + axs[0, 0].set_title("Input x") + axs[0, 1].set_title("Output x") + axs[1, 0].set_title("Input y") + axs[1, 1].set_title("Output y") + fig.suptitle(title) + + if show: + plt.show(block=False)