From 95db6e80ee8c20196e637a2d1aa1421d3b931044 Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 12 Nov 2022 10:10:46 -0800 Subject: [PATCH] cross_attention_control: stub (no-op) implementations for diffusers --- ldm/invoke/generator/diffusers_pipeline.py | 7 ++++ .../diffusion/cross_attention_control.py | 40 +++++++++++++------ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 0bd096ff6b..861bf22a7a 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -213,6 +213,13 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline): **extra_step_kwargs): if run_id is None: run_id = secrets.token_urlsafe(self.ID_LENGTH) + + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, + step_count=len(self.scheduler.timesteps)) + else: + self.invokeai_diffuser.remove_cross_attention_control() + # scale the initial noise by the standard deviation required by the scheduler latents *= self.scheduler.init_noise_sigma yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index a4362e0770..3284f990ce 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -1,4 +1,5 @@ import enum +import warnings from typing import Optional import torch @@ -244,19 +245,32 @@ def inject_attention_function(unet, context: Context): return attention_slice - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.identifier = name - module.set_attention_slice_wrangler(attention_slice_wrangler) - module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ - context.get_slicing_strategy(module_identifier)) + cross_attention_modules = [(name, module) for (name, module) in unet.named_modules() + if type(module).__name__ == "CrossAttention"] + for identifier, module in cross_attention_modules: + module.identifier = identifier + try: + module.set_attention_slice_wrangler(attention_slice_wrangler) + module.set_slicing_strategy_getter( + lambda module: context.get_slicing_strategy(identifier) + ) + except AttributeError as e: + if e.name == 'set_attention_slice_wrangler': + warnings.warn(f"TODO: implement for {type(module)}") # TODO + else: + raise def remove_attention_function(unet): - # clear wrangler callback - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.set_attention_slice_wrangler(None) - module.set_slicing_strategy_getter(None) + cross_attention_modules = [module for (_, module) in unet.named_modules() + if type(module).__name__ == "CrossAttention"] + for module in cross_attention_modules: + try: + # clear wrangler callback + module.set_attention_slice_wrangler(None) + module.set_slicing_strategy_getter(None) + except AttributeError as e: + if e.name == 'set_attention_slice_wrangler': + warnings.warn(f"TODO: implement for {type(module)}") # TODO + else: + raise