Remove patch_unet logic for now

This commit is contained in:
Sergey Borisov 2024-07-17 03:40:27 +03:00
parent 03e22c257b
commit 137202b77c

View File

@ -108,35 +108,5 @@ class ExtensionsManager:
@contextmanager @contextmanager
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
exit_stack = ExitStack() # TODO: create logic in PR with extension which uses it
try:
changed_keys = set()
changed_unknown_keys = {}
for ext in self.extensions:
patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet))
if patch_result is None:
continue
new_keys, new_unk_keys = patch_result
changed_keys.update(new_keys)
# skip already seen keys, as new weight might be changed
for k, v in new_unk_keys.items():
if k in changed_unknown_keys:
continue
changed_unknown_keys[k] = v
yield None yield None
finally:
exit_stack.close()
assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key in changed_keys:
weight = state_dict[module_key]
unet.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)
for module_key, weight in changed_unknown_keys.items():
unet.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)