Add personalization

This commit is contained in:
Sean McLellan 2022-08-23 18:26:28 -04:00
parent 0cdf5e61b0
commit 8952196bbf
21 changed files with 1924 additions and 37 deletions

166
.gitignore vendored Normal file
View File

@ -0,0 +1,166 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
**/*.ckpt
src/
logs/
**/__pycache__/
outputs

View File

@ -100,6 +100,64 @@ cat aspect of the image and 75% on the white duck aspect
use any combination of integers and floating point numbers, and they use any combination of integers and floating point numbers, and they
do not need to add up to 1. do not need to add up to 1.
## Personalizing Text-to-Image Generation
You may personalize the generated images to provide your own styles or objects by training a new LDM checkpoint
and introducing a new vocabulary to the fixed model.
To train, prepare a folder that contains images sized at 512x512 and execute the following:
~~~~
# As the default backend is not available on Windows, if you're using that platform, execute SET PL_TORCH_DISTRIBUTED_BACKEND=gloo
(ldm) ~/stable-diffusion$ python3 ./main.py --base ./configs/stable-diffusion/v1-finetune.yaml \
-t \
--actual_resume ./models/ldm/stable-diffusion-v1/model.ckpt \
-n my_cat \
-gpus 0, \
--data_root D:/textual-inversion/my_cat \
--init_word 'cat'
~~~~
During the training process, files will be created in /logs/[project][time][project]/
where you can see the process.
conditioning* contains the training prompts
inputs, reconstruction the input images for the training epoch
samples, samples scaled for a sample of the prompt and one with the init word provided
On a RTX3090, the process for SD will take ~1h @1.6 iterations/sec.
Note: According to the associated paper, the optimal number of images is 3-5 any more images than that and your model might not converge.
Training will run indefinately, but you may wish to stop it before the heat death of the universe, when you fine a low loss epoch or around ~5000 iterations.
Once the model is trained, specify the trained .pt file when starting dream using
~~~~
(ldm) ~/stable-diffusion$ python3 ./scripts/dream.py --embedding_path /path/to/embedding.pt --full_precision
~~~~
Then, to utilize your subject at the dream prompt
~~~
dream> "a photo of *"
~~~
this also works with image2image
~~~~
dream> "waterfall and rainbow in the style of *" --init_img=./init-images/crude_drawing.png --strength=0.5 -s100 -n4
~~~~
It's also possible to train multiple tokens (modify the placeholder string in configs/stable-diffusion/v1-finetune.yaml) and combine LDM checkpoints using:
~~~~
(ldm) ~/stable-diffusion$ python3 ./scripts/merge_embeddings.py \
--manager_ckpts /path/to/first/embedding.pt /path/to/second/embedding.pt [...] \
--output_path /path/to/output/embedding.pt
~~~~
Credit goes to @rinongal and the repository located at https://github.com/rinongal/textual_inversion Please see the repository and associated paper for details and limitations.
## Changes ## Changes
* v1.07 (23 August 2022) * v1.07 (23 August 2022)

View File

@ -0,0 +1,105 @@
model:
base_learning_rate: 5.0e-03
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 2
num_workers: 16
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
validation:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: val
per_image_tokens: false
repeats: 10
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
max_steps: 6100

View File

@ -26,6 +26,15 @@ model:
f_max: [ 1. ] f_max: [ 1. ]
f_min: [ 1. ] f_min: [ 1. ]
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config: unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params: params:

View File

@ -19,6 +19,7 @@ dependencies:
- omegaconf==2.1.1 - omegaconf==2.1.1
- test-tube>=0.7.5 - test-tube>=0.7.5
- streamlit>=0.73.1 - streamlit>=0.73.1
- pillow==9.0.1
- einops==0.3.0 - einops==0.3.0
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0
- transformers==4.19.2 - transformers==4.19.2

160
ldm/data/personalized.py Normal file
View File

