8.2.7. サンプル: Stable Diffusion
Stable Diffusion は、テキストプロンプトで条件付けされた潜在表現を段階的にデノイズすることで画像を生成する、テキストから画像への生成モデルです。このサンプルでは、データセットの準備、Stable Diffusion モデルのファインチューニング、および MLSDK を使用した MN-Core2 または PFVM 上での推論の実行方法を説明します。
このサンプルでは、軽量なデモ用データセットとして Caltech 256 を、ベースモデルとして CompVis/stable-diffusion-v1-4 を使用します。学習サンプルでは既定で UNet コンポーネントをファインチューニングし、後続の推論で再利用できるよう、生成されたモデルを Diffusers 互換形式で保存します。
The source code for this example is located at /opt/pfn/pfcomp/codegen/MLSDK/examples/stable_diffusion.
注釈
この学習および評価ワークフローで使用するモデルは CompVis/stable-diffusion-v1-4 です。ライセンスは CreativeML Open RAIL-M です。このサンプルには、必要に応じてモデルを自動的にダウンロードする処理が含まれています。
注釈
このサンプルでは、データセットとして Caltech 256 の改変版を使用します。データセットは CC BY 4.0 で配布されています。具体的な変更内容は Dataset Preparation で説明しています。
このサンプルは次の 3 つの段階で構成されています:
8.2.7.1. データセット準備
preparation.sh と preparation.py は、元の Caltech 256 アーカイブを、学習スクリプトが使用する Hugging Face の imagefolder dataset loader が期待するディレクトリ構造に変換します。
preparation.sh は全体の処理を次のように実行します:
${dataset_dir}配下に Caltech 256 アーカイブがまだ存在しない場合はダウンロードしますアーカイブを
${dataset_dir}/data/trainに展開しますデータセットがまだ変換されていない場合は
preparation.pyを実行します
preparation.py は続いて、展開済みのデータセットをこのサンプルに適した形式へ正規化します:
元の Caltech 256 アーカイブに含まれる数値プレフィックスを取り除いて、各カテゴリディレクトリ名を変更します
カテゴリ名に
-101接尾辞が付いている場合は削除します画像ファイル名は末尾のファイル名部分だけが残るようにリネームします
想定される命名規則に一致しないファイルは削除します
各カテゴリディレクトリ内に
metadata.csvファイルを生成します
生成される各 metadata.csv には file_name 列と caption 列が含まれます。caption にはカテゴリ名が設定されるため、このデータセットはそのままテキスト画像生成のファインチューニングに使用できます。たとえば、airplanes ディレクトリ内のすべての画像には airplanes という caption が設定されます。
--target_folders を指定すると、列挙したカテゴリだけを残し、前処理中にそれ以外の展開済みカテゴリを削除します。これは準備時間を短縮したい場合や、より小さいクラス集合で学習したい場合に便利です。
8.2.7.1.1. 使用方法
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/stable_diffusion
$ dataset_dir=/path/to/dataset ./preparation.sh [--target_folders [FOLDER_NAMES]]
パラメータ:
dataset_dir: ダウンロードしたアーカイブと準備済みデータセットの両方を保存するディレクトリです。準備後、学習画像は
/path/to/dataset/data/trainに配置されます。target_folders: (省略可能) 準備対象を特定の Caltech 256 オブジェクトカテゴリに限定します。たとえば
--target_folders airplanes mushroomを指定すると、それらのカテゴリだけを準備します。省略した場合はすべてのカテゴリを準備します。有効な名前については name_list.txt を参照してください。
8.2.7.2. 学習
stable_diffusion_training.sh は実行環境を準備してから stable_diffusion_training.py を起動します。このラッパースクリプトは Python 仮想環境を作成または再利用し、サンプルの依存パッケージをインストールし、必要な環境変数を設定したうえで、指定されたコマンドライン引数とともに学習スクリプトを呼び出します。
stable_diffusion_training.py は実際の学習ワークフローとして次を実行します:
Hugging Face
datasetsのimagefolderローダーを使って準備済みデータセットを読み込みますデータセット準備時に作成されたクラスごとの
metadata.csvから caption を読み取ります入力画像を設定された解像度にリサイズし、ランダムクロップします
設定されたモデルパスから Stable Diffusion の各コンポーネントを読み込みます
MLSDK でコンパイルされた関数を生成します
既定では UNet をファインチューニングし、必要に応じて結果のモデルを Diffusers 形式で保存します
既定では、このスクリプトは 1 エポックの短いデモ実行を行い、学習済みモデルを <out_dir>/model に保存します。出力ディレクトリには、実行中に生成されたコンパイルキャッシュや計測用アーティファクトが含まれる場合もあります。
学習スクリプトは次のバックエンドをサポートします:
mncore2:0: 最初の MN-Core2 デバイス上で実行しますpfvm:cpu: PFVM を介して CPU 上で実行します
実行時の既定値は configs.toml に定義されています。学習スクリプトと評価スクリプトは apply_toml_defaults() を呼び出すため、configs.toml の各項目はコマンドラインオプションとして公開され、シェルから上書きできます。
8.2.7.2.1. 使用方法
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/stable_diffusion
$ dataset_dir=/path/to/dataset ./stable_diffusion_training.sh \
--backend mncore2:0 \
--outdir /path/to/train/output
パラメータ:
dataset_dir: 準備済みデータセットのルートディレクトリへのパスです。ラッパーは
${dataset_dir}/dataを学習スクリプトに渡します。backend: 学習に使用するバックエンドです。有効な値は
mncore2:0とpfvm:cpuです。outdir: 学習結果を保存するディレクトリです。省略した場合は
/tmp配下にバックエンド別のディレクトリが自動的に作成されます。
8.2.7.2.2. 出力
学習が成功すると、出力ディレクトリには次の内容が含まれます:
model/:save_pretrained()で保存されたファインチューニング済み Diffusers モデルtext_encoder/、unet/、vae_encoder/などのコンポーネントごとのコード生成ディレクトリtext_encoder_cache/、unet_cache/、vae_encoder_cache/などのコンポーネントごとのキャッシュディレクトリ各生成コンポーネントディレクトリ配下のコンパイラレポート、レイアウトダンプ、トレース、および関連アーティファクト
8.2.7.3. 推論
stable_diffusion_eval.sh は学習スクリプトと同じ Python 環境を準備してから stable_diffusion_eval.py を起動します。学習済みモデルを読み込む際には、サンプルがインストール済みの Diffusers のバージョンや関連 Python パッケージに依存するため、同じ環境を使うことが重要です。学習側と同様に、stable_diffusion_eval.py も apply_toml_defaults() を通じて configs.toml を読み込むため、configs.toml に記載された既定値は推論にも適用されます。
stable_diffusion_eval.py は次の処理を実行します:
設定されたモデルパスから Stable Diffusion パイプラインを読み込みます
必要に応じて safety checker を無効化します
text encoder、UNet、および VAE decoder 用のコンパイル済み評価関数を生成します
指定したプロンプトに対する画像を生成します
選択した出力ディレクトリに生成画像を
output_eval.pngとして保存します
safety checker は、デコード後の画像に NSFW コンテンツが含まれる可能性がないかを検査する、Stable Diffusion 標準の後処理コンポーネントです。このサンプルでは、その動作は configs.toml 内の skip_safety_check で制御されます。現在の既定値ではチェックをスキップします。これにより、safety checker が生成結果を抑制し、サンプルとして確認できる画像の代わりに実質的に真っ黒な画像が出力されるケースを避けています。追加のコンテンツフィルタリング手順が必要な場合は、推論実行前に skip_safety_check = false を設定してください。
model_path が学習ステップで保存したディレクトリを指している場合、推論ではそのファインチューニング済みモデルを使用します。指定しない場合は、configs.toml に定義された既定のベースモデルを使用します。
8.2.7.3.1. 使用方法
$ cd /opt/pfn/pfcomp/codegen/MLSDK/examples/stable_diffusion
$ ./stable_diffusion_eval.sh \
--backend mncore2:0 \
--outdir /path/to/inference/output \
[--model_path /path/to/train/output/model] \
[--prompt "your text prompt here"]
パラメータ:
backend: 推論に使用するバックエンドです。有効な値は
mncore2:0とpfvm:cpuです。outdir: 推論アーティファクトを保存するディレクトリです。生成画像は
/path/to/inference/output/output_eval.pngとして保存されます。model_path: (省略可能) ファインチューニング済みモデルディレクトリへのパスです。通常は
<training outdir>/modelを指定します。省略した場合はconfigs.tomlの既定モデルが使用されます。prompt: (省略可能) 画像生成を誘導するテキストプロンプトです。省略した場合の既定プロンプトは
dogです。
8.2.7.3.2. 出力
推論が成功すると、出力ディレクトリには次の内容が含まれます:
output_eval.png: 生成された画像text_encoder_eval_eval/、unet_eval_eval/、vae_decoder_eval/などのコンポーネントごとの評価コード生成ディレクトリtext_encoder_eval_cache/、unet_eval_cache/、vae_decoder_cache/などのコンポーネントごとのキャッシュディレクトリ各生成済み評価コンポーネントディレクトリ配下のコンパイラレポート、レイアウトダンプ、トレース、および関連アーティファクト
図 8.6 プロンプト dog に対して MN-Core2 上で生成した画像
8.2.7.4. 付録
8.2.7.4.1. configs.toml
title = "stable_diffusion_training"
[mlsdk]
num_compiler_threads = -1
do_quiet_compilation = false
skip_text_encoder_compilation = false
skip_unet_compilation = false
skip_vae_encoder_compilation = false
skip_vae_decoder_compilation = false
[model]
model_path = "CompVis/stable-diffusion-v1-4"
# Can also set model_path to fine-tuned model created by `stable_diffusion_training.py`
# For example,
# ./stable_diffusion_training.py --outdir /tmp/sd_train
# Then, the model will be saved to
# model_path = "/tmp/sd_train/model"
height = 512 # height for input image
width = 512 # width for input image
do_lora = false
lora_rank = 4
init_lora_weights = ""
[dataset]
data_cache_dir = ""
[training]
epoch = 1
save_model = true
optimizer = "sgd"
learning_rate = 1e-4
momentum = 0.2
weight_decay = 1e-2
lr_scheduler = "constant"
use_mncore_lr_scheduler = false
lr_warmup = 500
use_pretrained_unet = true # Fine-tune pretrained UNet
[inference]
guidance_scale = 7.5
skip_safety_check = true # Set to `true` to prevent black image generation
[misc]
seed = 0
batch_size = 1
8.2.7.4.2. name_list.txt
airplanes,
ak47,
american-flag,
backpack,
baseball-bat,
baseball-glove,
basketball-hoop,
bat,
bathtub,
bear,
beer-mug,
billiards,
binoculars,
birdbath,
blimp,
bonsai,
boom-box,
bowling-ball,
bowling-pin,
boxing-glove,
brain,
breadmaker,
buddha,
bulldozer,
butterfly,
cactus,
cake,
calculator,
camel,
cannon,
canoe,
car-side,
car-tire,
cartman,
cd,
centipede,
cereal-box,
chandelier,
chess-board,
chimp,
chopsticks,
clutter,
cockroach,
coffee-mug,
coffin,
coin,
comet,
computer-keyboard,
computer-monitor,
computer-mouse,
conch,
cormorant,
covered-wagon,
cowboy-hat,
crab,
desk-globe,
diamond-ring,
dice,
dog,
dolphin,
doorknob,
drinking-straw,
duck,
dumb-bell,
eiffel-tower,
electric-guitar,
elephant,
elk,
ewer,
eyeglasses,
faces-easy,
fern,
fighter-jet,
fire-extinguisher,
fire-hydrant,
fire-truck,
fireworks,
flashlight,
floppy-disk,
football-helmet,
french-horn,
fried-egg,
frisbee,
frog,
frying-pan,
galaxy,
gas-pump,
giraffe,
goat,
golden-gate-bridge,
goldfish,
golf-ball,
goose,
gorilla,
grand-piano,
grapes,
grasshopper,
greyhound,
guitar-pick,
hamburger,
hammock,
harmonica,
harp,
harpsichord,
hawksbill,
head-phones,
helicopter,
hibiscus,
homer-simpson,
horse,
horseshoe-crab,
hot-air-balloon,
hot-dog,
hot-tub,
hourglass,
house-fly,
human-skeleton,
hummingbird,
ibis,
ice-cream-cone,
iguana,
ipod,
iris,
jesus-christ,
joy-stick,
kangaroo,
kayak,
ketch,
killer-whale,
knife,
ladder,
laptop,
lathe,
leopards,
license-plate,
light-house,
lightbulb,
lightning,
llama,
mailbox,
mandolin,
mars,
mattress,
megaphone,
menorah,
microscope,
microwave,
minaret,
minotaur,
motorbikes,
mountain-bike,
mushroom,
mussels,
necktie,
octopus,
ostrich,
owl,
palm-pilot,
palm-tree,
paper-shredder,
paperclip,
pci-card,
penguin,
people,
pez-dispenser,
photocopier,
picnic-table,
playing-card,
porcupine,
pram,
praying-mantis,
pyramid,
raccoon,
radio-telescope,
rainbow,
refrigerator,
revolver,
rifle,
rotary-phone,
roulette-wheel,
saddle,
saturn,
school-bus,
scorpion,
screwdriver,
segway,
self-propelled-lawn-mower,
sextant,
sheet-music,
skateboard,
skunk,
skyscraper,
smokestack,
snail,
snake,
sneaker,
snowmobile,
soccer-ball,
socks,
soda-can,
spaghetti,
speed-boat,
spider,
spoon,
stained-glass,
starfish,
steering-wheel,
stirrups,
sunflower,
superman,
sushi,
swan,
swiss-army-knife,
sword,
syringe,
t-shirt,
tambourine,
teapot,
teddy-bear,
teepee,
telephone-box,
tennis-ball,
tennis-court,
tennis-racket,
tennis-shoes,
theodolite,
toad,
toaster,
tomato,
tombstone,
top-hat,
touring-bike,
tower-pisa,
traffic-light,
treadmill,
triceratops,
tricycle,
trilobite,
tripod,
tuning-fork,
tweezer,
umbrella,
unicorn,
vcr,
video-projector,
washing-machine,
watch,
waterfall,
watermelon,
welding-mask,
wheelbarrow,
windmill,
wine-bottle,
xylophone,
yarmulke,
yo-yo,
zebra,