diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..e05fa41e06 --- /dev/null +++ b/.gitignore @@ -0,0 +1,166 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +**/*.ckpt +src/ +logs/ +**/__pycache__/ +outputs \ No newline at end of file diff --git a/README.md b/README.md index c191ffc6c2..f0ddb978ff 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,64 @@ cat aspect of the image and 75% on the white duck aspect use any combination of integers and floating point numbers, and they do not need to add up to 1. +## Personalizing Text-to-Image Generation + +You may personalize the generated images to provide your own styles or objects by training a new LDM checkpoint +and introducing a new vocabulary to the fixed model. + +To train, prepare a folder that contains images sized at 512x512 and execute the following: + +~~~~ +# As the default backend is not available on Windows, if you're using that platform, execute SET PL_TORCH_DISTRIBUTED_BACKEND=gloo +(ldm) ~/stable-diffusion$ python3 ./main.py --base ./configs/stable-diffusion/v1-finetune.yaml \ + -t \ + --actual_resume ./models/ldm/stable-diffusion-v1/model.ckpt \ + -n my_cat \ + --gpus 0, \ + --data_root D:/textual-inversion/my_cat \ + --init_word 'cat' +~~~~ + +During the training process, files will be created in /logs/[project][time][project]/ +where you can see the process. + +conditioning* contains the training prompts +inputs, reconstruction the input images for the training epoch +samples, samples scaled for a sample of the prompt and one with the init word provided + +On a RTX3090, the process for SD will take ~1h @1.6 iterations/sec. + +Note: According to the associated paper, the optimal number of images is 3-5 any more images than that and your model might not converge. + +Training will run indefinately, but you may wish to stop it before the heat death of the universe, when you fine a low loss epoch or around ~5000 iterations. + +Once the model is trained, specify the trained .pt file when starting dream using + +~~~~ +(ldm) ~/stable-diffusion$ python3 ./scripts/dream.py --embedding_path /path/to/embedding.pt --full_precision +~~~~ + +Then, to utilize your subject at the dream prompt + +~~~ +dream> "a photo of *" +~~~ + +this also works with image2image +~~~~ +dream> "waterfall and rainbow in the style of *" --init_img=./init-images/crude_drawing.png --strength=0.5 -s100 -n4 +~~~~ + +It's also possible to train multiple tokens (modify the placeholder string in configs/stable-diffusion/v1-finetune.yaml) and combine LDM checkpoints using: + +~~~~ +(ldm) ~/stable-diffusion$ python3 ./scripts/merge_embeddings.py \ + --manager_ckpts /path/to/first/embedding.pt /path/to/second/embedding.pt [...] \ + --output_path /path/to/output/embedding.pt +~~~~ + +Credit goes to @rinongal and the repository located at https://github.com/rinongal/textual_inversion Please see the repository and associated paper for details and limitations. + ## Changes * v1.08 (24 August 2022) diff --git a/configs/stable-diffusion/v1-finetune.yaml b/configs/stable-diffusion/v1-finetune.yaml new file mode 100644 index 0000000000..da6a5e775b --- /dev/null +++ b/configs/stable-diffusion/v1-finetune.yaml @@ -0,0 +1,105 @@ +model: + base_learning_rate: 5.0e-03 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 64 + channels: 4 + cond_stage_trainable: true # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + embedding_reg_weight: 0.0 + + personalization_config: + target: ldm.modules.embedding_manager.EmbeddingManager + params: + placeholder_strings: ["*"] + initializer_words: ["sculpture"] + per_image_tokens: false + num_vectors_per_token: 1 + progressive_words: False + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + +data: + target: main.DataModuleFromConfig + params: + batch_size: 2 + num_workers: 16 + wrap: false + train: + target: ldm.data.personalized.PersonalizedBase + params: + size: 512 + set: train + per_image_tokens: false + repeats: 100 + validation: + target: ldm.data.personalized.PersonalizedBase + params: + size: 512 + set: val + per_image_tokens: false + repeats: 10 + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 500 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True + max_steps: 6100 \ No newline at end of file diff --git a/configs/stable-diffusion/v1-finetune_style.yaml b/configs/stable-diffusion/v1-finetune_style.yaml new file mode 100644 index 0000000000..1964d925e1 --- /dev/null +++ b/configs/stable-diffusion/v1-finetune_style.yaml @@ -0,0 +1,103 @@ +model: + base_learning_rate: 5.0e-03 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: image + cond_stage_key: caption + image_size: 64 + channels: 4 + cond_stage_trainable: true # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + embedding_reg_weight: 0.0 + + personalization_config: + target: ldm.modules.embedding_manager.EmbeddingManager + params: + placeholder_strings: ["*"] + initializer_words: ["painting"] + per_image_tokens: false + num_vectors_per_token: 1 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + +data: + target: main.DataModuleFromConfig + params: + batch_size: 2 + num_workers: 16 + wrap: false + train: + target: ldm.data.personalized_style.PersonalizedBase + params: + size: 512 + set: train + per_image_tokens: false + repeats: 100 + validation: + target: ldm.data.personalized_style.PersonalizedBase + params: + size: 512 + set: val + per_image_tokens: false + repeats: 10 + +lightning: + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 500 + max_images: 8 + increase_log_steps: False + + trainer: + benchmark: True \ No newline at end of file diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml index d4effe569e..59d8f33125 100644 --- a/configs/stable-diffusion/v1-inference.yaml +++ b/configs/stable-diffusion/v1-inference.yaml @@ -26,6 +26,15 @@ model: f_max: [ 1. ] f_min: [ 1. ] + personalization_config: + target: ldm.modules.embedding_manager.EmbeddingManager + params: + placeholder_strings: ["*"] + initializer_words: ["sculpture"] + per_image_tokens: false + num_vectors_per_token: 1 + progressive_words: False + unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: diff --git a/environment.yaml b/environment.yaml index 0de05e815a..83669e254c 100644 --- a/environment.yaml +++ b/environment.yaml @@ -19,6 +19,7 @@ dependencies: - omegaconf==2.1.1 - test-tube>=0.7.5 - streamlit>=0.73.1 + - pillow==9.0.1 - einops==0.3.0 - torch-fidelity==0.3.0 - transformers==4.19.2 diff --git a/ldm/data/personalized.py b/ldm/data/personalized.py new file mode 100644 index 0000000000..c8a57d09fa --- /dev/null +++ b/ldm/data/personalized.py @@ -0,0 +1,160 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_smallest = [ + 'a photo of a {}', +] + +imagenet_templates_small = [ + 'a photo of a {}', + 'a rendering of a {}', + 'a cropped photo of the {}', + 'the photo of a {}', + 'a photo of a clean {}', + 'a photo of a dirty {}', + 'a dark photo of the {}', + 'a photo of my {}', + 'a photo of the cool {}', + 'a close-up photo of a {}', + 'a bright photo of the {}', + 'a cropped photo of a {}', + 'a photo of the {}', + 'a good photo of the {}', + 'a photo of one {}', + 'a close-up photo of the {}', + 'a rendition of the {}', + 'a photo of the clean {}', + 'a rendition of a {}', + 'a photo of a nice {}', + 'a good photo of a {}', + 'a photo of the nice {}', + 'a photo of the small {}', + 'a photo of the weird {}', + 'a photo of the large {}', + 'a photo of a cool {}', + 'a photo of a small {}', +] + +imagenet_dual_templates_small = [ + 'a photo of a {} with {}', + 'a rendering of a {} with {}', + 'a cropped photo of the {} with {}', + 'the photo of a {} with {}', + 'a photo of a clean {} with {}', + 'a photo of a dirty {} with {}', + 'a dark photo of the {} with {}', + 'a photo of my {} with {}', + 'a photo of the cool {} with {}', + 'a close-up photo of a {} with {}', + 'a bright photo of the {} with {}', + 'a cropped photo of a {} with {}', + 'a photo of the {} with {}', + 'a good photo of the {} with {}', + 'a photo of one {} with {}', + 'a close-up photo of the {} with {}', + 'a rendition of the {} with {}', + 'a photo of the clean {} with {}', + 'a rendition of a {} with {}', + 'a photo of a nice {} with {}', + 'a good photo of a {} with {}', + 'a photo of the nice {} with {}', + 'a photo of the small {} with {}', + 'a photo of the weird {} with {}', + 'a photo of the large {} with {}', + 'a photo of a cool {} with {}', + 'a photo of a small {} with {}', +] + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +class PersonalizedBase(Dataset): + def __init__(self, + data_root, + size=None, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + per_image_tokens=False, + center_crop=False, + mixing_prob=0.25, + coarse_class_text=None, + ): + + self.data_root = data_root + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + self.mixing_prob = mixing_prob + + self.coarse_class_text = coarse_class_text + + if per_image_tokens: + assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == "train": + self._length = self.num_images * repeats + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + placeholder_string = self.placeholder_token + if self.coarse_class_text: + placeholder_string = f"{self.coarse_class_text} {placeholder_string}" + + if self.per_image_tokens and np.random.uniform() < self.mixing_prob: + text = random.choice(imagenet_dual_templates_small).format(placeholder_string, per_img_token_list[i % self.num_images]) + else: + text = random.choice(imagenet_templates_small).format(placeholder_string) + + example["caption"] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example \ No newline at end of file diff --git a/ldm/data/personalized_style.py b/ldm/data/personalized_style.py new file mode 100644 index 0000000000..b6be7b15c4 --- /dev/null +++ b/ldm/data/personalized_style.py @@ -0,0 +1,129 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_small = [ + 'a painting in the style of {}', + 'a rendering in the style of {}', + 'a cropped painting in the style of {}', + 'the painting in the style of {}', + 'a clean painting in the style of {}', + 'a dirty painting in the style of {}', + 'a dark painting in the style of {}', + 'a picture in the style of {}', + 'a cool painting in the style of {}', + 'a close-up painting in the style of {}', + 'a bright painting in the style of {}', + 'a cropped painting in the style of {}', + 'a good painting in the style of {}', + 'a close-up painting in the style of {}', + 'a rendition in the style of {}', + 'a nice painting in the style of {}', + 'a small painting in the style of {}', + 'a weird painting in the style of {}', + 'a large painting in the style of {}', +] + +imagenet_dual_templates_small = [ + 'a painting in the style of {} with {}', + 'a rendering in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'the painting in the style of {} with {}', + 'a clean painting in the style of {} with {}', + 'a dirty painting in the style of {} with {}', + 'a dark painting in the style of {} with {}', + 'a cool painting in the style of {} with {}', + 'a close-up painting in the style of {} with {}', + 'a bright painting in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'a good painting in the style of {} with {}', + 'a painting of one {} in the style of {}', + 'a nice painting in the style of {} with {}', + 'a small painting in the style of {} with {}', + 'a weird painting in the style of {} with {}', + 'a large painting in the style of {} with {}', +] + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +class PersonalizedBase(Dataset): + def __init__(self, + data_root, + size=None, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + per_image_tokens=False, + center_crop=False, + ): + + self.data_root = data_root + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + + if per_image_tokens: + assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == "train": + self._length = self.num_images * repeats + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + if self.per_image_tokens and np.random.uniform() < 0.25: + text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images]) + else: + text = random.choice(imagenet_templates_small).format(self.placeholder_token) + + example["caption"] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 8bc93258a4..d5f74a0fbe 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -7,7 +7,9 @@ https://github.com/CompVis/taming-transformers """ import torch + import torch.nn as nn +import os import numpy as np import pytorch_lightning as pl from torch.optim.lr_scheduler import LambdaLR @@ -64,6 +66,7 @@ class DDPM(pl.LightningModule): cosine_s=8e-3, given_betas=None, original_elbo_weight=0., + embedding_reg_weight=0., v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta l_simple_weight=1., conditioning_key=None, @@ -98,6 +101,7 @@ class DDPM(pl.LightningModule): self.v_posterior = v_posterior self.original_elbo_weight = original_elbo_weight self.l_simple_weight = l_simple_weight + self.embedding_reg_weight = embedding_reg_weight if monitor is not None: self.monitor = monitor @@ -427,6 +431,7 @@ class LatentDiffusion(DDPM): def __init__(self, first_stage_config, cond_stage_config, + personalization_config, num_timesteps_cond=None, cond_stage_key="image", cond_stage_trainable=False, @@ -436,6 +441,7 @@ class LatentDiffusion(DDPM): scale_factor=1.0, scale_by_std=False, *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std assert self.num_timesteps_cond <= kwargs['timesteps'] @@ -450,6 +456,7 @@ class LatentDiffusion(DDPM): self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key + try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 except: @@ -460,6 +467,7 @@ class LatentDiffusion(DDPM): self.register_buffer('scale_factor', torch.tensor(scale_factor)) self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None @@ -469,6 +477,25 @@ class LatentDiffusion(DDPM): self.init_from_ckpt(ckpt_path, ignore_keys) self.restarted_from_ckpt = True + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + + self.model.eval() + self.model.train = disabled_train + for param in self.model.parameters(): + param.requires_grad = False + + self.embedding_manager = self.instantiate_embedding_manager(personalization_config, self.cond_stage_model) + + self.emb_ckpt_counter = 0 + + # if self.embedding_manager.is_clip: + # self.cond_stage_model.update_embedding_func(self.embedding_manager) + + for param in self.embedding_manager.embedding_parameters(): + param.requires_grad = True + def make_cond_schedule(self, ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() @@ -530,6 +557,15 @@ class LatentDiffusion(DDPM): except urllib.error.URLError: raise SystemExit("* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine.") self.cond_stage_model = model + + + def instantiate_embedding_manager(self, config, embedder): + model = instantiate_from_config(config, embedder=embedder) + + if config.params.get("embedding_manager_ckpt", None): # do not load if missing OR empty string + model.load(config.params.embedding_manager_ckpt) + + return model def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): denoise_row = [] @@ -555,7 +591,7 @@ class LatentDiffusion(DDPM): def get_learned_conditioning(self, c): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): - c = self.cond_stage_model.encode(c) + c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() else: @@ -880,6 +916,7 @@ class LatentDiffusion(DDPM): if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset @@ -1046,6 +1083,14 @@ class LatentDiffusion(DDPM): loss += (self.original_elbo_weight * loss_vlb) loss_dict.update({f'{prefix}/loss': loss}) + if self.embedding_reg_weight > 0: + loss_embedding_reg = self.embedding_manager.embedding_to_coarse_loss().mean() + + loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg}) + + loss += (self.embedding_reg_weight * loss_embedding_reg) + loss_dict.update({f'{prefix}/loss': loss}) + return loss, loss_dict def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, @@ -1250,11 +1295,10 @@ class LatentDiffusion(DDPM): return samples, intermediates - @torch.no_grad() def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, **kwargs): + quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False, + plot_diffusion_rows=False, **kwargs): use_ddim = ddim_steps is not None @@ -1312,6 +1356,16 @@ class LatentDiffusion(DDPM): if plot_denoise_rows: denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid + + uc = self.get_learned_conditioning(len(c) * [""]) + sample_scaled, _ = self.sample_log(cond=c, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=5.0, + unconditional_conditioning=uc) + log["samples_scaled"] = self.decode_first_stage(sample_scaled) if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( self.first_stage_model, IdentityFirstStage): @@ -1364,13 +1418,18 @@ class LatentDiffusion(DDPM): def configure_optimizers(self): lr = self.learning_rate - params = list(self.model.parameters()) - if self.cond_stage_trainable: - print(f"{self.__class__.__name__}: Also optimizing conditioner params!") - params = params + list(self.cond_stage_model.parameters()) - if self.learn_logvar: - print('Diffusion model optimizing logvar') - params.append(self.logvar) + + if self.embedding_manager is not None: + params = list(self.embedding_manager.embedding_parameters()) + # params = list(self.cond_stage_model.transformer.text_model.embeddings.embedding_manager.embedding_parameters()) + else: + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) opt = torch.optim.AdamW(params, lr=lr) if self.use_scheduler: assert 'target' in self.scheduler_config @@ -1395,6 +1454,18 @@ class LatentDiffusion(DDPM): x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. return x + @rank_zero_only + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + + if os.path.isdir(self.trainer.checkpoint_callback.dirpath): + self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, "embeddings.pt")) + + if (self.global_step - self.emb_ckpt_counter) > 500: + self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, f"embeddings_gs-{self.global_step}.pt")) + + self.emb_ckpt_counter += 500 + class DiffusionWrapper(pl.LightningModule): def __init__(self, diff_model_config, conditioning_key): diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index a952e6c403..6b5b9dc9e2 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -109,7 +109,7 @@ def checkpoint(func, inputs, params, flag): explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ - if flag: + if False: # disabled checkpointing to allow requires_grad = False for main model args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py new file mode 100644 index 0000000000..7020a27b9a --- /dev/null +++ b/ldm/modules/embedding_manager.py @@ -0,0 +1,164 @@ +from cmath import log +import torch +from torch import nn + +import sys + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ["*"] + +PROGRESSIVE_SCALE = 2000 + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"] + assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" + + return tokens[0, 1] + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs + ): + super().__init__() + + self.string_to_token_dict = {} + + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = nn.ParameterDict() # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) + token_dim = 1280 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) + + token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) + self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) + else: + token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + + if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + + placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] + new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save({"string_to_token": self.string_to_token_dict, + "string_to_param": self.string_to_param_dict}, ckpt_path) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + + def get_embedding_norms_squared(self): + all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0. + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings + + return loss \ No newline at end of file diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 5c02d8c6f5..def6d2136d 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -8,6 +8,27 @@ import kornia from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +def _expand_mask(mask, dtype, tgt_len = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + +def _build_causal_attention_mask(bsz, seq_len, dtype): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) + mask.triu_(1) # zero out the lower diagonal + mask = mask.unsqueeze(1) # expand mask + return mask class AbstractEncoder(nn.Module): def __init__(self): @@ -98,18 +119,17 @@ class BERTEmbedder(AbstractEncoder): attn_layers=Encoder(dim=n_embed, depth=n_layer), emb_dropout=embedding_dropout) - def forward(self, text): + def forward(self, text, embedding_manager=None): if self.use_tknz_fn: tokens = self.tknz_fn(text)#.to(self.device) else: tokens = text - z = self.transformer(tokens, return_embeddings=True) + z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager) return z - def encode(self, text): + def encode(self, text, **kwargs): # output of length 77 - return self(text) - + return self(text, **kwargs) class SpatialRescaler(nn.Module): def __init__(self, @@ -152,22 +172,165 @@ class FrozenCLIPEmbedder(AbstractEncoder): self.max_length = max_length self.freeze() + def embedding_forward( + self, + input_ids = None, + position_ids = None, + inputs_embeds = None, + embedding_manager = None, + ) -> torch.Tensor: + + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + if embedding_manager is not None: + inputs_embeds = embedding_manager(input_ids, inputs_embeds) + + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings) + + def encoder_forward( + self, + inputs_embeds, + attention_mask = None, + causal_attention_mask = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + causal_attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + return hidden_states + + self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder) + + + def text_encoder_forward( + self, + input_ids = None, + attention_mask = None, + position_ids = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + embedding_manager = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify either input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager) + + bsz, seq_len = input_shape + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _expand_mask(attention_mask, hidden_states.dtype) + + last_hidden_state = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = self.final_layer_norm(last_hidden_state) + + return last_hidden_state + + self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model) + + def transformer_forward( + self, + input_ids = None, + attention_mask = None, + position_ids = None, + output_attentions = None, + output_hidden_states = None, + return_dict = None, + embedding_manager = None, + ): + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + embedding_manager = embedding_manager + ) + + self.transformer.forward = transformer_forward.__get__(self.transformer) + + def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False - def forward(self, text): + def forward(self, text, **kwargs): batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens) + z = self.transformer(input_ids=tokens, **kwargs) - z = outputs.last_hidden_state return z - def encode(self, text): - return self(text) + def encode(self, text, **kwargs): + return self(text, **kwargs) class FrozenCLIPTextEmbedder(nn.Module): diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py index 5fc15bf9cf..1316dbd505 100644 --- a/ldm/modules/x_transformer.py +++ b/ldm/modules/x_transformer.py @@ -485,7 +485,8 @@ class AttentionLayers(nn.Module): mask=None, context_mask=None, mems=None, - return_hiddens=False + return_hiddens=False, + **kwargs ): hiddens = [] intermediates = [] @@ -603,11 +604,19 @@ class TransformerWrapper(nn.Module): return_mems=False, return_attn=False, mems=None, + embedding_manager=None, **kwargs ): b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens - x = self.token_emb(x) - x += self.pos_emb(x) + + embedded_x = self.token_emb(x) + + if embedding_manager: + x = embedding_manager(x, embedded_x) + else: + x = embedded_x + + x = x + self.pos_emb(x) x = self.emb_dropout(x) x = self.project_emb(x) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index dc93ef4fe0..8782d621d3 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -450,7 +450,8 @@ The vast majority of these arguments default to reasonable values. config = OmegaConf.load(self.config) self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = self._load_model_from_config(config,self.weights) - model.embedding_manager.load(self.embedding_path) + if self.embedding_path is not None: + model.embedding_manager.load(self.embedding_path) self.model = model.to(self.device) except AttributeError: raise SystemExit diff --git a/ldm/util.py b/ldm/util.py index e052c6b5ea..3affd249de 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -12,6 +12,7 @@ from queue import Queue from inspect import isfunction from PIL import Image, ImageDraw, ImageFont + def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot @@ -20,7 +21,7 @@ def log_txt_as_img(wh, xc, size=10): for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) - font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + font = ImageFont.load_default() nc = int(40 * (wh[0] / 256)) lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) @@ -73,14 +74,14 @@ def count_params(model, verbose=False): return total_params -def instantiate_from_config(config): +def instantiate_from_config(config, **kwargs): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) + return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs) def get_obj_from_str(string, reload=False): diff --git a/main.py b/main.py index e8e18c18fb..5653cf5e06 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ import argparse, os, sys, datetime, glob, importlib, csv import numpy as np import time import torch + import torchvision import pytorch_lightning as pl @@ -20,6 +21,22 @@ from pytorch_lightning.utilities import rank_zero_info from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.util import instantiate_from_config +def load_model_from_config(config, ckpt, verbose=False): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + sd = pl_sd["state_dict"] + config.model.params.ckpt_path = ckpt + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + model.cuda() + return model def get_parser(**parser_kwargs): def str2bool(v): @@ -120,6 +137,23 @@ def get_parser(**parser_kwargs): default=True, help="scale base-lr by ngpu * batch_size * n_accumulate", ) + + parser.add_argument( + "--datadir_in_name", + type=str2bool, + nargs="?", + const=True, + default=True, + help="Prepend the final directory in the data_root to the output directory name") + + parser.add_argument("--actual_resume", type=str, default="", help="Path to model to actually resume from") + parser.add_argument("--data_root", type=str, required=True, help="Path to directory with training images") + + parser.add_argument("--embedding_manager_ckpt", type=str, default="", help="Initialize embedding manager from a checkpoint") + parser.add_argument("--placeholder_tokens", type=str, nargs="+", default=["*"]) + + parser.add_argument("--init_word", type=str, help="Word to use as source for initial token embedding.") + return parser @@ -502,6 +536,10 @@ if __name__ == "__main__": name = "_" + cfg_name else: name = "" + + if opt.datadir_in_name: + now = os.path.basename(os.path.normpath(opt.data_root)) + now + nowname = now + name + opt.postfix logdir = os.path.join(opt.logdir, nowname) @@ -532,7 +570,18 @@ if __name__ == "__main__": lightning_config.trainer = trainer_config # model - model = instantiate_from_config(config.model) + + # config.model.params.personalization_config.params.init_word = opt.init_word + config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt + config.model.params.personalization_config.params.placeholder_tokens = opt.placeholder_tokens + + if opt.init_word: + config.model.params.personalization_config.params.initializer_words[0] = opt.init_word + + if opt.actual_resume: + model = load_model_from_config(config, opt.actual_resume) + else: + model = instantiate_from_config(config.model) # trainer and callbacks trainer_kwargs = dict() @@ -655,11 +704,16 @@ if __name__ == "__main__": del callbacks_cfg['ignore_keys_callback'] trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] + trainer_kwargs["max_steps"] = opt.max_steps trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer.logdir = logdir ### # data + config.data.params.train.params.data_root = opt.data_root + config.data.params.validation.params.data_root = opt.data_root + data = instantiate_from_config(config.data) + data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. @@ -710,8 +764,8 @@ if __name__ == "__main__": import signal - signal.signal(signal.SIGUSR1, melk) - signal.signal(signal.SIGUSR2, divein) + signal.signal(signal.SIGTERM, melk) + signal.signal(signal.SIGTERM, divein) # run if opt.train: @@ -737,5 +791,5 @@ if __name__ == "__main__": dst = os.path.join(dst, "debug_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) os.rename(logdir, dst) - if trainer.global_rank == 0: - print(trainer.profiler.summary()) + # if trainer.global_rank == 0: + # print(trainer.profiler.summary()) diff --git a/requirements.txt b/requirements.txt index 4c6dcdaab1..79f5fee300 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ kornia==0.6.0 numpy==1.19.2 omegaconf==2.1.1 opencv-python==4.1.2.30 +pillow==9.0.1 pudb==2019.2 pytorch pytorch-lightning==1.4.2 diff --git a/scripts/dream.py b/scripts/dream.py index 24063d6d5b..3b79e5dd35 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -59,7 +59,8 @@ def main(): weights=weights, full_precision=opt.full_precision, config=config, - latent_diffusion_weights=opt.laion400m # this is solely for recreating the prompt + latent_diffusion_weights=opt.laion400m, # this is solely for recreating the prompt + embedding_path=opt.embedding_path ) # make sure the output directory exists @@ -279,6 +280,10 @@ def create_argv_parser(): type=str, default="outputs/img-samples", help="directory in which to place generated images and a log of prompts and seeds") + + parser.add_argument('--embedding_path', + type=str, + help="Path to a pre-trained embedding manager checkpoint - can only be set on command line") return parser diff --git a/scripts/merge_embeddings.py b/scripts/merge_embeddings.py new file mode 100644 index 0000000000..0367d50dcf --- /dev/null +++ b/scripts/merge_embeddings.py @@ -0,0 +1,83 @@ +from ldm.modules.encoders.modules import BERTTokenizer +from ldm.modules.embedding_manager import EmbeddingManager + +import argparse, os +from functools import partial + +import torch + +def get_placeholder_loop(placeholder_string, tokenizer): + + new_placeholder = None + + while True: + if new_placeholder is None: + new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ") + else: + new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ") + + token = tokenizer(new_placeholder) + + if torch.count_nonzero(token) == 3: + return new_placeholder, token[0, 1] + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manager_ckpts", + type=str, + nargs="+", + required=True, + help="Paths to a set of embedding managers to be merged." + ) + + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Output path for the merged manager", + ) + + args = parser.parse_args() + + tokenizer = BERTTokenizer(vq_interface=False, max_length=77) + EmbeddingManager = partial(EmbeddingManager, tokenizer, ["*"]) + + string_to_token_dict = {} + string_to_param_dict = torch.nn.ParameterDict() + + placeholder_to_src = {} + + for manager_ckpt in args.manager_ckpts: + print(f"Parsing {manager_ckpt}...") + + manager = EmbeddingManager() + manager.load(manager_ckpt) + + for placeholder_string in manager.string_to_token_dict: + if not placeholder_string in string_to_token_dict: + string_to_token_dict[placeholder_string] = manager.string_to_token_dict[placeholder_string] + string_to_param_dict[placeholder_string] = manager.string_to_param_dict[placeholder_string] + + placeholder_to_src[placeholder_string] = manager_ckpt + else: + new_placeholder, new_token = get_placeholder_loop(placeholder_string, tokenizer) + string_to_token_dict[new_placeholder] = new_token + string_to_param_dict[new_placeholder] = manager.string_to_param_dict[placeholder_string] + + placeholder_to_src[new_placeholder] = manager_ckpt + + print("Saving combined manager...") + merged_manager = EmbeddingManager() + merged_manager.string_to_param_dict = string_to_param_dict + merged_manager.string_to_token_dict = string_to_token_dict + merged_manager.save(args.output_path) + + print("Managers merged. Final list of placeholders: ") + print(placeholder_to_src) + + + +