PyTorch:tensor-张量维度操作(拼接、维度扩展、压缩、转置、重复……)
1. 张量基础与维度操作概览在PyTorch中张量Tensor是多维数组的核心数据结构类似于NumPy的ndarray但具备GPU加速和自动求导功能。理解张量维度操作是深度学习模型开发的基础技能就像厨师需要掌握切菜技巧一样重要。张量维度操作主要分为以下几类拼接操作将多个张量合并为一个如torch.cat和torch.stack维度调整改变张量的形状和维度如view、reshape、permute扩展压缩增加或减少维度如unsqueeze和squeeze重复复制沿特定维度复制数据如repeat和expand我刚开始学习时经常混淆stack和cat的区别直到在图像分类任务中处理批次数据时才真正理解cat是在现有维度上拼接而stack会创建新维度。比如处理5张224x224的RGB图片时# 使用stack会得到[5,3,224,224]的四维张量 # 使用cat(dim0)会得到[15,224,224]的三维张量通常不是我们想要的2. 张量拼接torch.cat与torch.stack详解2.1 torch.cat沿现有维度拼接torch.cat是最常用的拼接方法它不会创建新维度而是在现有维度上连接张量。就像把多个书架上的书合并到一个更长的书架上。import torch # 创建两个形状相同的张量 a torch.randn(2, 3) # 2行3列 b torch.randn(2, 3) # 沿第0维行方向拼接 c torch.cat([a, b], dim0) # 结果形状[4,3] # 沿第1维列方向拼接 d torch.cat([a, b], dim1) # 结果形状[2,6]实际应用场景在自然语言处理中处理变长序列时常用cat拼接不同长度的句子配合padding mask使用。2.2 torch.stack创建新维度拼接torch.stack会在新创建的维度上拼接张量要求所有输入张量形状完全相同。就像把多张照片放入相册形成一个新的照片索引维度。# 继续使用前面的a和b张量 e torch.stack([a, b], dim0) # 形状[2,2,3] f torch.stack([a, b], dim1) # 形状[2,2,3] g torch.stack([a, b], dim2) # 形状[2,3,2]关键区别stack输入张量列表cat输入张量序列stack会新增维度cat不会stack要求所有张量形状相同cat只需在拼接维度外其他维度相同我在处理3D医学图像时常用stack将多个2D切片组合成3D体积数据slice_list [torch.randn(256,256) for _ in range(100)] volume torch.stack(slice_list, dim0) # [100,256,256]3. 维度扩展与压缩3.1 unsqueeze增加维度torch.unsqueeze在指定位置插入长度为1的维度就像给向量加一个括号使其成为矩阵x torch.tensor([1,2,3]) # 形状[3] y x.unsqueeze(0) # 形状[1,3]相当于[[1,2,3]] z x.unsqueeze(1) # 形状[3,1]相当于[[1],[2],[3]]等价操作y x[None,:] # 同unsqueeze(0) z x[:,None] # 同unsqueeze(1)3.2 squeeze压缩维度torch.squeeze移除长度为1的维度默认移除所有也可指定维度a torch.randn(1,3,1,2) b a.squeeze() # 形状[3,2] c a.squeeze(0) # 形状[3,1,2] d a.squeeze(2) # 形状[1,3,2]注意如果指定压缩的维度长度不为1则不会发生任何变化。3.3 expand内存高效的维度扩展expand不会实际复制数据而是通过广播机制实现维度扩展x torch.tensor([[1],[2],[3]]) # [3,1] y x.expand(3,4) # [3,4] 结果 [[1,1,1,1], [2,2,2,2], [3,3,3,3]] 使用限制只能将长度为1的维度扩展到更大尺寸-1表示保持该维度不变原始张量在非扩展维度上必须与新尺寸匹配4. 张量变形与转置操作4.1 view与reshape改变张量形状view和reshape都能改变张量形状而不改变数据x torch.randn(4,4) y x.view(16) # 展平 z x.view(-1,8) # -1表示自动计算该维度大小 w x.reshape(2,8) # 功能类似view重要区别view要求张量内存连续否则会报错reshape会自动处理非连续张量但可能产生拷贝在模型训练中建议先用contiguous()确保连续性再用view4.2 转置操作t、transpose和permute对于矩阵转置2D张量x torch.randn(3,4) y x.t() # [4,3] z x.T # 同t()高维张量转置x torch.randn(2,3,4) y x.transpose(0,1) # [3,2,4] z x.permute(2,0,1) # [4,2,3]permute比transpose更灵活可以一次性重新排列所有维度顺序。5. 张量复制与重复5.1 repeat数据复制的维度扩展repeat会实际复制数据沿各维度重复指定次数x torch.tensor([1,2,3]) y x.repeat(2,3) # [2,9] [[1,2,3,1,2,3,1,2,3], [1,2,3,1,2,3,1,2,3]] 5.2 expand与repeat的选择expand适用于广播场景不实际增加内存占用repeat适用于需要独立副本的场景会增加内存修改expand结果会影响原始张量repeat则不会6. 高级维度操作技巧6.1 gather与scatter索引操作gather按照索引从输入张量收集数据x torch.tensor([[1,2],[3,4]]) index torch.tensor([[0,0],[1,0]]) y x.gather(1, index) # 沿dim1收集 [[1,1], [4,3]] scatter是gather的逆操作将值分散到指定位置z torch.zeros(2,2) z.scatter_(1, index, x) # 将x的值按index分散到z6.2 内存共享机制许多张量操作如view、transpose、narrow会共享底层存储修改一个会影响另一个x torch.randn(3,4) y x.view(4,3) y[0,0] 100 # x也会被修改要避免这种情况可以使用clone()创建独立副本z x.clone().view(4,3) # 不共享内存7. 实际应用场景示例7.1 图像数据处理处理批次图像时常用维度操作# 单张图像 [C,H,W] - 批次 [B,C,H,W] img torch.randn(3,224,224) batch img.unsqueeze(0).expand(32,-1,-1,-1) # 特征图拼接 feat1 torch.randn(32,64,56,56) feat2 torch.randn(32,128,56,56) combined torch.cat([feat1,feat2], dim1) # [32,192,56,56]7.2 自然语言处理在Transformer模型中# 多头注意力机制中的维度变换 q torch.randn(32,10,64) # [batch, seq_len, dim] k q.view(32,10,8,8).transpose(1,2) # [32,8,10,8]7.3 模型权重初始化初始化线性层权重时weight torch.empty(3,4) nn.init.kaiming_normal_(weight) # 添加batch维度 weight weight.unsqueeze(0).expand(32,-1,-1) # [32,3,4]8. 常见错误与调试技巧维度不匹配错误# 错误示例 a torch.randn(2,3) b torch.randn(2,4) torch.cat([a,b], dim1) # 报错除dim1外其他维度必须相同非连续内存错误x torch.randn(3,4).transpose(0,1) y x.view(12) # 报错张量不连续 # 修复方案 y x.contiguous().view(12)广播机制误解a torch.randn(3,1) b torch.randn(1,3) c a b # 正确[3,3] d a.expand(3,4) # 正确 e a.expand(4,3) # 报错无法将3扩展到4调试建议经常检查张量shapeprint(tensor.shape)使用assert确保维度符合预期对复杂操作分步验证