diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 861bf22a7a..911de67601 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -3,10 +3,12 @@ import warnings from dataclasses import dataclass from typing import List, Optional, Union, Callable +import PIL.Image import torch from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -210,6 +212,7 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline): *, run_id: str = None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, + timesteps = None, **extra_step_kwargs): if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) @@ -220,16 +223,19 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline): else: self.invokeai_diffuser.remove_cross_attention_control() + if timesteps is None: + timesteps = self.scheduler.timesteps + # scale the initial noise by the standard deviation required by the scheduler latents *= self.scheduler.init_noise_sigma yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, latents=latents) batch_size = latents.shape[0] - batched_t = torch.full((batch_size,), self.scheduler.timesteps[0], - dtype=self.scheduler.timesteps.dtype, device=self.unet.device) + batched_t = torch.full((batch_size,), timesteps[0], + dtype=timesteps.dtype, device=self.unet.device) # NOTE: Depends on scheduler being already initialized! - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + for i, t in enumerate(self.progress_bar(timesteps)): batched_t.fill_(t) step_output = self.step(batched_t, latents, guidance_scale, text_embeddings, unconditioned_embeddings, @@ -272,6 +278,68 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline): # predict the noise residual return self.unet(latents, t, encoder_hidden_states=text_embeddings).sample + def img2img_from_embeddings(self, + init_image: Union[torch.FloatTensor, PIL.Image.Image], + strength: float, + num_inference_steps: int, + text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, + guidance_scale: float, + *, callback: Callable[[PipelineIntermediateState], None] = None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, + run_id=None, + noise_func=None, + **extra_step_kwargs) -> StableDiffusionPipelineOutput: + device = self.unet.device + latents_dtype = text_embeddings.dtype + batch_size = 1 + num_images_per_prompt = 1 + + if isinstance(init_image, PIL.Image.Image): + init_image = preprocess(init_image.convert('RGB')) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self._diffusers08_get_timesteps(num_inference_steps, strength) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) + + result = None + for result in self.generate_from_embeddings( + latents, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + timesteps=timesteps, + run_id=run_id, **extra_step_kwargs): + if callback is not None and isinstance(result, PipelineIntermediateState): + callback(result) + if result is None: + raise AssertionError("why was that an empty generator?") + return result + + def prepare_latents_from_image(self, init_image, timestep, dtype, device, noise_func) -> torch.FloatTensor: + # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents + # because we have our own noise function + init_image = init_image.to(device=device, dtype=dtype) + with torch.inference_mode(): + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample() # FIXME: uses torch.randn. make reproducible! + init_latents = 0.18215 * init_latents + + noise = noise_func(init_latents) + + return self.scheduler.add_noise(init_latents, noise, timestep) + + def _diffusers08_get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + t_start = max(num_inference_steps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps + @torch.inference_mode() def check_for_safety(self, output): if not getattr(self, 'feature_extractor') or not getattr(self, 'safety_checker'): diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index edcc855a29..6ea41fda33 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -2,14 +2,10 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator ''' -import PIL -import numpy as np import torch -from PIL import Image -from torch import Tensor -from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.base import Generator +from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline class Img2Img(Generator): @@ -25,66 +21,51 @@ class Img2Img(Generator): """ self.perlin = perlin - sampler.make_schedule( - ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False - ) - - if isinstance(init_image, PIL.Image.Image): - init_image = self._image_to_tensor(init_image.convert('RGB')) - - scope = choose_autocast(self.precision) - with scope(self.model.device.type): - self.init_latent = self.model.get_first_stage_encoding( - self.model.encode_first_stage(init_image) - ) # move to latent space - - t_enc = int(strength * steps) uc, c, extra_conditioning_info = conditioning + # noinspection PyTypeChecker + pipeline: StableDiffusionGeneratorPipeline = self.model + pipeline.scheduler = sampler + def make_image(x_T): - # encode (scaled latent) - z_enc = sampler.stochastic_encode( - self.init_latent, - torch.tensor([t_enc]).to(self.model.device), - noise=x_T - ) - # decode it - samples = sampler.decode( - z_enc, - c, - t_enc, - img_callback = step_callback, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - init_latent = self.init_latent, # changes how noising is performed in ksampler - extra_conditioning_info = extra_conditioning_info, - all_timesteps_count = steps + # FIXME: use x_T for initial seeded noise + pipeline_output = pipeline.img2img_from_embeddings( + init_image, strength, steps, c, uc, cfg_scale, + extra_conditioning_info=extra_conditioning_info, + noise_func=self.get_noise_like, + callback=step_callback ) - return self.sample_to_image(samples) + return pipeline.numpy_to_pil(pipeline_output.images)[0] return make_image - def get_noise(self,width,height): - device = self.model.device - init_latent = self.init_latent - assert init_latent is not None,'call to get_noise() when init_latent not set' + def get_noise_like(self, like: torch.Tensor): + device = like.device if device.type == 'mps': - x = torch.randn_like(init_latent, device='cpu').to(device) + x = torch.randn_like(like, device='cpu').to(device) else: - x = torch.randn_like(init_latent, device=device) + x = torch.randn_like(like, device=device) if self.perlin > 0.0: - shape = init_latent.shape + shape = like.shape x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) return x - def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor: - image = np.array(image).astype(np.float32) / 255.0 - if len(image.shape) == 2: # 'L' image, as in a mask - image = image[None,None] - else: # 'RGB' image - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - if normalize: - image = 2.0 * image - 1.0 - return image.to(self.model.device) + def get_noise(self,width,height): + # copy of the Txt2Img.get_noise + device = self.model.device + if self.use_mps_noise or device.type == 'mps': + x = torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device='cpu').to(device) + else: + x = torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device=device) + if self.perlin > 0.0: + x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) + return x