PyTorch: 事前学習モデルを活用した転移学習画像分類

このブログでは、PyTorchと事前学習モデルを使用して転移学習による画像分類タスクの実行方法について詳しく説明します。具体的には、ResNet-50という有名な深層学習モデルを使用し、その重みはImageNetという大規模な画像データセットで訓練されたものです。

目次

  1. 必要なライブラリのインストール
  2. データセットの準備
  3. 事前学習モデルのロード
  4. 新しいクラス分類器の訓練
  5. 結果の評価

1. 必要なライブラリのインストール

まず、このプロジェクトで必要となるPythonライブラリをインストールします。これらにはtorch, torchvision, numpy等が含まれます。

!pip install torch torchvision numpy

2. データセットの準備

ここでは自身で用意した「my_dataset」という名前のフォルダ内にある「train」と「test」フォルダから画像を読み込むコード例です。「my_dataset」フォルダは次の構造であることが想定されています。

my_dataset/
    train/
        class_1/
            img001.jpg
            img002.jpg
            ...
        class_2/
            img001.jpg
            img002.jpg
            ...
    test/
        class_1/
            img001.jpg
            img002.jpg
            ...
        class_2/
            img001.jpg
            img002.jpg
            ...
import torch 
from torchvision import datasets, transforms 

transform = transforms.Compose([
    transforms.Resize((224,224)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageFolder('./my_dataset/train', transform=transform) 

test_dataset = datasets.ImageFolder('./my_dataset/test', transform=transform) 

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) 

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

3. 事前学習モデルのロード

ここでは、ResNet-50 モデルをロードします。また、pretrained=True を指定することで ImageNet データセットで訓練された重みをロードします。

from torchvision import models

resnet50 = models.resnet50(pretrained=True)

# Freeze the parameters of the model to avoid their updates
for param in resnet50.parameters():
    param.requires_grad = False

4. 新しいクラス分類器の訓練

次に、新しいクラス分類器(全結合層)を追加して、自身のデータセットに対応するようにします。そして、この新しいクラス分類器のみを訓練します。

import torch.nn as nn 

num_classes = len(train_dataset.classes) 
resnet50.fc = nn.Linear(resnet50.fc.in_features, num_classes)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet50.fc.parameters())

# Training loop...
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        # Forward pass
        outputs = resnet50(inputs)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

5. 結果の評価

最後に、テストデータセット上でモデルの性能を評価します。

correct_predictions = 0 
total_predictions   = 0 

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs     = resnet50(inputs)
        _, predicted_labels   = torch.max(outputs.data ,1)

        
total_predictions += labels.size(0)
correct_predictions += (predicted_labels == labels).sum().item()

print('Accuracy of the network on test images: %d %%' % (
    100 * correct_predictions / total_predictions))

以上がPyTorchと事前学習モデル(ResNet-50)を用いた転移学習画像分類タスクの一例です。この方法は小規模なデータセットでも高精度な結果が得られることが多いです。