training loop speedup

This commit is contained in:
Joseph Hopfmüller
2024-11-20 11:29:18 +01:00
parent 1622c38582
commit cdca5de473
11 changed files with 1026 additions and 151 deletions

View File

@@ -1,11 +1,28 @@
from pathlib import Path
import torch
from torch.utils.data import Dataset
# from torch.utils.data import Sampler
import numpy as np
import configparser
# class SubsetSampler(Sampler[int]):
# """
# Samples elements from a given list of indices.
def load_data(config_path, skipfirst=0, num_symbols=None):
# :param indices: List of indices to sample from.
# :type indices: list[int]
# """
# def __init__(self, indices):
# self.indices = indices
# def __iter__(self):
# return iter(self.indices)
# def __len__(self):
# return len(self.indices)
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
filepath = Path(config_path)
filepath = filepath.parent.glob(filepath.name)
config = configparser.ConfigParser()
@@ -18,15 +35,25 @@ def load_data(config_path, skipfirst=0, num_symbols=None):
datapath = Path("/".join(path_elements).replace('"', ""))
sps = int(config["glova"]["sps"])
if num_symbols is None:
num_symbols = int(config["glova"]["nos"]) - skipfirst
if symbols is None:
symbols = int(config["glova"]["nos"]) - skipfirst
data = np.load(datapath)[skipfirst * sps : num_symbols * sps + skipfirst * sps]
config["glova"]["nos"] = str(num_symbols)
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
if normalize:
a, b, c, d = data.T
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
data = np.array([a, b, c, d]).T
if real:
data = np.abs(data)
config["glova"]["nos"] = str(symbols)
data = torch.tensor(data, device=device, dtype=dtype)
return data, config
def roll_along(arr, shifts, dim):
# https://stackoverflow.com/a/76920720
# (c) Mateen Ulhaq, 2023
@@ -39,7 +66,6 @@ def roll_along(arr, shifts, dim):
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
return torch.gather(arr, dim, indices)
class FiberRegenerationDataset(Dataset):
"""
Dataset for fiber regeneration training.
@@ -76,6 +102,9 @@ class FiberRegenerationDataset(Dataset):
target_delay: float | int = 0,
xy_delay: float | int = 0,
drop_first: float | int = 0,
dtype: torch.dtype = None,
real: bool = False,
device = None,
**kwargs,
):
"""
@@ -123,13 +152,16 @@ class FiberRegenerationDataset(Dataset):
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
dtype=np.complex128,
)
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
self.config = {
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
"glova": {"sps": 128},
}
else:
data_raw, self.config = load_data(file_path)
data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype)
self.device = data_raw.device
self.samples_per_symbol = int(self.config["glova"]["sps"])
self.samples_per_slice = int(symbols * self.samples_per_symbol)
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
@@ -140,7 +172,6 @@ class FiberRegenerationDataset(Dataset):
ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None)
ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None)
ovrd_drop_first_samples = kwargs.pop("ovrd_drop_first_samples", None)
self.target_delay_samples = (
ovrd_target_delay_samples
@@ -152,14 +183,8 @@ class FiberRegenerationDataset(Dataset):
if ovrd_xy_delay_samples is not None
else int(self.xy_delay * self.samples_per_symbol)
)
drop_first_samples = (
ovrd_drop_first_samples
if ovrd_drop_first_samples is not None
else int(drop_first * self.samples_per_symbol)
)
# drop samples from the beginning
data_raw = data_raw[drop_first_samples:]
# data_raw = torch.tensor(data_raw, dtype=dtype)
# data layout
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
@@ -240,10 +265,10 @@ class FiberRegenerationDataset(Dataset):
data = data.view(data.shape[0], self.data_size, -1)
data = data[:, :, 0]
# target is corresponding to the latest data point -> try to regenerate that
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
target = target[:, : target.shape[1] // self.data_size * self.data_size]
target = target.view(target.shape[0], self.data_size, -1)
target = target[:, 0, 0]
target = target[:, 0, target.shape[2] // 2]
data = data.transpose(0, 1).flatten().squeeze()
target = target.flatten().squeeze()