finish 09
This commit is contained in:
48
09_dataloaders.py
Normal file
48
09_dataloaders.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
class WineDataset(Dataset):
|
||||
def __init__(self):
|
||||
#data loading
|
||||
xy = np.loadtxt('./data/wine/wine.csv', delimiter=',', dtype=np.float32, skiprows=1)
|
||||
self.x = torch.from_numpy(xy[:,1:]) # n_samples x features
|
||||
self.y = torch.from_numpy(xy[:,[0]]) # n_samples x 1
|
||||
self.n_samples = xy.shape[0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.x[idx], self.y[idx]
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
dataset = WineDataset()
|
||||
# first_data = dataset[0]
|
||||
# features, labels = first_data
|
||||
# print(features, labels)
|
||||
batch_size = 4
|
||||
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=2)
|
||||
|
||||
# dataiter = iter(dataloader)
|
||||
# data = dataiter.next()
|
||||
# features, labels = data
|
||||
# print(features, labels)
|
||||
|
||||
# dummy training loop
|
||||
n_epochs = 2
|
||||
total_samples = len(dataset)
|
||||
n_iter = math.ceil(total_samples/batch_size)
|
||||
print(total_samples, n_iter)
|
||||
|
||||
for epoch in range(n_epochs):
|
||||
for i, (inputs, labels) in enumerate(dataloader):
|
||||
# forward backward update
|
||||
if (i+1) % 5 == 0:
|
||||
print(f'Epoch {epoch+1}/{n_epochs}, step {i+1}/{n_iter}, inputs {inputs.shape}')
|
||||
|
||||
|
||||
# pytorch built in datasets
|
||||
# torchvision.datasets.MNIST()
|
||||
# fashion-mnist, cifar, coco
|
||||
Reference in New Issue
Block a user