mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Change attention processor apply logic
This commit is contained in:
@ -98,39 +98,14 @@ class ExtensionsManager:
|
||||
if name in self._callbacks:
|
||||
self._callbacks[name](*args, **kwargs)
|
||||
|
||||
# TODO: is there any need in such high abstarction
|
||||
# @contextmanager
|
||||
# def patch_extensions(self):
|
||||
# exit_stack = ExitStack()
|
||||
# try:
|
||||
# for ext in self.extensions:
|
||||
# exit_stack.enter_context(ext.patch_extension(self))
|
||||
#
|
||||
# yield None
|
||||
#
|
||||
# finally:
|
||||
# exit_stack.close()
|
||||
|
||||
@contextmanager
|
||||
def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object):
|
||||
unet_orig_processors = unet.attn_processors
|
||||
exit_stack = ExitStack()
|
||||
try:
|
||||
# just to be sure that attentions have not same processor instance
|
||||
attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
attn_procs[name] = attn_processor_cls()
|
||||
unet.set_attn_processor(attn_procs)
|
||||
|
||||
def patch_extensions(self, context: DenoiseContext):
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self.extensions:
|
||||
exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls))
|
||||
exit_stack.enter_context(ext.patch_extension(context))
|
||||
|
||||
yield None
|
||||
|
||||
finally:
|
||||
unet.set_attn_processor(unet_orig_processors)
|
||||
exit_stack.close()
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
exit_stack = ExitStack()
|
||||
|
Reference in New Issue
Block a user