mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cross_attention_control: stub (no-op) implementations for diffusers
This commit is contained in:
parent
b6b1a8d97c
commit
95db6e80ee
@ -213,6 +213,13 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
||||
**extra_step_kwargs):
|
||||
if run_id is None:
|
||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps))
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents *= self.scheduler.init_noise_sigma
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
|
@ -1,4 +1,5 @@
|
||||
import enum
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -244,19 +245,32 @@ def inject_attention_function(unet, context: Context):
|
||||
|
||||
return attention_slice
|
||||
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.identifier = name
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
|
||||
context.get_slicing_strategy(module_identifier))
|
||||
cross_attention_modules = [(name, module) for (name, module) in unet.named_modules()
|
||||
if type(module).__name__ == "CrossAttention"]
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(
|
||||
lambda module: context.get_slicing_strategy(identifier)
|
||||
)
|
||||
except AttributeError as e:
|
||||
if e.name == 'set_attention_slice_wrangler':
|
||||
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
# clear wrangler callback
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
cross_attention_modules = [module for (_, module) in unet.named_modules()
|
||||
if type(module).__name__ == "CrossAttention"]
|
||||
for module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if e.name == 'set_attention_slice_wrangler':
|
||||
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
Loading…
x
Reference in New Issue
Block a user