@ -0,0 +1,160 @@
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
imagenet_templates_smallest = [
'a photo of a {}',
]
imagenet_templates_small = [
'a photo of a {}',
'a rendering of a {}',
'a cropped photo of the {}',
'the photo of a {}',
'a photo of a clean {}',
'a photo of a dirty {}',
'a dark photo of the {}',
'a photo of my {}',
'a photo of the cool {}',
'a close-up photo of a {}',
'a bright photo of the {}',
'a cropped photo of a {}',
'a photo of the {}',
'a good photo of the {}',
'a photo of one {}',
'a close-up photo of the {}',
'a rendition of the {}',
'a photo of the clean {}',
'a rendition of a {}',
'a photo of a nice {}',
'a good photo of a {}',
'a photo of the nice {}',
'a photo of the small {}',
'a photo of the weird {}',
'a photo of the large {}',
'a photo of a cool {}',
'a photo of a small {}',
]
imagenet_dual_templates_small = [
'a photo of a {} with {}',
'a rendering of a {} with {}',
'a cropped photo of the {} with {}',
'the photo of a {} with {}',
'a photo of a clean {} with {}',
'a photo of a dirty {} with {}',
'a dark photo of the {} with {}',
'a photo of my {} with {}',
'a photo of the cool {} with {}',
'a close-up photo of a {} with {}',
'a bright photo of the {} with {}',
'a cropped photo of a {} with {}',
'a photo of the {} with {}',
'a good photo of the {} with {}',
'a photo of one {} with {}',
'a close-up photo of the {} with {}',
'a rendition of the {} with {}',
'a photo of the clean {} with {}',
'a rendition of a {} with {}',
'a photo of a nice {} with {}',
'a good photo of a {} with {}',
'a photo of the nice {} with {}',
'a photo of the small {} with {}',
'a photo of the weird {} with {}',
'a photo of the large {} with {}',
'a photo of a cool {} with {}',
'a photo of a small {} with {}',
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
if self.per_image_tokens and np.random.uniform() < self.mixing_prob:
text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images])
else:
text = random.choice(imagenet_templates_small).format(placeholder_string)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -0,0 +1,129 @@
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
imagenet_templates_small = [
'a painting in the style of {}',
'a rendering in the style of {}',
'a cropped painting in the style of {}',
'the painting in the style of {}',
'a clean painting in the style of {}',
'a dirty painting in the style of {}',
'a dark painting in the style of {}',
'a picture in the style of {}',
'a cool painting in the style of {}',
'a close-up painting in the style of {}',
'a bright painting in the style of {}',
'a cropped painting in the style of {}',
'a good painting in the style of {}',
'a close-up painting in the style of {}',
'a rendition in the style of {}',
'a nice painting in the style of {}',
'a small painting in the style of {}',
'a weird painting in the style of {}',
'a large painting in the style of {}',
]
imagenet_dual_templates_small = [
'a painting in the style of {} with {}',
'a rendering in the style of {} with {}',
'a cropped painting in the style of {} with {}',
'the painting in the style of {} with {}',
'a clean painting in the style of {} with {}',
'a dirty painting in the style of {} with {}',
'a dark painting in the style of {} with {}',
'a cool painting in the style of {} with {}',
'a close-up painting in the style of {} with {}',
'a bright painting in the style of {} with {}',
'a cropped painting in the style of {} with {}',
'a good painting in the style of {} with {}',
'a painting of one {} in the style of {}',
'a nice painting in the style of {} with {}',
'a small painting in the style of {} with {}',
'a weird painting in the style of {} with {}',
'a large painting in the style of {} with {}',
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
else:
text = random.choice(imagenet_templates_small).format(self.placeholder_token)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example

View File

@ -7,7 +7,9 @@ https://github.com/CompVis/taming-transformers
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
import os
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
@ -64,6 +66,7 @@ class DDPM(pl.LightningModule):
cosine_s=8e-3, cosine_s=8e-3,
given_betas=None, given_betas=None,
original_elbo_weight=0., original_elbo_weight=0.,
embedding_reg_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1., l_simple_weight=1.,
conditioning_key=None, conditioning_key=None,
@ -98,6 +101,7 @@ class DDPM(pl.LightningModule):
self.v_posterior = v_posterior self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight self.l_simple_weight = l_simple_weight
self.embedding_reg_weight = embedding_reg_weight
if monitor is not None: if monitor is not None:
self.monitor = monitor self.monitor = monitor
@ -427,6 +431,7 @@ class LatentDiffusion(DDPM):
def __init__(self, def __init__(self,
first_stage_config, first_stage_config,
cond_stage_config, cond_stage_config,
personalization_config,
num_timesteps_cond=None, num_timesteps_cond=None,
cond_stage_key="image", cond_stage_key="image",
cond_stage_trainable=False, cond_stage_trainable=False,
@ -436,6 +441,7 @@ class LatentDiffusion(DDPM):
scale_factor=1.0, scale_factor=1.0,
scale_by_std=False, scale_by_std=False,
*args, **kwargs): *args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1) self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps'] assert self.num_timesteps_cond <= kwargs['timesteps']
@ -450,6 +456,7 @@ class LatentDiffusion(DDPM):
self.concat_mode = concat_mode self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key self.cond_stage_key = cond_stage_key
try: try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except: except:
@ -460,6 +467,7 @@ class LatentDiffusion(DDPM):
self.register_buffer('scale_factor', torch.tensor(scale_factor)) self.register_buffer('scale_factor', torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_config) self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config) self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False self.clip_denoised = False
self.bbox_tokenizer = None self.bbox_tokenizer = None
@ -469,6 +477,25 @@ class LatentDiffusion(DDPM):
self.init_from_ckpt(ckpt_path, ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True self.restarted_from_ckpt = True
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
self.model.eval()
self.model.train = disabled_train
for param in self.model.parameters():
param.requires_grad = False
self.embedding_manager = self.instantiate_embedding_manager(personalization_config, self.cond_stage_model)
self.emb_ckpt_counter = 0
# if self.embedding_manager.is_clip:
# self.cond_stage_model.update_embedding_func(self.embedding_manager)
for param in self.embedding_manager.embedding_parameters():
param.requires_grad = True
def make_cond_schedule(self, ): def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
@ -530,6 +557,15 @@ class LatentDiffusion(DDPM):
except urllib.error.URLError: except urllib.error.URLError:
raise SystemExit("* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.") raise SystemExit("* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.")
self.cond_stage_model = model self.cond_stage_model = model
def instantiate_embedding_manager(self, config, embedder):
model = instantiate_from_config(config, embedder=embedder)
if config.params.get("embedding_manager_ckpt", None): # do not load if missing OR empty string
model.load(config.params.embedding_manager_ckpt)
return model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = [] denoise_row = []
@ -555,7 +591,7 @@ class LatentDiffusion(DDPM):
def get_learned_conditioning(self, c): def get_learned_conditioning(self, c):
if self.cond_stage_forward is None: if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c) c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager)
if isinstance(c, DiagonalGaussianDistribution): if isinstance(c, DiagonalGaussianDistribution):
c = c.mode() c = c.mode()
else: else:
@ -880,6 +916,7 @@ class LatentDiffusion(DDPM):
if self.shorten_cond_schedule: # TODO: drop this option if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device) tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs) return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
@ -1046,6 +1083,14 @@ class LatentDiffusion(DDPM):
loss += (self.original_elbo_weight * loss_vlb) loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss}) loss_dict.update({f'{prefix}/loss': loss})
if self.embedding_reg_weight > 0:
loss_embedding_reg = self.embedding_manager.embedding_to_coarse_loss().mean()
loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg})
loss += (self.embedding_reg_weight * loss_embedding_reg)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
@ -1250,11 +1295,10 @@ class LatentDiffusion(DDPM):
return samples, intermediates return samples, intermediates
@torch.no_grad() @torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
plot_diffusion_rows=True, **kwargs): plot_diffusion_rows=False, **kwargs):
use_ddim = ddim_steps is not None use_ddim = ddim_steps is not None
@ -1312,6 +1356,16 @@ class LatentDiffusion(DDPM):
if plot_denoise_rows: if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row) denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid log["denoise_row"] = denoise_grid
uc = self.get_learned_conditioning(len(c) * [""])
sample_scaled, _ = self.sample_log(cond=c,
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=5.0,
unconditional_conditioning=uc)
log["samples_scaled"] = self.decode_first_stage(sample_scaled)
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage): self.first_stage_model, IdentityFirstStage):
@ -1364,13 +1418,18 @@ class LatentDiffusion(DDPM):
def configure_optimizers(self): def configure_optimizers(self):
lr = self.learning_rate lr = self.learning_rate
params = list(self.model.parameters())
if self.cond_stage_trainable: if self.embedding_manager is not None:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!") params = list(self.embedding_manager.embedding_parameters())
params = params + list(self.cond_stage_model.parameters()) # params = list(self.cond_stage_model.transformer.text_model.embeddings.embedding_manager.embedding_parameters())
if self.learn_logvar: else:
print('Diffusion model optimizing logvar') params = list(self.model.parameters())
params.append(self.logvar) if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
opt = torch.optim.AdamW(params, lr=lr) opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler: if self.use_scheduler:
assert 'target' in self.scheduler_config assert 'target' in self.scheduler_config
@ -1395,6 +1454,18 @@ class LatentDiffusion(DDPM):
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x return x
@rank_zero_only
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
if os.path.isdir(self.trainer.checkpoint_callback.dirpath):
self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, "embeddings.pt"))
if (self.global_step - self.emb_ckpt_counter) > 500:
self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, f"embeddings_gs-{self.global_step}.pt"))
self.emb_ckpt_counter += 500
class DiffusionWrapper(pl.LightningModule): class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key): def __init__(self, diff_model_config, conditioning_key):

