mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove patch_unet logic for now
This commit is contained in:
parent
03e22c257b
commit
137202b77c
@ -108,35 +108,5 @@ class ExtensionsManager:
|
|||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel):
|
||||||
exit_stack = ExitStack()
|
# TODO: create logic in PR with extension which uses it
|
||||||
try:
|
yield None
|
||||||
changed_keys = set()
|
|
||||||
changed_unknown_keys = {}
|
|
||||||
|
|
||||||
for ext in self.extensions:
|
|
||||||
patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet))
|
|
||||||
if patch_result is None:
|
|
||||||
continue
|
|
||||||
new_keys, new_unk_keys = patch_result
|
|
||||||
changed_keys.update(new_keys)
|
|
||||||
# skip already seen keys, as new weight might be changed
|
|
||||||
for k, v in new_unk_keys.items():
|
|
||||||
if k in changed_unknown_keys:
|
|
||||||
continue
|
|
||||||
changed_unknown_keys[k] = v
|
|
||||||
|
|
||||||
yield None
|
|
||||||
|
|
||||||
finally:
|
|
||||||
exit_stack.close()
|
|
||||||
assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
|
||||||
with torch.no_grad():
|
|
||||||
for module_key in changed_keys:
|
|
||||||
weight = state_dict[module_key]
|
|
||||||
unet.get_submodule(module_key).weight.copy_(
|
|
||||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
|
||||||
)
|
|
||||||
for module_key, weight in changed_unknown_keys.items():
|
|
||||||
unet.get_submodule(module_key).weight.copy_(
|
|
||||||
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
|
|
||||||
)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user