add SlicedDataset class and utility scripts; refactor: remove _path_fix.py and update imports;

This commit is contained in:
Joseph Hopfmüller
2024-11-17 01:04:33 +01:00
parent 90aa6dbaf8
commit 87f40fc37c
7 changed files with 172 additions and 11 deletions

View File

@@ -0,0 +1,51 @@
# move into dir single-core-regen before running
from util.dataset import SlicedDataset
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import numpy as np
def eye_dataset(dataset, no_symbols=None, offset=False, show=True):
if no_symbols is None:
no_symbols = len(dataset)
_, axs = plt.subplots(2,2, sharex=True, sharey=True)
xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice)
roll = dataset.samples_per_symbol//2 if offset else 0
for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]:
E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0')
axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0')
axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0')
axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0')
if show:
plt.show()
# def plt_dataloader(dataloader, show=True):
# _, axs = plt.subplots(2,2, sharex=True, sharey=True)
# E_outs, E_ins = next(iter(dataloader))
# for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)):
# xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice)
# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1]
# axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2)
# axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2)
# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2)
# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2)
# if show:
# plt.show()
if __name__ == "__main__":
dataset = SlicedDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=1, drop_first=100)
print(dataset[0][0].shape)
eye_dataset(dataset, 1000, offset=True, show=False)
train_loader = DataLoader(dataset, batch_size=10, shuffle=False)
# plt_dataloader(train_loader, show=False)
plt.show()

View File

@@ -0,0 +1,68 @@
import torch
import time
def print_torch_env():
print("Torch version: ", torch.__version__)
print("CUDA available: ", torch.cuda.is_available())
print("CUDA version: ", torch.version.cuda)
print("CUDNN version: ", torch.backends.cudnn.version())
print("Device count: ", torch.cuda.device_count())
print("Current device: ", torch.cuda.current_device())
print("Device name: ", torch.cuda.get_device_name(0))
print("Device capability: ", torch.cuda.get_device_capability(0))
print("Device memory: ", torch.cuda.get_device_properties(0).total_memory)
def measure_runtime(func):
"""
Measure the runtime of a function.
:param func: Function to measure
:type func: function
:return: Wrapped function with runtime measurement
:rtype: function
"""
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"Runtime: {end_time - start_time:.6f} seconds")
return result, end_time - start_time
return wrapper
@measure_runtime
def tensor_addition(a, b):
"""
Perform tensor addition.
:param a: First tensor
:type a: torch.Tensor
:param b: Second tensor
:type b: torch.Tensor
:return: Sum of tensors
:rtype: torch.Tensor
"""
return a + b
def runtime_test():
x = torch.rand(2**18, 2**10)
y = torch.rand(2**18, 2**10)
print("Tensor addition on CPU")
_, cpu_time = tensor_addition(x, y)
print()
print("Tensor addition on GPU")
if not torch.cuda.is_available():
print("CUDA is not available")
return
_, gpu_time = tensor_addition(x.cuda(), y.cuda())
print()
print(f"Speedup: {cpu_time / gpu_time *100:.2f}%")
if __name__ == "__main__":
print_torch_env()
print()
runtime_test()