このブログでは、PyTorchと事前学習モデルを使用して画像分類タスクを実行する方法について詳しく説明します。具体的には、ResNet-50という有名な深層学習モデルを使用し、その重みはImageNetという大規模な画像データセットで訓練されたものです。
目次
1. 必要なライブラリのインストール
まず、このプロジェクトで必要となるPythonライブラリをインストールします。これらにはtorch
, torchvision
, numpy
等が含まれます。
!pip install torch torchvision numpy
2. データセットの準備
今回はPyTorchが提供するCIFAR-10という小規模な画像分類用データセットを使います。
import torch from torchvision import datasets, transforms transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
3. 事前学習モデルのロード
ここでは、ResNet-50 モデルをロードします。また、pretrained=True
を指定することで ImageNet データセットで訓練された重みをロードします。
from torchvision import models resnet50 = models.resnet50(pretrained=True)
4. 事前学習モデルで特徴抽出
事前学習モデルを使って画像から特徴量を抽出します。そのためには、まず全結合層を取り除きます。
resnet50 = torch.nn.Sequential(*(list(resnet50.children())[:-1]))
5. 新しいクラス分類器の訓練
次に、新しいクラス分類器(全結合層)を追加して、CIFAR-10の10クラス分類が可能なようにします。そして、この新しいクラス分類器のみを訓練します。
num_classes = 10 classifier = torch.nn.Linear(2048, num_classes) # Train the classifier for images, labels in train_loader: features = resnet50(images) features = features.view(features.size(0), -1) outputs = classifier(features) # Compute loss and backpropagation here...
6. 結果の評価
最後に、テストデータセット上でモデルの性能を評価します。
correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: features = resnet50(images) features = features.view(features.size(0), -1) outputs = classifier(features) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the test images: %d %%' % ( 100 * correct / total))
以上がPyTorchと事前学習モデル(ResNet-50)を用いた画像分類タスクの一例です。この方法は転移学習と呼ばれ、小規模なデータセットでも高精度な結果が得られることが多いです。