From 137202b77cadf5f6c9205a376177eaf89516e51d Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 17 Jul 2024 03:40:27 +0300 Subject: [PATCH] Remove patch_unet logic for now --- .../stable_diffusion/extensions_manager.py | 34 ++----------------- 1 file changed, 2 insertions(+), 32 deletions(-) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index e747579d8b..08004339e9 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -108,35 +108,5 @@ class ExtensionsManager: @contextmanager def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): - exit_stack = ExitStack() - try: - changed_keys = set() - changed_unknown_keys = {} - - for ext in self.extensions: - patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet)) - if patch_result is None: - continue - new_keys, new_unk_keys = patch_result - changed_keys.update(new_keys) - # skip already seen keys, as new weight might be changed - for k, v in new_unk_keys.items(): - if k in changed_unknown_keys: - continue - changed_unknown_keys[k] = v - - yield None - - finally: - exit_stack.close() - assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() - with torch.no_grad(): - for module_key in changed_keys: - weight = state_dict[module_key] - unet.get_submodule(module_key).weight.copy_( - weight, non_blocking=TorchDevice.get_non_blocking(weight.device) - ) - for module_key, weight in changed_unknown_keys.items(): - unet.get_submodule(module_key).weight.copy_( - weight, non_blocking=TorchDevice.get_non_blocking(weight.device) - ) + # TODO: create logic in PR with extension which uses it + yield None