View File

@ -109,7 +109,7 @@ def checkpoint(func, inputs, params, flag):
explicitly take as arguments. explicitly take as arguments.
:param flag: if False, disable gradient checkpointing. :param flag: if False, disable gradient checkpointing.
""" """
if flag: if False: # disabled checkpointing to allow requires_grad = False for main model
args = tuple(inputs) + tuple(params) args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args) return CheckpointFunction.apply(func, len(inputs), *args)
else: else:

View File

@ -0,0 +1,165 @@
from cmath import log
import torch
from torch import nn
import sys
from ldm.data.personalized import per_img_token_list
from transformers import CLIPTokenizer
from functools import partial
DEFAULT_PLACEHOLDER_TOKEN = ["*"]
PROGRESSIVE_SCALE = 2000
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"]
sys.stdout.write(f"tokeme: {tokens}")
assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string"
return tokens[0, 1]
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
# assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string"
token = token[0, 1]
return token
def get_embedding_for_clip_token(embedder, token):
return embedder(token.unsqueeze(0))[0, 0]
class EmbeddingManager(nn.Module):
def __init__(
self,
embedder,
placeholder_strings=None,
initializer_words=None,
per_image_tokens=False,
num_vectors_per_token=1,
progressive_words=False,
**kwargs
):
super().__init__()
self.string_to_token_dict = {}
self.string_to_param_dict = nn.ParameterDict()
self.initial_embeddings = nn.ParameterDict() # These should not be optimized
self.progressive_words = progressive_words
self.progressive_counter = 0
self.max_vectors_per_token = num_vectors_per_token
if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder
self.is_clip = True
get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)
get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)
token_dim = 1280
else: # using LDM's BERT encoder
self.is_clip = False
get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)
get_embedding_for_tkn = embedder.transformer.token_emb
token_dim = 1280
if per_image_tokens:
placeholder_strings.extend(per_img_token_list)
for idx, placeholder_string in enumerate(placeholder_strings):
token = get_token_for_string(placeholder_string)
if initializer_words and idx < len(initializer_words):
init_word_token = get_token_for_string(initializer_words[idx])
with torch.no_grad():
init_word_embedding = get_embedding_for_tkn(init_word_token.cpu())
token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)
self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)
else:
token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))
self.string_to_token_dict[placeholder_string] = token
self.string_to_param_dict[placeholder_string] = token_params
def forward(
self,
tokenized_text,
embedded_text,
):
b, n, device = *tokenized_text.shape, tokenized_text.device
for placeholder_string, placeholder_token in self.string_to_token_dict.items():
placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)
if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement
placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))
embedded_text[placeholder_idx] = placeholder_embedding
else: # otherwise, need to insert and keep track of changing indices
if self.progressive_words:
self.progressive_counter += 1
max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE
else:
max_step_tokens = self.max_vectors_per_token
num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)
placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))
if placeholder_rows.nelement() == 0:
continue
sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)
sorted_rows = placeholder_rows[sort_idx]
for idx in range(len(sorted_rows)):
row = sorted_rows[idx]
col = sorted_cols[idx]
new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]
new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]
embedded_text[row] = new_embed_row
tokenized_text[row] = new_token_row
return embedded_text
def save(self, ckpt_path):
torch.save({"string_to_token": self.string_to_token_dict,
"string_to_param": self.string_to_param_dict}, ckpt_path)
def load(self, ckpt_path):
ckpt = torch.load(ckpt_path, map_location='cpu')
self.string_to_token_dict = ckpt["string_to_token"]
self.string_to_param_dict = ckpt["string_to_param"]
def get_embedding_norms_squared(self):
all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim
param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders
return param_norm_squared
def embedding_parameters(self):
return self.string_to_param_dict.parameters()
def embedding_to_coarse_loss(self):
loss = 0.
num_embeddings = len(self.initial_embeddings)
for key in self.initial_embeddings:
optimized = self.string_to_param_dict[key]
coarse = self.initial_embeddings[key].clone().to(optimized.device)
loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings
return loss

