mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
@ -70,26 +70,24 @@ class ExtensionsManager:
|
||||
modified_weights: Dict[str, torch.Tensor] = {}
|
||||
modified_cached_weights: Set[str] = set()
|
||||
|
||||
exit_stack = ExitStack()
|
||||
try:
|
||||
for ext in self._extensions:
|
||||
res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
|
||||
if res is None:
|
||||
continue
|
||||
ext_modified_cached_weights, ext_modified_weights = res
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
|
||||
ext.patch_unet(unet, cached_weights)
|
||||
)
|
||||
|
||||
modified_cached_weights.update(ext_modified_cached_weights)
|
||||
# store only first returned weight for each key, because
|
||||
# next extension which changes it, will work with already modified weight
|
||||
for param_key, weight in ext_modified_weights.items():
|
||||
if param_key in modified_weights:
|
||||
continue
|
||||
modified_weights[param_key] = weight
|
||||
modified_cached_weights.update(ext_modified_cached_weights)
|
||||
# store only first returned weight for each key, because
|
||||
# next extension which changes it, will work with already modified weight
|
||||
for param_key, weight in ext_modified_weights.items():
|
||||
if param_key in modified_weights:
|
||||
continue
|
||||
modified_weights[param_key] = weight
|
||||
|
||||
yield None
|
||||
yield None
|
||||
|
||||
finally:
|
||||
exit_stack.close()
|
||||
with torch.no_grad():
|
||||
for param_key in modified_cached_weights:
|
||||
unet.get_parameter(param_key).copy_(cached_weights[param_key])
|
||||
|
Reference in New Issue
Block a user