PyTorch: 事前学習モデルを活用した画像分類の例

このブログでは、PyTorchと事前学習モデルを使用して画像分類タスクを実行する方法について詳しく説明します。具体的には、ResNet-50という有名な深層学習モデルを使用し、その重みはImageNetという大規模な画像データセットで訓練されたものです。

目次

  1. 必要なライブラリのインストール
  2. データセットの準備
  3. 事前学習モデルのロード
  4. 事前学習モデルで特徴抽出
  5. 新しいクラス分類器の訓練
  6. 結果の評価

1. 必要なライブラリのインストール

まず、このプロジェクトで必要となるPythonライブラリをインストールします。これらにはtorch, torchvision, numpy等が含まれます。

!pip install torch torchvision numpy

2. データセットの準備

今回はPyTorchが提供するCIFAR-10という小規模な画像分類用データセットを使います。

import torch
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

test_dataset = datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64,
                                          shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64,
                                         shuffle=False)

3. 事前学習モデルのロード

ここでは、ResNet-50 モデルをロードします。また、pretrained=True を指定することで ImageNet データセットで訓練された重みをロードします。

from torchvision import models

resnet50 = models.resnet50(pretrained=True)

4. 事前学習モデルで特徴抽出

事前学習モデルを使って画像から特徴量を抽出します。そのためには、まず全結合層を取り除きます。

resnet50 = torch.nn.Sequential(*(list(resnet50.children())[:-1]))

5. 新しいクラス分類器の訓練

次に、新しいクラス分類器(全結合層)を追加して、CIFAR-10の10クラス分類が可能なようにします。そして、この新しいクラス分類器のみを訓練します。

num_classes = 10
classifier = torch.nn.Linear(2048, num_classes)

# Train the classifier
for images, labels in train_loader:
    features = resnet50(images)
    features = features.view(features.size(0), -1)
    
    outputs = classifier(features)
    
    # Compute loss and backpropagation here...

6. 結果の評価

最後に、テストデータセット上でモデルの性能を評価します。

correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        features = resnet50(images)
        features = features.view(features.size(0), -1)

        outputs = classifier(features)

        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))

以上がPyTorchと事前学習モデル(ResNet-50)を用いた画像分類タスクの一例です。この方法は転移学習と呼ばれ、小規模なデータセットでも高精度な結果が得られることが多いです。