finish chapter 12
This commit is contained in:
52
11_02_crossentropy.py
Normal file
52
11_02_crossentropy.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# loss function for multiclass problems -> labels must be one-hot encoded, predictions is a vector of probabilities (after applying softmax)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
# numpy
|
||||
def cross_entropy(actual, predicted, normalize=False):
|
||||
loss = -np.sum(actual * np.log(predicted))
|
||||
if normalize:
|
||||
loss /= float(predicted.shape[0])
|
||||
return loss
|
||||
|
||||
# y must be one-hot encoded
|
||||
# class 0 [1 0 0]
|
||||
# class 1 [0 1 0]
|
||||
# class 2 [0 0 1]
|
||||
|
||||
Y = np.array([1, 0, 0]) # class 0
|
||||
|
||||
# y_pred has probabilities
|
||||
Y_pred_good = np.array([.7, .2, .1])
|
||||
Y_pred_bad = np.array([.1, .3, .6])
|
||||
l1 = cross_entropy(Y, Y_pred_good)
|
||||
# l1_norm = cross_entropy(Y, Y_pred_good, normalize=True)
|
||||
l2 = cross_entropy(Y, Y_pred_bad)
|
||||
# l2_norm = cross_entropy(Y, Y_pred_bad, normalize=True)
|
||||
print(f'Loss1 numpy: {l1:.4f}')
|
||||
# print(f'Loss1 numpy normalized: {l1_norm:.4f}')
|
||||
print(f'Loss2 numpy: {l2:.4f}')
|
||||
# print(f'Loss1 numpy normalized: {l2_norm:.4f}')
|
||||
|
||||
#pytorch
|
||||
# 3 samples
|
||||
loss = nn.CrossEntropyLoss() # includes softmax -> y_pred has raw scores, y has class labels, not one-hot
|
||||
Y = torch.tensor([2, 0, 1])
|
||||
# nsamples x nclasses = 3x3
|
||||
Y_pred_good = torch.tensor([[0.1, 1., 2.1], [2., 1., .1], [0.5, 2., .3]])
|
||||
Y_pred_bad = torch.tensor([[2.1, 1., .1], [.1, 1., 2.1], [.1, 3., .1]])
|
||||
|
||||
l1 = loss(Y_pred_good, Y)
|
||||
l2 = loss(Y_pred_bad, Y)
|
||||
|
||||
print(f'Loss1 pytorch: {l1.item():.4f}')
|
||||
print(f'Loss2 pytorch: {l2.item():.4f}')
|
||||
|
||||
_, prediction1 = torch.max(Y_pred_good, 1)
|
||||
_, prediction2 = torch.max(Y_pred_bad, 1)
|
||||
print(prediction1)
|
||||
print(prediction2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user