Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov
2024-07-27 04:25:15 +03:00
parent faa88f72bf
commit 9e582563eb
5 changed files with 41 additions and 26 deletions

View File

@ -70,26 +70,24 @@ class ExtensionsManager:
modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set()
exit_stack = ExitStack()
try:
for ext in self._extensions:
res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
if res is None:
continue
ext_modified_cached_weights, ext_modified_weights = res
with ExitStack() as exit_stack:
for ext in self._extensions:
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
ext.patch_unet(unet, cached_weights)
)
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
yield None
yield None
finally:
exit_stack.close()
with torch.no_grad():
for param_key in modified_cached_weights:
unet.get_parameter(param_key).copy_(cached_weights[param_key])