mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove patch_unet logic for now
This commit is contained in:
parent
03e22c257b
commit
137202b77c
@ -108,35 +108,5 @@ class ExtensionsManager:
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||
exit_stack = ExitStack()
|
||||
try:
|
||||
changed_keys = set()
|
||||
changed_unknown_keys = {}
|
||||
|
||||
for ext in self.extensions:
|
||||
patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet))
|
||||
if patch_result is None:
|
||||
continue
|
||||
new_keys, new_unk_keys = patch_result
|
||||
changed_keys.update(new_keys)
|
||||
# skip already seen keys, as new weight might be changed
|
||||
for k, v in new_unk_keys.items():
|
||||
if k in changed_unknown_keys:
|
||||
continue
|
||||
changed_unknown_keys[k] = v
|
||||
|
||||
yield None
|
||||
|
||||
finally:
|
||||
exit_stack.close()
|
||||
assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||
with torch.no_grad():
|
||||
for module_key in changed_keys:
|
||||
weight = state_dict[module_key]
|
||||
unet.get_submodule(module_key).weight.copy_(
|
||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||
)
|
||||
for module_key, weight in changed_unknown_keys.items():
|
||||
unet.get_submodule(module_key).weight.copy_(
|
||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
||||
)
|
||||
# TODO: create logic in PR with extension which uses it
|
||||
yield None
|
||||
|
Loading…
Reference in New Issue
Block a user