刘二大人pytorch教程课后作业(05)——pytorch的API+多种优化器

发布时间:2026/6/27 7:50:00
刘二大人pytorch教程课后作业(05)——pytorch的API+多种优化器 import torch import matplotlib.pyplot as plt # prepare dataset # x,y是矩阵3行1列 也就是说总共有3个数据每个数据只有1个特征 x_data torch.tensor([[1.0], [2.0], [3.0]]) y_data torch.tensor([[2.0], [4.0], [6.0]]) # design model using class # 1. 定义一个线性模型模板继承自 PyTorch 的基础模块 class LinearModel(torch.nn.Module): # 2. 初始化函数当你创建模型时比如 model LinearModel()这里会自动执行 def __init__(self): # 3. 调用父类torch.nn.Module的初始化告诉 PyTorch“我要开始建网络了” super(LinearModel, self).__init__() # 4. 在“我自己”身上安装一个输入1维、输出1维的线性层里面包含了要学习的 w 和 b self.linear torch.nn.Linear(1, 1)#第一个数字为输入的数量第二个数字为输出的数量 # 5. 前向传播函数定义数据是怎么流过这个模型的 def forward(self, x): # 6. 把输入 x 传给“我自己的 ”线性层得到预测值 y_pred y_pred self.linear(x) # 7. 把结果返回 return y_pred #创建模型 model LinearModel() # construct loss and optimizer # criterion torch.nn.MSELoss(size_average False) criterion torch.nn.MSELoss(reductionsum)#定义损失函数Loss Function #optimizer torch.optim.SGD(model.parameters(), lr0.01) # 定义优化器Optimizer也就是决定模型 #model.parameters()这是一个生成器它会自动把模型内部所有需要学习的参数如权重 weight 和偏置 bias打包传给优化器。 #optimizer torch.optim.Adagrad(model.parameters(), lr0.01) #optimizer torch.optim.Adam(model.parameters(), lr0.01) #optimizer torch.optim.Adamax(model.parameters(), lr0.01) #optimizer torch.optim.ASGD(model.parameters(), lr0.01) #optimizer torch.optim.LBFGS(model.parameters(), lr0.01) #optimizer torch.optim.RMSprop(model.parameters(), lr0.01) optimizer torch.optim.Rprop(model.parameters(), lr0.01) # training cycle forward, backward, update epoch_list[] loss_list[] for epoch in range(1000): y_pred model(x_data) # forward:predict loss criterion(y_pred, y_data) # forward: loss print(epoch, loss.item()) optimizer.zero_grad() # 由于pytorch会累加梯度因此这里需要清0 loss.backward() # backward: autograd自动计算梯度 optimizer.step() # update 参数即更新w和b的值 epoch_list.append(epoch) loss_list.append(loss.item()) print(w , model.linear.weight.item()) print(b , model.linear.bias.item()) x_test torch.tensor([[4.0]]) y_test model(x_test) print(y_pred , y_test.data) plt.plot(epoch_list, loss_list) #plt.ylabel(loss_SGD) #plt.ylabel(loss_Adagrad) #plt.ylabel(loss_Adam) #plt.ylabel(loss_Adamax) #plt.ylabel(loss_ASGD) #plt.ylabel(loss_LBFGS) #plt.ylabel(loss_RMSprop) plt.ylabel(loss_Rprop) plt.xlabel(epoch) plt.show()