diff --git a/src/single-core-regen/sliced_dataset_test.py b/src/single-core-regen/sliced_dataset_test.py new file mode 100644 index 0000000..284885f --- /dev/null +++ b/src/single-core-regen/sliced_dataset_test.py @@ -0,0 +1,88 @@ +# move into dir single-core-regen before running + +from util.datasets import FiberRegenerationDataset +from torch.utils.data import DataLoader +from matplotlib import pyplot as plt +import numpy as np + +# def eye_dataset(dataset, no_symbols=None, offset=False, show=True): +# if no_symbols is None: +# no_symbols = len(dataset) +# _, axs = plt.subplots(2,2, sharex=True, sharey=True) + +# xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice) +# roll = dataset.samples_per_symbol//2 if offset else 0 +# for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]: +# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1] +# axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0') +# axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0') +# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0') +# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0') + +# if show: +# plt.show() + +# # def plt_dataloader(dataloader, show=True): +# # _, axs = plt.subplots(2,2, sharex=True, sharey=True) + +# # E_outs, E_ins = next(iter(dataloader)) +# # for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)): +# # xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice) +# # E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1] +# # axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2) +# # axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2) +# # axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2) +# # axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2) + +# # if show: +# # plt.show() + +if __name__ == "__main__": + + dataset = FiberRegenerationDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=13, drop_first=100, output_dim=26, num_symbols=100) + + loader = DataLoader(dataset, batch_size=10, shuffle=True) + + x = [] + y_fiber_in = [] + y_fiber_out = [] + + for i, batch in enumerate(loader): + # if i > 128: + # break + + fiber_in, fiber_out, timestamp = batch + + fiber_out = fiber_out.reshape(fiber_out.shape[0], -1, 2) + fiber_out = fiber_out[:,fiber_out.shape[1]//2, :] + + # input_data = input_data.reshape(-1,2) + # target = target.reshape(-1,2).squeeze() + # timestamp = timestamp.reshape(-1,1).squeeze() + + x.append(timestamp.detach().numpy()) + y_fiber_in.append(fiber_in.abs().square().detach().numpy()) + y_fiber_out.append(fiber_out.abs().square().detach().numpy()) + + x = np.concat(x) + y_fiber_in = np.concat(y_fiber_in) + y_fiber_out = np.concat(y_fiber_out) + + # order = np.argsort(x) + # x = x[order] + # y = y[order] + + fig, axs = plt.subplots(2,2, sharex=True, sharey=True) + axs[0,0].scatter((x/dataset.samples_per_symbol)%2, y_fiber_in[:,0], s=1, alpha=0.1) + axs[1,0].scatter((x/dataset.samples_per_symbol)%2, y_fiber_in[:,1], s=1, alpha=0.1) + axs[0,1].scatter((x/dataset.samples_per_symbol)%2, y_fiber_out[:,0], s=1, alpha=0.1) + axs[1,1].scatter((x/dataset.samples_per_symbol)%2, y_fiber_out[:,1], s=1, alpha=0.1) + plt.show() + + # eye_dataset(dataset, 1000, offset=True, show=False) + + # train_loader = DataLoader(dataset, batch_size=10, shuffle=False) + + # plt_dataloader(train_loader, show=False) + + # plt.show() diff --git a/src/single-core-regen/testing/sliced_dataset_test.py b/src/single-core-regen/testing/sliced_dataset_test.py deleted file mode 100644 index 85d635a..0000000 --- a/src/single-core-regen/testing/sliced_dataset_test.py +++ /dev/null @@ -1,51 +0,0 @@ -# move into dir single-core-regen before running - -from util.dataset import SlicedDataset -from torch.utils.data import DataLoader -from matplotlib import pyplot as plt -import numpy as np - -def eye_dataset(dataset, no_symbols=None, offset=False, show=True): - if no_symbols is None: - no_symbols = len(dataset) - _, axs = plt.subplots(2,2, sharex=True, sharey=True) - - xaxis = np.linspace(0,dataset.symbols_per_slice,dataset.samples_per_slice) - roll = dataset.samples_per_symbol//2 if offset else 0 - for E_out, E_in in dataset[roll:dataset.samples_per_symbol*no_symbols+roll:dataset.samples_per_symbol]: - E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1] - axs[0,0].plot(xaxis, np.abs( E_in_x.numpy())**2, alpha=0.05, color='C0') - axs[1,0].plot(xaxis, np.abs( E_in_y.numpy())**2, alpha=0.05, color='C0') - axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2, alpha=0.05, color='C0') - axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2, alpha=0.05, color='C0') - - if show: - plt.show() - -# def plt_dataloader(dataloader, show=True): -# _, axs = plt.subplots(2,2, sharex=True, sharey=True) - -# E_outs, E_ins = next(iter(dataloader)) -# for i, (E_out, E_in) in enumerate(zip(E_outs, E_ins)): -# xaxis = np.linspace(dataset.symbols_per_slice*i,dataset.symbols_per_slice+dataset.symbols_per_slice*i,dataset.samples_per_slice) -# E_in_x, E_in_y, E_out_x, E_out_y = E_in[0], E_in[1], E_out[0], E_out[1] -# axs[0,0].plot(xaxis, np.abs(E_in_x.numpy())**2) -# axs[1,0].plot(xaxis, np.abs(E_in_y.numpy())**2) -# axs[0,1].plot(xaxis, np.abs(E_out_x.numpy())**2) -# axs[1,1].plot(xaxis, np.abs(E_out_y.numpy())**2) - -# if show: -# plt.show() - -if __name__ == "__main__": - - dataset = SlicedDataset("data/20241115-175517-128-16384-10000-0-0-17-0-PAM4-0.ini", symbols=1, drop_first=100) - print(dataset[0][0].shape) - - eye_dataset(dataset, 1000, offset=True, show=False) - - train_loader = DataLoader(dataset, batch_size=10, shuffle=False) - - # plt_dataloader(train_loader, show=False) - - plt.show()