这里是一个使用PyTorch做图像分类的示例,里面使用了数据增强和迁移学习:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
# 定义数据增强操作
data_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 自定义数据集类
class CustomDataset(Dataset):
def __init__(self, imgs, labels, transform=None):
self.imgs = imgs
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
img = self.imgs[index]
label = self.labels[index]
if self.transform:
img = self.transform(img)
return img, label
# 获取数据集
train_set = torchvision.datasets.CIFAR10('./data', train=True,
download=True, transform=data_transform)
test_set = torchvision.datasets.CIFAR10('./data', train=False,
download=True, transform=data_transform)
# 定义训练集和测试集的数据加载器
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64,
shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64,
shuffle=False, num_workers=0)
# 获取迁移学习模型,冻结前面几层权重
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True)
model.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
model.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
model.layer1 = nn.Sequential(nn.BatchNorm2d(64), nn.ReLU(), model.layer1)
model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
model.fc = nn.Linear(512, 10)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# 训练模型
for epoch in range(5):
for inputs, labels in train_loader:
outputs = model(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试模型
total = 0
correct = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Epoch:{}, Accuracy:{:.2f}'.format(epoch+1, correct/total))
该代码使用CIFAR-10数据集,采用ResNet18作为迁移学习模型,对前几层网络参数进行冻结,最后在新的全连接层上进行图像分类训练和测试。