mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor: remove backported img2img.get_timesteps
now that we can use it directly from diffusers 0.8.1
This commit is contained in:
parent
ceb53ccdfb
commit
b7864aa1a7
@ -5,7 +5,7 @@ from typing import List, Optional, Union, Callable
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
|
||||
@ -260,8 +260,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image.convert('RGB'))
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self._diffusers08_get_timesteps(num_inference_steps, strength)
|
||||
img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
|
||||
img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@ -292,17 +293,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
return self.scheduler.add_noise(init_latents, noise, timestep)
|
||||
|
||||
def _diffusers08_get_timesteps(self, num_inference_steps, strength):
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
return timesteps
|
||||
|
||||
def check_for_safety(self, output, dtype):
|
||||
with torch.inference_mode():
|
||||
screened_images, has_nsfw_concept = self.run_safety_checker(
|
||||
|
Loading…
x
Reference in New Issue
Block a user