ECCV 2024で発表された論文「Few-Shot Anomaly-Driven Generation for Anomaly Classification and Segmentation」の実装である、AnoGen の使用方法をまとめました。
このプロジェクトは、少数の異常画像(Few-Shot)から多様でリアルな異常画像を生成し、それを用いて異常検知モデル(DRAEMやDeSTSegなど)の精度を向上させることを目的としています。
リポジトリ: gaobb/AnoGen
1. 概要
AnoGenは、Latent Diffusion Model (LDM) をベースにしており、以下の3段階のプロセスで構成されています。
- Embedding Learning: 少数の異常画像から欠陥の特徴(埋め込み表現)を学習。
- Anomaly Generation: 正常画像とマスク、学習した埋め込みを用いて異常画像を生成。
- Anomaly Detection: 生成した画像を用いて異常検知モデルを学習。
2. 環境構築 (Environment Setup)
このプロジェクトを動かすには、Python 3.8+ と PyTorch を含む環境が必要です。高速なパッケージマネージャーである uv を使用したセットアップ例を紹介します。
仮想環境の作成
# 仮想環境の作成 (Python 3.8を指定)
uv venv .venv --python 3.8
# 仮想環境の有効化
# Windows
.venv\Scripts\activate
# Linux/macOS
source .venv/bin/activate
主要なパッケージのインストール
リポジトリ直下に requirements.txt がない場合は、主要なコンポーネント(LDM, DRAEM, DeSTSeg)に必要な以下のパッケージをインストールしてください。
# PyTorch (CUDA環境に合わせて選択)
uv pip install torch torchvision
# LDM関連
uv pip install pytorch-lightning omegaconf transformers clip
uv pip install taming-transformers-rom1504 # VQGAN用
# 画像処理・ユーティリティ
uv pip install numpy scipy scikit-image opencv-python matplotlib tqdm
3. 準備 (Preparation)
まず、データセットと事前学習済みモデルを準備します。
データセットのダウンロード (MVTec AD)
MVTec Anomaly Detection データセットをダウンロードし、配置します。
mkdir ./datasets/mvtec
cd ./datasets/mvtec
wget https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz
tar -xf mvtec_anomaly_detection.tar.xz
rm mvtec_anomaly_detection.tar.xz
事前学習済み Diffusion Model のダウンロード
LDM (Latent Diffusion Model) の事前学習済み重みをダウンロードします。
cd DIFFUSION
mkdir -p models/ldm/text2img-large/
wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt
Note: 開発者が事前に生成した異常画像データも提供されている場合があります(詳細はリポジトリのリンク参照)。
3. 使用手順
Stage 1: Few-Shot 画像からの Embedding 学習
特定の欠陥カテゴリに対して、埋め込み表現を学習させます。
特定のカテゴリのみ実行する場合の例:
cd DIFFUSION
# 例としてボトルの欠陥などを学習
python main.py \
--base configs/latent-diffusion/txt2img-1p4B-finetune.yaml \
-t --actual_resume models/ldm/text2img-large/model.ckpt \
-n v2_bottle_broken_large \
--gpus 0, \
--data_root mvtec_train_data/bottle/broken_large \
--init_word "defect"
※ train.sh スクリプトを使用すれば、全カテゴリを一括で学習可能です。
Stage 2: 異常画像の生成 (Anomaly-Driven Generation)
学習した Embedding (embeddings.pt)、正常画像、および欠陥位置を指定するマスク画像を使用して、新しい異常画像を生成します。
cd DIFFUSION
python scripts/txt2img.py \
--ddim_eta 0.0 \
--n_samples 1 \
--n_iter 2 \
--scale 10.0 \
--ddim_steps 50 \
--embedding_path "logs/bottle_broken_large/embeddings.pt" \
--ckpt_path "models/ldm/text2img-large/model.ckpt" \
--prompt "*" \
--mask_prompt "images/demo_images/mask.png" \
--image_prompt "images/demo_images/bottle.png" \
--outdir "outputs"
--embedding_path: Stage 1で学習した埋め込みファイルへのパス--image_prompt: ベースとなる正常画像--mask_prompt: 欠陥を生成したい領域を指定するマスク画像
Stage 3: 異常検知モデルの学習 (Weakly-Supervised)
生成された異常画像を使用して、既存の異常検知モデル(DRAEM や DeSTSeg)を学習させます。
DRAEM の場合:
cd DREAM
sh ./train.sh # 学習
sh ./test.sh # テスト
DeSTSeg の場合:
cd DeSTSeg
sh train.sh # 学習
sh test.sh # テスト
依存関係
このプロジェクトは、以下のリポジトリのコードや手法をベースにしています。
- LDM: CompVis/latent-diffusion
- DRAEM: VitjanZ/DRAEM
- DeSTSeg: apple/ml-destseg