diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 1ef89913c5..a63159b118 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -17,10 +17,8 @@ import PIL.Image import einops import torch import torchvision.transforms as T -from diffusers.models import attention from diffusers.utils.import_utils import is_xformers_available -from ...models.diffusion import cross_attention_control from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter @@ -506,11 +504,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): strength, noise: torch.Tensor, run_id=None, callback=None ) -> InvokeAIStableDiffusionPipelineOutput: - device = self.unet.device - img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) - img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) - + timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device) result_latents, result_attention_maps = self.latents_from_embeddings( initial_latents, num_inference_steps, conditioning_data, timesteps=timesteps, @@ -526,6 +520,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps) return self.check_for_safety(output, dtype=conditioning_data.dtype) + def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device) -> (torch.Tensor, int): + img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) + assert img2img_pipeline.scheduler is self.scheduler + img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, adjusted_steps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) + # Workaround for low strength resulting in zero timesteps. + # TODO: submit upstream fix for zero-step img2img + if timesteps.numel() == 0: + timesteps = self.scheduler.timesteps[-1:] + adjusted_steps = timesteps.numel() + return timesteps, adjusted_steps + def inpaint_from_embeddings( self, init_image: torch.FloatTensor, @@ -549,11 +555,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if init_image.dim() == 3: init_image = init_image.unsqueeze(0) - img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) - img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, _ = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) - - assert img2img_pipeline.scheduler is self.scheduler + timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, device=device) # 6. Prepare latent variables # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents