''' Transforms can be applied to PIL images, tensors, ndarrays, or custom data during creation of the DataSet complete list of built-in transforms: https://pytorch.org/docs/stable/torchvision/transforms.html On Images --------- CenterCrop, Grayscale, Pad, RandomAffine RandomCrop, RandomHorizontalFlip, RandomRotation Resize, Scale On Tensors ---------- LinearTransformation, Normalize, RandomErasing Conversion ---------- ToPILImage: from tensor or ndrarray ToTensor : from numpy.ndarray or PILImage Generic ------- Use Lambda Custom ------ Write own class Compose multiple Transforms --------------------------- composed = transforms.Compose([Rescale(256), RandomCrop(224)]) ''' import torch import torchvision from torch.utils.data import Dataset, DataLoader import numpy as np import math #example dataset = torchvision.datasets.MNIST(root='./data', transform=torchvision.transforms.ToTensor(), download=True) class WineDataset(Dataset): def __init__(self, transform=None): xy = np.loadtxt('./data/wine/wine.csv', delimiter=',', dtype=np.float32, skiprows=1) self.n_samples = xy.shape[0] self.x = xy[:,1:] # n_samples x features self.y = xy[:,[0]] # n_samples x 1 self.transform = transform def __getitem__(self, idx): sample = self.x[idx], self.y[idx] if self.transform: sample = self.transform(sample) return sample def __len__(self): return self.n_samples class ToTensor: def __call__(self, sample): x, y = sample return torch.from_numpy(x), torch.from_numpy(y) class MulTransform: def __init__(self, factor): self.factor = factor def __call__(self, sample): x, y = sample x *= self.factor return x, y dataset = WineDataset() first_data = dataset[0] features, labels = first_data print(type(features), type(labels)) dataset = WineDataset(transform=ToTensor()) first_data = dataset[0] features, labels = first_data print(features) print(type(features), type(labels)) composed = torchvision.transforms.Compose([ToTensor(), MulTransform(4.)]) dataset = WineDataset(transform=composed) first_data = dataset[0] features, labels = first_data print(features) print(type(features), type(labels))