From 8952196bbfa75cf2f4a66a847e1f3b7a50d76cf5 Mon Sep 17 00:00:00 2001 From: Sean McLellan Date: Tue, 23 Aug 2022 18:26:28 -0400 Subject: [PATCH] Add personalization --- .gitignore | 166 +++++++++ README.md | 58 +++ configs/stable-diffusion/v1-finetune.yaml | 105 ++++++ configs/stable-diffusion/v1-inference.yaml | 9 + environment.yaml | 1 + ldm/data/personalized.py | 160 +++++++++ ldm/data/personalized_style.py | 129 +++++++ ldm/models/diffusion/ddpm.py | 93 ++++- ldm/modules/diffusionmodules/util.py | 2 +- ldm/modules/embedding_manager.py | 165 +++++++++ ldm/modules/encoders/modules copy.py | 397 +++++++++++++++++++++ ldm/modules/encoders/modules.py | 183 +++++++++- ldm/modules/x_transformer.py | 15 +- ldm/simplet2i.py | 10 +- ldm/stable_txt2img.py | 287 +++++++++++++++ ldm/util.py | 7 +- main.py | 66 +++- requirements.txt | 1 + scripts/dream.py | 7 +- scripts/merge_embeddings.py | 83 +++++ train.ps1 | 17 + 21 files changed, 1924 insertions(+), 37 deletions(-) create mode 100644 .gitignore create mode 100644 configs/stable-diffusion/v1-finetune.yaml create mode 100644 ldm/data/personalized.py create mode 100644 ldm/data/personalized_style.py create mode 100644 ldm/modules/embedding_manager.py create mode 100644 ldm/modules/encoders/modules copy.py create mode 100644 ldm/stable_txt2img.py create mode 100644 scripts/merge_embeddings.py create mode 100644 train.ps1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..e05fa41e06 --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +**/*.ckpt +src/ +logs/ +**/__pycache__/ +outputs \ No newline at end of file diff --git a/README.md b/README.md index 48dbc60815..5d733d8af3 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,64 @@ cat aspect of the image and 75% on the white duck aspect use any combination of integers and floating point numbers, and they do not need to add up to 1. +## Personalizing Text-to-Image Generation + +You may personalize the generated images to provide your own styles or objects by training a new LDM checkpoint +and introducing a new vocabulary to the fixed model. + +To train, prepare a folder that contains images sized at 512x512 and execute the following: + +~~~~ +# As the default backend is not available on Windows, if you're using that platform, execute SET PL_TORCH_DISTRIBUTED_BACKEND=gloo +(ldm) ~/stable-diffusion$ python3 ./main.py --base ./configs/stable-diffusion/v1-finetune.yaml \ + -t \ + --actual_resume ./models/ldm/stable-diffusion-v1/model.ckpt \ + -n my_cat \ + -gpus 0, \ + --data_root D:/textual-inversion/my_cat \ + --init_word 'cat' +~~~~ + +During the training process, files will be created in /logs/[project][time][project]/ +where you can see the process. + +conditioning* contains the training prompts +inputs, reconstruction the input images for the training epoch +samples, samples scaled for a sample of the prompt and one with the init word provided + +On a RTX3090, the process for SD will take ~1h @1.6 iterations/sec. + +Note: According to the associated paper, the optimal number of images is 3-5 any more images than that and your model might not converge. + +Training will run indefinately, but you may wish to stop it before the heat death of the universe, when you fine a low loss epoch or around ~5000 iterations. + +Once the model is trained, specify the trained .pt file when starting dream using + +~~~~ +(ldm) ~/stable-diffusion$ python3 ./scripts/dream.py --embedding_path /path/to/embedding.pt --full_precision +~~~~ + +Then, to utilize your subject at the dream prompt + +~~~ +dream> "a photo of *" +~~~ + +this also works with image2image +~~~~ +dream> "waterfall and rainbow in the style of *" --init_img=./init-images/crude_drawing.png --strength=0.5 -s100 -n4 +~~~~ + +It's also possible to train multiple tokens (modify the placeholder string in configs/stable-diffusion/v1-finetune.yaml) and combine LDM checkpoints using: + +~~~~ +(ldm) ~/stable-diffusion$ python3 ./scripts/merge_embeddings.py \ + --manager_ckpts /path/to/first/embedding.pt /path/to/second/embedding.pt [...] \ + --output_path /path/to/output/embedding.pt +~~~~ + +Credit goes to @rinongal and the repository located at https://github.com/rinongal/textual_inversion Please see the repository and associated paper for details and limitations. + ## Changes * v1.07 (23 August 2022) diff --git a/configs/stable-diffusion/v1-finetune.yaml b/configs/stable-diffusion/v1-finetune.yaml new file mode 100644 index 0000000000..da6a5e775b --- /dev/null +++ b/configs/stable-diffusion/v1-finetune.yaml @@ -0,0 +1,105 @@ +model: + base_learning_rate: 5.0e-03 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 64 + channels: 4 + cond_stage_trainable: true # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + embedding_reg_weight: 0.0 + + personalization_config: + target: ldm.modules.embedding_manager.EmbeddingManager + params: + placeholder_strings: ["*"] + initializer_words: ["sculpture"] + per_image_tokens: false + num_vectors_per_token: 1 + progressive_words: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + +data: + target: main.DataModuleFromConfig + params: + batch_size: 2 + num_workers: 16 + wrap: false + train: + target: ldm.data.personalized.PersonalizedBase + params: + size: 512 + set: train + per_image_tokens: false + repeats: 100 + validation: + target: ldm.data.personalized.PersonalizedBase + params: + size: 512 + set: val + per_image_tokens: false + repeats: 10 + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 500 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True + max_steps: 6100 \ No newline at end of file diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml index d4effe569e..59d8f33125 100644 --- a/configs/stable-diffusion/v1-inference.yaml +++ b/configs/stable-diffusion/v1-inference.yaml @@ -26,6 +26,15 @@ model: f_max: [ 1. ] f_min: [ 1. ] + personalization_config: + target: ldm.modules.embedding_manager.EmbeddingManager + params: + placeholder_strings: ["*"] + initializer_words: ["sculpture"] + per_image_tokens: false + num_vectors_per_token: 1 + progressive_words: False + unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: diff --git a/environment.yaml b/environment.yaml index 0de05e815a..83669e254c 100644 --- a/environment.yaml +++ b/environment.yaml @@ -19,6 +19,7 @@ dependencies: - omegaconf==2.1.1 - test-tube>=0.7.5 - streamlit>=0.73.1 + - pillow==9.0.1 - einops==0.3.0 - torch-fidelity==0.3.0 - transformers==4.19.2 diff --git a/ldm/data/personalized.py b/ldm/data/personalized.py new file mode 100644 index 0000000000..c8a57d09fa --- /dev/null +++ b/ldm/data/personalized.py @@ -0,0 +1,160 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_smallest = [ + 'a photo of a {}', +] + +imagenet_templates_small = [ + 'a photo of a {}', + 'a rendering of a {}', + 'a cropped photo of the {}', + 'the photo of a {}', + 'a photo of a clean {}', + 'a photo of a dirty {}', + 'a dark photo of the {}', + 'a photo of my {}', + 'a photo of the cool {}', + 'a close-up photo of a {}', + 'a bright photo of the {}', + 'a cropped photo of a {}', + 'a photo of the {}', + 'a good photo of the {}', + 'a photo of one {}', + 'a close-up photo of the {}', + 'a rendition of the {}', + 'a photo of the clean {}', + 'a rendition of a {}', + 'a photo of a nice {}', + 'a good photo of a {}', + 'a photo of the nice {}', + 'a photo of the small {}', + 'a photo of the weird {}', + 'a photo of the large {}', + 'a photo of a cool {}', + 'a photo of a small {}', +] + +imagenet_dual_templates_small = [ + 'a photo of a {} with {}', + 'a rendering of a {} with {}', + 'a cropped photo of the {} with {}', + 'the photo of a {} with {}', + 'a photo of a clean {} with {}', + 'a photo of a dirty {} with {}', + 'a dark photo of the {} with {}', + 'a photo of my {} with {}', + 'a photo of the cool {} with {}', + 'a close-up photo of a {} with {}', + 'a bright photo of the {} with {}', + 'a cropped photo of a {} with {}', + 'a photo of the {} with {}', + 'a good photo of the {} with {}', + 'a photo of one {} with {}', + 'a close-up photo of the {} with {}', + 'a rendition of the {} with {}', + 'a photo of the clean {} with {}', + 'a rendition of a {} with {}', + 'a photo of a nice {} with {}', + 'a good photo of a {} with {}', + 'a photo of the nice {} with {}', + 'a photo of the small {} with {}', + 'a photo of the weird {} with {}', + 'a photo of the large {} with {}', + 'a photo of a cool {} with {}', + 'a photo of a small {} with {}', +] + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +class PersonalizedBase(Dataset): + def __init__(self, + data_root, + size=None, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + per_image_tokens=False, + center_crop=False, + mixing_prob=0.25, + coarse_class_text=None, + ): + + self.data_root = data_root + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + self.mixing_prob = mixing_prob + + self.coarse_class_text = coarse_class_text + + if per_image_tokens: + assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == "train": + self._length = self.num_images * repeats + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + placeholder_string = self.placeholder_token + if self.coarse_class_text: + placeholder_string = f"{self.coarse_class_text} {placeholder_string}" + + if self.per_image_tokens and np.random.uniform() < self.mixing_prob: + text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images]) + else: + text = random.choice(imagenet_templates_small).format(placeholder_string) + + example["caption"] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example \ No newline at end of file diff --git a/ldm/data/personalized_style.py b/ldm/data/personalized_style.py new file mode 100644 index 0000000000..b6be7b15c4 --- /dev/null +++ b/ldm/data/personalized_style.py @@ -0,0 +1,129 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_small = [ + 'a painting in the style of {}', + 'a rendering in the style of {}', + 'a cropped painting in the style of {}', + 'the painting in the style of {}', + 'a clean painting in the style of {}', + 'a dirty painting in the style of {}', + 'a dark painting in the style of {}', + 'a picture in the style of {}', + 'a cool painting in the style of {}', + 'a close-up painting in the style of {}', + 'a bright painting in the style of {}', + 'a cropped painting in the style of {}', + 'a good painting in the style of {}', + 'a close-up painting in the style of {}', + 'a rendition in the style of {}', + 'a nice painting in the style of {}', + 'a small painting in the style of {}', + 'a weird painting in the style of {}', + 'a large painting in the style of {}', +] + +imagenet_dual_templates_small = [ + 'a painting in the style of {} with {}', + 'a rendering in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'the painting in the style of {} with {}', + 'a clean painting in the style of {} with {}', + 'a dirty painting in the style of {} with {}', + 'a dark painting in the style of {} with {}', + 'a cool painting in the style of {} with {}', + 'a close-up painting in the style of {} with {}', + 'a bright painting in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'a good painting in the style of {} with {}', + 'a painting of one {} in the style of {}', + 'a nice painting in the style of {} with {}', + 'a small painting in the style of {} with {}', + 'a weird painting in the style of {} with {}', + 'a large painting in the style of {} with {}', +] + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +class PersonalizedBase(Dataset): + def __init__(self, + data_root, + size=None, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + per_image_tokens=False, + center_crop=False, + ): + + self.data_root = data_root + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + + if per_image_tokens: + assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == "train": + self._length = self.num_images * repeats + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + if self.per_image_tokens and np.random.uniform() < 0.25: + text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images]) + else: + text = random.choice(imagenet_templates_small).format(self.placeholder_token) + + example["caption"] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 8bc93258a4..d5f74a0fbe 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -7,7 +7,9 @@ https://github.com/CompVis/taming-transformers """ import torch + import torch.nn as nn +import os import numpy as np import pytorch_lightning as pl from torch.optim.lr_scheduler import LambdaLR @@ -64,6 +66,7 @@ class DDPM(pl.LightningModule): cosine_s=8e-3, given_betas=None, original_elbo_weight=0., + embedding_reg_weight=0., v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta l_simple_weight=1., conditioning_key=None, @@ -98,6 +101,7 @@ class DDPM(pl.LightningModule): self.v_posterior = v_posterior self.original_elbo_weight = original_elbo_weight self.l_simple_weight = l_simple_weight + self.embedding_reg_weight = embedding_reg_weight if monitor is not None: self.monitor = monitor @@ -427,6 +431,7 @@ class LatentDiffusion(DDPM): def __init__(self, first_stage_config, cond_stage_config, + personalization_config, num_timesteps_cond=None, cond_stage_key="image", cond_stage_trainable=False, @@ -436,6 +441,7 @@ class LatentDiffusion(DDPM): scale_factor=1.0, scale_by_std=False, *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std assert self.num_timesteps_cond <= kwargs['timesteps'] @@ -450,6 +456,7 @@ class LatentDiffusion(DDPM): self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key + try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 except: @@ -460,6 +467,7 @@ class LatentDiffusion(DDPM): self.register_buffer('scale_factor', torch.tensor(scale_factor)) self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None @@ -469,6 +477,25 @@ class LatentDiffusion(DDPM): self.init_from_ckpt(ckpt_path, ignore_keys) self.restarted_from_ckpt = True + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + + self.model.eval() + self.model.train = disabled_train + for param in self.model.parameters(): + param.requires_grad = False + + self.embedding_manager = self.instantiate_embedding_manager(personalization_config, self.cond_stage_model) + + self.emb_ckpt_counter = 0 + + # if self.embedding_manager.is_clip: + # self.cond_stage_model.update_embedding_func(self.embedding_manager) + + for param in self.embedding_manager.embedding_parameters(): + param.requires_grad = True + def make_cond_schedule(self, ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() @@ -530,6 +557,15 @@ class LatentDiffusion(DDPM): except urllib.error.URLError: raise SystemExit("* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.") self.cond_stage_model = model + + + def instantiate_embedding_manager(self, config, embedder): + model = instantiate_from_config(config, embedder=embedder) + + if config.params.get("embedding_manager_ckpt", None): # do not load if missing OR empty string + model.load(config.params.embedding_manager_ckpt) + + return model def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): denoise_row = [] @@ -555,7 +591,7 @@ class LatentDiffusion(DDPM): def get_learned_conditioning(self, c): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): - c = self.cond_stage_model.encode(c) + c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() else: @@ -880,6 +916,7 @@ class LatentDiffusion(DDPM): if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset @@ -1046,6 +1083,14 @@ class LatentDiffusion(DDPM): loss += (self.original_elbo_weight * loss_vlb) loss_dict.update({f'{prefix}/loss': loss}) + if self.embedding_reg_weight > 0: + loss_embedding_reg = self.embedding_manager.embedding_to_coarse_loss().mean() + + loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg}) + + loss += (self.embedding_reg_weight * loss_embedding_reg) + loss_dict.update({f'{prefix}/loss': loss}) + return loss, loss_dict def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, @@ -1250,11 +1295,10 @@ class LatentDiffusion(DDPM): return samples, intermediates - @torch.no_grad() def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, **kwargs): + quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False, + plot_diffusion_rows=False, **kwargs): use_ddim = ddim_steps is not None @@ -1312,6 +1356,16 @@ class LatentDiffusion(DDPM): if plot_denoise_rows: denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid + + uc = self.get_learned_conditioning(len(c) * [""]) + sample_scaled, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=5.0, + unconditional_conditioning=uc) + log["samples_scaled"] = self.decode_first_stage(sample_scaled) if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( self.first_stage_model, IdentityFirstStage): @@ -1364,13 +1418,18 @@ class LatentDiffusion(DDPM): def configure_optimizers(self): lr = self.learning_rate - params = list(self.model.parameters()) - if self.cond_stage_trainable: - print(f"{self.__class__.__name__}: Also optimizing conditioner params!") - params = params + list(self.cond_stage_model.parameters()) - if self.learn_logvar: - print('Diffusion model optimizing logvar') - params.append(self.logvar) + + if self.embedding_manager is not None: + params = list(self.embedding_manager.embedding_parameters()) + # params = list(self.cond_stage_model.transformer.text_model.embeddings.embedding_manager.embedding_parameters()) + else: + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) opt = torch.optim.AdamW(params, lr=lr) if self.use_scheduler: assert 'target' in self.scheduler_config @@ -1395,6 +1454,18 @@ class LatentDiffusion(DDPM): x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. return x + @rank_zero_only + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + + if os.path.isdir(self.trainer.checkpoint_callback.dirpath): + self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, "embeddings.pt")) + + if (self.global_step - self.emb_ckpt_counter) > 500: + self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, f"embeddings_gs-{self.global_step}.pt")) + + self.emb_ckpt_counter += 500 + class DiffusionWrapper(pl.LightningModule): def __init__(self, diff_model_config, conditioning_key): diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index a952e6c403..6b5b9dc9e2 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -109,7 +109,7 @@ def checkpoint(func, inputs, params, flag): explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ - if flag: + if False: # disabled checkpointing to allow requires_grad = False for main model args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py new file mode 100644 index 0000000000..dae0ca6ebf --- /dev/null +++ b/ldm/modules/embedding_manager.py @@ -0,0 +1,165 @@ +from cmath import log +import torch +from torch import nn + +import sys + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ["*"] + +PROGRESSIVE_SCALE = 2000 + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"] + sys.stdout.write(f"tokeme: {tokens}") + assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" + + return tokens[0, 1] + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs + ): + super().__init__() + + self.string_to_token_dict = {} + + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = nn.ParameterDict() # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) + token_dim = 1280 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) + + token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) + self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) + else: + token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + + if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + + placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] + new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save({"string_to_token": self.string_to_token_dict, + "string_to_param": self.string_to_param_dict}, ckpt_path) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + + def get_embedding_norms_squared(self): + all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0. + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings + + return loss \ No newline at end of file diff --git a/ldm/modules/encoders/modules copy.py b/ldm/modules/encoders/modules copy.py new file mode 100644 index 0000000000..4fe1dcc27e --- /dev/null +++ b/ldm/modules/encoders/modules copy.py @@ -0,0 +1,397 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + +def _expand_mask(mask, dtype, tgt_len = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + +def _build_causal_attention_mask(bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text, embedding_manager=None): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager) + return z + + def encode(self, text, **kwargs): + # output of length 77 + return self(text, **kwargs) + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + def embedding_forward( + self, + input_ids = None, + position_ids = None, + inputs_embeds = None, + embedding_manager = None, + ) -> torch.Tensor: + + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + if embedding_manager is not None: + inputs_embeds = embedding_manager(input_ids, inputs_embeds) + + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings) + + def encoder_forward( + self, + inputs_embeds, + attention_mask = None, + causal_attention_mask = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return hidden_states + + self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) + + + def text_encoder_forward( + self, + input_ids = None, + attention_mask = None, + position_ids = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + embedding_manager = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager) + + bsz, seq_len = input_shape + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = self.final_layer_norm(last_hidden_state) + + return last_hidden_state + + self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) + + def transformer_forward( + self, + input_ids = None, + attention_mask = None, + position_ids = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + embedding_manager = None, + ): + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + embedding_manager = embedding_manager + ) + + self.transformer.forward = transformer_forward.__get__(self.transformer) + + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text, **kwargs): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + z = self.transformer(input_ids=tokens, **kwargs) + + return z + + def encode(self, text, **kwargs): + return self(text, **kwargs) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) \ No newline at end of file diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 5c02d8c6f5..def6d2136d 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -8,6 +8,27 @@ import kornia from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +def _expand_mask(mask, dtype, tgt_len = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + +def _build_causal_attention_mask(bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask class AbstractEncoder(nn.Module): def __init__(self): @@ -98,18 +119,17 @@ class BERTEmbedder(AbstractEncoder): attn_layers=Encoder(dim=n_embed, depth=n_layer), emb_dropout=embedding_dropout) - def forward(self, text): + def forward(self, text, embedding_manager=None): if self.use_tknz_fn: tokens = self.tknz_fn(text)#.to(self.device) else: tokens = text - z = self.transformer(tokens, return_embeddings=True) + z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager) return z - def encode(self, text): + def encode(self, text, **kwargs): # output of length 77 - return self(text) - + return self(text, **kwargs) class SpatialRescaler(nn.Module): def __init__(self, @@ -152,22 +172,165 @@ class FrozenCLIPEmbedder(AbstractEncoder): self.max_length = max_length self.freeze() + def embedding_forward( + self, + input_ids = None, + position_ids = None, + inputs_embeds = None, + embedding_manager = None, + ) -> torch.Tensor: + + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + if embedding_manager is not None: + inputs_embeds = embedding_manager(input_ids, inputs_embeds) + + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings) + + def encoder_forward( + self, + inputs_embeds, + attention_mask = None, + causal_attention_mask = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return hidden_states + + self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) + + + def text_encoder_forward( + self, + input_ids = None, + attention_mask = None, + position_ids = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + embedding_manager = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager) + + bsz, seq_len = input_shape + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = self.final_layer_norm(last_hidden_state) + + return last_hidden_state + + self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) + + def transformer_forward( + self, + input_ids = None, + attention_mask = None, + position_ids = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + embedding_manager = None, + ): + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + embedding_manager = embedding_manager + ) + + self.transformer.forward = transformer_forward.__get__(self.transformer) + + def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False - def forward(self, text): + def forward(self, text, **kwargs): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens) + z = self.transformer(input_ids=tokens, **kwargs) - z = outputs.last_hidden_state return z - def encode(self, text): - return self(text) + def encode(self, text, **kwargs): + return self(text, **kwargs) class FrozenCLIPTextEmbedder(nn.Module): diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py index 5fc15bf9cf..1316dbd505 100644 --- a/ldm/modules/x_transformer.py +++ b/ldm/modules/x_transformer.py @@ -485,7 +485,8 @@ class AttentionLayers(nn.Module): mask=None, context_mask=None, mems=None, - return_hiddens=False + return_hiddens=False, + **kwargs ): hiddens = [] intermediates = [] @@ -603,11 +604,19 @@ class TransformerWrapper(nn.Module): return_mems=False, return_attn=False, mems=None, + embedding_manager=None, **kwargs ): b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens - x = self.token_emb(x) - x += self.pos_emb(x) + + embedded_x = self.token_emb(x) + + if embedding_manager: + x = embedding_manager(x, embedded_x) + else: + x = embedded_x + + x = x + self.pos_emb(x) x = self.emb_dropout(x) x = self.project_emb(x) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index c7f6263816..25897ed500 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -89,6 +89,7 @@ class T2I: downsampling_factor precision strength + embedding_path The vast majority of these arguments default to reasonable values. """ @@ -113,6 +114,7 @@ The vast majority of these arguments default to reasonable values. precision='autocast', full_precision=False, strength=0.75, # default in scripts/img2img.py + embedding_path=None, latent_diffusion_weights=False # just to keep track of this parameter when regenerating prompt ): self.outdir = outdir @@ -133,6 +135,7 @@ The vast majority of these arguments default to reasonable values. self.precision = precision self.full_precision = full_precision self.strength = strength + self.embedding_path = embedding_path self.model = None # empty for now self.sampler = None self.latent_diffusion_weights=latent_diffusion_weights @@ -143,7 +146,7 @@ The vast majority of these arguments default to reasonable values. def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None, - cfg_scale=None,ddim_eta=None,strength=None,init_img=None,skip_normalize=False): + cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None,skip_normalize=False): """ Generate an image from the prompt, writing iteration images into the outdir The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] @@ -158,6 +161,7 @@ The vast majority of these arguments default to reasonable values. batch_size = batch_size or self.batch_size iterations = iterations or self.iterations strength = strength or self.strength # not actually used here, but preserved for code refactoring + embedding_path = embedding_path or self.embedding_path model = self.load_model() # will instantiate the model or return it from cache @@ -268,7 +272,7 @@ The vast majority of these arguments default to reasonable values. # There is lots of shared code between this and txt2img and should be refactored. def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None, - cfg_scale=None,ddim_eta=None,strength=None,skip_normalize=False): + cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,skip_normalize=False): """ Generate an image from the prompt and the initial image, writing iteration images into the outdir The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] @@ -281,6 +285,7 @@ The vast majority of these arguments default to reasonable values. batch_size = batch_size or self.batch_size iterations = iterations or self.iterations strength = strength or self.strength + embedding_path = embedding_path or self.embedding_path if init_img is None: print("no init_img provided!") @@ -431,6 +436,7 @@ The vast majority of these arguments default to reasonable values. config = OmegaConf.load(self.config) self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = self._load_model_from_config(config,self.weights) + model.embedding_manager.load(self.embedding_path) self.model = model.to(self.device) except AttributeError: raise SystemExit diff --git a/ldm/stable_txt2img.py b/ldm/stable_txt2img.py new file mode 100644 index 0000000000..1f9cc4af52 --- /dev/null +++ b/ldm/stable_txt2img.py @@ -0,0 +1,287 @@ +import argparse, os, sys, glob +import torch +import numpy as np +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext + +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + + +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + model.eval() + return model + + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--prompt", + type=str, + nargs="?", + default="a painting of a virus monster playing guitar", + help="the prompt to render" + ) + parser.add_argument( + "--outdir", + type=str, + nargs="?", + help="dir to write results to", + default="outputs/txt2img-samples" + ) + parser.add_argument( + "--skip_grid", + action='store_true', + help="do not save a grid, only individual samples. Helpful when evaluating lots of samples", + ) + parser.add_argument( + "--skip_save", + action='store_true', + help="do not save individual samples. For speed measurements.", + ) + parser.add_argument( + "--ddim_steps", + type=int, + default=50, + help="number of ddim sampling steps", + ) + parser.add_argument( + "--plms", + action='store_true', + help="use plms sampling", + ) + parser.add_argument( + "--laion400m", + action='store_true', + help="uses the LAION400M model", + ) + parser.add_argument( + "--fixed_code", + action='store_true', + help="if enabled, uses the same starting code across samples ", + ) + parser.add_argument( + "--ddim_eta", + type=float, + default=0.0, + help="ddim eta (eta=0.0 corresponds to deterministic sampling", + ) + parser.add_argument( + "--n_iter", + type=int, + default=2, + help="sample this often", + ) + parser.add_argument( + "--H", + type=int, + default=512, + help="image height, in pixel space", + ) + parser.add_argument( + "--W", + type=int, + default=512, + help="image width, in pixel space", + ) + parser.add_argument( + "--C", + type=int, + default=4, + help="latent channels", + ) + parser.add_argument( + "--f", + type=int, + default=8, + help="downsampling factor", + ) + parser.add_argument( + "--n_samples", + type=int, + default=3, + help="how many samples to produce for each given prompt. A.k.a. batch size", + ) + parser.add_argument( + "--n_rows", + type=int, + default=0, + help="rows in the grid (default: n_samples)", + ) + parser.add_argument( + "--scale", + type=float, + default=7.5, + help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", + ) + parser.add_argument( + "--from-file", + type=str, + help="if specified, load prompts from this file", + ) + parser.add_argument( + "--config", + type=str, + default="configs/stable-diffusion/v1-inference.yaml", + help="path to config which constructs model", + ) + parser.add_argument( + "--ckpt", + type=str, + default="models/ldm/stable-diffusion-v1/model.ckpt", + help="path to checkpoint of model", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="the seed (for reproducible sampling)", + ) + parser.add_argument( + "--precision", + type=str, + help="evaluate at this precision", + choices=["full", "autocast"], + default="autocast" + ) + + + parser.add_argument( + "--embedding_path", + type=str, + help="Path to a pre-trained embedding manager checkpoint") + + opt = parser.parse_args() + + if opt.laion400m: + print("Falling back to LAION 400M model...") + opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" + opt.ckpt = "models/ldm/text2img-large/model.ckpt" + opt.outdir = "outputs/txt2img-samples-laion400m" + + seed_everything(opt.seed) + + config = OmegaConf.load(f"{opt.config}") + model = load_model_from_config(config, f"{opt.ckpt}") + model.embedding_manager.load(opt.embedding_path) + + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + + if opt.plms: + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + os.makedirs(opt.outdir, exist_ok=True) + outpath = opt.outdir + + batch_size = opt.n_samples + n_rows = opt.n_rows if opt.n_rows > 0 else batch_size + if not opt.from_file: + prompt = opt.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + else: + print(f"reading prompts from {opt.from_file}") + with open(opt.from_file, "r") as f: + data = f.read().splitlines() + data = list(chunk(data, batch_size)) + + sample_path = os.path.join(outpath, "samples") + os.makedirs(sample_path, exist_ok=True) + base_count = len(os.listdir(sample_path)) + grid_count = len(os.listdir(outpath)) - 1 + + start_code = None + if opt.fixed_code: + start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) + + precision_scope = autocast if opt.precision=="autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + all_samples = list() + for n in trange(opt.n_iter, desc="Sampling"): + for prompts in tqdm(data, desc="data"): + uc = None + if opt.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [opt.C, opt.H // opt.f, opt.W // opt.f] + samples_ddim, _ = sampler.sample(S=opt.ddim_steps, + conditioning=c, + batch_size=opt.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=opt.scale, + unconditional_conditioning=uc, + eta=opt.ddim_eta, + x_T=start_code) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + if not opt.skip_save: + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + Image.fromarray(x_sample.astype(np.uint8)).save( + os.path.join(sample_path, f"{base_count:05}.jpg")) + base_count += 1 + + if not opt.skip_grid: + all_samples.append(x_samples_ddim) + + if not opt.skip_grid: + # additionally, save as grid + grid = torch.stack(all_samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}-{grid_count:04}.jpg')) + grid_count += 1 + + toc = time.time() + + print(f"Your samples are ready and waiting for you here: \n{outpath} \n" + f" \nEnjoy.") + + +if __name__ == "__main__": + main() diff --git a/ldm/util.py b/ldm/util.py index e052c6b5ea..3affd249de 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -12,6 +12,7 @@ from queue import Queue from inspect import isfunction from PIL import Image, ImageDraw, ImageFont + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot @@ -20,7 +21,7 @@ def log_txt_as_img(wh, xc, size=10): for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) - font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + font = ImageFont.load_default() nc = int(40 * (wh[0] / 256)) lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) @@ -73,14 +74,14 @@ def count_params(model, verbose=False): return total_params -def instantiate_from_config(config): +def instantiate_from_config(config, **kwargs): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) + return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) def get_obj_from_str(string, reload=False): diff --git a/main.py b/main.py index e8e18c18fb..0c611cd743 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import argparse, os, sys, datetime, glob, importlib, csv import numpy as np import time import torch + import torchvision import pytorch_lightning as pl @@ -20,6 +21,22 @@ from pytorch_lightning.utilities import rank_zero_info from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.util import instantiate_from_config +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + sd = pl_sd["state_dict"] + config.model.params.ckpt_path = ckpt + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + return model def get_parser(**parser_kwargs): def str2bool(v): @@ -120,6 +137,23 @@ def get_parser(**parser_kwargs): default=True, help="scale base-lr by ngpu * batch_size * n_accumulate", ) + + parser.add_argument( + "--datadir_in_name", + type=str2bool, + nargs="?", + const=True, + default=True, + help="Prepend the final directory in the data_root to the output directory name") + + parser.add_argument("--actual_resume", type=str, default="", help="Path to model to actually resume from") + parser.add_argument("--data_root", type=str, required=True, help="Path to directory with training images") + + parser.add_argument("--embedding_manager_ckpt", type=str, default="", help="Initialize embedding manager from a checkpoint") + parser.add_argument("--placeholder_tokens", type=str, nargs="+", default=["*"]) + + parser.add_argument("--init_word", type=str, help="Word to use as source for initial token embedding.") + return parser @@ -502,6 +536,10 @@ if __name__ == "__main__": name = "_" + cfg_name else: name = "" + + if opt.datadir_in_name: + now = os.path.basename(os.path.normpath(opt.data_root)) + now + nowname = now + name + opt.postfix logdir = os.path.join(opt.logdir, nowname) @@ -532,7 +570,18 @@ if __name__ == "__main__": lightning_config.trainer = trainer_config # model - model = instantiate_from_config(config.model) + + # config.model.params.personalization_config.params.init_word = opt.init_word + config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt + config.model.params.personalization_config.params.placeholder_tokens = opt.placeholder_tokens + + if opt.init_word: + config.model.params.personalization_config.params.initializer_words[0] = opt.init_word + + if opt.actual_resume: + model = load_model_from_config(config, opt.actual_resume) + else: + model = instantiate_from_config(config.model) # trainer and callbacks trainer_kwargs = dict() @@ -578,7 +627,7 @@ if __name__ == "__main__": if hasattr(model, "monitor"): print(f"Monitoring {model.monitor} as checkpoint metric.") default_modelckpt_cfg["params"]["monitor"] = model.monitor - default_modelckpt_cfg["params"]["save_top_k"] = 3 + default_modelckpt_cfg["params"]["save_top_k"] = 1 if "modelcheckpoint" in lightning_config: modelckpt_cfg = lightning_config.modelcheckpoint @@ -655,11 +704,16 @@ if __name__ == "__main__": del callbacks_cfg['ignore_keys_callback'] trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + trainer_kwargs["max_steps"] = opt.max_steps trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer.logdir = logdir ### # data + config.data.params.train.params.data_root = opt.data_root + config.data.params.validation.params.data_root = opt.data_root + data = instantiate_from_config(config.data) + data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. @@ -710,8 +764,8 @@ if __name__ == "__main__": import signal - signal.signal(signal.SIGUSR1, melk) - signal.signal(signal.SIGUSR2, divein) + signal.signal(signal.SIGTERM, melk) + signal.signal(signal.SIGTERM, divein) # run if opt.train: @@ -737,5 +791,5 @@ if __name__ == "__main__": dst = os.path.join(dst, "debug_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) os.rename(logdir, dst) - if trainer.global_rank == 0: - print(trainer.profiler.summary()) + # if trainer.global_rank == 0: + # print(trainer.profiler.summary()) diff --git a/requirements.txt b/requirements.txt index 4c6dcdaab1..79f5fee300 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ kornia==0.6.0 numpy==1.19.2 omegaconf==2.1.1 opencv-python==4.1.2.30 +pillow==9.0.1 pudb==2019.2 pytorch pytorch-lightning==1.4.2 diff --git a/scripts/dream.py b/scripts/dream.py index fb8fec2384..3c97476877 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -57,7 +57,8 @@ def main(): weights=weights, full_precision=opt.full_precision, config=config, - latent_diffusion_weights=opt.laion400m # this is solely for recreating the prompt + latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt + embedding_path=opt.embedding_path ) # make sure the output directory exists @@ -268,6 +269,10 @@ def create_argv_parser(): type=str, default="outputs/img-samples", help="directory in which to place generated images and a log of prompts and seeds") + + parser.add_argument('--embedding_path', + type=str, + help="Path to a pre-trained embedding manager checkpoint - can only be set on command line") return parser diff --git a/scripts/merge_embeddings.py b/scripts/merge_embeddings.py new file mode 100644 index 0000000000..0367d50dcf --- /dev/null +++ b/scripts/merge_embeddings.py @@ -0,0 +1,83 @@ +from ldm.modules.encoders.modules import BERTTokenizer +from ldm.modules.embedding_manager import EmbeddingManager + +import argparse, os +from functools import partial + +import torch + +def get_placeholder_loop(placeholder_string, tokenizer): + + new_placeholder = None + + while True: + if new_placeholder is None: + new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ") + else: + new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ") + + token = tokenizer(new_placeholder) + + if torch.count_nonzero(token) == 3: + return new_placeholder, token[0, 1] + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manager_ckpts", + type=str, + nargs="+", + required=True, + help="Paths to a set of embedding managers to be merged." + ) + + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Output path for the merged manager", + ) + + args = parser.parse_args() + + tokenizer = BERTTokenizer(vq_interface=False, max_length=77) + EmbeddingManager = partial(EmbeddingManager, tokenizer, ["*"]) + + string_to_token_dict = {} + string_to_param_dict = torch.nn.ParameterDict() + + placeholder_to_src = {} + + for manager_ckpt in args.manager_ckpts: + print(f"Parsing {manager_ckpt}...") + + manager = EmbeddingManager() + manager.load(manager_ckpt) + + for placeholder_string in manager.string_to_token_dict: + if not placeholder_string in string_to_token_dict: + string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string] + string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string] + + placeholder_to_src[placeholder_string] = manager_ckpt + else: + new_placeholder, new_token = get_placeholder_loop(placeholder_string, tokenizer) + string_to_token_dict[new_placeholder] = new_token + string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string] + + placeholder_to_src[new_placeholder] = manager_ckpt + + print("Saving combined manager...") + merged_manager = EmbeddingManager() + merged_manager.string_to_param_dict = string_to_param_dict + merged_manager.string_to_token_dict = string_to_token_dict + merged_manager.save(args.output_path) + + print("Managers merged. Final list of placeholders: ") + print(placeholder_to_src) + + + + diff --git a/train.ps1 b/train.ps1 new file mode 100644 index 0000000000..9753381831 --- /dev/null +++ b/train.ps1 @@ -0,0 +1,17 @@ +conda activate ldm +pip install -e . +$env:PL_TORCH_DISTRIBUTED_BACKEND="gloo" +python ./main.py --base ./configs/stable-diffusion/v1-finetune.yaml ` + -t ` + --actual_resume ./models/ldm/stable-diffusion-v1/model.ckpt ` + -n my_cat ` + --gpus 0, ` + --data_root D:/textual-inversion/my_cat ` + --init_Word 'Isla Fisher' +# python ./scripts/train_personalization.py --base ./configs/stable-diffusion/v1-finetune.yaml ` +# -t ` +# --actual_resume ../stable-diffusion-dream/models/ldm/stable-diffusion-v1/model.ckpt ` +# --gpus 1 ` +# --data_root D:/textual-inversion/isla_fisher ` +# --resume 'logs/my_cat2022-08-23T01-46-37_my_cat' ` +# --init_Word 'cat'