Add invocation cancellation logic to patchers

This commit is contained in:
Sergey Borisov 2024-07-19 23:17:01 +03:00
parent 83a86abce2
commit 39e10d894c

View File

@ -44,8 +44,6 @@ class ExtensionsManager:
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order) self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext): def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
# TODO: add to patchers too?
# and if so, should it be only in beginning of function or in for loop
if self._is_canceled and self._is_canceled(): if self._is_canceled and self._is_canceled():
raise CanceledException raise CanceledException
@ -55,6 +53,9 @@ class ExtensionsManager:
@contextmanager @contextmanager
def patch_extensions(self, context: DenoiseContext): def patch_extensions(self, context: DenoiseContext):
if self._is_canceled and self._is_canceled():
raise CanceledException
with ExitStack() as exit_stack: with ExitStack() as exit_stack:
for ext in self._extensions: for ext in self._extensions:
exit_stack.enter_context(ext.patch_extension(context)) exit_stack.enter_context(ext.patch_extension(context))
@ -63,5 +64,8 @@ class ExtensionsManager:
@contextmanager @contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
if self._is_canceled and self._is_canceled():
raise CanceledException
# TODO: create logic in PR with extension which uses it # TODO: create logic in PR with extension which uses it
yield None yield None