diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 41e57ec63e..e94daf70bd 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -657,155 +657,155 @@ class DenoiseLatentsInvocation(BaseInvocation): return 1 - mask, masked_latents, self.denoise_mask.gradient @torch.no_grad() + @SilenceWarnings() # This quenches the NSFW nag from diffusers. def invoke(self, context: InvocationContext) -> LatentsOutput: - with SilenceWarnings(): # this quenches NSFW nag from diffusers - seed = None - noise = None - if self.noise is not None: - noise = context.tensors.load(self.noise.latents_name) - seed = self.noise.seed - - if self.latents is not None: - latents = context.tensors.load(self.latents.latents_name) - if seed is None: - seed = self.latents.seed - - if noise is not None and noise.shape[1:] != latents.shape[1:]: - raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") - - elif noise is not None: - latents = torch.zeros_like(noise) - else: - raise Exception("'latents' or 'noise' must be provided!") + seed = None + noise = None + if self.noise is not None: + noise = context.tensors.load(self.noise.latents_name) + seed = self.noise.seed + if self.latents is not None: + latents = context.tensors.load(self.latents.latents_name) if seed is None: - seed = 0 + seed = self.latents.seed - mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) + if noise is not None and noise.shape[1:] != latents.shape[1:]: + raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}") - # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, - # below. Investigate whether this is appropriate. - t2i_adapter_data = self.run_t2i_adapters( - context, - self.t2i_adapter, - latents.shape, - do_classifier_free_guidance=True, + elif noise is not None: + latents = torch.zeros_like(noise) + else: + raise Exception("'latents' or 'noise' must be provided!") + + if seed is None: + seed = 0 + + mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) + + # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, + # below. Investigate whether this is appropriate. + t2i_adapter_data = self.run_t2i_adapters( + context, + self.t2i_adapter, + latents.shape, + do_classifier_free_guidance=True, + ) + + ip_adapters: List[IPAdapterField] = [] + if self.ip_adapter is not None: + # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. + if isinstance(self.ip_adapter, list): + ip_adapters = self.ip_adapter + else: + ip_adapters = [self.ip_adapter] + + # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return + # a series of image conditioning embeddings. This is being done here rather than in the + # big model context below in order to use less VRAM on low-VRAM systems. + # The image prompts are then passed to prep_ip_adapter_data(). + image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters) + + # get the unet's config so that we can pass the base to dispatch_progress() + unet_config = context.models.get_config(self.unet.unet.key) + + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) + + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + for lora in self.unet.loras: + lora_info = context.models.load(lora.lora) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) + del lora_info + return + + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + ExitStack() as exit_stack, + unet_info.model_on_device() as (model_state_dict, unet), + ModelPatcher.apply_freeu(unet, self.unet.freeu_config), + set_seamless(unet, self.unet.seamless_axes), # FIXME + # Apply the LoRA after unet has been moved to its target device for faster patching. + ModelPatcher.apply_lora_unet( + unet, + loras=_lora_loader(), + model_state_dict=model_state_dict, + ), + ): + assert isinstance(unet, UNet2DConditionModel) + latents = latents.to(device=unet.device, dtype=unet.dtype) + if noise is not None: + noise = noise.to(device=unet.device, dtype=unet.dtype) + if mask is not None: + mask = mask.to(device=unet.device, dtype=unet.dtype) + if masked_latents is not None: + masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) + + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + seed=seed, ) - ip_adapters: List[IPAdapterField] = [] - if self.ip_adapter is not None: - # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. - if isinstance(self.ip_adapter, list): - ip_adapters = self.ip_adapter - else: - ip_adapters = [self.ip_adapter] + pipeline = self.create_pipeline(unet, scheduler) - # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return - # a series of image conditioning embeddings. This is being done here rather than in the - # big model context below in order to use less VRAM on low-VRAM systems. - # The image prompts are then passed to prep_ip_adapter_data(). - image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters) + _, _, latent_height, latent_width = latents.shape + conditioning_data = self.get_conditioning_data( + context=context, unet=unet, latent_height=latent_height, latent_width=latent_width + ) - # get the unet's config so that we can pass the base to dispatch_progress() - unet_config = context.models.get_config(self.unet.unet.key) + controlnet_data = self.prep_control_data( + context=context, + control_input=self.control, + latents_shape=latents.shape, + # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + do_classifier_free_guidance=True, + exit_stack=exit_stack, + ) - def step_callback(state: PipelineIntermediateState) -> None: - context.util.sd_step_callback(state, unet_config.base) + ip_adapter_data = self.prep_ip_adapter_data( + context=context, + ip_adapters=ip_adapters, + image_prompts=image_prompts, + exit_stack=exit_stack, + latent_height=latent_height, + latent_width=latent_width, + dtype=unet.dtype, + ) - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: - for lora in self.unet.loras: - lora_info = context.models.load(lora.lora) - assert isinstance(lora_info.model, LoRAModelRaw) - yield (lora_info.model, lora.weight) - del lora_info - return + num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( + scheduler, + device=unet.device, + steps=self.steps, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, + seed=seed, + ) - unet_info = context.models.load(self.unet.unet) - assert isinstance(unet_info.model, UNet2DConditionModel) - with ( - ExitStack() as exit_stack, - unet_info.model_on_device() as (model_state_dict, unet), - ModelPatcher.apply_freeu(unet, self.unet.freeu_config), - set_seamless(unet, self.unet.seamless_axes), # FIXME - # Apply the LoRA after unet has been moved to its target device for faster patching. - ModelPatcher.apply_lora_unet( - unet, - loras=_lora_loader(), - model_state_dict=model_state_dict, - ), - ): - assert isinstance(unet, UNet2DConditionModel) - latents = latents.to(device=unet.device, dtype=unet.dtype) - if noise is not None: - noise = noise.to(device=unet.device, dtype=unet.dtype) - if mask is not None: - mask = mask.to(device=unet.device, dtype=unet.dtype) - if masked_latents is not None: - masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) + result_latents = pipeline.latents_from_embeddings( + latents=latents, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise, + seed=seed, + mask=mask, + masked_latents=masked_latents, + gradient_mask=gradient_mask, + num_inference_steps=num_inference_steps, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + control_data=controlnet_data, + ip_adapter_data=ip_adapter_data, + t2i_adapter_data=t2i_adapter_data, + callback=step_callback, + ) - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, - seed=seed, - ) + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") + TorchDevice.empty_cache() - pipeline = self.create_pipeline(unet, scheduler) - - _, _, latent_height, latent_width = latents.shape - conditioning_data = self.get_conditioning_data( - context=context, unet=unet, latent_height=latent_height, latent_width=latent_width - ) - - controlnet_data = self.prep_control_data( - context=context, - control_input=self.control, - latents_shape=latents.shape, - # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) - do_classifier_free_guidance=True, - exit_stack=exit_stack, - ) - - ip_adapter_data = self.prep_ip_adapter_data( - context=context, - ip_adapters=ip_adapters, - image_prompts=image_prompts, - exit_stack=exit_stack, - latent_height=latent_height, - latent_width=latent_width, - dtype=unet.dtype, - ) - - num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( - scheduler, - device=unet.device, - steps=self.steps, - denoising_start=self.denoising_start, - denoising_end=self.denoising_end, - seed=seed, - ) - - result_latents = pipeline.latents_from_embeddings( - latents=latents, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise, - seed=seed, - mask=mask, - masked_latents=masked_latents, - gradient_mask=gradient_mask, - num_inference_steps=num_inference_steps, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, - control_data=controlnet_data, - ip_adapter_data=ip_adapter_data, - t2i_adapter_data=t2i_adapter_data, - callback=step_callback, - ) - - # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - result_latents = result_latents.to("cpu") - TorchDevice.empty_cache() - - name = context.tensors.save(tensor=result_latents) + name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)