View File

@ -0,0 +1,397 @@
import torch
import torch.nn as nn
from functools import partial
import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
def _expand_mask(mask, dtype, tgt_len = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def _build_causal_attention_mask(bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
def forward(self, batch, key=None):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
c = self.embedding(c)
return c
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, x):
return self(x)
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
return tokens
@torch.no_grad()
def encode(self, text):
tokens = self(text)
if not self.vq_interface:
return tokens
return None, None, [None, None, tokens]
def decode(self, text):
return text
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
def forward(self, text, embedding_manager=None):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager)
return z
def encode(self, text, **kwargs):
# output of length 77
return self(text, **kwargs)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
def forward(self,x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
self.freeze()
def embedding_forward(
self,
input_ids = None,
position_ids = None,
inputs_embeds = None,
embedding_manager = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
def encoder_forward(
self,
inputs_embeds,
attention_mask = None,
causal_attention_mask = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
def text_encoder_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = self.final_layer_norm(last_hidden_state)
return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
def transformer_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
):
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
embedding_manager = embedding_manager
)
self.transformer.forward = transformer_forward.__get__(self.transformer)
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs)
return z
def encode(self, text, **kwargs):
return self(text, **kwargs)
class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
self.normalize = normalize
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = clip.tokenize(text).to(self.device)
z = self.model.encode_text(tokens)
if self.normalize:
z = z / torch.linalg.norm(z, dim=1, keepdim=True)
return z
def encode(self, text):
z = self(text)
if z.ndim==2:
z = z[:, None, :]
z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
return z
class FrozenClipImageEmbedder(nn.Module):
"""
Uses the CLIP image encoder.
"""
def __init__(
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
return self.model.encode_image(self.preprocess(x))
if __name__ == "__main__":
from ldm.util import count_params
model = FrozenCLIPEmbedder()
count_params(model, verbose=True)

View File

@ -8,6 +8,27 @@ import kornia
from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
def _expand_mask(mask, dtype, tgt_len = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def _build_causal_attention_mask(bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class AbstractEncoder(nn.Module): class AbstractEncoder(nn.Module):
def __init__(self): def __init__(self):
@ -98,18 +119,17 @@ class BERTEmbedder(AbstractEncoder):
attn_layers=Encoder(dim=n_embed, depth=n_layer), attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout) emb_dropout=embedding_dropout)
def forward(self, text): def forward(self, text, embedding_manager=None):
if self.use_tknz_fn: if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device) tokens = self.tknz_fn(text)#.to(self.device)
else: else:
tokens = text tokens = text
z = self.transformer(tokens, return_embeddings=True) z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager)
return z return z
def encode(self, text): def encode(self, text, **kwargs):
# output of length 77 # output of length 77
return self(text) return self(text, **kwargs)
class SpatialRescaler(nn.Module): class SpatialRescaler(nn.Module):
def __init__(self, def __init__(self,
@ -152,22 +172,165 @@ class FrozenCLIPEmbedder(AbstractEncoder):
self.max_length = max_length self.max_length = max_length
self.freeze() self.freeze()
def embedding_forward(
self,
input_ids = None,
position_ids = None,
inputs_embeds = None,
embedding_manager = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
if embedding_manager is not None:
inputs_embeds = embedding_manager(input_ids, inputs_embeds)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)
def encoder_forward(
self,
inputs_embeds,
attention_mask = None,
causal_attention_mask = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return hidden_states
self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)
def text_encoder_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
last_hidden_state = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = self.final_layer_norm(last_hidden_state)
return last_hidden_state
self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)
def transformer_forward(
self,
input_ids = None,
attention_mask = None,
position_ids = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
embedding_manager = None,
):
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
embedding_manager = embedding_manager
)
self.transformer.forward = transformer_forward.__get__(self.transformer)
def freeze(self): def freeze(self):
self.transformer = self.transformer.eval() self.transformer = self.transformer.eval()
for param in self.parameters(): for param in self.parameters():
param.requires_grad = False param.requires_grad = False
def forward(self, text): def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt") return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device) tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens) z = self.transformer(input_ids=tokens, **kwargs)
z = outputs.last_hidden_state
return z return z
def encode(self, text): def encode(self, text, **kwargs):
return self(text) return self(text, **kwargs)
class FrozenCLIPTextEmbedder(nn.Module): class FrozenCLIPTextEmbedder(nn.Module):

