cross_attention_control: stub (no-op) implementations for diffusers

This commit is contained in:
Kevin Turner 2022-11-12 10:10:46 -08:00
parent b6b1a8d97c
commit 95db6e80ee
2 changed files with 34 additions and 13 deletions

View File

@ -213,6 +213,13 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
**extra_step_kwargs):
if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
step_count=len(self.scheduler.timesteps))
else:
self.invokeai_diffuser.remove_cross_attention_control()
# scale the initial noise by the standard deviation required by the scheduler
latents *= self.scheduler.init_noise_sigma
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,

View File

@ -1,4 +1,5 @@
import enum
import warnings
from typing import Optional
import torch
@ -244,19 +245,32 @@ def inject_attention_function(unet, context: Context):
return attention_slice
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.identifier = name
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
context.get_slicing_strategy(module_identifier))
cross_attention_modules = [(name, module) for (name, module) in unet.named_modules()
if type(module).__name__ == "CrossAttention"]
for identifier, module in cross_attention_modules:
module.identifier = identifier
try:
module.set_attention_slice_wrangler(attention_slice_wrangler)
module.set_slicing_strategy_getter(
lambda module: context.get_slicing_strategy(identifier)
)
except AttributeError as e:
if e.name == 'set_attention_slice_wrangler':
warnings.warn(f"TODO: implement for {type(module)}") # TODO
else:
raise
def remove_attention_function(unet):
# clear wrangler callback
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)
cross_attention_modules = [module for (_, module) in unet.named_modules()
if type(module).__name__ == "CrossAttention"]
for module in cross_attention_modules:
try:
# clear wrangler callback
module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)
except AttributeError as e:
if e.name == 'set_attention_slice_wrangler':
warnings.warn(f"TODO: implement for {type(module)}") # TODO
else:
raise