diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 22fee10019..35db1db383 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -3,7 +3,7 @@ from __future__ import annotations import secrets import warnings from dataclasses import dataclass -from typing import List, Optional, Union, Callable +from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any, ParamSpec import PIL.Image import einops @@ -11,11 +11,11 @@ import torch import torchvision.transforms as T from diffusers.models import attention -from ldm.models.diffusion.cross_attention_control import InvokeAIDiffusersCrossAttention +from ...models.diffusion import cross_attention_control # monkeypatch diffusers CrossAttention 🙈 # this is to make prompt2prompt and (future) attention maps work -attention.CrossAttention = InvokeAIDiffusersCrossAttention +attention.CrossAttention = cross_attention_control.InvokeAIDiffusersCrossAttention from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -126,6 +126,10 @@ class AddsMaskGuidance: return masked_input +def trim_to_multiple_of(*args, multiple_of=8): + return tuple((x - x % multiple_of) for x in args) + + def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor: """ @@ -133,8 +137,7 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True :param normalize: scale the range to [-1, 1] instead of [0, 1] :param multiple_of: resize the input so both dimensions are a multiple of this """ - w, h = image.size - w, h = map(lambda x: x - x % multiple_of, (w, h)) # resize to integer multiple of 8 + w, h = trim_to_multiple_of(*image.size) transformation = T.Compose([ T.Resize((h, w), T.InterpolationMode.LANCZOS), T.ToTensor(), @@ -148,6 +151,26 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True def is_inpainting_model(unet: UNet2DConditionModel): return unet.conv_in.in_channels == 9 +CallbackType = TypeVar('CallbackType') +ReturnType = TypeVar('ReturnType') +ParamType = ParamSpec('ParamType') + +@dataclass(frozen=True) +class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): + generator_method: Callable[ParamType, ReturnType] + callback_arg_type: Type[CallbackType] + + def __call__(self, *args: ParamType.args, + callback:Callable[[CallbackType], Any]=None, + **kwargs: ParamType.kwargs) -> ReturnType: + result = None + for result in self.generator_method(*args, **kwargs): + if callback is not None and isinstance(result, self.callback_arg_type): + callback(result) + if result is None: + raise AssertionError("why was that an empty generator?") + return result + class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): r""" @@ -250,6 +273,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): raise AssertionError("why was that an empty generator?") return result + def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, + text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, + guidance_scale: float, + *, callback: Callable[[PipelineIntermediateState], None]=None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None, + run_id=None, + **extra_step_kwargs) -> PipelineIntermediateState: + self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) + f = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) + return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + run_id=run_id, + callback=callback, + **extra_step_kwargs) + def generate( self, prompt: Union[str, List[str]], @@ -303,19 +341,42 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timesteps = None, additional_guidance: List[Callable] = None, **extra_step_kwargs): + latents = yield from self.generate_latents_from_embeddings(latents, text_embeddings, unconditioned_embeddings, + guidance_scale, run_id=run_id, extra_conditioning_info=extra_conditioning_info, + timesteps=timesteps, additional_guidance=additional_guidance, **extra_step_kwargs) + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + with torch.inference_mode(): + image = self.decode_latents(latents) + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) + yield self.check_for_safety(output, dtype=text_embeddings.dtype) + + def generate_latents_from_embeddings( + self, + latents: torch.Tensor, + text_embeddings: torch.Tensor, + unconditioned_embeddings: torch.Tensor, + guidance_scale: float, + *, + run_id: str = None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, + timesteps = None, + additional_guidance: List[Callable] = None, + **extra_step_kwargs + ): if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) - if additional_guidance is None: additional_guidance = [] - if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count=len(self.scheduler.timesteps)) else: self.invokeai_diffuser.remove_cross_attention_control() - if timesteps is None: + # NOTE: Depends on scheduler being already initialized! timesteps = self.scheduler.timesteps # scale the initial noise by the standard deviation required by the scheduler @@ -326,7 +387,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): batch_size = latents.shape[0] batched_t = torch.full((batch_size,), timesteps[0], dtype=timesteps.dtype, device=self.unet.device) - # NOTE: Depends on scheduler being already initialized! + for i, t in enumerate(self.progress_bar(timesteps)): batched_t.fill_(t) step_output = self.step(batched_t, latents, guidance_scale, @@ -337,14 +398,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): predicted_original = getattr(step_output, 'pred_original_sample', None) yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, predicted_original=predicted_original) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - torch.cuda.empty_cache() - - with torch.inference_mode(): - image = self.decode_latents(latents) - output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) - yield self.check_for_safety(output, dtype=text_embeddings.dtype) + return latents @torch.inference_mode() def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float, @@ -396,34 +450,38 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): **extra_step_kwargs) -> StableDiffusionPipelineOutput: device = self.unet.device latents_dtype = self.unet.dtype - batch_size = 1 - num_images_per_prompt = 1 - if isinstance(init_image, PIL.Image.Image): init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) if init_image.dim() == 3: init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') + # 6. Prepare latent variables + initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype) + + result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings, + unconditioned_embeddings, guidance_scale, strength, + extra_conditioning_info, noise_func, run_id, callback, + **extra_step_kwargs) + return result + + def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings, + unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info, + noise_func, run_id=None, callback=None, **extra_step_kwargs): + device = self.unet.device + batch_size = initial_latents.size(0) img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latent_timestep = timesteps[:1].repeat(batch_size) + latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func) - # 6. Prepare latent variables - latents, _ = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) - - result = None - for result in self.generate_from_embeddings( - latents, text_embeddings, unconditioned_embeddings, guidance_scale, - extra_conditioning_info=extra_conditioning_info, - timesteps=timesteps, - run_id=run_id, **extra_step_kwargs): - if callback is not None and isinstance(result, PipelineIntermediateState): - callback(result) - if result is None: - raise AssertionError("why was that an empty generator?") - return result + f = GeneratorToCallbackinator(self.generate_from_embeddings, PipelineIntermediateState) + return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + timesteps=timesteps, + callback=callback, + run_id=run_id, **extra_step_kwargs) def inpaint_from_embeddings( self, @@ -459,7 +517,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): # 6. Prepare latent variables latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) + # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents + # because we have our own noise function + init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype) + latents = self.noise_latents_for_time(init_image_latents, latent_timestep, noise_func=noise_func) if mask.dim() == 3: mask = mask.unsqueeze(0) @@ -491,19 +552,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): finally: self.invokeai_diffuser.model_forward_callback = self._unet_forward - - def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> (torch.FloatTensor, torch.FloatTensor): - # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents - # because we have our own noise function + def non_noised_latents_from_image(self, init_image, *, device, dtype): init_image = init_image.to(device=device, dtype=dtype) with torch.inference_mode(): init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample().to(dtype=dtype) # FIXME: uses torch.randn. make reproducible! init_latents = 0.18215 * init_latents + return init_latents - noise = noise_func(init_latents) - noised_latents = self.scheduler.add_noise(init_latents, noise, timestep) - return noised_latents, init_latents + def noise_latents_for_time(self, latents, timestep, *, noise_func): + noise = noise_func(latents) + noised_latents = self.scheduler.add_noise(latents, noise, timestep) + return noised_latents def check_for_safety(self, output, dtype): with torch.inference_mode(): diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index f9af1ac3ed..36f5219b28 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -29,10 +29,6 @@ class Txt2Img(Generator): pipeline.scheduler = sampler def make_image(x_T) -> PIL.Image.Image: - # FIXME: restore free_gpu_mem functionality - # if self.free_gpu_mem and self.model.model.device != self.model.device: - # self.model.model.to(self.model.device) - pipeline_output = pipeline.image_from_embeddings( latents=x_T, num_inference_steps=steps, @@ -45,10 +41,6 @@ class Txt2Img(Generator): # TODO: threshold = threshold, ) - # FIXME: restore free_gpu_mem functionality - # if self.free_gpu_mem: - # self.model.model.to("cpu") - return pipeline.numpy_to_pil(pipeline_output.images)[0] return make_image diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 3da42ebb8a..35c6a39ca2 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -3,13 +3,12 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator ''' import math +from typing import Callable, Optional import torch -from PIL import Image from ldm.invoke.generator.base import Generator -from ldm.invoke.generator.omnibus import Omnibus -from ldm.models.diffusion.ddim import DDIMSampler +from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline class Txt2Img2Img(Generator): @@ -17,9 +16,9 @@ class Txt2Img2Img(Generator): super().__init__(model, precision) self.init_latent = None # for get_noise() - @torch.no_grad() - def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,width,height,strength,step_callback=None,**kwargs): + def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta, + conditioning, width:int, height:int, strength:float, + step_callback:Optional[Callable]=None, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it @@ -29,125 +28,72 @@ class Txt2Img2Img(Generator): scale_dim = min(width, height) scale = 512 / scale_dim - init_width = math.ceil(scale * width / 64) * 64 - init_height = math.ceil(scale * height / 64) * 64 + init_width, init_height = trim_to_multiple_of(scale * width, scale * height) + + # noinspection PyTypeChecker + pipeline: StableDiffusionGeneratorPipeline = self.model + pipeline.scheduler = sampler - @torch.no_grad() def make_image(x_T): - shape = [ - self.latent_channels, - init_height // self.downsampling_factor, - init_width // self.downsampling_factor, - ] - - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False + pipeline_output = pipeline.latents_from_embeddings( + latents=x_T, + num_inference_steps=steps, + text_embeddings=c, + unconditioned_embeddings=uc, + guidance_scale=cfg_scale, + callback=step_callback, + extra_conditioning_info=extra_conditioning_info, + # TODO: eta = ddim_eta, + # TODO: threshold = threshold, ) - #x = self.get_noise(init_width, init_height) - x = x_T - - if self.free_gpu_mem and self.model.model.device != self.model.device: - self.model.model.to(self.model.device) - - samples, _ = sampler.sample( - batch_size = 1, - S = steps, - x_T = x, - conditioning = c, - shape = shape, - verbose = False, - unconditional_guidance_scale = cfg_scale, - unconditional_conditioning = uc, - eta = ddim_eta, - img_callback = step_callback, - extra_conditioning_info = extra_conditioning_info - ) + first_pass_latent_output = pipeline_output.latents print( f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling" ) # resizing - samples = torch.nn.functional.interpolate( - samples, + resized_latents = torch.nn.functional.interpolate( + first_pass_latent_output, size=(height // self.downsampling_factor, width // self.downsampling_factor), mode="bilinear" ) - t_enc = int(strength * steps) - ddim_sampler = DDIMSampler(self.model, device=self.model.device) - ddim_sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - z_enc = ddim_sampler.stochastic_encode( - samples, - torch.tensor([t_enc]).to(self.model.device), - noise=self.get_noise(width,height,False) - ) - - # decode it - samples = ddim_sampler.decode( - z_enc, - c, - t_enc, - img_callback = step_callback, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, + pipeline_output = pipeline.img2img_from_latents_and_embeddings( + resized_latents, + num_inference_steps=steps, + text_embeddings=c, + unconditioned_embeddings=uc, + guidance_scale=cfg_scale, strength=strength, extra_conditioning_info=extra_conditioning_info, - all_timesteps_count=steps - ) + noise_func=self.get_noise_like, + callback=step_callback) - if self.free_gpu_mem: - self.model.model.to("cpu") + return pipeline.numpy_to_pil(pipeline_output.images)[0] - return self.sample_to_image(samples) + + # FIXME: do we really need something entirely different for the inpainting model? # in the case of the inpainting model being loaded, the trick of # providing an interpolated latent doesn't work, so we transiently # create a 512x512 PIL image, upscale it, and run the inpainting # over it in img2img mode. Because the inpaing model is so conservative # it doesn't change the image (much) - def inpaint_make_image(x_T): - omnibus = Omnibus(self.model,self.precision) - result = omnibus.generate( - prompt, - sampler=sampler, - width=init_width, - height=init_height, - step_callback=step_callback, - steps = steps, - cfg_scale = cfg_scale, - ddim_eta = ddim_eta, - conditioning = conditioning, - **kwargs - ) - assert result is not None and len(result)>0,'** txt2img failed **' - image = result[0][0] - interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS) - print(kwargs.pop('init_image',None)) - result = omnibus.generate( - prompt, - sampler=sampler, - init_image=interpolated_image, - width=width, - height=height, - seed=result[0][1], - step_callback=step_callback, - steps = steps, - cfg_scale = cfg_scale, - ddim_eta = ddim_eta, - conditioning = conditioning, - **kwargs - ) - return result[0][0] - - if sampler.uses_inpainting_model(): - return inpaint_make_image + + return make_image + + def get_noise_like(self, like: torch.Tensor): + device = like.device + if device.type == 'mps': + x = torch.randn_like(like, device='cpu').to(device) else: - return make_image + x = torch.randn_like(like, device=device) + if self.perlin > 0.0: + shape = like.shape + x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) + return x # returns a tensor filled with random numbers from a normal distribution def get_noise(self,width,height,scale = True): @@ -175,4 +121,3 @@ class Txt2Img2Img(Generator): scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], device=device) -