View File

@ -485,7 +485,8 @@ class AttentionLayers(nn.Module):
mask=None, mask=None,
context_mask=None, context_mask=None,
mems=None, mems=None,
return_hiddens=False return_hiddens=False,
**kwargs
): ):
hiddens = [] hiddens = []
intermediates = [] intermediates = []
@ -603,11 +604,19 @@ class TransformerWrapper(nn.Module):
return_mems=False, return_mems=False,
return_attn=False, return_attn=False,
mems=None, mems=None,
embedding_manager=None,
**kwargs **kwargs
): ):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
x += self.pos_emb(x) embedded_x = self.token_emb(x)
if embedding_manager:
x = embedding_manager(x, embedded_x)
else:
x = embedded_x
x = x + self.pos_emb(x)
x = self.emb_dropout(x) x = self.emb_dropout(x)
x = self.project_emb(x) x = self.project_emb(x)

View File

@ -89,6 +89,7 @@ class T2I:
downsampling_factor downsampling_factor
precision precision
strength strength
embedding_path
The vast majority of these arguments default to reasonable values. The vast majority of these arguments default to reasonable values.
""" """
@ -113,6 +114,7 @@ The vast majority of these arguments default to reasonable values.
precision='autocast', precision='autocast',
full_precision=False, full_precision=False,
strength=0.75, # default in scripts/img2img.py strength=0.75, # default in scripts/img2img.py
embedding_path=None,
latent_diffusion_weights=False # just to keep track of this parameter when regenerating prompt latent_diffusion_weights=False # just to keep track of this parameter when regenerating prompt
): ):
self.outdir = outdir self.outdir = outdir
@ -133,6 +135,7 @@ The vast majority of these arguments default to reasonable values.
self.precision = precision self.precision = precision
self.full_precision = full_precision self.full_precision = full_precision
self.strength = strength self.strength = strength
self.embedding_path = embedding_path
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
self.latent_diffusion_weights=latent_diffusion_weights self.latent_diffusion_weights=latent_diffusion_weights
@ -143,7 +146,7 @@ The vast majority of these arguments default to reasonable values.
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None, def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,init_img=None,skip_normalize=False): cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None,skip_normalize=False):
""" """
Generate an image from the prompt, writing iteration images into the outdir Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -158,6 +161,7 @@ The vast majority of these arguments default to reasonable values.
batch_size = batch_size or self.batch_size batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength # not actually used here, but preserved for code refactoring strength = strength or self.strength # not actually used here, but preserved for code refactoring
embedding_path = embedding_path or self.embedding_path
model = self.load_model() # will instantiate the model or return it from cache model = self.load_model() # will instantiate the model or return it from cache
@ -268,7 +272,7 @@ The vast majority of these arguments default to reasonable values.
# There is lots of shared code between this and txt2img and should be refactored. # There is lots of shared code between this and txt2img and should be refactored.
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False): cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,skip_normalize=False):
""" """
Generate an image from the prompt and the initial image, writing iteration images into the outdir Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
@ -281,6 +285,7 @@ The vast majority of these arguments default to reasonable values.
batch_size = batch_size or self.batch_size batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength strength = strength or self.strength
embedding_path = embedding_path or self.embedding_path
if init_img is None: if init_img is None:
print("no init_img provided!") print("no init_img provided!")
@ -431,6 +436,7 @@ The vast majority of these arguments default to reasonable values.
config = OmegaConf.load(self.config) config = OmegaConf.load(self.config)
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = self._load_model_from_config(config,self.weights) model = self._load_model_from_config(config,self.weights)
model.embedding_manager.load(self.embedding_path)
self.model = model.to(self.device) self.model = model.to(self.device)
except AttributeError: except AttributeError:
raise SystemExit raise SystemExit

