stable diffusion
17
LICENSE
@ -1,16 +1,9 @@
|
|||||||
MIT License
|
All rights reserved by the authors.
|
||||||
|
You must not distribute the weights provided to you directly or indirectly without explicit consent of the authors.
|
||||||
|
You must not distribute harmful, offensive, dehumanizing content or otherwise harmful representations of people or their environments, cultures, religions, etc. produced with the model weights
|
||||||
|
or other generated content described in the "Misuse and Malicious Use" section in the model card.
|
||||||
|
The model weights are provided for research purposes only.
|
||||||
|
|
||||||
Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
269
README.md
@ -1,11 +1,5 @@
|
|||||||
# Latent Diffusion Models
|
# Stable Diffusion
|
||||||
[arXiv](https://arxiv.org/abs/2112.10752) | [BibTeX](#bibtex)
|
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<img src=assets/results.gif />
|
|
||||||
</p>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)<br/>
|
[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)<br/>
|
||||||
[Robin Rombach](https://github.com/rromb)\*,
|
[Robin Rombach](https://github.com/rromb)\*,
|
||||||
@ -13,11 +7,18 @@
|
|||||||
[Dominik Lorenz](https://github.com/qp-qp)\,
|
[Dominik Lorenz](https://github.com/qp-qp)\,
|
||||||
[Patrick Esser](https://github.com/pesser),
|
[Patrick Esser](https://github.com/pesser),
|
||||||
[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
|
[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)<br/>
|
||||||
\* equal contribution
|
|
||||||
|
|
||||||
<p align="center">
|
which is available on [GitHub](https://github.com/CompVis/latent-diffusion).
|
||||||
<img src=assets/modelfigure.png />
|
|
||||||
</p>
|
![txt2img-stable2](assets/stable-samples/txt2img/merged-0006.png)
|
||||||
|
[Stable Diffusion](#stable-diffusion-v1) is a latent text-to-image diffusion
|
||||||
|
model.
|
||||||
|
Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
|
||||||
|
Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
|
||||||
|
this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
|
||||||
|
With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
|
||||||
|
See [this section](#stable-diffusion-v1) below and the [model card](https://huggingface.co/CompVis/stable-diffusion).
|
||||||
|
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
A suitable [conda](https://conda.io/) environment named `ldm` can be created
|
A suitable [conda](https://conda.io/) environment named `ldm` can be created
|
||||||
@ -28,176 +29,135 @@ conda env create -f environment.yaml
|
|||||||
conda activate ldm
|
conda activate ldm
|
||||||
```
|
```
|
||||||
|
|
||||||
# Model Zoo
|
You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
|
||||||
|
|
||||||
## Pretrained Autoencoding Models
|
```
|
||||||
![rec2](assets/reconstruction2.png)
|
conda install pytorch torchvision -c pytorch
|
||||||
|
pip install transformers==4.19.2
|
||||||
All models were trained until convergence (no further substantial improvement in rFID).
|
pip install -e .
|
||||||
|
|
||||||
| Model | rFID vs val | train steps |PSNR | PSIM | Link | Comments
|
|
||||||
|-------------------------|------------|----------------|----------------|---------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------|
|
|
||||||
| f=4, VQ (Z=8192, d=3) | 0.58 | 533066 | 27.43 +/- 4.26 | 0.53 +/- 0.21 | https://ommer-lab.com/files/latent-diffusion/vq-f4.zip | |
|
|
||||||
| f=4, VQ (Z=8192, d=3) | 1.06 | 658131 | 25.21 +/- 4.17 | 0.72 +/- 0.26 | https://heibox.uni-heidelberg.de/f/9c6681f64bb94338a069/?dl=1 | no attention |
|
|
||||||
| f=8, VQ (Z=16384, d=4) | 1.14 | 971043 | 23.07 +/- 3.99 | 1.17 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | |
|
|
||||||
| f=8, VQ (Z=256, d=4) | 1.49 | 1608649 | 22.35 +/- 3.81 | 1.26 +/- 0.37 | https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip |
|
|
||||||
| f=16, VQ (Z=16384, d=8) | 5.15 | 1101166 | 20.83 +/- 3.61 | 1.73 +/- 0.43 | https://heibox.uni-heidelberg.de/f/0e42b04e2e904890a9b6/?dl=1 | |
|
|
||||||
| | | | | | | |
|
|
||||||
| f=4, KL | 0.27 | 176991 | 27.53 +/- 4.54 | 0.55 +/- 0.24 | https://ommer-lab.com/files/latent-diffusion/kl-f4.zip | |
|
|
||||||
| f=8, KL | 0.90 | 246803 | 24.19 +/- 4.19 | 1.02 +/- 0.35 | https://ommer-lab.com/files/latent-diffusion/kl-f8.zip | |
|
|
||||||
| f=16, KL (d=16) | 0.87 | 442998 | 24.08 +/- 4.22 | 1.07 +/- 0.36 | https://ommer-lab.com/files/latent-diffusion/kl-f16.zip | |
|
|
||||||
| f=32, KL (d=64) | 2.04 | 406763 | 22.27 +/- 3.93 | 1.41 +/- 0.40 | https://ommer-lab.com/files/latent-diffusion/kl-f32.zip | |
|
|
||||||
|
|
||||||
### Get the models
|
|
||||||
|
|
||||||
Running the following script downloads und extracts all available pretrained autoencoding models.
|
|
||||||
```shell script
|
|
||||||
bash scripts/download_first_stages.sh
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The first stage models can then be found in `models/first_stage_models/<model_spec>`
|
|
||||||
|
## Stable Diffusion v1
|
||||||
|
|
||||||
|
Stable Diffusion v1 refers to a specific configuration of the model
|
||||||
|
architecture that uses a downsampling-factor 8 autoencoder with an 860M UNet
|
||||||
|
and CLIP ViT-L/14 text encoder for the diffusion model. The model was pretrained on 256x256 images and
|
||||||
|
then finetuned on 512x512 images.
|
||||||
|
|
||||||
|
*Note: Stable Diffusion v1 is a general text-to-image diffusion model and therefore mirrors biases and (mis-)conceptions that are present
|
||||||
|
in its training data.
|
||||||
|
Details on the training procedure and data, as well as the intended use of the model can be found in the corresponding [model card](https://huggingface.co/CompVis/stable-diffusion).
|
||||||
|
Research into the safe deployment of general text-to-image models is an ongoing effort. To prevent misuse and harm, we currently provide access to the checkpoints only for [academic research purposes upon request](TODO).
|
||||||
|
**This is an experiment in safe and community-driven publication of a capable and general text-to-image model. We are working on a public release with a more permissive license that also incorporates ethical considerations.***
|
||||||
|
|
||||||
|
[Request access to Stable Diffusion v1 checkpoints for academic research](TODO)
|
||||||
|
|
||||||
|
### Weights
|
||||||
|
|
||||||
|
We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`,
|
||||||
|
which were trained as follows,
|
||||||
|
|
||||||
|
- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
|
||||||
|
194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
|
||||||
|
- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
|
||||||
|
515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
|
||||||
|
filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
|
||||||
|
- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
||||||
|
|
||||||
|
Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
|
||||||
|
5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
|
||||||
|
steps show the relative improvements of the checkpoints:
|
||||||
|
![sd evaluation results](assets/v1-variants-scores.jpg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Pretrained LDMs
|
### Text-to-Image with Stable Diffusion
|
||||||
| Datset | Task | Model | FID | IS | Prec | Recall | Link | Comments
|
![txt2img-stable2](assets/stable-samples/txt2img/merged-0005.png)
|
||||||
|---------------------------------|------|--------------|---------------|-----------------|------|------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------|
|
![txt2img-stable2](assets/stable-samples/txt2img/merged-0007.png)
|
||||||
| CelebA-HQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0)| 5.11 (5.11) | 3.29 | 0.72 | 0.49 | https://ommer-lab.com/files/latent-diffusion/celeba.zip | |
|
|
||||||
| FFHQ | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 4.98 (4.98) | 4.50 (4.50) | 0.73 | 0.50 | https://ommer-lab.com/files/latent-diffusion/ffhq.zip | |
|
|
||||||
| LSUN-Churches | Unconditional Image Synthesis | LDM-KL-8 (400 DDIM steps, eta=0)| 4.02 (4.02) | 2.72 | 0.64 | 0.52 | https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip | |
|
|
||||||
| LSUN-Bedrooms | Unconditional Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=1)| 2.95 (3.0) | 2.22 (2.23)| 0.66 | 0.48 | https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip | |
|
|
||||||
| ImageNet | Class-conditional Image Synthesis | LDM-VQ-8 (200 DDIM steps, eta=1) | 7.77(7.76)* /15.82** | 201.56(209.52)* /78.82** | 0.84* / 0.65** | 0.35* / 0.63** | https://ommer-lab.com/files/latent-diffusion/cin.zip | *: w/ guiding, classifier_scale 10 **: w/o guiding, scores in bracket calculated with script provided by [ADM](https://github.com/openai/guided-diffusion) |
|
|
||||||
| Conceptual Captions | Text-conditional Image Synthesis | LDM-VQ-f4 (100 DDIM steps, eta=0) | 16.79 | 13.89 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/text2img.zip | finetuned from LAION |
|
|
||||||
| OpenImages | Super-resolution | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip | BSR image degradation |
|
|
||||||
| OpenImages | Layout-to-Image Synthesis | LDM-VQ-4 (200 DDIM steps, eta=0) | 32.02 | 15.92 | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip | |
|
|
||||||
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip | |
|
|
||||||
| Landscapes | Semantic Image Synthesis | LDM-VQ-4 | N/A | N/A | N/A | N/A | https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip | finetuned on resolution 512x512 |
|
|
||||||
|
|
||||||
|
Stable Diffusion is a latent diffusion model conditioned on the (non-pooled) text embeddings of a CLIP ViT-L/14 text encoder.
|
||||||
|
|
||||||
### Get the models
|
After [obtaining the weights](#weights), link them
|
||||||
|
|
||||||
The LDMs listed above can jointly be downloaded and extracted via
|
|
||||||
|
|
||||||
```shell script
|
|
||||||
bash scripts/download_models.sh
|
|
||||||
```
|
```
|
||||||
|
mkdir -p models/ldm/stable-diffusion-v1/
|
||||||
The models can then be found in `models/ldm/<model_spec>`.
|
ln -s <path/to/model.ckpt> models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
|
|
||||||
### Sampling with unconditional models
|
|
||||||
|
|
||||||
We provide a first script for sampling from our unconditional models. Start it via
|
|
||||||
|
|
||||||
```shell script
|
|
||||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python scripts/sample_diffusion.py -r models/ldm/<model_spec>/model.ckpt -l <logdir> -n <\#samples> --batch_size <batch_size> -c <\#ddim steps> -e <\#eta>
|
|
||||||
```
|
```
|
||||||
|
|
||||||
# Inpainting
|
|
||||||
![inpainting](assets/inpainting.png)
|
|
||||||
|
|
||||||
Download the pre-trained weights
|
|
||||||
```
|
|
||||||
wget -O models/ldm/inpainting_big/last.ckpt https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1
|
|
||||||
```
|
|
||||||
|
|
||||||
and sample with
|
and sample with
|
||||||
```
|
```
|
||||||
python scripts/inpaint.py --indir data/inpainting_examples/ --outdir outputs/inpainting_results
|
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
|
||||||
```
|
```
|
||||||
`indir` should contain images `*.png` and masks `<image_fname>_mask.png` like
|
By default, this uses a guidance scale of `--scale 7.5`, [Katherine Crowson's implementation](https://github.com/CompVis/latent-diffusion/pull/51) of the [PLMS](https://arxiv.org/abs/2202.09778) sampler,
|
||||||
the examples provided in `data/inpainting_examples`.
|
and renders images of size 512x512 (which it was trained on) in 50 steps. All supported arguments are listed below (type `python scripts/txt2img.py --help`).
|
||||||
|
|
||||||
|
```commandline
|
||||||
|
usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA] [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS]
|
||||||
|
[--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT] [--seed SEED] [--precision {full,autocast}]
|
||||||
|
|
||||||
# Train your own LDMs
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
## Data preparation
|
--prompt [PROMPT] the prompt to render
|
||||||
|
--outdir [OUTDIR] dir to write results to
|
||||||
### Faces
|
--skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples
|
||||||
For downloading the CelebA-HQ and FFHQ datasets, proceed as described in the [taming-transformers](https://github.com/CompVis/taming-transformers#celeba-hq)
|
--skip_save do not save individual samples. For speed measurements.
|
||||||
repository.
|
--ddim_steps DDIM_STEPS
|
||||||
|
number of ddim sampling steps
|
||||||
### LSUN
|
--plms use plms sampling
|
||||||
|
--laion400m uses the LAION400M model
|
||||||
The LSUN datasets can be conveniently downloaded via the script available [here](https://github.com/fyu/lsun).
|
--fixed_code if enabled, uses the same starting code across samples
|
||||||
We performed a custom split into training and validation images, and provide the corresponding filenames
|
--ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling
|
||||||
at [https://ommer-lab.com/files/lsun.zip](https://ommer-lab.com/files/lsun.zip).
|
--n_iter N_ITER sample this often
|
||||||
After downloading, extract them to `./data/lsun`. The beds/cats/churches subsets should
|
--H H image height, in pixel space
|
||||||
also be placed/symlinked at `./data/lsun/bedrooms`/`./data/lsun/cats`/`./data/lsun/churches`, respectively.
|
--W W image width, in pixel space
|
||||||
|
--C C latent channels
|
||||||
### ImageNet
|
--f F downsampling factor
|
||||||
The code will try to download (through [Academic
|
--n_samples N_SAMPLES
|
||||||
Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
|
how many samples to produce for each given prompt. A.k.a. batch size
|
||||||
is used. However, since ImageNet is quite large, this requires a lot of disk
|
--n_rows N_ROWS rows in the grid (default: n_samples)
|
||||||
space and time. If you already have ImageNet on your disk, you can speed things
|
--scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))
|
||||||
up by putting the data into
|
--from-file FROM_FILE
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
|
if specified, load prompts from this file
|
||||||
`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
|
--config CONFIG path to config which constructs model
|
||||||
of `train`/`validation`. It should have the following structure:
|
--ckpt CKPT path to checkpoint of model
|
||||||
|
--seed SEED the seed (for reproducible sampling)
|
||||||
|
--precision {full,autocast}
|
||||||
|
evaluate at this precision
|
||||||
|
|
||||||
```
|
```
|
||||||
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
|
Note: The inference config for all v1 versions is designed to be used with EMA-only checkpoints.
|
||||||
├── n01440764
|
For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
|
||||||
│ ├── n01440764_10026.JPEG
|
non-EMA to EMA weights. If you want to examine the effect of EMA vs no EMA, we provide "full" checkpoints
|
||||||
│ ├── n01440764_10027.JPEG
|
which contain both types of weights. For these, `use_ema=False` will load and use the non-EMA weights.
|
||||||
│ ├── ...
|
|
||||||
├── n01443537
|
|
||||||
│ ├── n01443537_10007.JPEG
|
### Image Modification with Stable Diffusion
|
||||||
│ ├── n01443537_10014.JPEG
|
|
||||||
│ ├── ...
|
By using a diffusion-denoising mechanism as first proposed by [SDEdit](https://arxiv.org/abs/2108.01073), the model can be used for different
|
||||||
├── ...
|
tasks such as text-guided image-to-image translation and upscaling. Similar to the txt2img sampling script,
|
||||||
|
we provide a script to perform image modification with Stable Diffusion.
|
||||||
|
|
||||||
|
The following describes an example where a rough sketch made in [Pinta](https://www.pinta-project.com/) is converted into a detailed artwork.
|
||||||
```
|
```
|
||||||
|
python scripts/img2img.py --prompt "A fantasy landscape, trending on artstation" --init-img <path-to-img.jpg> --strength 0.8
|
||||||
If you haven't extracted the data, you can also place
|
|
||||||
`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
|
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
|
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
|
|
||||||
extracted into above structure without downloading it again. Note that this
|
|
||||||
will only happen if neither a folder
|
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
|
|
||||||
`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
|
|
||||||
if you want to force running the dataset preparation again.
|
|
||||||
|
|
||||||
|
|
||||||
## Model Training
|
|
||||||
|
|
||||||
Logs and checkpoints for trained models are saved to `logs/<START_DATE_AND_TIME>_<config_spec>`.
|
|
||||||
|
|
||||||
### Training autoencoder models
|
|
||||||
|
|
||||||
Configs for training a KL-regularized autoencoder on ImageNet are provided at `configs/autoencoder`.
|
|
||||||
Training can be started by running
|
|
||||||
```
|
```
|
||||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/autoencoder/<config_spec>.yaml -t --gpus 0,
|
Here, strength is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
|
||||||
```
|
Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input. See the following example.
|
||||||
where `config_spec` is one of {`autoencoder_kl_8x8x64`(f=32, d=64), `autoencoder_kl_16x16x16`(f=16, d=16),
|
|
||||||
`autoencoder_kl_32x32x4`(f=8, d=4), `autoencoder_kl_64x64x3`(f=4, d=3)}.
|
|
||||||
|
|
||||||
For training VQ-regularized models, see the [taming-transformers](https://github.com/CompVis/taming-transformers)
|
**Input**
|
||||||
repository.
|
|
||||||
|
|
||||||
### Training LDMs
|
![sketch-in](assets/stable-samples/img2img/sketch-mountains-input.jpg)
|
||||||
|
|
||||||
In ``configs/latent-diffusion/`` we provide configs for training LDMs on the LSUN-, CelebA-HQ, FFHQ and ImageNet datasets.
|
**Outputs**
|
||||||
Training can be started by running
|
|
||||||
|
|
||||||
```shell script
|
![out3](assets/stable-samples/img2img/mountains-3.png)
|
||||||
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/latent-diffusion/<config_spec>.yaml -t --gpus 0,
|
![out2](assets/stable-samples/img2img/mountains-2.png)
|
||||||
```
|
|
||||||
|
|
||||||
where ``<config_spec>`` is one of {`celebahq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),`ffhq-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
|
This procedure can, for example, also be used to upscale samples from the base model.
|
||||||
`lsun_bedrooms-ldm-vq-4`(f=4, VQ-reg. autoencoder, spatial size 64x64x3),
|
|
||||||
`lsun_churches-ldm-vq-4`(f=8, KL-reg. autoencoder, spatial size 32x32x4),`cin-ldm-vq-8`(f=8, VQ-reg. autoencoder, spatial size 32x32x4)}.
|
|
||||||
|
|
||||||
## Coming Soon...
|
|
||||||
|
|
||||||
* More inference scripts for conditional LDMs.
|
|
||||||
* In the meantime, you can play with our colab notebook https://colab.research.google.com/drive/1xqzUi2iXQXDqXBHQGP9Mqt2YrYW6cx-J?usp=sharing
|
|
||||||
* We will also release some further pretrained models.
|
|
||||||
|
|
||||||
|
|
||||||
## Comments
|
## Comments
|
||||||
|
|
||||||
- Our codebase for the diffusion models builds heavily on [OpenAI's codebase](https://github.com/openai/guided-diffusion)
|
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
||||||
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
|
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
|
||||||
Thanks for open-sourcing!
|
Thanks for open-sourcing!
|
||||||
|
|
||||||
@ -215,6 +175,7 @@ Thanks for open-sourcing!
|
|||||||
archivePrefix={arXiv},
|
archivePrefix={arXiv},
|
||||||
primaryClass={cs.CV}
|
primaryClass={cs.CV}
|
||||||
}
|
}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
140
Stable_Diffusion_v1_Model_Card.md
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
# Stable Diffusion v1 Model Card
|
||||||
|
This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion).
|
||||||
|
|
||||||
|
## Model Details
|
||||||
|
- **Developed by:** Robin Rombach, Patrick Esser
|
||||||
|
- **Model type:** Diffusion-based text-to-image generation model
|
||||||
|
- **Language(s):** English
|
||||||
|
- **License:** [Proprietary](LICENSE)
|
||||||
|
- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
|
||||||
|
- **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
|
||||||
|
- **Cite as:**
|
||||||
|
|
||||||
|
@InProceedings{Rombach_2022_CVPR,
|
||||||
|
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
||||||
|
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
||||||
|
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
month = {June},
|
||||||
|
year = {2022},
|
||||||
|
pages = {10684-10695}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Uses
|
||||||
|
|
||||||
|
## Direct Use
|
||||||
|
The model is intended for research purposes only. Possible research areas and
|
||||||
|
tasks include
|
||||||
|
|
||||||
|
- Safe deployment of models which have the potential to generate harmful content.
|
||||||
|
- Probing and understanding the limitations and biases of generative models.
|
||||||
|
- Generation of artworks and use in design and other artistic processes.
|
||||||
|
- Applications in educational or creative tools.
|
||||||
|
- Research on generative models.
|
||||||
|
|
||||||
|
Excluded uses are described below.
|
||||||
|
|
||||||
|
### Misuse, Malicious Use, and Out-of-Scope Use
|
||||||
|
_Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
|
||||||
|
|
||||||
|
|
||||||
|
The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
||||||
|
#### Out-of-Scope Use
|
||||||
|
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
||||||
|
#### Misuse and Malicious Use
|
||||||
|
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
||||||
|
|
||||||
|
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
||||||
|
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
||||||
|
- Impersonating individuals without their consent.
|
||||||
|
- Sexual content without consent of the people who might see it.
|
||||||
|
- Mis- and disinformation
|
||||||
|
- Representations of egregious violence and gore
|
||||||
|
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
||||||
|
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
||||||
|
|
||||||
|
## Limitations and Bias
|
||||||
|
|
||||||
|
### Limitations
|
||||||
|
|
||||||
|
- The model does not achieve perfect photorealism
|
||||||
|
- The model cannot render legible text
|
||||||
|
- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
|
||||||
|
- Faces and people in general may not be generated properly.
|
||||||
|
- The model was trained mainly with English captions and will not work as well in other languages.
|
||||||
|
- The autoencoding part of the model is lossy
|
||||||
|
- The model was trained on a large-scale dataset
|
||||||
|
[LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
|
||||||
|
and is not fit for product use without additional safety mechanisms and
|
||||||
|
considerations.
|
||||||
|
|
||||||
|
### Bias
|
||||||
|
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
||||||
|
Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
|
||||||
|
which consists of images that are primarily limited to English descriptions.
|
||||||
|
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
|
||||||
|
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
|
||||||
|
ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
|
||||||
|
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
**Training Data**
|
||||||
|
The model developers used the following dataset for training the model:
|
||||||
|
|
||||||
|
- LAION-2B (en) and subsets thereof (see next section)
|
||||||
|
|
||||||
|
**Training Procedure**
|
||||||
|
Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
|
||||||
|
|
||||||
|
- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
|
||||||
|
- Text prompts are encoded through a ViT-L/14 text-encoder.
|
||||||
|
- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
|
||||||
|
- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
|
||||||
|
|
||||||
|
We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`,
|
||||||
|
which were trained as follows,
|
||||||
|
|
||||||
|
- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
|
||||||
|
194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
|
||||||
|
- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
|
||||||
|
515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
|
||||||
|
filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
|
||||||
|
- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
||||||
|
|
||||||
|
|
||||||
|
- **Hardware:** 32 x 8 x A100 GPUs
|
||||||
|
- **Optimizer:** AdamW
|
||||||
|
- **Gradient Accumulations**: 2
|
||||||
|
- **Batch:** 32 x 8 x 2 x 4 = 2048
|
||||||
|
- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
|
||||||
|
|
||||||
|
## Evaluation Results
|
||||||
|
Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
|
||||||
|
5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
|
||||||
|
steps show the relative improvements of the checkpoints:
|
||||||
|
|
||||||
|
![pareto](assets/v1-variants-scores.jpg)
|
||||||
|
|
||||||
|
Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
|
||||||
|
## Environmental Impact
|
||||||
|
|
||||||
|
**Stable Diffusion v1** **Estimated Emissions**
|
||||||
|
Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
|
||||||
|
|
||||||
|
- **Hardware Type:** A100 PCIe 40GB
|
||||||
|
- **Hours used:** 150000
|
||||||
|
- **Cloud Provider:** AWS
|
||||||
|
- **Compute Region:** US-east
|
||||||
|
- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
|
||||||
|
## Citation
|
||||||
|
@InProceedings{Rombach_2022_CVPR,
|
||||||
|
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
||||||
|
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
||||||
|
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
month = {June},
|
||||||
|
year = {2022},
|
||||||
|
pages = {10684-10695}
|
||||||
|
}
|
||||||
|
|
||||||
|
*This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
|
||||||
|
|
BIN
assets/a-painting-of-a-fire.png
Normal file
After Width: | Height: | Size: 651 KiB |
BIN
assets/a-photograph-of-a-fire.png
Normal file
After Width: | Height: | Size: 596 KiB |
BIN
assets/a-shirt-with-a-fire-printed-on-it.png
Normal file
After Width: | Height: | Size: 609 KiB |
BIN
assets/a-shirt-with-the-inscription-'fire'.png
Normal file
After Width: | Height: | Size: 548 KiB |
BIN
assets/a-watercolor-painting-of-a-fire.png
Normal file
After Width: | Height: | Size: 705 KiB |
BIN
assets/birdhouse.png
Normal file
After Width: | Height: | Size: 757 KiB |
BIN
assets/fire.png
Normal file
After Width: | Height: | Size: 612 KiB |
BIN
assets/rdm-preview.jpg
Normal file
After Width: | Height: | Size: 319 KiB |
BIN
assets/stable-samples/img2img/mountains-1.png
Normal file
After Width: | Height: | Size: 610 KiB |
BIN
assets/stable-samples/img2img/mountains-2.png
Normal file
After Width: | Height: | Size: 643 KiB |
BIN
assets/stable-samples/img2img/mountains-3.png
Normal file
After Width: | Height: | Size: 641 KiB |
BIN
assets/stable-samples/img2img/sketch-mountains-input.jpg
Normal file
After Width: | Height: | Size: 174 KiB |
BIN
assets/stable-samples/img2img/upscaling-in.png
Normal file
After Width: | Height: | Size: 1.1 MiB |
BIN
assets/stable-samples/img2img/upscaling-out.png
Normal file
After Width: | Height: | Size: 1.3 MiB |
BIN
assets/stable-samples/txt2img/000002025.png
Normal file
After Width: | Height: | Size: 945 KiB |
BIN
assets/stable-samples/txt2img/000002035.png
Normal file
After Width: | Height: | Size: 972 KiB |
BIN
assets/stable-samples/txt2img/merged-0005.png
Normal file
After Width: | Height: | Size: 2.5 MiB |
BIN
assets/stable-samples/txt2img/merged-0006.png
Normal file
After Width: | Height: | Size: 2.5 MiB |
BIN
assets/stable-samples/txt2img/merged-0007.png
Normal file
After Width: | Height: | Size: 2.3 MiB |
BIN
assets/the-earth-is-on-fire,-oil-on-canvas.png
Normal file
After Width: | Height: | Size: 662 KiB |
BIN
assets/txt2img-convsample.png
Normal file
After Width: | Height: | Size: 302 KiB |
BIN
assets/txt2img-preview.png
Normal file
After Width: | Height: | Size: 2.2 MiB |
BIN
assets/v1-variants-scores.jpg
Normal file
After Width: | Height: | Size: 70 KiB |
68
configs/latent-diffusion/cin256-v2.yaml
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 0.0001
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: class_label
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 192
|
||||||
|
attention_resolutions:
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
num_heads: 1
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||||
|
params:
|
||||||
|
n_classes: 1001
|
||||||
|
embed_dim: 512
|
||||||
|
key: class_label
|
71
configs/latent-diffusion/txt2img-1p4B-eval.yaml
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.012
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: true
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
68
configs/retrieval-augmented-diffusion/768x768.yaml
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 0.0001
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: jpg
|
||||||
|
cond_stage_key: nix
|
||||||
|
image_size: 48
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_by_std: false
|
||||||
|
scale_factor: 0.22765929
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 48
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 448
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
use_scale_shift_norm: false
|
||||||
|
resblock_updown: false
|
||||||
|
num_head_channels: 32
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: true
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: val/rec_loss
|
||||||
|
embed_dim: 16
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions:
|
||||||
|
- 16
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config:
|
||||||
|
target: torch.nn.Identity
|
70
configs/stable-diffusion/v1-inference.yaml
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
1000
data/imagenet_clsidx_to_label.txt
Executable file
@ -5,9 +5,9 @@ channels:
|
|||||||
dependencies:
|
dependencies:
|
||||||
- python=3.8.5
|
- python=3.8.5
|
||||||
- pip=20.3
|
- pip=20.3
|
||||||
- cudatoolkit=11.0
|
- cudatoolkit=11.3
|
||||||
- pytorch=1.7.0
|
- pytorch=1.11.0
|
||||||
- torchvision=0.8.1
|
- torchvision=0.12.0
|
||||||
- numpy=1.19.2
|
- numpy=1.19.2
|
||||||
- pip:
|
- pip:
|
||||||
- albumentations==0.4.3
|
- albumentations==0.4.3
|
||||||
@ -21,7 +21,7 @@ dependencies:
|
|||||||
- streamlit>=0.73.1
|
- streamlit>=0.73.1
|
||||||
- einops==0.3.0
|
- einops==0.3.0
|
||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
- transformers==4.3.1
|
- transformers==4.19.2
|
||||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
- -e .
|
- -e .
|
@ -5,7 +5,8 @@ import numpy as np
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
|
||||||
|
extract_into_tensor
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
class DDIMSampler(object):
|
||||||
@ -72,6 +73,9 @@ class DDIMSampler(object):
|
|||||||
verbose=True,
|
verbose=True,
|
||||||
x_T=None,
|
x_T=None,
|
||||||
log_every_t=100,
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
@ -100,7 +104,9 @@ class DDIMSampler(object):
|
|||||||
score_corrector=score_corrector,
|
score_corrector=score_corrector,
|
||||||
corrector_kwargs=corrector_kwargs,
|
corrector_kwargs=corrector_kwargs,
|
||||||
x_T=x_T,
|
x_T=x_T,
|
||||||
log_every_t=log_every_t
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
)
|
)
|
||||||
return samples, intermediates
|
return samples, intermediates
|
||||||
|
|
||||||
@ -109,7 +115,8 @@ class DDIMSampler(object):
|
|||||||
x_T=None, ddim_use_original_steps=False,
|
x_T=None, ddim_use_original_steps=False,
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||||
device = self.model.betas.device
|
device = self.model.betas.device
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
if x_T is None:
|
if x_T is None:
|
||||||
@ -142,7 +149,9 @@ class DDIMSampler(object):
|
|||||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
corrector_kwargs=corrector_kwargs)
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
img, pred_x0 = outs
|
img, pred_x0 = outs
|
||||||
if callback: callback(i)
|
if callback: callback(i)
|
||||||
if img_callback: img_callback(pred_x0, i)
|
if img_callback: img_callback(pred_x0, i)
|
||||||
@ -155,9 +164,19 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
e_t = self.model.apply_model(x, t, c)
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
if score_corrector is not None:
|
if score_corrector is not None:
|
||||||
assert self.model.parameterization == "eps"
|
assert self.model.parameterization == "eps"
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
@ -183,3 +202,40 @@ class DDIMSampler(object):
|
|||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
return x_prev, pred_x0
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||||
|
# fast, but does not allow for exact reconstruction
|
||||||
|
# t serves as an index to gather the correct alphas
|
||||||
|
if use_original_steps:
|
||||||
|
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||||
|
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||||
|
else:
|
||||||
|
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||||
|
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||||
|
|
||||||
|
if noise is None:
|
||||||
|
noise = torch.randn_like(x0)
|
||||||
|
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||||
|
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||||
|
use_original_steps=False):
|
||||||
|
|
||||||
|
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||||
|
timesteps = timesteps[:t_start]
|
||||||
|
|
||||||
|
time_range = np.flip(timesteps)
|
||||||
|
total_steps = timesteps.shape[0]
|
||||||
|
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||||
|
x_dec = x_latent
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||||
|
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning)
|
||||||
|
return x_dec
|
236
ldm/models/diffusion/plms.py
Normal file
@ -0,0 +1,236 @@
|
|||||||
|
"""SAMPLING ONLY."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||||
|
|
||||||
|
|
||||||
|
class PLMSSampler(object):
|
||||||
|
def __init__(self, model, schedule="linear", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
|
self.schedule = schedule
|
||||||
|
|
||||||
|
def register_buffer(self, name, attr):
|
||||||
|
if type(attr) == torch.Tensor:
|
||||||
|
if attr.device != torch.device("cuda"):
|
||||||
|
attr = attr.to(torch.device("cuda"))
|
||||||
|
setattr(self, name, attr)
|
||||||
|
|
||||||
|
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||||
|
if ddim_eta != 0:
|
||||||
|
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||||
|
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||||
|
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||||
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
|
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
|
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||||
|
|
||||||
|
self.register_buffer('betas', to_torch(self.model.betas))
|
||||||
|
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||||
|
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||||
|
|
||||||
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
|
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||||
|
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||||
|
|
||||||
|
# ddim sampling parameters
|
||||||
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||||
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
|
eta=ddim_eta,verbose=verbose)
|
||||||
|
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||||
|
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||||
|
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||||
|
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||||
|
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||||
|
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||||
|
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||||
|
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self,
|
||||||
|
S,
|
||||||
|
batch_size,
|
||||||
|
shape,
|
||||||
|
conditioning=None,
|
||||||
|
callback=None,
|
||||||
|
normals_sequence=None,
|
||||||
|
img_callback=None,
|
||||||
|
quantize_x0=False,
|
||||||
|
eta=0.,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
temperature=1.,
|
||||||
|
noise_dropout=0.,
|
||||||
|
score_corrector=None,
|
||||||
|
corrector_kwargs=None,
|
||||||
|
verbose=True,
|
||||||
|
x_T=None,
|
||||||
|
log_every_t=100,
|
||||||
|
unconditional_guidance_scale=1.,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
if conditioning is not None:
|
||||||
|
if isinstance(conditioning, dict):
|
||||||
|
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||||
|
if cbs != batch_size:
|
||||||
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
else:
|
||||||
|
if conditioning.shape[0] != batch_size:
|
||||||
|
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||||
|
|
||||||
|
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||||
|
# sampling
|
||||||
|
C, H, W = shape
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
mask=mask, x0=x0,
|
||||||
|
ddim_use_original_steps=False,
|
||||||
|
noise_dropout=noise_dropout,
|
||||||
|
temperature=temperature,
|
||||||
|
score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
x_T=x_T,
|
||||||
|
log_every_t=log_every_t,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
)
|
||||||
|
return samples, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sampling(self, cond, shape,
|
||||||
|
x_T=None, ddim_use_original_steps=False,
|
||||||
|
callback=None, timesteps=None, quantize_denoised=False,
|
||||||
|
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None,):
|
||||||
|
device = self.model.betas.device
|
||||||
|
b = shape[0]
|
||||||
|
if x_T is None:
|
||||||
|
img = torch.randn(shape, device=device)
|
||||||
|
else:
|
||||||
|
img = x_T
|
||||||
|
|
||||||
|
if timesteps is None:
|
||||||
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
|
elif timesteps is not None and not ddim_use_original_steps:
|
||||||
|
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||||
|
timesteps = self.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
|
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||||
|
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||||
|
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||||
|
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
|
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||||
|
old_eps = []
|
||||||
|
|
||||||
|
for i, step in enumerate(iterator):
|
||||||
|
index = total_steps - i - 1
|
||||||
|
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||||
|
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
assert x0 is not None
|
||||||
|
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||||
|
img = img_orig * mask + (1. - mask) * img
|
||||||
|
|
||||||
|
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||||
|
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||||
|
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||||
|
corrector_kwargs=corrector_kwargs,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
old_eps=old_eps, t_next=ts_next)
|
||||||
|
img, pred_x0, e_t = outs
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
if callback: callback(i)
|
||||||
|
if img_callback: img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
return img, intermediates
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||||
|
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
|
||||||
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
|
def get_model_output(x, t):
|
||||||
|
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||||
|
e_t = self.model.apply_model(x, t, c)
|
||||||
|
else:
|
||||||
|
x_in = torch.cat([x] * 2)
|
||||||
|
t_in = torch.cat([t] * 2)
|
||||||
|
c_in = torch.cat([unconditional_conditioning, c])
|
||||||
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||||
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||||
|
|
||||||
|
if score_corrector is not None:
|
||||||
|
assert self.model.parameterization == "eps"
|
||||||
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||||
|
|
||||||
|
return e_t
|
||||||
|
|
||||||
|
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||||
|
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||||
|
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||||
|
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||||
|
|
||||||
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
|
# select parameters corresponding to the currently considered timestep
|
||||||
|
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||||
|
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||||
|
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||||
|
|
||||||
|
# current prediction for x_0
|
||||||
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
if quantize_denoised:
|
||||||
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
|
# direction pointing to x_t
|
||||||
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||||
|
if noise_dropout > 0.:
|
||||||
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||||
|
return x_prev, pred_x0
|
||||||
|
|
||||||
|
e_t = get_model_output(x, t)
|
||||||
|
if len(old_eps) == 0:
|
||||||
|
# Pseudo Improved Euler (2nd order)
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||||
|
e_t_next = get_model_output(x_prev, t_next)
|
||||||
|
e_t_prime = (e_t + e_t_next) / 2
|
||||||
|
elif len(old_eps) == 1:
|
||||||
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||||
|
elif len(old_eps) == 2:
|
||||||
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||||
|
elif len(old_eps) >= 3:
|
||||||
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
|
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||||
|
|
||||||
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
|
|
||||||
|
return x_prev, pred_x0, e_t
|
@ -455,7 +455,7 @@ class UNetModel(nn.Module):
|
|||||||
num_classes=None,
|
num_classes=None,
|
||||||
use_checkpoint=False,
|
use_checkpoint=False,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
num_heads=1,
|
num_heads=-1,
|
||||||
num_head_channels=-1,
|
num_head_channels=-1,
|
||||||
num_heads_upsample=-1,
|
num_heads_upsample=-1,
|
||||||
use_scale_shift_norm=False,
|
use_scale_shift_norm=False,
|
||||||
@ -464,21 +464,28 @@ class UNetModel(nn.Module):
|
|||||||
use_spatial_transformer=False, # custom transformer support
|
use_spatial_transformer=False, # custom transformer support
|
||||||
transformer_depth=1, # custom transformer support
|
transformer_depth=1, # custom transformer support
|
||||||
context_dim=None, # custom transformer support
|
context_dim=None, # custom transformer support
|
||||||
n_embed=None # custom support for prediction of discrete ids into codebook of first stage vq model
|
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||||
|
legacy=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if use_spatial_transformer:
|
if use_spatial_transformer:
|
||||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||||
|
|
||||||
if context_dim is not None:
|
if context_dim is not None:
|
||||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||||
|
from omegaconf.listconfig import ListConfig
|
||||||
|
if type(context_dim) == ListConfig:
|
||||||
|
context_dim = list(context_dim)
|
||||||
|
|
||||||
if num_heads_upsample == -1:
|
if num_heads_upsample == -1:
|
||||||
num_heads_upsample = num_heads
|
num_heads_upsample = num_heads
|
||||||
|
|
||||||
|
if num_heads == -1:
|
||||||
|
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
|
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||||
|
|
||||||
self.image_size = image_size
|
self.image_size = image_size
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
@ -532,13 +539,20 @@ class UNetModel(nn.Module):
|
|||||||
]
|
]
|
||||||
ch = mult * model_channels
|
ch = mult * model_channels
|
||||||
if ds in attention_resolutions:
|
if ds in attention_resolutions:
|
||||||
|
if num_head_channels == -1:
|
||||||
dim_head = ch // num_heads
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
#num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_head_channels=num_head_channels,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
@ -572,7 +586,14 @@ class UNetModel(nn.Module):
|
|||||||
ds *= 2
|
ds *= 2
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
|
if num_head_channels == -1:
|
||||||
dim_head = ch // num_heads
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
#num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
self.middle_block = TimestepEmbedSequential(
|
self.middle_block = TimestepEmbedSequential(
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
@ -586,7 +607,7 @@ class UNetModel(nn.Module):
|
|||||||
ch,
|
ch,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
num_head_channels=num_head_channels,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
@ -619,13 +640,20 @@ class UNetModel(nn.Module):
|
|||||||
]
|
]
|
||||||
ch = model_channels * mult
|
ch = model_channels * mult
|
||||||
if ds in attention_resolutions:
|
if ds in attention_resolutions:
|
||||||
|
if num_head_channels == -1:
|
||||||
dim_head = ch // num_heads
|
dim_head = ch // num_heads
|
||||||
|
else:
|
||||||
|
num_heads = ch // num_head_channels
|
||||||
|
dim_head = num_head_channels
|
||||||
|
if legacy:
|
||||||
|
#num_heads = 1
|
||||||
|
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||||
layers.append(
|
layers.append(
|
||||||
AttentionBlock(
|
AttentionBlock(
|
||||||
ch,
|
ch,
|
||||||
use_checkpoint=use_checkpoint,
|
use_checkpoint=use_checkpoint,
|
||||||
num_heads=num_heads_upsample,
|
num_heads=num_heads_upsample,
|
||||||
num_head_channels=num_head_channels,
|
num_head_channels=dim_head,
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
) if not use_spatial_transformer else SpatialTransformer(
|
) if not use_spatial_transformer else SpatialTransformer(
|
||||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||||
@ -691,7 +719,6 @@ class UNetModel(nn.Module):
|
|||||||
assert (y is not None) == (
|
assert (y is not None) == (
|
||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
assert timesteps is not None, 'need to implement no-timestep usage'
|
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
@ -710,14 +737,12 @@ class UNetModel(nn.Module):
|
|||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
#return self.out(h), self.id_predictor(h)
|
|
||||||
return self.id_predictor(h)
|
return self.id_predictor(h)
|
||||||
else:
|
else:
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
class EncoderUNetModel(nn.Module):
|
class EncoderUNetModel(nn.Module):
|
||||||
# TODO: do we use it ?
|
|
||||||
"""
|
"""
|
||||||
The half UNet model with attention and timestep embedding.
|
The half UNet model with attention and timestep embedding.
|
||||||
For usage, see UNet.
|
For usage, see UNet.
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import clip
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
import kornia
|
||||||
|
|
||||||
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||||
|
|
||||||
@ -129,3 +133,102 @@ class SpatialRescaler(nn.Module):
|
|||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
return self(x)
|
return self(x)
|
||||||
|
|
||||||
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
|
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
||||||
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
|
||||||
|
super().__init__()
|
||||||
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
self.freeze()
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.transformer = self.transformer.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||||
|
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||||
|
tokens = batch_encoding["input_ids"].to(self.device)
|
||||||
|
outputs = self.transformer(input_ids=tokens)
|
||||||
|
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
return self(text)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenCLIPTextEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Uses the CLIP transformer encoder for text.
|
||||||
|
"""
|
||||||
|
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
|
||||||
|
super().__init__()
|
||||||
|
self.model, _ = clip.load(version, jit=False, device="cpu")
|
||||||
|
self.device = device
|
||||||
|
self.max_length = max_length
|
||||||
|
self.n_repeat = n_repeat
|
||||||
|
self.normalize = normalize
|
||||||
|
|
||||||
|
def freeze(self):
|
||||||
|
self.model = self.model.eval()
|
||||||
|
for param in self.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, text):
|
||||||
|
tokens = clip.tokenize(text).to(self.device)
|
||||||
|
z = self.model.encode_text(tokens)
|
||||||
|
if self.normalize:
|
||||||
|
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
z = self(text)
|
||||||
|
if z.ndim==2:
|
||||||
|
z = z[:, None, :]
|
||||||
|
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenClipImageEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Uses the CLIP image encoder.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
jit=False,
|
||||||
|
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
|
antialias=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||||
|
|
||||||
|
self.antialias = antialias
|
||||||
|
|
||||||
|
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||||
|
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||||
|
|
||||||
|
def preprocess(self, x):
|
||||||
|
# normalize to [0,1]
|
||||||
|
x = kornia.geometry.resize(x, (224, 224),
|
||||||
|
interpolation='bicubic',align_corners=True,
|
||||||
|
antialias=self.antialias)
|
||||||
|
x = (x + 1.) / 2.
|
||||||
|
# renormalize according to clip
|
||||||
|
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x is assumed to be in range [-1,1]
|
||||||
|
return self.model.encode_image(self.preprocess(x))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from ldm.util import count_params
|
||||||
|
model = FrozenCLIPEmbedder()
|
||||||
|
count_params(model, verbose=True)
|
@ -407,7 +407,7 @@ class AttentionLayers(nn.Module):
|
|||||||
self.rotary_pos_emb = always(None)
|
self.rotary_pos_emb = always(None)
|
||||||
|
|
||||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||||
self.rel_pos = always(None)
|
self.rel_pos = None
|
||||||
|
|
||||||
self.pre_norm = pre_norm
|
self.pre_norm = pre_norm
|
||||||
|
|
||||||
|
121
ldm/util.py
@ -2,6 +2,13 @@ import importlib
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from collections import abc
|
||||||
|
from einops import rearrange
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import multiprocessing as mp
|
||||||
|
from threading import Thread
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
@ -38,7 +45,7 @@ def ismap(x):
|
|||||||
|
|
||||||
|
|
||||||
def isimage(x):
|
def isimage(x):
|
||||||
if not isinstance(x,torch.Tensor):
|
if not isinstance(x, torch.Tensor):
|
||||||
return False
|
return False
|
||||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||||
|
|
||||||
@ -64,7 +71,7 @@ def mean_flat(tensor):
|
|||||||
def count_params(model, verbose=False):
|
def count_params(model, verbose=False):
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
|
|
||||||
@ -84,3 +91,113 @@ def get_obj_from_str(string, reload=False):
|
|||||||
module_imp = importlib.import_module(module)
|
module_imp = importlib.import_module(module)
|
||||||
importlib.reload(module_imp)
|
importlib.reload(module_imp)
|
||||||
return getattr(importlib.import_module(module, package=None), cls)
|
return getattr(importlib.import_module(module, package=None), cls)
|
||||||
|
|
||||||
|
|
||||||
|
def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
|
||||||
|
# create dummy dataset instance
|
||||||
|
|
||||||
|
# run prefetching
|
||||||
|
if idx_to_fn:
|
||||||
|
res = func(data, worker_id=idx)
|
||||||
|
else:
|
||||||
|
res = func(data)
|
||||||
|
Q.put([idx, res])
|
||||||
|
Q.put("Done")
|
||||||
|
|
||||||
|
|
||||||
|
def parallel_data_prefetch(
|
||||||
|
func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
|
||||||
|
):
|
||||||
|
# if target_data_type not in ["ndarray", "list"]:
|
||||||
|
# raise ValueError(
|
||||||
|
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
|
||||||
|
# )
|
||||||
|
if isinstance(data, np.ndarray) and target_data_type == "list":
|
||||||
|
raise ValueError("list expected but function got ndarray.")
|
||||||
|
elif isinstance(data, abc.Iterable):
|
||||||
|
if isinstance(data, dict):
|
||||||
|
print(
|
||||||
|
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
|
||||||
|
)
|
||||||
|
data = list(data.values())
|
||||||
|
if target_data_type == "ndarray":
|
||||||
|
data = np.asarray(data)
|
||||||
|
else:
|
||||||
|
data = list(data)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cpu_intensive:
|
||||||
|
Q = mp.Queue(1000)
|
||||||
|
proc = mp.Process
|
||||||
|
else:
|
||||||
|
Q = Queue(1000)
|
||||||
|
proc = Thread
|
||||||
|
# spawn processes
|
||||||
|
if target_data_type == "ndarray":
|
||||||
|
arguments = [
|
||||||
|
[func, Q, part, i, use_worker_id]
|
||||||
|
for i, part in enumerate(np.array_split(data, n_proc))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
step = (
|
||||||
|
int(len(data) / n_proc + 1)
|
||||||
|
if len(data) % n_proc != 0
|
||||||
|
else int(len(data) / n_proc)
|
||||||
|
)
|
||||||
|
arguments = [
|
||||||
|
[func, Q, part, i, use_worker_id]
|
||||||
|
for i, part in enumerate(
|
||||||
|
[data[i: i + step] for i in range(0, len(data), step)]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
processes = []
|
||||||
|
for i in range(n_proc):
|
||||||
|
p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
|
||||||
|
processes += [p]
|
||||||
|
|
||||||
|
# start processes
|
||||||
|
print(f"Start prefetching...")
|
||||||
|
import time
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
gather_res = [[] for _ in range(n_proc)]
|
||||||
|
try:
|
||||||
|
for p in processes:
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
k = 0
|
||||||
|
while k < n_proc:
|
||||||
|
# get result
|
||||||
|
res = Q.get()
|
||||||
|
if res == "Done":
|
||||||
|
k += 1
|
||||||
|
else:
|
||||||
|
gather_res[res[0]] = res[1]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Exception: ", e)
|
||||||
|
for p in processes:
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
print(f"Prefetching complete. [{time.time() - start} sec.]")
|
||||||
|
|
||||||
|
if target_data_type == 'ndarray':
|
||||||
|
if not isinstance(gather_res[0], np.ndarray):
|
||||||
|
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
|
||||||
|
|
||||||
|
# order outputs
|
||||||
|
return np.concatenate(gather_res, axis=0)
|
||||||
|
elif target_data_type == 'list':
|
||||||
|
out = []
|
||||||
|
for r in gather_res:
|
||||||
|
out.extend(r)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return gather_res
|
||||||
|
293
scripts/img2img.py
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
"""make variations of input image"""
|
||||||
|
|
||||||
|
import argparse, os, sys, glob
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from itertools import islice
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
from torch import autocast
|
||||||
|
from contextlib import nullcontext
|
||||||
|
import time
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(it, size):
|
||||||
|
it = iter(it)
|
||||||
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_img(path):
|
||||||
|
image = Image.open(path).convert("RGB")
|
||||||
|
w, h = image.size
|
||||||
|
print(f"loaded input image of size ({w}, {h}) from {path}")
|
||||||
|
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||||
|
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
|
||||||
|
image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
image = image[None].transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image)
|
||||||
|
return 2.*image - 1.
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default="a painting of a virus monster playing guitar",
|
||||||
|
help="the prompt to render"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--init-img",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="path to the input image"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--outdir",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="dir to write results to",
|
||||||
|
default="outputs/img2img-samples"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_grid",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_save",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save indiviual samples. For speed measurements.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_steps",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="number of ddim sampling steps",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--plms",
|
||||||
|
action='store_true',
|
||||||
|
help="use plms sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fixed_code",
|
||||||
|
action='store_true',
|
||||||
|
help="if enabled, uses the same starting code across all samples ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_eta",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_iter",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="sample this often",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--C",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="latent channels",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--f",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="downsampling factor, most often 8 or 16",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_samples",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_rows",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="rows in the grid (default: n_samples)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale",
|
||||||
|
type=float,
|
||||||
|
default=5.0,
|
||||||
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--strength",
|
||||||
|
type=float,
|
||||||
|
default=0.75,
|
||||||
|
help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from-file",
|
||||||
|
type=str,
|
||||||
|
help="if specified, load prompts from this file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="configs/stable-diffusion/v1-inference.yaml",
|
||||||
|
help="path to config which constructs model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt",
|
||||||
|
type=str,
|
||||||
|
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||||
|
help="path to checkpoint of model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="the seed (for reproducible sampling)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
help="evaluate at this precision",
|
||||||
|
choices=["full", "autocast"],
|
||||||
|
default="autocast"
|
||||||
|
)
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
if opt.plms:
|
||||||
|
raise NotImplementedError("PLMS sampler not (yet) supported")
|
||||||
|
sampler = PLMSSampler(model)
|
||||||
|
else:
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
|
outpath = opt.outdir
|
||||||
|
|
||||||
|
batch_size = opt.n_samples
|
||||||
|
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||||
|
if not opt.from_file:
|
||||||
|
prompt = opt.prompt
|
||||||
|
assert prompt is not None
|
||||||
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"reading prompts from {opt.from_file}")
|
||||||
|
with open(opt.from_file, "r") as f:
|
||||||
|
data = f.read().splitlines()
|
||||||
|
data = list(chunk(data, batch_size))
|
||||||
|
|
||||||
|
sample_path = os.path.join(outpath, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(sample_path))
|
||||||
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
assert os.path.isfile(opt.init_img)
|
||||||
|
init_image = load_img(opt.init_img).to(device)
|
||||||
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
|
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
|
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
|
||||||
|
|
||||||
|
assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
t_enc = int(opt.strength * opt.ddim_steps)
|
||||||
|
print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
|
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||||
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
tic = time.time()
|
||||||
|
all_samples = list()
|
||||||
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
uc = None
|
||||||
|
if opt.scale != 1.0:
|
||||||
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
|
# encode (scaled latent)
|
||||||
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
|
||||||
|
# decode it
|
||||||
|
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
|
||||||
|
unconditional_conditioning=uc,)
|
||||||
|
|
||||||
|
x_samples = model.decode_first_stage(samples)
|
||||||
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if not opt.skip_save:
|
||||||
|
for x_sample in x_samples:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
|
base_count += 1
|
||||||
|
all_samples.append(x_samples)
|
||||||
|
|
||||||
|
if not opt.skip_grid:
|
||||||
|
# additionally, save as grid
|
||||||
|
grid = torch.stack(all_samples, 0)
|
||||||
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
|
# to image
|
||||||
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
|
grid_count += 1
|
||||||
|
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||||
|
f" \nEnjoy.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
398
scripts/knn2img.py
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
import argparse, os, sys, glob
|
||||||
|
import clip
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from itertools import islice
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
import scann
|
||||||
|
import time
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config, parallel_data_prefetch
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
|
||||||
|
|
||||||
|
DATABASES = [
|
||||||
|
"openimages",
|
||||||
|
"artbench-art_nouveau",
|
||||||
|
"artbench-baroque",
|
||||||
|
"artbench-expressionism",
|
||||||
|
"artbench-impressionism",
|
||||||
|
"artbench-post_impressionism",
|
||||||
|
"artbench-realism",
|
||||||
|
"artbench-romanticism",
|
||||||
|
"artbench-renaissance",
|
||||||
|
"artbench-surrealism",
|
||||||
|
"artbench-ukiyo_e",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(it, size):
|
||||||
|
it = iter(it)
|
||||||
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class Searcher(object):
|
||||||
|
def __init__(self, database, retriever_version='ViT-L/14'):
|
||||||
|
assert database in DATABASES
|
||||||
|
# self.database = self.load_database(database)
|
||||||
|
self.database_name = database
|
||||||
|
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
|
||||||
|
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
|
||||||
|
self.retriever = self.load_retriever(version=retriever_version)
|
||||||
|
self.database = {'embedding': [],
|
||||||
|
'img_id': [],
|
||||||
|
'patch_coords': []}
|
||||||
|
self.load_database()
|
||||||
|
self.load_searcher()
|
||||||
|
|
||||||
|
def train_searcher(self, k,
|
||||||
|
metric='dot_product',
|
||||||
|
searcher_savedir=None):
|
||||||
|
|
||||||
|
print('Start training searcher')
|
||||||
|
searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
|
||||||
|
np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
|
||||||
|
k, metric)
|
||||||
|
self.searcher = searcher.score_brute_force().build()
|
||||||
|
print('Finish training searcher')
|
||||||
|
|
||||||
|
if searcher_savedir is not None:
|
||||||
|
print(f'Save trained searcher under "{searcher_savedir}"')
|
||||||
|
os.makedirs(searcher_savedir, exist_ok=True)
|
||||||
|
self.searcher.serialize(searcher_savedir)
|
||||||
|
|
||||||
|
def load_single_file(self, saved_embeddings):
|
||||||
|
compressed = np.load(saved_embeddings)
|
||||||
|
self.database = {key: compressed[key] for key in compressed.files}
|
||||||
|
print('Finished loading of clip embeddings.')
|
||||||
|
|
||||||
|
def load_multi_files(self, data_archive):
|
||||||
|
out_data = {key: [] for key in self.database}
|
||||||
|
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
|
||||||
|
for key in d.files:
|
||||||
|
out_data[key].append(d[key])
|
||||||
|
|
||||||
|
return out_data
|
||||||
|
|
||||||
|
def load_database(self):
|
||||||
|
|
||||||
|
print(f'Load saved patch embedding from "{self.database_path}"')
|
||||||
|
file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
|
||||||
|
|
||||||
|
if len(file_content) == 1:
|
||||||
|
self.load_single_file(file_content[0])
|
||||||
|
elif len(file_content) > 1:
|
||||||
|
data = [np.load(f) for f in file_content]
|
||||||
|
prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
|
||||||
|
n_proc=min(len(data), cpu_count()), target_data_type='dict')
|
||||||
|
|
||||||
|
self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
|
||||||
|
self.database}
|
||||||
|
else:
|
||||||
|
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
|
||||||
|
|
||||||
|
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
|
||||||
|
|
||||||
|
def load_retriever(self, version='ViT-L/14', ):
|
||||||
|
model = FrozenClipImageEmbedder(model=version)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
def load_searcher(self):
|
||||||
|
print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
|
||||||
|
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
|
||||||
|
print('Finished loading searcher.')
|
||||||
|
|
||||||
|
def search(self, x, k):
|
||||||
|
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
|
||||||
|
self.train_searcher(k) # quickly fit searcher on the fly for small databases
|
||||||
|
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
x = x.detach().cpu().numpy()
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[:, 0]
|
||||||
|
query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
out_embeddings = self.database['embedding'][nns]
|
||||||
|
out_img_ids = self.database['img_id'][nns]
|
||||||
|
out_pc = self.database['patch_coords'][nns]
|
||||||
|
|
||||||
|
out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
|
||||||
|
'img_ids': out_img_ids,
|
||||||
|
'patch_coords': out_pc,
|
||||||
|
'queries': x,
|
||||||
|
'exec_time': end - start,
|
||||||
|
'nns': nns,
|
||||||
|
'q_embeddings': query_embeddings}
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __call__(self, x, n):
|
||||||
|
return self.search(x, n)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
# TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
|
||||||
|
# TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default="a painting of a virus monster playing guitar",
|
||||||
|
help="the prompt to render"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--outdir",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="dir to write results to",
|
||||||
|
default="outputs/txt2img-samples"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_grid",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_steps",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="number of ddim sampling steps",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_repeat",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="number of repeats in CLIP latent space",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--plms",
|
||||||
|
action='store_true',
|
||||||
|
help="use plms sampling",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_eta",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_iter",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="sample this often",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--H",
|
||||||
|
type=int,
|
||||||
|
default=768,
|
||||||
|
help="image height, in pixel space",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--W",
|
||||||
|
type=int,
|
||||||
|
default=768,
|
||||||
|
help="image width, in pixel space",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_samples",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="how many samples to produce for each given prompt. A.k.a batch size",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_rows",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="rows in the grid (default: n_samples)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale",
|
||||||
|
type=float,
|
||||||
|
default=5.0,
|
||||||
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--from-file",
|
||||||
|
type=str,
|
||||||
|
help="if specified, load prompts from this file",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="configs/retrieval-augmented-diffusion/768x768.yaml",
|
||||||
|
help="path to config which constructs model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt",
|
||||||
|
type=str,
|
||||||
|
default="models/rdm/rdm768x768/model.ckpt",
|
||||||
|
help="path to checkpoint of model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--clip_type",
|
||||||
|
type=str,
|
||||||
|
default="ViT-L/14",
|
||||||
|
help="which CLIP model to use for retrieval and NN encoding",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--database",
|
||||||
|
type=str,
|
||||||
|
default='artbench-surrealism',
|
||||||
|
choices=DATABASES,
|
||||||
|
help="The database used for the search, only applied when --use_neighbors=True",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_neighbors",
|
||||||
|
default=False,
|
||||||
|
action='store_true',
|
||||||
|
help="Include neighbors in addition to text prompt for conditioning",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--knn",
|
||||||
|
default=10,
|
||||||
|
type=int,
|
||||||
|
help="The number of included neighbors, only applied when --use_neighbors=True",
|
||||||
|
)
|
||||||
|
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
|
||||||
|
|
||||||
|
if opt.plms:
|
||||||
|
sampler = PLMSSampler(model)
|
||||||
|
else:
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
|
outpath = opt.outdir
|
||||||
|
|
||||||
|
batch_size = opt.n_samples
|
||||||
|
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||||
|
if not opt.from_file:
|
||||||
|
prompt = opt.prompt
|
||||||
|
assert prompt is not None
|
||||||
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"reading prompts from {opt.from_file}")
|
||||||
|
with open(opt.from_file, "r") as f:
|
||||||
|
data = f.read().splitlines()
|
||||||
|
data = list(chunk(data, batch_size))
|
||||||
|
|
||||||
|
sample_path = os.path.join(outpath, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(sample_path))
|
||||||
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
print(f"sampling scale for cfg is {opt.scale:.2f}")
|
||||||
|
|
||||||
|
searcher = None
|
||||||
|
if opt.use_neighbors:
|
||||||
|
searcher = Searcher(opt.database)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with model.ema_scope():
|
||||||
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
all_samples = list()
|
||||||
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
print("sampling prompts:", prompts)
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
c = clip_text_encoder.encode(prompts)
|
||||||
|
uc = None
|
||||||
|
if searcher is not None:
|
||||||
|
nn_dict = searcher(c, opt.knn)
|
||||||
|
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
|
||||||
|
if opt.scale != 1.0:
|
||||||
|
uc = torch.zeros_like(c)
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
|
||||||
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=c.shape[0],
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=opt.scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt.ddim_eta,
|
||||||
|
)
|
||||||
|
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
for x_sample in x_samples_ddim:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
|
base_count += 1
|
||||||
|
all_samples.append(x_samples_ddim)
|
||||||
|
|
||||||
|
if not opt.skip_grid:
|
||||||
|
# additionally, save as grid
|
||||||
|
grid = torch.stack(all_samples, 0)
|
||||||
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
|
# to image
|
||||||
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
|
grid_count += 1
|
||||||
|
|
||||||
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
|
429
scripts/latent_imagenet_diffusion.ipynb
Normal file
147
scripts/train_searcher.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
import os, sys
|
||||||
|
import numpy as np
|
||||||
|
import scann
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
from multiprocessing import cpu_count
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from ldm.util import parallel_data_prefetch
|
||||||
|
|
||||||
|
|
||||||
|
def search_bruteforce(searcher):
|
||||||
|
return searcher.score_brute_force().build()
|
||||||
|
|
||||||
|
|
||||||
|
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
|
||||||
|
partioning_trainsize, num_leaves, num_leaves_to_search):
|
||||||
|
return searcher.tree(num_leaves=num_leaves,
|
||||||
|
num_leaves_to_search=num_leaves_to_search,
|
||||||
|
training_sample_size=partioning_trainsize). \
|
||||||
|
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
|
||||||
|
|
||||||
|
|
||||||
|
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
|
||||||
|
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
|
||||||
|
reorder_k).build()
|
||||||
|
|
||||||
|
def load_datapool(dpath):
|
||||||
|
|
||||||
|
|
||||||
|
def load_single_file(saved_embeddings):
|
||||||
|
compressed = np.load(saved_embeddings)
|
||||||
|
database = {key: compressed[key] for key in compressed.files}
|
||||||
|
return database
|
||||||
|
|
||||||
|
def load_multi_files(data_archive):
|
||||||
|
database = {key: [] for key in data_archive[0].files}
|
||||||
|
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
|
||||||
|
for key in d.files:
|
||||||
|
database[key].append(d[key])
|
||||||
|
|
||||||
|
return database
|
||||||
|
|
||||||
|
print(f'Load saved patch embedding from "{dpath}"')
|
||||||
|
file_content = glob.glob(os.path.join(dpath, '*.npz'))
|
||||||
|
|
||||||
|
if len(file_content) == 1:
|
||||||
|
data_pool = load_single_file(file_content[0])
|
||||||
|
elif len(file_content) > 1:
|
||||||
|
data = [np.load(f) for f in file_content]
|
||||||
|
prefetched_data = parallel_data_prefetch(load_multi_files, data,
|
||||||
|
n_proc=min(len(data), cpu_count()), target_data_type='dict')
|
||||||
|
|
||||||
|
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
|
||||||
|
else:
|
||||||
|
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
|
||||||
|
|
||||||
|
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
|
||||||
|
return data_pool
|
||||||
|
|
||||||
|
|
||||||
|
def train_searcher(opt,
|
||||||
|
metric='dot_product',
|
||||||
|
partioning_trainsize=None,
|
||||||
|
reorder_k=None,
|
||||||
|
# todo tune
|
||||||
|
aiq_thld=0.2,
|
||||||
|
dims_per_block=2,
|
||||||
|
num_leaves=None,
|
||||||
|
num_leaves_to_search=None,):
|
||||||
|
|
||||||
|
data_pool = load_datapool(opt.database)
|
||||||
|
k = opt.knn
|
||||||
|
|
||||||
|
if not reorder_k:
|
||||||
|
reorder_k = 2 * k
|
||||||
|
|
||||||
|
# normalize
|
||||||
|
# embeddings =
|
||||||
|
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
|
||||||
|
pool_size = data_pool['embedding'].shape[0]
|
||||||
|
|
||||||
|
print(*(['#'] * 100))
|
||||||
|
print('Initializing scaNN searcher with the following values:')
|
||||||
|
print(f'k: {k}')
|
||||||
|
print(f'metric: {metric}')
|
||||||
|
print(f'reorder_k: {reorder_k}')
|
||||||
|
print(f'anisotropic_quantization_threshold: {aiq_thld}')
|
||||||
|
print(f'dims_per_block: {dims_per_block}')
|
||||||
|
print(*(['#'] * 100))
|
||||||
|
print('Start training searcher....')
|
||||||
|
print(f'N samples in pool is {pool_size}')
|
||||||
|
|
||||||
|
# this reflects the recommended design choices proposed at
|
||||||
|
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
|
||||||
|
if pool_size < 2e4:
|
||||||
|
print('Using brute force search.')
|
||||||
|
searcher = search_bruteforce(searcher)
|
||||||
|
elif 2e4 <= pool_size and pool_size < 1e5:
|
||||||
|
print('Using asymmetric hashing search and reordering.')
|
||||||
|
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
||||||
|
else:
|
||||||
|
print('Using using partioning, asymmetric hashing search and reordering.')
|
||||||
|
|
||||||
|
if not partioning_trainsize:
|
||||||
|
partioning_trainsize = data_pool['embedding'].shape[0] // 10
|
||||||
|
if not num_leaves:
|
||||||
|
num_leaves = int(np.sqrt(pool_size))
|
||||||
|
|
||||||
|
if not num_leaves_to_search:
|
||||||
|
num_leaves_to_search = max(num_leaves // 20, 1)
|
||||||
|
|
||||||
|
print('Partitioning params:')
|
||||||
|
print(f'num_leaves: {num_leaves}')
|
||||||
|
print(f'num_leaves_to_search: {num_leaves_to_search}')
|
||||||
|
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
|
||||||
|
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
|
||||||
|
partioning_trainsize, num_leaves, num_leaves_to_search)
|
||||||
|
|
||||||
|
print('Finish training searcher')
|
||||||
|
searcher_savedir = opt.target_path
|
||||||
|
os.makedirs(searcher_savedir, exist_ok=True)
|
||||||
|
searcher.serialize(searcher_savedir)
|
||||||
|
print(f'Saved trained searcher under "{searcher_savedir}"')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--database',
|
||||||
|
'-d',
|
||||||
|
default='data/rdm/retrieval_databases/openimages',
|
||||||
|
type=str,
|
||||||
|
help='path to folder containing the clip feature of the database')
|
||||||
|
parser.add_argument('--target_path',
|
||||||
|
'-t',
|
||||||
|
default='data/rdm/searchers/openimages',
|
||||||
|
type=str,
|
||||||
|
help='path to the target folder where the searcher shall be stored.')
|
||||||
|
parser.add_argument('--knn',
|
||||||
|
'-k',
|
||||||
|
default=20,
|
||||||
|
type=int,
|
||||||
|
help='number of nearest neighbors, for which the searcher shall be optimized')
|
||||||
|
|
||||||
|
opt, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
train_searcher(opt,)
|
279
scripts/txt2img.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
import argparse, os, sys, glob
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from itertools import islice
|
||||||
|
from einops import rearrange
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
import time
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
from torch import autocast
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(it, size):
|
||||||
|
it = iter(it)
|
||||||
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
|
print(f"Loading model from {ckpt}")
|
||||||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
if "global_step" in pl_sd:
|
||||||
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
|
sd = pl_sd["state_dict"]
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default="a painting of a virus monster playing guitar",
|
||||||
|
help="the prompt to render"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--outdir",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
help="dir to write results to",
|
||||||
|
default="outputs/txt2img-samples"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_grid",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_save",
|
||||||
|
action='store_true',
|
||||||
|
help="do not save individual samples. For speed measurements.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_steps",
|
||||||
|
type=int,
|
||||||
|
default=50,
|
||||||
|
help="number of ddim sampling steps",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--plms",
|
||||||
|
action='store_true',
|
||||||
|
help="use plms sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--laion400m",
|
||||||
|
action='store_true',
|
||||||
|
help="uses the LAION400M model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fixed_code",
|
||||||
|
action='store_true',
|
||||||
|
help="if enabled, uses the same starting code across samples ",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddim_eta",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_iter",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="sample this often",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--H",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="image height, in pixel space",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--W",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="image width, in pixel space",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--C",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="latent channels",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--f",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="downsampling factor",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_samples",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="how many samples to produce for each given prompt. A.k.a. batch size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n_rows",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="rows in the grid (default: n_samples)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--scale",
|
||||||
|
type=float,
|
||||||
|
default=7.5,
|
||||||
|
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from-file",
|
||||||
|
type=str,
|
||||||
|
help="if specified, load prompts from this file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="configs/stable-diffusion/v1-inference.yaml",
|
||||||
|
help="path to config which constructs model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ckpt",
|
||||||
|
type=str,
|
||||||
|
default="models/ldm/stable-diffusion-v1/model.ckpt",
|
||||||
|
help="path to checkpoint of model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="the seed (for reproducible sampling)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--precision",
|
||||||
|
type=str,
|
||||||
|
help="evaluate at this precision",
|
||||||
|
choices=["full", "autocast"],
|
||||||
|
default="autocast"
|
||||||
|
)
|
||||||
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
if opt.laion400m:
|
||||||
|
print("Falling back to LAION 400M model...")
|
||||||
|
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
|
||||||
|
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
|
||||||
|
opt.outdir = "outputs/txt2img-samples-laion400m"
|
||||||
|
|
||||||
|
seed_everything(opt.seed)
|
||||||
|
|
||||||
|
config = OmegaConf.load(f"{opt.config}")
|
||||||
|
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||||
|
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
if opt.plms:
|
||||||
|
sampler = PLMSSampler(model)
|
||||||
|
else:
|
||||||
|
sampler = DDIMSampler(model)
|
||||||
|
|
||||||
|
os.makedirs(opt.outdir, exist_ok=True)
|
||||||
|
outpath = opt.outdir
|
||||||
|
|
||||||
|
batch_size = opt.n_samples
|
||||||
|
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
|
||||||
|
if not opt.from_file:
|
||||||
|
prompt = opt.prompt
|
||||||
|
assert prompt is not None
|
||||||
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"reading prompts from {opt.from_file}")
|
||||||
|
with open(opt.from_file, "r") as f:
|
||||||
|
data = f.read().splitlines()
|
||||||
|
data = list(chunk(data, batch_size))
|
||||||
|
|
||||||
|
sample_path = os.path.join(outpath, "samples")
|
||||||
|
os.makedirs(sample_path, exist_ok=True)
|
||||||
|
base_count = len(os.listdir(sample_path))
|
||||||
|
grid_count = len(os.listdir(outpath)) - 1
|
||||||
|
|
||||||
|
start_code = None
|
||||||
|
if opt.fixed_code:
|
||||||
|
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
||||||
|
|
||||||
|
precision_scope = autocast if opt.precision=="autocast" else nullcontext
|
||||||
|
with torch.no_grad():
|
||||||
|
with precision_scope("cuda"):
|
||||||
|
with model.ema_scope():
|
||||||
|
tic = time.time()
|
||||||
|
all_samples = list()
|
||||||
|
for n in trange(opt.n_iter, desc="Sampling"):
|
||||||
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
uc = None
|
||||||
|
if opt.scale != 1.0:
|
||||||
|
uc = model.get_learned_conditioning(batch_size * [""])
|
||||||
|
if isinstance(prompts, tuple):
|
||||||
|
prompts = list(prompts)
|
||||||
|
c = model.get_learned_conditioning(prompts)
|
||||||
|
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
||||||
|
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=opt.n_samples,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=opt.scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt.ddim_eta,
|
||||||
|
x_T=start_code)
|
||||||
|
|
||||||
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
|
if not opt.skip_save:
|
||||||
|
for x_sample in x_samples_ddim:
|
||||||
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
|
Image.fromarray(x_sample.astype(np.uint8)).save(
|
||||||
|
os.path.join(sample_path, f"{base_count:05}.png"))
|
||||||
|
base_count += 1
|
||||||
|
|
||||||
|
if not opt.skip_grid:
|
||||||
|
all_samples.append(x_samples_ddim)
|
||||||
|
|
||||||
|
if not opt.skip_grid:
|
||||||
|
# additionally, save as grid
|
||||||
|
grid = torch.stack(all_samples, 0)
|
||||||
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
|
grid = make_grid(grid, nrow=n_rows)
|
||||||
|
|
||||||
|
# to image
|
||||||
|
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
|
||||||
|
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
|
||||||
|
grid_count += 1
|
||||||
|
|
||||||
|
toc = time.time()
|
||||||
|
|
||||||
|
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
||||||
|
f" \nEnjoy.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|