cross_attention_control: stub (no-op) implementations for diffusers

This commit is contained in:
Kevin Turner 2022-11-12 10:10:46 -08:00
parent b6b1a8d97c
commit 95db6e80ee
2 changed files with 34 additions and 13 deletions

View File

@ -213,6 +213,13 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
**extra_step_kwargs): **extra_step_kwargs):
if run_id is None: if run_id is None:
run_id = secrets.token_urlsafe(self.ID_LENGTH) run_id = secrets.token_urlsafe(self.ID_LENGTH)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
step_count=len(self.scheduler.timesteps))
else:
self.invokeai_diffuser.remove_cross_attention_control()
# scale the initial noise by the standard deviation required by the scheduler # scale the initial noise by the standard deviation required by the scheduler
latents *= self.scheduler.init_noise_sigma latents *= self.scheduler.init_noise_sigma
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps, yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,

View File

@ -1,4 +1,5 @@
import enum import enum
import warnings
from typing import Optional from typing import Optional
import torch import torch
@ -244,19 +245,32 @@ def inject_attention_function(unet, context: Context):
return attention_slice return attention_slice
for name, module in unet.named_modules(): cross_attention_modules = [(name, module) for (name, module) in unet.named_modules()
module_name = type(module).__name__ if type(module).__name__ == "CrossAttention"]
if module_name == "CrossAttention": for identifier, module in cross_attention_modules:
module.identifier = name module.identifier = identifier
module.set_attention_slice_wrangler(attention_slice_wrangler) try:
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ module.set_attention_slice_wrangler(attention_slice_wrangler)
context.get_slicing_strategy(module_identifier)) module.set_slicing_strategy_getter(
lambda module: context.get_slicing_strategy(identifier)
)
except AttributeError as e:
if e.name == 'set_attention_slice_wrangler':
warnings.warn(f"TODO: implement for {type(module)}") # TODO
else:
raise
def remove_attention_function(unet): def remove_attention_function(unet):
# clear wrangler callback cross_attention_modules = [module for (_, module) in unet.named_modules()
for name, module in unet.named_modules(): if type(module).__name__ == "CrossAttention"]
module_name = type(module).__name__ for module in cross_attention_modules:
if module_name == "CrossAttention": try:
module.set_attention_slice_wrangler(None) # clear wrangler callback
module.set_slicing_strategy_getter(None) module.set_attention_slice_wrangler(None)
module.set_slicing_strategy_getter(None)
except AttributeError as e:
if e.name == 'set_attention_slice_wrangler':
warnings.warn(f"TODO: implement for {type(module)}") # TODO
else:
raise