diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index b0b4ac2a67..95de688f57 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -22,8 +22,9 @@ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils.outputs import BaseOutput from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -53,39 +54,76 @@ _default_personalization_config_params = dict( @dataclass class AddsMaskLatents: + """Add the channels required for inpainting model input. + + The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask + and the latent encoding of the base image. + + This class assumes the same mask and base image should apply to all items in the batch. + """ forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] mask: torch.FloatTensor - mask_latents: torch.FloatTensor + initial_image_latents: torch.FloatTensor def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor: - batch_size = latents.size(0) - mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) - mask_latents = einops.repeat(self.mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) - model_input, _ = einops.pack([latents, mask, mask_latents], 'b * h w') + model_input = self.add_mask_channels(latents) return self.forward(model_input, t, text_embeddings) + def add_mask_channels(self, latents): + batch_size = latents.size(0) + # duplicate mask and latents for each batch + mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) + image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) + # add mask and image as additional channels + model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w') + return model_input + + +def are_like_tensors(a: torch.Tensor, b: object) -> bool: + return ( + isinstance(b, torch.Tensor) + and (a.size() == b.size()) + ) @dataclass class AddsMaskGuidance: - forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] mask: torch.FloatTensor mask_latents: torch.FloatTensor _scheduler: SchedulerMixin _noise_func: Callable _debug: Optional[Callable] = None - def __call__(self, latents: torch.FloatTensor, t: torch.Tensor, text_embeddings: torch.FloatTensor) -> torch.Tensor: + def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput: + output_class = step_output.__class__ # We'll create a new one with masked data. + + # The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it. + # It's reasonable to assume the first thing is prev_sample, but then does it have other things + # like pred_original_sample? Should we apply the mask to them too? + # But what if there's just some other random field? + prev_sample = step_output[0] + # Mask anything that has the same shape as prev_sample, return others as-is. + return output_class( + {k: (self.apply_mask(v, self._t_for_field(k, t)) + if are_like_tensors(prev_sample, v) else v) + for k, v in step_output.items()} + ) + + def _t_for_field(self, field_name:str, t): + if field_name == "pred_original_sample": + return torch.zeros_like(t, dtype=t.dtype) # it represents t=0 + return t + + def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: batch_size = latents.size(0) mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size) noise = self._noise_func(self.mask_latents) - mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t[0]) # .to(dtype=mask_latents.dtype) + mask_latents = self._scheduler.add_noise(self.mask_latents, noise, t) + # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size) - # if self._debug: - # self._debug(latents, f"t={t[0]} latents") masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) if self._debug: - self._debug(masked_input, f"t={t[0]} lerped") - return self.forward(masked_input, t, text_embeddings) + self._debug(masked_input, f"t={t} lerped") + return masked_input def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor: @@ -263,10 +301,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): run_id: str = None, extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, timesteps = None, + additional_guidance: List[Callable] = None, **extra_step_kwargs): if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) + if additional_guidance is None: + additional_guidance = [] + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count=len(self.scheduler.timesteps)) @@ -289,7 +331,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): batched_t.fill_(t) step_output = self.step(batched_t, latents, guidance_scale, text_embeddings, unconditioned_embeddings, - i, **extra_step_kwargs) + i, additional_guidance=additional_guidance, + **extra_step_kwargs) latents = step_output.prev_sample predicted_original = getattr(step_output, 'pred_original_sample', None) yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents, @@ -306,11 +349,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): @torch.inference_mode() def step(self, t: torch.Tensor, latents: torch.Tensor, guidance_scale: float, text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, - step_index:int | None = None, + step_index:int | None = None, additional_guidance: List[Callable] = None, **extra_step_kwargs): # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value timestep = t[0] + if additional_guidance is None: + additional_guidance = [] + # TODO: should this scaling happen here or inside self._unet_forward? # i.e. before or after passing it to InvokeAIDiffuserComponent latent_model_input = self.scheduler.scale_model_input(latents, timestep) @@ -323,7 +369,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_index=step_index) # compute the previous noisy sample x_t -> x_t-1 - return self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs) + step_output = self.scheduler.step(noise_pred, timestep, latents, **extra_step_kwargs) + + # TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent. + # But the way things are now, scheduler runs _after_ that, so there was + # no way to use it to apply an operation that happens after the last scheduler.step. + for guidance in additional_guidance: + step_output = guidance(step_output, timestep, (unconditioned_embeddings, text_embeddings)) + + return step_output def _unet_forward(self, latents, t, text_embeddings): # predict the noise residual @@ -401,6 +455,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device) + assert img2img_pipeline.scheduler is self.scheduler + # 6. Prepare latent variables latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) latents, init_image_latents = self.prepare_latents_from_image(init_image, latent_timestep, latents_dtype, device, noise_func) @@ -410,13 +466,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR) \ .to(device=device, dtype=latents_dtype) + guidance: List[Callable] = [] + if is_inpainting_model(self.unet): + # TODO: we should probably pass this in so we don't have to try/finally around setting it. self.invokeai_diffuser.model_forward_callback = \ AddsMaskLatents(self._unet_forward, mask, init_image_latents) else: - self.invokeai_diffuser.model_forward_callback = \ - AddsMaskGuidance(self._unet_forward, mask, init_image_latents, - self.scheduler, noise_func) # self.debug_latents) + guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func)) result = None @@ -425,7 +482,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latents, text_embeddings, unconditioned_embeddings, guidance_scale, extra_conditioning_info=extra_conditioning_info, timesteps=timesteps, - run_id=run_id, **extra_step_kwargs): + run_id=run_id, additional_guidance=guidance, **extra_step_kwargs): if callback is not None and isinstance(result, PipelineIntermediateState): callback(result) if result is None: