1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
| def sgd(params,lr,batch_size): for param in params: param.data-=lr*param.grad/batch_size
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None): for epoch in range(num_epochs): train_l_sum, train_acc_sum, n = 0.0, 0.0, 0 for X, y in train_iter: y_hat = net(X) l = loss(y_hat, y).sum()
if optimizer is not None: optimizer.zero_grad() elif params is not None and params[0].grad is not None: for param in params: param.grad.data.zero_()
l.backward() if optimizer is None: sgd(params, lr, batch_size) else: optimizer.step()
train_l_sum += l.item() train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item() n += y.shape[0]
test_acc = evaluate_accuracy(test_iter, net) print('周期 %d, 损失 %.4f, 数据集准确率 %.3f, 测试集准确率 %.3f' % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc)) if __name__ == '__main__': batch_size = 256 train_iter, test_iter = load_data_fashion_mnist(batch_size)
num_inputs = 28*28 num_outputs = 10
W = torch.tensor(np.random.normal(0, 0.01, (num_inputs, num_outputs)), dtype=torch.float) b = torch.zeros(num_outputs, dtype=torch.float)
W.requires_grad_(requires_grad=True) b.requires_grad_(requires_grad=True) num_epochs, lr = 5, 0.1
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [W, b], lr)
|