From d405216a55a8fdf691cf70a80d60d0811e637857 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joseph=20Hopfm=C3=BCller?= Date: Mon, 17 Oct 2022 22:07:07 +0200 Subject: [PATCH] in chapter 14 (3:23:19) --- .gitignore | 1 + 14_cnn | 128 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+) create mode 100644 14_cnn diff --git a/.gitignore b/.gitignore index 6d4ba14..ff0d053 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ share/ pyvenv.cfg .python-version data/MNIST/ +data/CIFAR10/ \ No newline at end of file diff --git a/14_cnn b/14_cnn new file mode 100644 index 0000000..b3be0ba --- /dev/null +++ b/14_cnn @@ -0,0 +1,128 @@ +# cnn on cifar-10 +''' +convolutional net +similar to ff net, but applies convolutional filters (mainly on images) + +also include pooling layers +specifically max pooling +downsamples image by getting max value in a region + + 12 20 30 0 + 8 12 2 0 20 30 + 34 70 37 4 -- 2x2 max-pool --> 112 37 +112 100 25 12 + helps avoid overfitting by providing abstract form of input +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +import matplotlib.pyplot as plt +import numpy as np + +# device config +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'Device is {device}') + +# hyper parameters +num_epochs = 4 +batch_size = 4 +learning_rate = 0.001 + +# dataset has PILImage images of range [0,1] +# transform to tensors of normalized range [-1, 1] + +transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)) + ]) + +train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, + download=True, transform=transform) + +test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, + download=False, transform=transform) + +train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) +test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) + +classes = ('plane', 'car', 'bird', 'cat', 'deer', + 'dog', 'frog', 'horse', 'ship', 'truck') # can also get them from the data, probably + +#implement conv net +class ConvNet(nn.Module): + def __init__(self) -> None: + super().__init__() + self.layers = nn.Sequential([ + nn.Conv2d(3, 6, 5), # 3 color channels, 6 output channels, kernel size 5 + nn.MaxPool2d(2, 2), # kernel size 2, stride 2 (whats stride?) + + # 3:23:19 + + ]) + pass + + def forward(self, x): + # return self.layers(x) + pass + +model = ConvNet().to(device) + +criterion = nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate) + +n_total_steps = len(train_loader) +for epoch in range(num_epochs): + for i, (images, labels) in enumerate(train_loader): + # origin shape: [4, 3, 32, 32] = 4, 3, 1024 + # input_layer: 3 input channels, 6 output channels, 5 kernel size + images = images.to(device) + labels = labels.to(device) + + #forward pass + outputs = model(images) + loss = criterion(outputs, labels) + + # backward pass + update + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (i+1) % 2000 == 0: + print(f'Epoch {(epoch+1)}/{num_epochs}, Step {(i+1)}/{n_total_steps}, loss = {loss.item():.4f}') + +print('Finished Training') + +with torch.no_grad(): + n_correct = 0 + n_samples = 0 + n_class_correct = [0 for _ in range(10)] + n_class_samples = [0 for _ in range(10)] + for images, labels in test_loader: + images = images.to(device) + labels = labels.to(device) + outputs = model(images) + # max returns (value, index) + _, predicted = torch.max(outputs, 1) + n_samples += labels.size(0) + n_correct += (predicted == labels).sum().item() + + for i in range(batch_size): + label = labels[i] + pred = predicted[i] + if (label == pred): + n_class_correct[label] += 1 + n_class_samples[label] += 1 + + acc = 100. * n_correct / n_samples + print(f'Accuracy of network: {acc:.1f}%') + + for i in range(10): + acc = 100. * n_class_correct[i]/n_class_samples[i] + print(f'Accuracy of {classes[i]}: {acc:.1f}%') + + + +