diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 4a4700232c..5d6b66ccf0 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union, Callable import PIL.Image import torch -from diffusers import StableDiffusionPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess @@ -260,8 +260,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image.convert('RGB')) - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self._diffusers08_get_timesteps(num_inference_steps, strength) + img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components) + img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -292,17 +293,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): return self.scheduler.add_noise(init_latents, noise, timestep) - def _diffusers08_get_timesteps(self, num_inference_steps, strength): - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:] - - return timesteps - def check_for_safety(self, output, dtype): with torch.inference_mode(): screened_images, has_nsfw_concept = self.run_safety_checker(