From 7e8f364d8db4af04c745c3f324090d98712dd5bd Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 16 Jan 2023 19:32:06 -0500 Subject: [PATCH 1/5] do not use autocast for diffusers - All tensors in diffusers code path are now set explicitly to float32 or float16, depending on the --precision flag. - autocast is still used in the ckpt path, since it is being deprecated. --- ldm/generate.py | 3 ++- ldm/invoke/ckpt_generator/base.py | 3 ++- ldm/invoke/ckpt_generator/txt2img.py | 18 +++++++++------- ldm/invoke/conditioning.py | 3 +-- ldm/invoke/devices.py | 9 ++++++++ ldm/invoke/generator/base.py | 6 +++++- ldm/invoke/generator/txt2img.py | 21 ++++++++++--------- ldm/invoke/globals.py | 3 +++ ldm/invoke/model_manager.py | 2 +- ldm/modules/prompt_to_embeddings_converter.py | 4 ++-- ldm/modules/textual_inversion_manager.py | 1 - 11 files changed, 46 insertions(+), 27 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 7670448a93..63eaf79b50 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -29,7 +29,7 @@ from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.generator.inpaint import infill_methods -from ldm.invoke.globals import global_cache_dir +from ldm.invoke.globals import global_cache_dir, Globals from ldm.invoke.image_util import InitImageResizer from ldm.invoke.model_manager import ModelManager from ldm.invoke.pngwriter import PngWriter @@ -201,6 +201,7 @@ class Generate: self.precision = 'float32' if self.precision == 'auto': self.precision = choose_precision(self.device) + Globals.full_precision = self.precision=='float32' # model caching system for fast switching self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models) diff --git a/ldm/invoke/ckpt_generator/base.py b/ldm/invoke/ckpt_generator/base.py index c73bb50c9b..c84550a6e3 100644 --- a/ldm/invoke/ckpt_generator/base.py +++ b/ldm/invoke/ckpt_generator/base.py @@ -335,4 +335,5 @@ class CkptGenerator(): os.makedirs(dirname, exist_ok=True) image.save(filepath,'PNG') - + def torch_dtype(self)->torch.dtype: + return torch.float16 if self.precision == 'float16' else torch.float32 diff --git a/ldm/invoke/ckpt_generator/txt2img.py b/ldm/invoke/ckpt_generator/txt2img.py index 48b83be2ed..825b8583b9 100644 --- a/ldm/invoke/ckpt_generator/txt2img.py +++ b/ldm/invoke/ckpt_generator/txt2img.py @@ -72,16 +72,18 @@ class CkptTxt2Img(CkptGenerator): device = self.model.device if self.use_mps_noise or device.type == 'mps': x = torch.randn([1, - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - device='cpu').to(device) + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + dtype=self.torch_dtype(), + device='cpu').to(device) else: x = torch.randn([1, - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - device=device) + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + dtype=self.torch_dtype(), + device=device) if self.perlin > 0.0: x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) return x diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fec3c7e7b1..2a687d4c51 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -17,7 +17,7 @@ from ..models.diffusion import cross_attention_control from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter - +from ldm.invoke.devices import torch_dtype def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False): @@ -238,7 +238,6 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm if log_tokens: text = " ".join(fragments) log_tokenization(text, model, display_label=log_display_label) - return embeddings, tokens diff --git a/ldm/invoke/devices.py b/ldm/invoke/devices.py index 0fc749c4a4..94ddb74b24 100644 --- a/ldm/invoke/devices.py +++ b/ldm/invoke/devices.py @@ -21,10 +21,19 @@ def choose_precision(device) -> str: return 'float16' return 'float32' +def torch_dtype(device) -> torch.dtype: + if Globals.full_precision: + return torch.float32 + if choose_precision(device) == 'float16': + return torch.float16 + else: + return torch.float32 + def choose_autocast(precision): '''Returns an autocast context or nullcontext for the given precision string''' # float16 currently requires autocast to avoid errors like: # 'expected scalar type Half but found Float' + print(f'DEBUG: choose_autocast() called') if precision == 'autocast' or precision == 'float16': return autocast return nullcontext diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 68c5ccdeff..bac7bbb333 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -22,6 +22,7 @@ from ldm.invoke.devices import choose_autocast from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver from ldm.models.diffusion.ddpm import DiffusionWrapper from ldm.util import rand_perlin_2d +from contextlib import nullcontext downsampling = 8 CAUTION_IMG = 'assets/caution.png' @@ -64,7 +65,7 @@ class Generator: image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, safety_checker:dict=None, **kwargs): - scope = choose_autocast(self.precision) + scope = nullcontext self.safety_checker = safety_checker attention_maps_images = [] attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) @@ -341,3 +342,6 @@ class Generator: image.save(filepath,'PNG') + def torch_dtype(self)->torch.dtype: + return torch.float16 if self.precision == 'float16' else torch.float32 + diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 174c1e469d..77b16a734e 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -36,10 +36,9 @@ class Txt2Img(Generator): threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) - def make_image(x_T) -> PIL.Image.Image: pipeline_output = pipeline.image_from_embeddings( - latents=torch.zeros_like(x_T), + latents=torch.zeros_like(x_T,dtype=self.torch_dtype()), noise=x_T, num_inference_steps=steps, conditioning_data=conditioning_data, @@ -59,16 +58,18 @@ class Txt2Img(Generator): input_channels = min(self.latent_channels, 4) if self.use_mps_noise or device.type == 'mps': x = torch.randn([1, - input_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - device='cpu').to(device) + input_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + dtype=self.torch_dtype(), + device='cpu').to(device) else: x = torch.randn([1, - input_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - device=device) + input_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + dtype=self.torch_dtype(), + device=device) if self.perlin > 0.0: x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) return x diff --git a/ldm/invoke/globals.py b/ldm/invoke/globals.py index c67dbac145..897bf5e204 100644 --- a/ldm/invoke/globals.py +++ b/ldm/invoke/globals.py @@ -43,6 +43,9 @@ Globals.always_use_cpu = False # The CLI will test connectivity at startup time. Globals.internet_available = True +# whether we are forcing full precision +Globals.full_precision = False + def global_config_dir()->Path: return Path(Globals.root, Globals.config_dir) diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index 880d75476f..bc19ba1449 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -761,7 +761,7 @@ class ModelManager(object): for model in legacy_locations: source = models_dir /model if source.exists(): - print(f'DEBUG: Moving {models_dir / model} into hub') + print(f'** Moving {models_dir / model} into hub') move(models_dir / model, hub) # anything else gets moved into the diffusers directory diff --git a/ldm/modules/prompt_to_embeddings_converter.py b/ldm/modules/prompt_to_embeddings_converter.py index b52577c83c..ab989e4892 100644 --- a/ldm/modules/prompt_to_embeddings_converter.py +++ b/ldm/modules/prompt_to_embeddings_converter.py @@ -4,7 +4,7 @@ import torch from transformers import CLIPTokenizer, CLIPTextModel from ldm.modules.textual_inversion_manager import TextualInversionManager - +from ldm.invoke.devices import torch_dtype class WeightedPromptFragmentsToEmbeddingsConverter(): @@ -207,7 +207,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter(): per_token_weights += [1.0] * pad_length all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device) - per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32, device=device) + per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device) #print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}") return all_token_ids_tensor, per_token_weights_tensor diff --git a/ldm/modules/textual_inversion_manager.py b/ldm/modules/textual_inversion_manager.py index 471a8ee07b..f7ced79a52 100644 --- a/ldm/modules/textual_inversion_manager.py +++ b/ldm/modules/textual_inversion_manager.py @@ -111,7 +111,6 @@ class TextualInversionManager(): if ti.trigger_token_id is not None: raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'") - print(f'DEBUG: Injecting token {ti.trigger_string}') trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0]) if ti.embedding_vector_length > 1: From ce00c9856fdf077f115d2468e9c11da907a4a3a2 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 16 Jan 2023 22:50:13 -0500 Subject: [PATCH 2/5] fix perlin noise and txt2img2img --- ldm/invoke/generator/base.py | 3 ++- ldm/invoke/generator/txt2img2img.py | 10 ++++++---- ldm/invoke/model_manager.py | 2 +- ldm/util.py | 4 +++- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index bac7bbb333..25cd281cfe 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -237,7 +237,8 @@ class Generator: def get_perlin_noise(self,width,height): fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device - return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) + noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) + return noise def new_seed(self): self.seed = random.randrange(0, np.iinfo(np.uint32).max) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index e356f719c4..1dba0cfafb 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -90,9 +90,9 @@ class Txt2Img2Img(Generator): def get_noise_like(self, like: torch.Tensor): device = like.device if device.type == 'mps': - x = torch.randn_like(like, device='cpu').to(device) + x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device) else: - x = torch.randn_like(like, device=device) + x = torch.randn_like(like, device=device, dtype=self.torch_dtype()) if self.perlin > 0.0: shape = like.shape x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) @@ -117,10 +117,12 @@ class Txt2Img2Img(Generator): self.latent_channels, scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], - device='cpu').to(device) + dtype=self.torch_dtype(), + device='cpu').to(device) else: return torch.randn([1, self.latent_channels, scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], - device=device) + dtype=self.torch_dtype(), + device=device) diff --git a/ldm/invoke/model_manager.py b/ldm/invoke/model_manager.py index bc19ba1449..f74706aaef 100644 --- a/ldm/invoke/model_manager.py +++ b/ldm/invoke/model_manager.py @@ -349,7 +349,7 @@ class ModelManager(object): if self.precision == 'float16': print(' | Using faster float16 precision') - model.to(torch.float16) + model = model.to(torch.float16) else: print(' | Using more accurate float32 precision') diff --git a/ldm/util.py b/ldm/util.py index 282a56c3e5..7d44dcd266 100644 --- a/ldm/util.py +++ b/ldm/util.py @@ -8,6 +8,7 @@ from threading import Thread from urllib import request from tqdm import tqdm from pathlib import Path +from ldm.invoke.devices import torch_dtype import numpy as np import torch @@ -235,7 +236,8 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t* n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device) n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device) t = fade(grid[:shape[0], :shape[1]]) - return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device) + noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device) + return noise.to(dtype=torch_dtype(device)) def ask_user(question: str, answers: list): from itertools import chain, repeat From ce1c5e70b8f260955592b837db78332db0cf70f5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 16 Jan 2023 23:18:43 -0500 Subject: [PATCH 3/5] fix autocast dependency in cross_attention_control --- ldm/models/diffusion/cross_attention_control.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 7415f1435b..03d5a5bcec 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -7,6 +7,7 @@ import torch import diffusers from torch import nn from diffusers.models.unet_2d_condition import UNet2DConditionModel +from ldm.invoke.devices import torch_dtype # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -383,7 +384,7 @@ def inject_attention_function(unet, context: Context): remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) this_attention_slice = suggested_attention_slice - mask = context.cross_attention_mask + mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device)) saved_mask = mask this_mask = 1 - mask attention_slice = remapped_saved_attention_slice * saved_mask + \ From 3c919f0337d38e72c8e007a33458dbb070e4e697 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Tue, 17 Jan 2023 11:37:14 -0800 Subject: [PATCH 4/5] Restore ldm/invoke/conditioning.py --- ldm/invoke/conditioning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 2a687d4c51..fec3c7e7b1 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -17,7 +17,7 @@ from ..models.diffusion import cross_attention_control from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter -from ldm.invoke.devices import torch_dtype + def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False): @@ -238,6 +238,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm if log_tokens: text = " ".join(fragments) log_tokenization(text, model, display_label=log_display_label) + return embeddings, tokens From 5aec48735e0bb2fbb19f4b69469bce3c46035e28 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Tue, 17 Jan 2023 11:44:45 -0800 Subject: [PATCH 5/5] =?UTF-8?q?lint(generator):=20=F0=9F=9A=AE=20remove=20?= =?UTF-8?q?unused=20imports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ldm/invoke/generator/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 25cd281cfe..3fd34765c6 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -8,6 +8,7 @@ import os import os.path as osp import random import traceback +from contextlib import nullcontext import cv2 import numpy as np @@ -18,11 +19,8 @@ from einops import rearrange from pytorch_lightning import seed_everything from tqdm import trange -from ldm.invoke.devices import choose_autocast -from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver from ldm.models.diffusion.ddpm import DiffusionWrapper from ldm.util import rand_perlin_2d -from contextlib import nullcontext downsampling = 8 CAUTION_IMG = 'assets/caution.png'