介绍一个好用的张量操作库——einops

2024-09-22 83 0

最近一段时间看论文代码,发现有不少项目用到了einops库来实现语义化操作,感觉很好用,所以写一个博客来记录下,以免以后要用到。

维度交换

import torch
from einops import rearrange

input_tensor = torch.randint(low=0,high=255,size=(2,3,224,224)) # 创建一个batch_size为2,分辨率为224*224的RGB图片
print(input_tensor.shape)

# 交换顺序
output_tensor = rearrange(input_tensor,"bs c h w -> bs h w c")
print(output_tensor.shape)

维度合并

import torch
from einops import rearrange

input_tensor = torch.randint(low=0,high=255,size=(2,3,224,224)) # 创建一个batch_size为2,分辨率为224*224的RGB图片
print(input_tensor.shape)

# 维度合并
output_tensor = rearrange(input_tensor,"bs c h w -> bs c (h w)")
print(output_tensor.shape)

等价于

print(input_tensor.view(input_tensor.shape[0],input_tensor.shape[1],input_tensor.shape[2]*input_tensor.shape[3]).shape)

维度拆分

拆分时需要指定拆分的规则

import torch
from einops import rearrange

input_tensor = torch.randint(low=0,high=255,size=(2,3,224*224)) # 创建一个batch_size为2,分辨率为224*224的RGB图片
print(input_tensor.shape)

# 维度拆分
output_tensor = rearrange(input_tensor,"bs c (h w) -> bs c h w",h=224)
print(output_tensor.shape)

降维

对某个维度以求平均/最大/最小/求和/求乘积的形式降维4

import torch
from einops import rearrange, reduce

input_tensor = torch.randint(low=0,high=255,size=(2,3,224,224)) # 创建一个batch_size为2,分辨率为224*224的RGB图片
input_tensor = input_tensor.to(float)
print(input_tensor.shape)

# 降维
output_tensor = reduce(input_tensor,"bs c h w -> bs h w","mean")
print(output_tensor.shape)

对某个维度求其平均/最大/最小/求和/求乘积

对batch中的每个图片的每个chanel,逐行求均值

import torch
from einops import rearrange, reduce

input_tensor = torch.randint(low=0,high=255,size=(2,3,224,224)) # 创建一个batch_size为2,分辨率为224*224的RGB图片
input_tensor = input_tensor.to(float)
print(input_tensor.shape)

output_tensor = reduce(input_tensor,"bs c h w -> bs c h ()","mean")
print(output_tensor.shape)

对batch中的每个图片,逐chanel求均值

import torch
from einops import rearrange, reduce

input_tensor = torch.randint(low=0,high=255,size=(2,3,224,224)) # 创建一个batch_size为2,分辨率为224*224的RGB图片
input_tensor = input_tensor.to(float)
print(input_tensor.shape)

output_tensor = reduce(input_tensor,"bs c h w -> bs c () ()","mean")
print(output_tensor.shape)

2D MaxPooling (Kernel_size = 2*2)

y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)

Adaptive 2d max-pooling to 3 * 4 grid

reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4)

添加/删除某个维度

x = rearrange(ims, "b h w c -> b 1 h w 1 c")
x = rearrange(x, "b 1 h w 1 c -> b h w c")

重复

把new_axis后面的东西重复

repeat(x, "h w c -> h new_axis w c", new_axis=5)

沿着某个维度重复

file

repeat(ims, "h w c -> h (repeat w) c", repeat=3)

重复每个元素

file

repeat(ims[0], "h w c -> h (w repeat) c", repeat=3)

相关文章

记Vision Mamba在Pascal架构的显卡下训练出错的的解决办法
NJUCS 2024 分布式系统 课程内容整理(Last Updated 2024/12/30)

发布评论