From 821c7df2401deb3101654304d963642d22256f53 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Wed, 7 Dec 2022 18:20:56 -0800 Subject: [PATCH] refactor(diffusers): reduce some code duplication amongst the different tasks --- ldm/invoke/generator/diffusers_pipeline.py | 176 +++++++-------------- 1 file changed, 61 insertions(+), 115 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 35db1db383..4dbd576c51 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -157,6 +157,8 @@ ParamType = ParamSpec('ParamType') @dataclass(frozen=True) class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]): + """Convert a generator to a function with a callback and a return value.""" + generator_method: Callable[ParamType, ReturnType] callback_arg_type: Type[CallbackType] @@ -261,111 +263,48 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): :param run_id: :param extra_step_kwargs: """ - self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) - result = None - for result in self.generate_from_embeddings( - latents, text_embeddings, unconditioned_embeddings, guidance_scale, - extra_conditioning_info=extra_conditioning_info, - run_id=run_id, **extra_step_kwargs): - if callback is not None and isinstance(result, PipelineIntermediateState): - callback(result) - if result is None: - raise AssertionError("why was that an empty generator?") - return result - - def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, - text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, - guidance_scale: float, - *, callback: Callable[[PipelineIntermediateState], None]=None, - extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo=None, - run_id=None, - **extra_step_kwargs) -> PipelineIntermediateState: - self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) - f = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) - return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale, - extra_conditioning_info=extra_conditioning_info, - run_id=run_id, - callback=callback, - **extra_step_kwargs) - - def generate( - self, - prompt: Union[str, List[str]], - *, - opposing_prompt: Union[str, List[str]] = None, - height: Optional[int] = 512, - width: Optional[int] = 512, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 7.5, - generator: Optional[torch.Generator] = None, - latents: Optional[torch.FloatTensor] = None, - run_id: str = None, - **extra_step_kwargs, - ): - if isinstance(prompt, str): - batch_size = 1 - else: - batch_size = len(prompt) - - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - combined_embeddings = self._encode_prompt(prompt, device=self._execution_device, num_images_per_prompt=1, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=opposing_prompt) - text_embeddings, unconditioned_embeddings = combined_embeddings.chunk(2) - self.scheduler.set_timesteps(num_inference_steps) - latents = self.prepare_latents(batch_size=batch_size, num_channels_latents=self.unet.in_channels, - height=height, width=width, - dtype=self.unet.dtype, device=self._execution_device, - generator=generator, - latents=latents) - - yield from self.generate_from_embeddings(latents, text_embeddings, unconditioned_embeddings, - guidance_scale, run_id=run_id, **extra_step_kwargs) - - def generate_from_embeddings( - self, - latents: torch.Tensor, - text_embeddings: torch.Tensor, - unconditioned_embeddings: torch.Tensor, - guidance_scale: float, - *, - run_id: str = None, - extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, - timesteps = None, - additional_guidance: List[Callable] = None, - **extra_step_kwargs): - latents = yield from self.generate_latents_from_embeddings(latents, text_embeddings, unconditioned_embeddings, - guidance_scale, run_id=run_id, extra_conditioning_info=extra_conditioning_info, - timesteps=timesteps, additional_guidance=additional_guidance, **extra_step_kwargs) - + result_latents = self.latents_from_embeddings( + latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + run_id=run_id, callback=callback, **extra_step_kwargs + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() with torch.inference_mode(): - image = self.decode_latents(latents) + image = self.decode_latents(result_latents) output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) - yield self.check_for_safety(output, dtype=text_embeddings.dtype) + return self.check_for_safety(output, dtype=text_embeddings.dtype) - def generate_latents_from_embeddings( - self, - latents: torch.Tensor, - text_embeddings: torch.Tensor, - unconditioned_embeddings: torch.Tensor, + def latents_from_embeddings( + self, latents: torch.Tensor, num_inference_steps: int, + text_embeddings: torch.Tensor, unconditioned_embeddings: torch.Tensor, guidance_scale: float, *, - run_id: str = None, - extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, timesteps = None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, additional_guidance: List[Callable] = None, + run_id=None, + callback: Callable[[PipelineIntermediateState], None]=None, **extra_step_kwargs - ): + ) -> torch.Tensor: + if timesteps is None: + self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device) + timesteps = self.scheduler.timesteps + infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState) + return infer_latents_from_embeddings( + latents, timesteps, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + additional_guidance=additional_guidance, + run_id=run_id, + callback=callback, + **extra_step_kwargs).latents + + def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps, text_embeddings: torch.Tensor, + unconditioned_embeddings: torch.Tensor, guidance_scale: float, *, + run_id: str = None, + extra_conditioning_info: InvokeAIDiffuserComponent.ExtraConditioningInfo = None, + additional_guidance: List[Callable] = None, **extra_step_kwargs): if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) if additional_guidance is None: @@ -375,9 +314,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): step_count=len(self.scheduler.timesteps)) else: self.invokeai_diffuser.remove_cross_attention_control() - if timesteps is None: - # NOTE: Depends on scheduler being already initialized! - timesteps = self.scheduler.timesteps # scale the initial noise by the standard deviation required by the scheduler latents *= self.scheduler.init_noise_sigma @@ -448,8 +384,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): run_id=None, noise_func=None, **extra_step_kwargs) -> StableDiffusionPipelineOutput: - device = self.unet.device - latents_dtype = self.unet.dtype if isinstance(init_image, PIL.Image.Image): init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB')) @@ -457,13 +391,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): init_image = einops.rearrange(init_image, 'c h w -> 1 c h w') # 6. Prepare latent variables + device = self.unet.device + latents_dtype = self.unet.dtype initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype) - result = self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings, + return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info, noise_func, run_id, callback, **extra_step_kwargs) - return result def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, strength, extra_conditioning_info, @@ -476,13 +411,21 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): latent_timestep = timesteps[:1].repeat(batch_size) latents = self.noise_latents_for_time(initial_latents, latent_timestep, noise_func=noise_func) - f = GeneratorToCallbackinator(self.generate_from_embeddings, PipelineIntermediateState) - return f(latents, text_embeddings, unconditioned_embeddings, guidance_scale, + result_latents = self.latents_from_embeddings( + latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, extra_conditioning_info=extra_conditioning_info, timesteps=timesteps, callback=callback, run_id=run_id, **extra_step_kwargs) + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + with torch.inference_mode(): + image = self.decode_latents(result_latents) + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) + return self.check_for_safety(output, dtype=text_embeddings.dtype) + def inpaint_from_embeddings( self, init_image: torch.FloatTensor, @@ -536,22 +479,25 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): else: guidance.append(AddsMaskGuidance(mask, init_image_latents, self.scheduler, noise_func)) - result = None - try: - for result in self.generate_from_embeddings( - latents, text_embeddings, unconditioned_embeddings, guidance_scale, - extra_conditioning_info=extra_conditioning_info, - timesteps=timesteps, - run_id=run_id, additional_guidance=guidance, **extra_step_kwargs): - if callback is not None and isinstance(result, PipelineIntermediateState): - callback(result) - if result is None: - raise AssertionError("why was that an empty generator?") - return result + result_latents = self.latents_from_embeddings( + latents, num_inference_steps, text_embeddings, unconditioned_embeddings, guidance_scale, + extra_conditioning_info=extra_conditioning_info, + timesteps=timesteps, + run_id=run_id, additional_guidance=guidance, + callback=callback, + **extra_step_kwargs) finally: self.invokeai_diffuser.model_forward_callback = self._unet_forward + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + torch.cuda.empty_cache() + + with torch.inference_mode(): + image = self.decode_latents(result_latents) + output = StableDiffusionPipelineOutput(images=image, nsfw_content_detected=[]) + return self.check_for_safety(output, dtype=text_embeddings.dtype) + def non_noised_latents_from_image(self, init_image, *, device, dtype): init_image = init_image.to(device=device, dtype=dtype) with torch.inference_mode():