PyTorch中的混合精度训练与AMP

编程艺术家 2019-05-08 ⋅ 26 阅读

介绍

混合精度训练是一种通过使用不同的精度(浮点16和浮点32)来加速模型训练的技术。在传统的训练过程中,所有的参数和梯度都使用浮点32进行计算,而在混合精度训练中,一部分参数和梯度会用浮点16进行计算,从而减少了存储和计算的开销。

PyTorch的自动混合精度(Automatic Mixed Precision,AMP)是使用PyTorch实现混合精度训练的一种方法。AMP可以自动为模型选择合适的数值精度,同时提供额外的功能来处理精度损失和数值不稳定问题。AMP通过使用混合精度训练可以加速训练,并且在大多数情况下不会明显降低模型的性能。

PyTorch中的AMP

在PyTorch中使用AMP进行混合精度训练非常简单。首先,需要导入相应的库:

import torch
from torch.cuda.amp import autocast, GradScaler

接下来,在训练过程中的前向传播和反向传播部分使用autocast上下文管理器来启用混合精度:

with autocast():
    # 前向传播
    output = model(input)
    loss = criterion(output, target)

# 反向传播
scaler = GradScaler()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

使用autocast上下文管理器后,PyTorch会根据输入的数据类型自动选择合适的精度进行计算。对于支持的硬件和操作,PyTorch会将梯度计算和参数更新转换为浮点16,从而减少了GPU存储和计算的开销。

此外,使用GradScaler可以自动缩放损失和梯度,以减少精度损失的影响。scaler.scale(loss)用于缩放损失和梯度,scaler.step(optimizer)用于更新参数,scaler.update()用于更新缩放因子。

实例

下面以一个简单的图像分类任务为例,演示如何在PyTorch中使用AMP进行混合精度训练。

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

# 加载数据集
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1)
        self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 128 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# 使用AMP进行混合精度训练
scaler = GradScaler()

for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        with autocast():
            outputs = net(inputs)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

        if i % 200 == 199:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0

print('Finished Training')

在上述代码中,使用autocast上下文管理器启用混合精度,在每个mini-batch上进行前向传播和反向传播。使用GradScaler进行梯度缩放和参数更新。上述代码会训练一个简单的CNN模型进行CIFAR10图像分类。

结论

混合精度训练及其在PyTorch中的实现(AMP)是一种有效的方法,可以在减少存储和计算成本的同时加速模型训练。通过使用autocastGradScaler,可以简化混合精度训练的实现,并减小精度损失的影响。在实际应用中,可以根据硬件设备和任务需求选择最适合的精度设置,从而提高模型训练的效率和性能。

以上就是关于PyTorch中的混合精度训练与AMP的介绍和示例。希望对你有所帮助!


全部评论: 0

    我有话说: