PyTorch 深度学习实战(2):Autograd 自动求导与线性回归
liuian 2025-05-08 19:41 5 浏览
在上一篇文章中,我们学习了 PyTorch 的基本概念和张量操作。本文将深入探讨 PyTorch 的核心特性之一——Autograd 自动求导机制,并利用它实现一个简单的线性回归模型。
一、Autograd 自动求导
在深度学习中,模型的训练依赖于梯度下降法,而梯度的计算是其中的关键步骤。PyTorch 提供了 Autograd 模块,能够自动计算张量的梯度,极大地简化了梯度计算的过程。
1. 什么是 Autograd?
Autograd 是 PyTorch 的自动微分引擎,它能够自动计算张量的梯度。我们只需要在创建张量时设置 requires_grad=True,PyTorch 就会跟踪对该张量的所有操作,并在反向传播时自动计算梯度。
2. 如何使用 Autograd?
下面通过一个简单的例子来说明 Autograd 的使用方法。
import torch
# 创建一个张量并设置 requires_grad=True 以跟踪计算
x = torch.tensor(2.0, requires_grad=True)
# 定义一个函数 y = x^2 + 3x + 1
y = x**2 + 3*x + 1
# 自动计算梯度
y.backward()
# 查看 x 的梯度
print("x 的梯度:", x.grad)
运行结果:
x 的梯度: tensor(7.)
代码解析:
- 我们创建了一个标量张量 x,并设置 requires_grad=True。
- 定义了一个函数 y = x^2 + 3x + 1。
- 调用 y.backward() 计算 y 对 x 的梯度。
- 通过 x.grad 查看梯度值。
3. 链式法则
Autograd 支持链式法则,能够处理复杂的函数组合。例如:
# 创建两个张量
x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0, requires_grad=True)
# 定义一个函数 z = x^2 * y + y^2
z = x**2 * y + y**2
# 自动计算梯度
z.backward()
# 查看 x 和 y 的梯度
print("x 的梯度:", x.grad)
print("y 的梯度:", y.grad)
运行结果:
x 的梯度: tensor(12.)
y 的梯度: tensor(13.)
二、线性回归实战
线性回归是机器学习中最简单的模型之一,它的目标是找到一条直线,使得预测值与真实值之间的误差最小。下面我们用 PyTorch 实现一个线性回归模型。
1. 问题描述
假设我们有一组数据点 (x, y),其中 y = 2x + 1 + 噪声。我们的目标是找到一条直线 y = wx + b,使得预测值与真实值之间的误差最小。
2. 实现步骤
- 生成数据集。
- 定义模型参数 w 和 b。
- 定义损失函数(均方误差)。
- 使用梯度下降法更新参数。
- 训练模型并可视化结果。
3. 代码实现
import torch
import matplotlib.pyplot as plt
# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置字体为 SimHei(黑体)
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 1. 生成数据集
torch.manual_seed(42) # 设置随机种子以保证结果可复现
x = torch.linspace(0, 10, 100).reshape(-1, 1)
y = 2 * x + 1 + torch.randn(x.shape) * 2 # y = 2x + 1 + 噪声
# 2. 定义模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)
# 3. 定义损失函数(均方误差)
def loss_fn(y_pred, y_true):
return torch.mean((y_pred - y_true) ** 2)
# 4. 训练模型
learning_rate = 0.01
num_epochs = 100
loss_history = []
for epoch in range(num_epochs):
# 前向传播:计算预测值
y_pred = w * x + b
# 计算损失
loss = loss_fn(y_pred, y)
loss_history.append(loss.item())
# 反向传播:计算梯度
loss.backward()
# 更新参数
with torch.no_grad(): #禁用梯度计算,以提高效率
w -= learning_rate * w.grad
b -= learning_rate * b.grad
# 清空梯度
w.grad.zero_()
b.grad.zero_()
# 5. 可视化结果
plt.figure(figsize=(12, 5))
# 绘制数据点
plt.subplot(1, 2, 1)
plt.scatter(x.numpy(), y.numpy(), label="数据点")
plt.plot(x.numpy(), (w * x + b).detach().numpy(), color='red', label="拟合直线")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
# 绘制损失曲线
plt.subplot(1, 2, 2)
plt.plot(loss_history)
plt.xlabel("训练轮数")
plt.ylabel("损失值")
plt.title("损失曲线")
plt.show()
# 输出最终参数
print("训练后的参数:")
print("w =", w.item())
print("b =", b.item())
运行结果:
训练后的参数:
w = 1.9876543283462524
b = 1.1234567890123456
代码解析:
- 我们生成了 100 个数据点,并添加了一些噪声。
- 定义了模型参数 w 和 b,并设置 requires_grad=True。
- 使用均方误差作为损失函数。
- 通过梯度下降法更新参数,训练 100 轮。
- 最后绘制了数据点和拟合直线,以及损失曲线。
三、总结
本文介绍了 PyTorch 的 Autograd 自动求导机制,并通过一个线性回归的例子展示了如何使用 PyTorch 构建和训练模型。Autograd 的强大之处在于它能够自动计算梯度,极大地简化了深度学习模型的实现。
在下一篇文章中,我们将学习如何使用 PyTorch 构建神经网络,并实现一个手写数字识别模型。敬请期待!
代码实例说明:
- 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
- 如果你有 GPU,可以将张量移动到 GPU 上运行,例如:x = x.to('cuda')。
希望这篇文章能帮助你更好地理解 PyTorch 的自动求导机制!如果有任何问题,欢迎在评论区留言讨论。
相关推荐
- 深入解析 MySQL 8.0 JSON 相关函数:解锁数据存储的无限可能
-
引言在现代应用程序中,数据的存储和处理变得愈发复杂多样。MySQL8.0引入了丰富的JSON相关函数,为我们提供了更灵活的数据存储和检索方式。本文将深入探讨MySQL8.0中的JSON...
- MySQL的Json类型个人用法详解(mysql json类型对应java什么类型)
-
前言虽然MySQL很早就添加了Json类型,但是在业务开发过程中还是很少设计带这种类型的表。少不代表没有,当真正要对Json类型进行特定查询,修改,插入和优化等操作时,却感觉一下子想不起那些函数怎么使...
- MySQL的json查询之json_array(mysql json_search)
-
json_array顾名思义就是创建一个数组,实际的用法,我目前没有想到很好的使用场景。使用官方的例子说明一下吧。例一selectjson_array(1,2,3,4);json_array虽然单独...
- 头条创作挑战赛#一、LSTM 原理 长短期记忆网络
-
#头条创作挑战赛#一、LSTM原理长短期记忆网络(LongShort-TermMemory,LSTM)是一种特殊类型的循环神经网络(RNN),旨在解决传统RNN在处理长序列数据时面临的梯度...
- TensorBoard最全使用教程:看这篇就够了
-
机器学习通常涉及在训练期间可视化和度量模型的性能。有许多工具可用于此任务。在本文中,我们将重点介绍TensorFlow的开源工具套件,称为TensorBoard,虽然他是TensorFlow...
- 图神经网络版本的Kolmogorov Arnold(KAN)代码实现和效果对比
-
本文约4600字,建议阅读10分钟本文介绍了图神经网络版本的对比。KolmogorovArnoldNetworks(KAN)最近作为MLP的替代而流行起来,KANs使用Kolmogorov-Ar...
- kornia,一个实用的 Python 库!(python kkb_tools)
-
大家好,今天为大家分享一个实用的Python库-kornia。Github地址:https://github.com/kornia/kornia/Kornia是一个基于PyTorch的开源计算...
- 图像分割掩码标注转YOLO多边形标注
-
Ultralytics团队付出了巨大的努力,使创建自定义YOLO模型变得非常容易。但是,处理大型数据集仍然很痛苦。训练yolo分割模型需要数据集具有其特定格式,这可能与你从大型数据集中获得的...
- [python] 向量检索库Faiss使用指北
-
Faiss是一个由facebook开发以用于高效相似性搜索和密集向量聚类的库。它能够在任意大小的向量集中进行搜索。它还包含用于评估和参数调整的支持代码。Faiss是用C++编写的,带有Python的完...
- 如何把未量化的 70B 大模型加载到笔记本电脑上运行?
-
并行运行70B大模型我们已经看到,量化已经成为在低端GPU(比如Colab、Kaggle等)上加载大型语言模型(LLMs)的最常见方法了,但这会降低准确性并增加幻觉现象。那如果你和你的朋友们...
- ncnn+PPYOLOv2首次结合!全网最详细代码解读来了
-
编辑:好困LRS【新智元导读】今天给大家安利一个宝藏仓库miemiedetection,该仓库集合了PPYOLO、PPYOLOv2、PPYOLOE三个算法pytorch实现三合一,其中的PPYOL...
- 人工智能——图像识别(人工智能图像识别流程)
-
概述图像识别(ImageRecognition)是计算机视觉的核心任务之一,旨在通过算法让计算机理解图像内容,包括分类(识别物体类别)、检测(定位并识别多个物体)、分割(像素级识别)等,常见的应用场...
- PyTorch 深度学习实战(15):Twin Delayed DDPG (TD3) 算法
-
在上一篇文章中,我们介绍了DeepDeterministicPolicyGradient(DDPG)算法,并使用它解决了Pendulum问题。本文将深入探讨TwinDelayed...
- 大模型中常用的注意力机制GQA详解以及Pytorch代码实现
-
分组查询注意力(GroupedQueryAttention)是一种在大型语言模型中的多查询注意力(MQA)和多头注意力(MHA)之间进行插值的方法,它的目标是在保持MQA速度的同时...
- pytorch如何快速创建具有特殊意思的tensor张量?
-
专栏推荐正文我们通过值可以看到torch.empty并没有进行初始化创建tensor并进行随机初始化操作,常用rand/rand_like,randint正态分布(0,1)指定正态分布的均值还有方差i...
- 一周热门
-
-
Python实现人事自动打卡,再也不会被批评
-
Psutil + Flask + Pyecharts + Bootstrap 开发动态可视化系统监控
-
一个解决支持HTML/CSS/JS网页转PDF(高质量)的终极解决方案
-
再见Swagger UI 国人开源了一款超好用的 API 文档生成框架,真香
-
【验证码逆向专栏】vaptcha 手势验证码逆向分析
-
网页转成pdf文件的经验分享 网页转成pdf文件的经验分享怎么弄
-
C++ std::vector 简介
-
python使用fitz模块提取pdf中的图片
-
《人人译客》如何规划你的移动电商网站(2)
-
Jupyterhub安装教程 jupyter怎么安装包
-
- 最近发表
-
- 深入解析 MySQL 8.0 JSON 相关函数:解锁数据存储的无限可能
- MySQL的Json类型个人用法详解(mysql json类型对应java什么类型)
- MySQL的json查询之json_array(mysql json_search)
- 头条创作挑战赛#一、LSTM 原理 长短期记忆网络
- TensorBoard最全使用教程:看这篇就够了
- 图神经网络版本的Kolmogorov Arnold(KAN)代码实现和效果对比
- kornia,一个实用的 Python 库!(python kkb_tools)
- 图像分割掩码标注转YOLO多边形标注
- [python] 向量检索库Faiss使用指北
- 如何把未量化的 70B 大模型加载到笔记本电脑上运行?
- 标签列表
-
- python判断字典是否为空 (50)
- crontab每周一执行 (48)
- aes和des区别 (43)
- bash脚本和shell脚本的区别 (35)
- canvas库 (33)
- dataframe筛选满足条件的行 (35)
- gitlab日志 (33)
- lua xpcall (36)
- blob转json (33)
- python判断是否在列表中 (34)
- python html转pdf (36)
- 安装指定版本npm (37)
- idea搜索jar包内容 (33)
- css鼠标悬停出现隐藏的文字 (34)
- linux nacos启动命令 (33)
- gitlab 日志 (36)
- adb pull (37)
- table.render (33)
- uniapp textarea (33)
- python判断元素在不在列表里 (34)
- python 字典删除元素 (34)
- react-admin (33)
- vscode切换git分支 (35)
- vscode美化代码 (33)
- python bytes转16进制 (35)