百度360必应搜狗淘宝本站头条
当前位置:网站首页 > IT知识 > 正文

Pytorch学习-Day 3:张量高级操作(张量svd)

liuian 2025-05-08 19:40 15 浏览

以下是 PyTorch 第3天学习任务的示例代码,涵盖张量形状变换(view 和 reshape)、广播机制以及 GPU 加速的实现。代码中包含详细注释,帮助理解每个部分的功能,并完成任务:将张量移动到 GPU 并计算两个张量的矩阵乘法。

python

# 示例代码:PyTorch Day 3 - 张量高级操作

# 导入 PyTorch 库
import torch

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前设备: {device}")

# 1. 张量形状变换(view 和 reshape)
def tensor_shape_transform():
    # 创建一个 3x4 的张量
    x = torch.arange(12).float()  # 生成 0 到 11 的张量,转换为浮点型
    print("原始张量 x:")
    print(x)
    print(f"x 的形状: {x.shape}")  # torch.Size([12])

    # 使用 view 变换形状(要求元素总数不变)
    x_view = x.view(3, 4)  # 变换为 3x4
    print("\n使用 view 变换后的 x:")
    print(x_view)
    print(f"x_view 的形状: {x_view.shape}")  # torch.Size([3, 4])

    # 使用 reshape 变换形状(更灵活,可自动推导维度)
    x_reshape = x.reshape(2, -1)  # -1 表示自动计算此维度,变为 2x6
    print("\n使用 reshape 变换后的 x:")
    print(x_reshape)
    print(f"x_reshape 的形状: {x_reshape.shape}")  # torch.Size([2, 6])

# 2. 广播机制(Broadcasting)
def tensor_broadcasting():
    # 创建两个形状不同的张量
    a = torch.tensor([[1, 2, 3]])  # 形状: 1x3
    b = torch.tensor([[4], [5], [6]])  # 形状: 3x1
    print("\n张量 a:")
    print(a)
    print(f"a 的形状: {a.shape}")
    print("张量 b:")
    print(b)
    print(f"b 的形状: {b.shape}")

    # 广播机制自动扩展维度并进行逐元素相加
    c = a + b  # 结果形状为 3x3
    print("\n广播后 a + b 的结果:")
    print(c)
    print(f"c 的形状: {c.shape}")

# 3. GPU 加速 + 矩阵乘法任务
def matrix_multiplication_on_gpu():
    # 创建两个张量
    A = torch.randn(3, 4)  # 随机生成 3x4 矩阵
    B = torch.randn(4, 5)  # 随机生成 4x5 矩阵
    print("\n矩阵 A:")
    print(A)
    print("矩阵 B:")
    print(B)

    # 将张量移动到 GPU(如果可用)
    if torch.cuda.is_available():
        A = A.to(device)
        B = B.to(device)
        print(f"\n张量已移动到 {device}")

    # 计算矩阵乘法
    C = torch.matmul(A, B)  # 结果形状为 3x5
    print("\n矩阵乘法结果 C (A @ B):")
    print(C)
    print(f"C 的形状: {C.shape}")

# 主函数:依次运行所有任务
def main():
    print("=== 张量形状变换 ===")
    tensor_shape_transform()
    print("\n=== 广播机制 ===")
    tensor_broadcasting()
    print("\n=== GPU 加速与矩阵乘法 ===")
    matrix_multiplication_on_gpu()

if __name__ == "__main__":
    main()

"""
学习内容说明:
1. 张量形状变换:
   - view(): 改变张量形状,要求内存连续且元素总数不变。
   - reshape(): 更灵活的形状变换,支持非连续内存。

2. 广播机制:
   - 自动扩展张量维度以匹配操作(如加法),无需手动调整形状。
   - 规则:从尾部对齐维度,小维度扩展为大维度或补1。

3. GPU 加速:
   - 使用 torch.device 和 .to() 将张量移到 GPU。
   - torch.cuda.is_available() 检查 GPU 可用性。

任务完成:
- 将张量 A 和 B 移动到 GPU(如果可用)。
- 使用 torch.matmul() 计算矩阵乘法。
"""

代码说明

  1. 张量形状变换:
  2. view(3, 4): 将一维张量变换为 3x4 的二维张量。
  3. reshape(2, -1): 将张量变为 2 行,列数自动推导为 6。
  4. 广播机制:
  5. 两个张量 a (1x3) 和 b (3x1) 通过广播扩展为 3x3,然后逐元素相加。
  6. 展示了 PyTorch 如何自动处理维度不匹配的情况。
  7. GPU 加速与矩阵乘法:
  8. 检查 GPU 可用性并定义 device。
  9. 使用 .to(device) 将张量 A 和 B 移到 GPU。
  10. 使用 torch.matmul() 计算矩阵乘法,结果为 3x5 的矩阵。

运行要求

  • 安装 PyTorch: pip install torch(如果需要 GPU 支持,确保安装 CUDA 版本,例如 pip install torch torchvision -f https://download.pytorch.org/whl/cu117)。
  • 有 GPU 的环境(可选):如果没有 GPU,代码会自动在 CPU 上运行。

如何运行

  1. 保存代码为 pytorch_day3.py。
  2. 在终端运行:python pytorch_day3.py。
  3. 观察输出,验证形状变换、广播和矩阵乘法的结果。

输出示例(部分)

当前设备: cuda
=== 张量形状变换 ===
原始张量 x:
tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.])
x 的形状: torch.Size([12])

