mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cross_attention_control: stub (no-op) implementations for diffusers
This commit is contained in:
parent
b6b1a8d97c
commit
95db6e80ee
@ -213,6 +213,13 @@ class StableDiffusionGeneratorPipeline(DiffusionPipeline):
|
|||||||
**extra_step_kwargs):
|
**extra_step_kwargs):
|
||||||
if run_id is None:
|
if run_id is None:
|
||||||
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
run_id = secrets.token_urlsafe(self.ID_LENGTH)
|
||||||
|
|
||||||
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
|
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info,
|
||||||
|
step_count=len(self.scheduler.timesteps))
|
||||||
|
else:
|
||||||
|
self.invokeai_diffuser.remove_cross_attention_control()
|
||||||
|
|
||||||
# scale the initial noise by the standard deviation required by the scheduler
|
# scale the initial noise by the standard deviation required by the scheduler
|
||||||
latents *= self.scheduler.init_noise_sigma
|
latents *= self.scheduler.init_noise_sigma
|
||||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import enum
|
import enum
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -244,19 +245,32 @@ def inject_attention_function(unet, context: Context):
|
|||||||
|
|
||||||
return attention_slice
|
return attention_slice
|
||||||
|
|
||||||
for name, module in unet.named_modules():
|
cross_attention_modules = [(name, module) for (name, module) in unet.named_modules()
|
||||||
module_name = type(module).__name__
|
if type(module).__name__ == "CrossAttention"]
|
||||||
if module_name == "CrossAttention":
|
for identifier, module in cross_attention_modules:
|
||||||
module.identifier = name
|
module.identifier = identifier
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
try:
|
||||||
module.set_slicing_strategy_getter(lambda module, module_identifier=name: \
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||||
context.get_slicing_strategy(module_identifier))
|
module.set_slicing_strategy_getter(
|
||||||
|
lambda module: context.get_slicing_strategy(identifier)
|
||||||
|
)
|
||||||
|
except AttributeError as e:
|
||||||
|
if e.name == 'set_attention_slice_wrangler':
|
||||||
|
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def remove_attention_function(unet):
|
def remove_attention_function(unet):
|
||||||
# clear wrangler callback
|
cross_attention_modules = [module for (_, module) in unet.named_modules()
|
||||||
for name, module in unet.named_modules():
|
if type(module).__name__ == "CrossAttention"]
|
||||||
module_name = type(module).__name__
|
for module in cross_attention_modules:
|
||||||
if module_name == "CrossAttention":
|
try:
|
||||||
module.set_attention_slice_wrangler(None)
|
# clear wrangler callback
|
||||||
module.set_slicing_strategy_getter(None)
|
module.set_attention_slice_wrangler(None)
|
||||||
|
module.set_slicing_strategy_getter(None)
|
||||||
|
except AttributeError as e:
|
||||||
|
if e.name == 'set_attention_slice_wrangler':
|
||||||
|
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
Loading…
x
Reference in New Issue
Block a user