温馨提示×

PyTorch中怎么实现批量归一化

小亿
136
2024-05-10 19:07:56
栏目: 深度学习
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

在PyTorch中,可以使用torch.nn.BatchNorm1dtorch.nn.BatchNorm2d来实现批量归一化。具体代码示例如下:

import torch
import torch.nn as nn

# 对输入数据进行批量归一化
input_data = torch.randn(20, 16, 50, 50)  # 输入数据的shape为(batch_size, channels, height, width)

# 对2D数据进行批量归一化
batchnorm = nn.BatchNorm2d(16)  # 对通道维度进行批量归一化
output_data = batchnorm(input_data)

# 对1D数据进行批量归一化
input_data = torch.randn(20, 16, 100)  # 输入数据的shape为(batch_size, channels, length)
batchnorm = nn.BatchNorm1d(16)  # 对特征维度进行批量归一化
output_data = batchnorm(input_data)

上述代码中,nn.BatchNorm2d用于对2D数据(如图像数据)进行批量归一化,nn.BatchNorm1d用于对1D数据进行批量归一化。需要注意的是,这两个函数都会自动计算并更新均值和方差,同时也会学习伽马和贝塔参数来进行缩放和偏移。

亿速云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

推荐阅读:TensorFlow中怎么实现批量归一化

0