使用 view 变换后的 x:
tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
x_view 的形状: torch.Size([3, 4])

...
=== GPU 加速与矩阵乘法 ===
张量已移动到 cuda
矩阵乘法结果 C (A @ B):
tensor([[...], [...], [...]])  # 具体值因随机生成而异
C 的形状: torch.Size([3, 5])

这个代码完整实现了第3天的学习任务,并通过注释和输出展示了 PyTorch 的张量高级操作。建议参考 PyTorch 官方教程“Tensor Operations”进一步深入学习!

相关推荐

Optional是个好东西,如果用错了就太可惜了

原文出处:https://xie.infoq.cn/article/e3d1f0f4f095397c44812a5be我们都知道,在Java8新增了一个类-Optional,主要是用来解决程...

IDEA建议:不要在字段上使用@Autowire了!

在使用IDEA写Spring相关的项目的时候,在字段上使用@Autowired注解时,总是会有一个波浪线提示:Fieldinjectionisnotrecommended.纳尼?我天天用,咋...

Spring源码|Spring实例Bean的方法

Spring实例Bean的方法,在AbstractAutowireCapableBeanFactory中的protectedBeanWrappercreateBeanInstance(String...

Spring技巧:深入研究Java 14和SpringBoot

在本期文章中,我们将介绍Java14中的新特性及其在构建基于SpringBoot的应用程序中的应用。开始,我们需要使用Java的最新版本,也是最棒的版本,Java14,它现在还没有发布。预计将于2...

Java开发200+个学习知识路线-史上最全(框架篇)

1.Spring框架深入SpringIOC容器:BeanFactory与ApplicationContextBean生命周期:实例化、属性填充、初始化、销毁依赖注入方式:构造器注入、Setter注...

年末将至,Java 开发者必须了解的 15 个Java 顶级开源项目

专注于Java领域优质技术,欢迎关注作者:SnailClimbStar的数量统计于2019-12-29。1.JavaGuideGuide哥大三开始维护的,目前算是纯Java类型项目中Sta...

字节跨平台框架 Lynx 开源:一个 Web 开发者的原生体验

最近各大厂都在开源自己的跨平台框架,前脚腾讯刚宣布计划四月开源基于Kotlin的跨平台框架「Kuikly」,后脚字节跳动旧开源了他们的跨平台框架「Lynx」,如果说Kuikly是一个面向...

我要狠狠的反驳“公司禁止使用Lombok”的观点

经常在其它各个地方在说公司禁止使用Lombok,我一直不明白为什么不让用,今天看到一篇文章列举了一下“缺点”,这里我只想狠狠地反驳,看到列举的理由我竟无言以对。原文如下:下面,结合我自己使用Lomb...

SpringBoot Lombok使用详解:从入门到精通(注解最全)

一、Lombok概述与基础使用1.1Lombok是什么Lombok是一个Java库,它通过注解的方式自动生成Java代码(如getter、setter、toString等),从而减少样板代码的编写,...

Java 8之后的那些新特性(六):记录类 Record Class

Java是一门面向对象的语言,而对于面向对象的语言中,一个众所周知的概念就是,对象是包含属性与行为的。比如HR系统中都会有雇员的概念,那雇员会有姓名,ID身份,性别等,这些我们称之为属性;而雇员同时肯...

为什么大厂要求安卓开发者掌握Kotlin和Jetpack?优雅草卓伊凡

为什么大厂要求安卓开发者掌握Kotlin和Jetpack?深度解析现代Android开发生态优雅草卓伊凡一、Kotlin:Android开发的现代语言选择1.1Kotlin是什么?Kotlin是由...

Kotlin这5招太绝了!码农秒变优雅艺术家!

Kotlin因其简洁性、空安全性和与Java的无缝互操作性而备受喜爱。虽然许多开发者熟悉协程、扩展函数和数据类等特性,但还有一些鲜为人知的特性可以让你的代码从仅仅能用变得真正优雅且异常简洁。让我们来看...

自行部署一款免费高颜值的IT资产管理系统-咖啡壶chemex

在运维时,ICT资产太多怎么办,还是用excel表格来管理?效率太低,也不好多人使用。在几个IT资产管理系统中选择比较中,最终在Snipe-IT和chemex间选择了chemex咖啡壶。Snip...

PHP对接百度语音识别技术(php对接百度语音识别技术实验报告)

引言在目前的各种应用场景中,语音识别技术已经越来越常用,并且其应用场景正在不断扩大。百度提供的语音识别服务允许用户通过简单的接口调用,将语音内容转换为文本。本文将通过PHP语言集成百度的语音识别服务,...

知识付费系统功能全解析(知识付费项目怎么样)

开发知识付费系统需包含核心功能模块,确保内容变现、用户体验及运营管理需求。以下是完整功能架构:一、用户端功能注册登录:手机号/邮箱注册,第三方登录(微信、QQ)内容浏览:分类展示课程、文章、音频等付费...