287
ldm/stable_txt2img.py Normal file
View File

@ -0,0 +1,287 @@
import argparse, os, sys, glob
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action='store_true',
help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="number of ddim sampling steps",
)
parser.add_argument(
"--plms",
action='store_true',
help="use plms sampling",
)
parser.add_argument(
"--laion400m",
action='store_true',
help="uses the LAION400M model",
)
parser.add_argument(
"--fixed_code",
action='store_true',
help="if enabled, uses the same starting code across samples ",
)
parser.add_argument(
"--ddim_eta",
type=float,
default=0.0,
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
)
parser.add_argument(
"--n_iter",
type=int,
default=2,
help="sample this often",
)
parser.add_argument(
"--H",
type=int,
default=512,
help="image height, in pixel space",
)
parser.add_argument(
"--W",
type=int,
default=512,
help="image width, in pixel space",
)
parser.add_argument(
"--C",
type=int,
default=4,
help="latent channels",
)
parser.add_argument(
"--f",
type=int,
default=8,
help="downsampling factor",
)
parser.add_argument(
"--n_samples",
type=int,
default=3,
help="how many samples to produce for each given prompt. A.k.a. batch size",
)
parser.add_argument(
"--n_rows",
type=int,
default=0,
help="rows in the grid (default: n_samples)",
)
parser.add_argument(
"--scale",
type=float,
default=7.5,
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--from-file",
type=str,
help="if specified, load prompts from this file",
)
parser.add_argument(
"--config",
type=str,
default="configs/stable-diffusion/v1-inference.yaml",
help="path to config which constructs model",
)
parser.add_argument(
"--ckpt",
type=str,
default="models/ldm/stable-diffusion-v1/model.ckpt",
help="path to checkpoint of model",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
)
parser.add_argument(
"--embedding_path",
type=str,
help="Path to a pre-trained embedding manager checkpoint")
opt = parser.parse_args()
if opt.laion400m:
print("Falling back to LAION 400M model...")
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
model.embedding_manager.load(opt.embedding_path)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
if opt.plms:
sampler = PLMSSampler(model)
else:
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
if not opt.from_file:
prompt = opt.prompt
assert prompt is not None
data = [batch_size * [prompt]]
else:
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
data = list(chunk(data, batch_size))
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
start_code = None
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
tic = time.time()
all_samples = list()
for n in trange(opt.n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
uc = None
if opt.scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.jpg"))
base_count += 1
if not opt.skip_grid:
all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}-{grid_count:04}.jpg'))
grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")
if __name__ == "__main__":
main()

