Change attention processor apply logic

This commit is contained in:
Sergey Borisov
2024-07-16 20:03:29 +03:00
parent 608cbe3f5c
commit cec345cb5c
5 changed files with 36 additions and 34 deletions

View File

@ -98,39 +98,14 @@ class ExtensionsManager:
if name in self._callbacks:
self._callbacks[name](*args, **kwargs)
# TODO: is there any need in such high abstarction
# @contextmanager
# def patch_extensions(self):
# exit_stack = ExitStack()
# try:
# for ext in self.extensions:
# exit_stack.enter_context(ext.patch_extension(self))
#
# yield None
#
# finally:
# exit_stack.close()
@contextmanager
def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object):
unet_orig_processors = unet.attn_processors
exit_stack = ExitStack()
try:
# just to be sure that attentions have not same processor instance
attn_procs = {}
for name in unet.attn_processors.keys():
attn_procs[name] = attn_processor_cls()
unet.set_attn_processor(attn_procs)
def patch_extensions(self, context: DenoiseContext):
with ExitStack() as exit_stack:
for ext in self.extensions:
exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls))
exit_stack.enter_context(ext.patch_extension(context))
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
exit_stack.close()
@contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
exit_stack = ExitStack()