【GitHubリポジトリ紹介】AnoGen: 少数の異常画像からの生成と検知

ECCV 2024で発表された論文「Few-Shot Anomaly-Driven Generation for Anomaly Classification and Segmentation」の実装である、AnoGen の使用方法をまとめました。

このプロジェクトは、少数の異常画像(Few-Shot)から多様でリアルな異常画像を生成し、それを用いて異常検知モデル(DRAEMやDeSTSegなど)の精度を向上させることを目的としています。

リポジトリ: gaobb/AnoGen

1. 概要

AnoGenは、Latent Diffusion Model (LDM) をベースにしており、以下の3段階のプロセスで構成されています。

  1. Embedding Learning: 少数の異常画像から欠陥の特徴(埋め込み表現)を学習。
  2. Anomaly Generation: 正常画像とマスク、学習した埋め込みを用いて異常画像を生成。
  3. Anomaly Detection: 生成した画像を用いて異常検知モデルを学習。

2. 環境構築 (Environment Setup)

このプロジェクトを動かすには、Python 3.8+ と PyTorch を含む環境が必要です。高速なパッケージマネージャーである uv を使用したセットアップ例を紹介します。

仮想環境の作成

# 仮想環境の作成 (Python 3.8を指定)
uv venv .venv --python 3.8

# 仮想環境の有効化
# Windows
.venv\Scripts\activate
# Linux/macOS
source .venv/bin/activate

主要なパッケージのインストール

リポジトリ直下に requirements.txt がない場合は、主要なコンポーネント(LDM, DRAEM, DeSTSeg)に必要な以下のパッケージをインストールしてください。

# PyTorch (CUDA環境に合わせて選択)
uv pip install torch torchvision

# LDM関連
uv pip install pytorch-lightning omegaconf transformers clip
uv pip install taming-transformers-rom1504  # VQGAN用

# 画像処理・ユーティリティ
uv pip install numpy scipy scikit-image opencv-python matplotlib tqdm

3. 準備 (Preparation)

まず、データセットと事前学習済みモデルを準備します。

データセットのダウンロード (MVTec AD)

MVTec Anomaly Detection データセットをダウンロードし、配置します。

mkdir ./datasets/mvtec
cd ./datasets/mvtec
wget https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz
tar -xf mvtec_anomaly_detection.tar.xz
rm mvtec_anomaly_detection.tar.xz

事前学習済み Diffusion Model のダウンロード

LDM (Latent Diffusion Model) の事前学習済み重みをダウンロードします。

cd DIFFUSION
mkdir -p models/ldm/text2img-large/
wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt

Note: 開発者が事前に生成した異常画像データも提供されている場合があります(詳細はリポジトリのリンク参照)。

3. 使用手順

Stage 1: Few-Shot 画像からの Embedding 学習

特定の欠陥カテゴリに対して、埋め込み表現を学習させます。

特定のカテゴリのみ実行する場合の例:

cd DIFFUSION
# 例としてボトルの欠陥などを学習
python main.py \
 --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
 -t --actual_resume models/ldm/text2img-large/model.ckpt \
 -n v2_bottle_broken_large \
 --gpus 0, \
 --data_root mvtec_train_data/bottle/broken_large \
 --init_word "defect"

train.sh スクリプトを使用すれば、全カテゴリを一括で学習可能です。

Stage 2: 異常画像の生成 (Anomaly-Driven Generation)

学習した Embedding (embeddings.pt)、正常画像、および欠陥位置を指定するマスク画像を使用して、新しい異常画像を生成します。

cd DIFFUSION
python scripts/txt2img.py \
 --ddim_eta 0.0 \
 --n_samples 1 \
 --n_iter 2 \
 --scale 10.0 \
 --ddim_steps 50 \
 --embedding_path "logs/bottle_broken_large/embeddings.pt" \
 --ckpt_path "models/ldm/text2img-large/model.ckpt" \
 --prompt "*" \
 --mask_prompt "images/demo_images/mask.png" \
 --image_prompt "images/demo_images/bottle.png" \
 --outdir "outputs"
  • --embedding_path: Stage 1で学習した埋め込みファイルへのパス
  • --image_prompt: ベースとなる正常画像
  • --mask_prompt: 欠陥を生成したい領域を指定するマスク画像

Stage 3: 異常検知モデルの学習 (Weakly-Supervised)

生成された異常画像を使用して、既存の異常検知モデル(DRAEM や DeSTSeg)を学習させます。

DRAEM の場合:

cd DREAM
sh ./train.sh  # 学習
sh ./test.sh   # テスト

DeSTSeg の場合:

cd DeSTSeg
sh train.sh    # 学習
sh test.sh     # テスト

依存関係

このプロジェクトは、以下のリポジトリのコードや手法をベースにしています。