View File

@ -12,6 +12,7 @@ from queue import Queue
from inspect import isfunction from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
def log_txt_as_img(wh, xc, size=10): def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height) # wh a tuple of (width, height)
# xc a list of captions to plot # xc a list of captions to plot
@ -20,7 +21,7 @@ def log_txt_as_img(wh, xc, size=10):
for bi in range(b): for bi in range(b):
txt = Image.new("RGB", wh, color="white") txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt) draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256)) nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
@ -73,14 +74,14 @@ def count_params(model, verbose=False):
return total_params return total_params
def instantiate_from_config(config): def instantiate_from_config(config, **kwargs):
if not "target" in config: if not "target" in config:
if config == '__is_first_stage__': if config == '__is_first_stage__':
return None return None
elif config == "__is_unconditional__": elif config == "__is_unconditional__":
return None return None
raise KeyError("Expected key `target` to instantiate.") raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict())) return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
def get_obj_from_str(string, reload=False): def get_obj_from_str(string, reload=False):

66
main.py
View File

@ -2,6 +2,7 @@ import argparse, os, sys, datetime, glob, importlib, csv
import numpy as np import numpy as np
import time import time
import torch import torch
import torchvision import torchvision
import pytorch_lightning as pl import pytorch_lightning as pl
@ -20,6 +21,22 @@ from pytorch_lightning.utilities import rank_zero_info
from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
config.model.params.ckpt_path = ckpt
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
return model
def get_parser(**parser_kwargs): def get_parser(**parser_kwargs):
def str2bool(v): def str2bool(v):
@ -120,6 +137,23 @@ def get_parser(**parser_kwargs):
default=True, default=True,
help="scale base-lr by ngpu * batch_size * n_accumulate", help="scale base-lr by ngpu * batch_size * n_accumulate",
) )
parser.add_argument(
"--datadir_in_name",
type=str2bool,
nargs="?",
const=True,
default=True,
help="Prepend the final directory in the data_root to the output directory name")
parser.add_argument("--actual_resume", type=str, default="", help="Path to model to actually resume from")
parser.add_argument("--data_root", type=str, required=True, help="Path to directory with training images")
parser.add_argument("--embedding_manager_ckpt", type=str, default="", help="Initialize embedding manager from a checkpoint")
parser.add_argument("--placeholder_tokens", type=str, nargs="+", default=["*"])
parser.add_argument("--init_word", type=str, help="Word to use as source for initial token embedding.")
return parser return parser
@ -502,6 +536,10 @@ if __name__ == "__main__":
name = "_" + cfg_name name = "_" + cfg_name
else: else:
name = "" name = ""
if opt.datadir_in_name:
now = os.path.basename(os.path.normpath(opt.data_root)) + now
nowname = now + name + opt.postfix nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname) logdir = os.path.join(opt.logdir, nowname)
@ -532,7 +570,18 @@ if __name__ == "__main__":
lightning_config.trainer = trainer_config lightning_config.trainer = trainer_config
# model # model
model = instantiate_from_config(config.model)
# config.model.params.personalization_config.params.init_word = opt.init_word
config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt
config.model.params.personalization_config.params.placeholder_tokens = opt.placeholder_tokens
if opt.init_word:
config.model.params.personalization_config.params.initializer_words[0] = opt.init_word
if opt.actual_resume:
model = load_model_from_config(config, opt.actual_resume)
else:
model = instantiate_from_config(config.model)
# trainer and callbacks # trainer and callbacks
trainer_kwargs = dict() trainer_kwargs = dict()
@ -578,7 +627,7 @@ if __name__ == "__main__":
if hasattr(model, "monitor"): if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.") print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3 default_modelckpt_cfg["params"]["save_top_k"] = 1
if "modelcheckpoint" in lightning_config: if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint modelckpt_cfg = lightning_config.modelcheckpoint
@ -655,11 +704,16 @@ if __name__ == "__main__":
del callbacks_cfg['ignore_keys_callback'] del callbacks_cfg['ignore_keys_callback']
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer_kwargs["max_steps"] = opt.max_steps
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ### trainer.logdir = logdir ###
# data # data
config.data.params.train.params.data_root = opt.data_root
config.data.params.validation.params.data_root = opt.data_root
data = instantiate_from_config(config.data)
data = instantiate_from_config(config.data) data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is. # calling these ourselves should not be necessary but it is.
@ -710,8 +764,8 @@ if __name__ == "__main__":
import signal import signal
signal.signal(signal.SIGUSR1, melk) signal.signal(signal.SIGTERM, melk)
signal.signal(signal.SIGUSR2, divein) signal.signal(signal.SIGTERM, divein)
# run # run
if opt.train: if opt.train:
@ -737,5 +791,5 @@ if __name__ == "__main__":
dst = os.path.join(dst, "debug_runs", name) dst = os.path.join(dst, "debug_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True) os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst) os.rename(logdir, dst)
if trainer.global_rank == 0: # if trainer.global_rank == 0:
print(trainer.profiler.summary()) # print(trainer.profiler.summary())

View File

@ -9,6 +9,7 @@ kornia==0.6.0
numpy==1.19.2 numpy==1.19.2
omegaconf==2.1.1 omegaconf==2.1.1
opencv-python==4.1.2.30 opencv-python==4.1.2.30
pillow==9.0.1
pudb==2019.2 pudb==2019.2
pytorch pytorch
pytorch-lightning==1.4.2 pytorch-lightning==1.4.2

View File

@ -57,7 +57,8 @@ def main():
weights=weights, weights=weights,
full_precision=opt.full_precision, full_precision=opt.full_precision,
config=config, config=config,
latent_diffusion_weights=opt.laion400m # this is solely for recreating the prompt latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt
embedding_path=opt.embedding_path
) )
# make sure the output directory exists # make sure the output directory exists
@ -268,6 +269,10 @@ def create_argv_parser():
type=str, type=str,
default="outputs/img-samples", default="outputs/img-samples",
help="directory in which to place generated images and a log of prompts and seeds") help="directory in which to place generated images and a log of prompts and seeds")
parser.add_argument('--embedding_path',
type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")
return parser return parser

View File

@ -0,0 +1,83 @@
from ldm.modules.encoders.modules import BERTTokenizer
from ldm.modules.embedding_manager import EmbeddingManager
import argparse, os
from functools import partial
import torch
def get_placeholder_loop(placeholder_string, tokenizer):
new_placeholder = None
while True:
if new_placeholder is None:
new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ")
else:
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
token = tokenizer(new_placeholder)
if torch.count_nonzero(token) == 3:
return new_placeholder, token[0, 1]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--manager_ckpts",
type=str,
nargs="+",
required=True,
help="Paths to a set of embedding managers to be merged."
)
parser.add_argument(
"--output_path",
type=str,
required=True,
help="Output path for the merged manager",
)
args = parser.parse_args()
tokenizer = BERTTokenizer(vq_interface=False, max_length=77)
EmbeddingManager = partial(EmbeddingManager, tokenizer, ["*"])
string_to_token_dict = {}
string_to_param_dict = torch.nn.ParameterDict()
placeholder_to_src = {}
for manager_ckpt in args.manager_ckpts:
print(f"Parsing {manager_ckpt}...")
manager = EmbeddingManager()
manager.load(manager_ckpt)
for placeholder_string in manager.string_to_token_dict:
if not placeholder_string in string_to_token_dict:
string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string]
string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string]
placeholder_to_src[placeholder_string] = manager_ckpt
else:
new_placeholder, new_token = get_placeholder_loop(placeholder_string, tokenizer)
string_to_token_dict[new_placeholder] = new_token
string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string]
placeholder_to_src[new_placeholder] = manager_ckpt
print("Saving combined manager...")
merged_manager = EmbeddingManager()
merged_manager.string_to_param_dict = string_to_param_dict
merged_manager.string_to_token_dict = string_to_token_dict
merged_manager.save(args.output_path)
print("Managers merged. Final list of placeholders: ")
print(placeholder_to_src)

17
train.ps1 Normal file
View File

@ -0,0 +1,17 @@
conda activate ldm
pip install -e .
$env:PL_TORCH_DISTRIBUTED_BACKEND="gloo"
python ./main.py --base ./configs/stable-diffusion/v1-finetune.yaml `
-t `
--actual_resume ./models/ldm/stable-diffusion-v1/model.ckpt `
-n my_cat `
--gpus 0, `
--data_root D:/textual-inversion/my_cat `
--init_Word 'Isla Fisher'
# python ./scripts/train_personalization.py --base ./configs/stable-diffusion/v1-finetune.yaml `
# -t `
# --actual_resume ../stable-diffusion-dream/models/ldm/stable-diffusion-v1/model.ckpt `
# --gpus 1 `
# --data_root D:/textual-inversion/isla_fisher `
# --resume 'logs/my_cat2022-08-23T01-46-37_my_cat' `
# --init_Word 'cat'