training loop speedup
This commit is contained in:
@@ -1,11 +1,28 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
# from torch.utils.data import Sampler
|
||||
import numpy as np
|
||||
import configparser
|
||||
|
||||
# class SubsetSampler(Sampler[int]):
|
||||
# """
|
||||
# Samples elements from a given list of indices.
|
||||
|
||||
def load_data(config_path, skipfirst=0, num_symbols=None):
|
||||
# :param indices: List of indices to sample from.
|
||||
# :type indices: list[int]
|
||||
# """
|
||||
|
||||
# def __init__(self, indices):
|
||||
# self.indices = indices
|
||||
|
||||
# def __iter__(self):
|
||||
# return iter(self.indices)
|
||||
|
||||
# def __len__(self):
|
||||
# return len(self.indices)
|
||||
|
||||
def load_data(config_path, skipfirst=0, symbols=None, real=False, normalize=False, device=None, dtype=None):
|
||||
filepath = Path(config_path)
|
||||
filepath = filepath.parent.glob(filepath.name)
|
||||
config = configparser.ConfigParser()
|
||||
@@ -18,15 +35,25 @@ def load_data(config_path, skipfirst=0, num_symbols=None):
|
||||
datapath = Path("/".join(path_elements).replace('"', ""))
|
||||
sps = int(config["glova"]["sps"])
|
||||
|
||||
if num_symbols is None:
|
||||
num_symbols = int(config["glova"]["nos"]) - skipfirst
|
||||
if symbols is None:
|
||||
symbols = int(config["glova"]["nos"]) - skipfirst
|
||||
|
||||
data = np.load(datapath)[skipfirst * sps : num_symbols * sps + skipfirst * sps]
|
||||
config["glova"]["nos"] = str(num_symbols)
|
||||
data = np.load(datapath)[skipfirst * sps : symbols * sps + skipfirst * sps]
|
||||
|
||||
if normalize:
|
||||
a, b, c, d = data.T
|
||||
a, b, c, d = a/np.max(np.abs(a)), b/np.max(np.abs(b)), c/np.max(np.abs(c)), d/np.max(np.abs(d))
|
||||
data = np.array([a, b, c, d]).T
|
||||
|
||||
if real:
|
||||
data = np.abs(data)
|
||||
|
||||
config["glova"]["nos"] = str(symbols)
|
||||
|
||||
data = torch.tensor(data, device=device, dtype=dtype)
|
||||
|
||||
return data, config
|
||||
|
||||
|
||||
def roll_along(arr, shifts, dim):
|
||||
# https://stackoverflow.com/a/76920720
|
||||
# (c) Mateen Ulhaq, 2023
|
||||
@@ -39,7 +66,6 @@ def roll_along(arr, shifts, dim):
|
||||
indices = (dim_indices - shifts.unsqueeze(dim)) % arr.shape[dim]
|
||||
return torch.gather(arr, dim, indices)
|
||||
|
||||
|
||||
class FiberRegenerationDataset(Dataset):
|
||||
"""
|
||||
Dataset for fiber regeneration training.
|
||||
@@ -76,6 +102,9 @@ class FiberRegenerationDataset(Dataset):
|
||||
target_delay: float | int = 0,
|
||||
xy_delay: float | int = 0,
|
||||
drop_first: float | int = 0,
|
||||
dtype: torch.dtype = None,
|
||||
real: bool = False,
|
||||
device = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -123,13 +152,16 @@ class FiberRegenerationDataset(Dataset):
|
||||
[[i + 0.1j, i + 0.2j, i + 1.1j, i + 1.2j] for i in range(12800)],
|
||||
dtype=np.complex128,
|
||||
)
|
||||
data_raw = torch.tensor(data_raw, device=device, dtype=dtype)
|
||||
self.config = {
|
||||
"data": {"dir": '"."', "npy_dir": '"."', "file": "faux"},
|
||||
"glova": {"sps": 128},
|
||||
}
|
||||
else:
|
||||
data_raw, self.config = load_data(file_path)
|
||||
data_raw, self.config = load_data(file_path, skipfirst=drop_first, real=real, normalize=True, device=device, dtype=dtype)
|
||||
|
||||
self.device = data_raw.device
|
||||
|
||||
self.samples_per_symbol = int(self.config["glova"]["sps"])
|
||||
self.samples_per_slice = int(symbols * self.samples_per_symbol)
|
||||
self.symbols_per_slice = self.samples_per_slice / self.samples_per_symbol
|
||||
@@ -140,7 +172,6 @@ class FiberRegenerationDataset(Dataset):
|
||||
|
||||
ovrd_target_delay_samples = kwargs.pop("ovrd_target_delay_samples", None)
|
||||
ovrd_xy_delay_samples = kwargs.pop("ovrd_xy_delay_samples", None)
|
||||
ovrd_drop_first_samples = kwargs.pop("ovrd_drop_first_samples", None)
|
||||
|
||||
self.target_delay_samples = (
|
||||
ovrd_target_delay_samples
|
||||
@@ -152,14 +183,8 @@ class FiberRegenerationDataset(Dataset):
|
||||
if ovrd_xy_delay_samples is not None
|
||||
else int(self.xy_delay * self.samples_per_symbol)
|
||||
)
|
||||
drop_first_samples = (
|
||||
ovrd_drop_first_samples
|
||||
if ovrd_drop_first_samples is not None
|
||||
else int(drop_first * self.samples_per_symbol)
|
||||
)
|
||||
|
||||
# drop samples from the beginning
|
||||
data_raw = data_raw[drop_first_samples:]
|
||||
# data_raw = torch.tensor(data_raw, dtype=dtype)
|
||||
|
||||
# data layout
|
||||
# [ [E_in_x0, E_in_y0, E_out_x0, E_out_y0],
|
||||
@@ -240,10 +265,10 @@ class FiberRegenerationDataset(Dataset):
|
||||
data = data.view(data.shape[0], self.data_size, -1)
|
||||
data = data[:, :, 0]
|
||||
|
||||
# target is corresponding to the latest data point -> try to regenerate that
|
||||
# target is corresponding to the middle of the data as the output sample is influenced by the data before and after it
|
||||
target = target[:, : target.shape[1] // self.data_size * self.data_size]
|
||||
target = target.view(target.shape[0], self.data_size, -1)
|
||||
target = target[:, 0, 0]
|
||||
target = target[:, 0, target.shape[2] // 2]
|
||||
|
||||
data = data.transpose(0, 1).flatten().squeeze()
|
||||
target = target.flatten().squeeze()
|
||||
|
||||
Reference in New Issue
Block a user