diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 7fd101a3a0..41be7f7138 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -320,8 +320,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): else: c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_", zero_on_empty=True) - print(f"{c1.shape=} {c2.shape=} {c2_pooled.shape=} {self.prompt=}") - original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) target_size = (self.target_height, self.target_width) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a63f98de24..fef3bcbf6f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -122,6 +122,7 @@ class TextToLatentsInvocation(BaseInvocation): scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) unet: UNetField = Field(default=None, description="UNet submodel") control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") + denoising_end: float = Field(default=1.0, ge=0, le=1, description="") # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") # fmt: on @@ -310,6 +311,25 @@ class TextToLatentsInvocation(BaseInvocation): # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data + def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): + # apply denoising_start + num_inference_steps = steps + scheduler.set_timesteps(num_inference_steps, device=device) + + t_start = int(round(denoising_start * num_inference_steps)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + num_inference_steps = num_inference_steps - t_start + + # apply denoising_end + num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0) + + skipped_final_steps = int(round((1 - denoising_end) * steps)) + num_inference_steps = num_inference_steps - skipped_final_steps + timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps] + + return num_inference_steps, timesteps + + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: with SilenceWarnings(): @@ -359,12 +379,21 @@ class TextToLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, exit_stack=exit_stack, ) + + num_inference_steps, timesteps = self.init_scheduler( + scheduler, + device=unet.device, + steps=self.steps, + denoising_start=0.0, + denoising_end=self.denoising_end, + ) # TODO: Verify the noise is the right size result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), noise=noise, - num_inference_steps=self.steps, + timesteps=timesteps, + num_inference_steps=num_inference_steps, conditioning_data=conditioning_data, control_data=control_data, # list[ControlNetData] callback=step_callback, @@ -385,8 +414,12 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): type: Literal["l2l"] = "l2l" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") + noise: Optional[LatentsField] = Field(description="The noise to use (test override for future optional)") + + denoising_start: float = Field(default=0.0, ge=0, le=1, description="") + #denoising_end: float = Field(default=1.0, ge=0, le=1, description="") + + latents: Optional[LatentsField] = Field(description="The latents to use as a base image") # Schema customisation class Config(InvocationConfig): @@ -405,7 +438,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: with SilenceWarnings(): # this quenches NSFW nag from diffusers - noise = context.services.latents.get(self.noise.latents_name) + noise = None + if self.noise is not None: + noise = context.services.latents.get(self.noise.latents_name) latent = context.services.latents.get(self.latents.latents_name) # Get the source node id (we are invoking the prepared node) @@ -432,7 +467,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( unet_info.context.model, _lora_loader() ), unet_info as unet: - noise = noise.to(device=unet.device, dtype=unet.dtype) + if noise is not None: + noise = noise.to(device=unet.device, dtype=unet.dtype) latent = latent.to(device=unet.device, dtype=unet.dtype) scheduler = get_scheduler( @@ -448,28 +484,30 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): model=pipeline, context=context, control_input=self.control, - latents_shape=noise.shape, + latents_shape=latent.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, exit_stack=exit_stack, ) # TODO: Verify the noise is the right size - initial_latents = ( - latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype) - ) + initial_latents = latent + if self.denoising_start <= 0.0: + initial_latents = torch.zeros_like(latent, device=unet.device, dtype=latent.dtype) - timesteps, _ = pipeline.get_img2img_timesteps( - self.steps, - self.strength, + num_inference_steps, timesteps = self.init_scheduler( + scheduler, device=unet.device, + steps=self.steps, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, ) result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( latents=initial_latents, timesteps=timesteps, noise=noise, - num_inference_steps=self.steps, + num_inference_steps=num_inference_steps, conditioning_data=conditioning_data, control_data=control_data, # list[ControlNetData] callback=step_callback, diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 8a7616f1f1..ed1c8deeb5 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -340,33 +340,39 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if xformers is available, use it, otherwise use sliced attention. """ config = InvokeAIAppConfig.get_config() - if torch.cuda.is_available() and is_xformers_available() and not config.disable_xformers: - self.enable_xformers_memory_efficient_attention() + if self.unet.device.type == "cuda": + if is_xformers_available() and not config.disable_xformers: + self.enable_xformers_memory_efficient_attention() + return + elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): + # diffusers enable sdp automatically + return + + + if self.device.type == "cpu" or self.device.type == "mps": + mem_free = psutil.virtual_memory().free + elif self.device.type == "cuda": + mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device)) else: - if self.device.type == "cpu" or self.device.type == "mps": - mem_free = psutil.virtual_memory().free - elif self.device.type == "cuda": - mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.device)) - else: - raise ValueError(f"unrecognized device {self.device}") - # input tensor of [1, 4, h/8, w/8] - # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] - bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 - max_size_required_for_baddbmm = ( - 16 - * latents.size(dim=2) - * latents.size(dim=3) - * latents.size(dim=2) - * latents.size(dim=3) - * bytes_per_element_needed_for_baddbmm_duplication - ) - if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code - self.enable_attention_slicing(slice_size="max") - elif torch.backends.mps.is_available(): - # diffusers recommends always enabling for mps - self.enable_attention_slicing(slice_size="max") - else: - self.disable_attention_slicing() + raise ValueError(f"unrecognized device {self.device}") + # input tensor of [1, 4, h/8, w/8] + # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] + bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 + max_size_required_for_baddbmm = ( + 16 + * latents.size(dim=2) + * latents.size(dim=3) + * latents.size(dim=2) + * latents.size(dim=3) + * bytes_per_element_needed_for_baddbmm_duplication + ) + if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code + self.enable_attention_slicing(slice_size="max") + elif torch.backends.mps.is_available(): + # diffusers recommends always enabling for mps + self.enable_attention_slicing(slice_size="max") + else: + self.disable_attention_slicing() def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): # overridden method; types match the superclass. @@ -398,7 +404,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): num_inference_steps: int, conditioning_data: ConditioningData, *, - noise: torch.Tensor, + noise: Optional[torch.Tensor], timesteps=None, additional_guidance: List[Callable] = None, run_id=None, @@ -434,7 +440,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): timesteps, conditioning_data: ConditioningData, *, - noise: torch.Tensor, + noise: Optional[torch.Tensor], run_id: str = None, additional_guidance: List[Callable] = None, control_data: List[ControlNetData] = None, @@ -457,8 +463,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): dtype=timesteps.dtype, device=self._model_group.device_for(self.unet), ) - #latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers - latents = self.scheduler.add_noise(latents, noise, batched_t) + if noise is not None: + #latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers + latents = self.scheduler.add_noise(latents, noise, batched_t) yield PipelineIntermediateState( run_id=run_id,