mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes + simplify weights logic in patching
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -67,29 +67,18 @@ class ExtensionsManager:
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
modified_weights: Dict[str, torch.Tensor] = {}
|
||||
modified_cached_weights: Set[str] = set()
|
||||
original_weights: Dict[str, torch.Tensor] = {}
|
||||
if cached_weights:
|
||||
original_weights.update(cached_weights)
|
||||
|
||||
try:
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
|
||||
ext.patch_unet(unet, cached_weights)
|
||||
)
|
||||
|
||||
modified_cached_weights.update(ext_modified_cached_weights)
|
||||
# store only first returned weight for each key, because
|
||||
# next extension which changes it, will work with already modified weight
|
||||
for param_key, weight in ext_modified_weights.items():
|
||||
if param_key in modified_weights:
|
||||
continue
|
||||
modified_weights[param_key] = weight
|
||||
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
|
||||
|
||||
yield None
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key in modified_cached_weights:
|
||||
unet.get_parameter(param_key).copy_(cached_weights[param_key])
|
||||
for param_key, weight in modified_weights.items():
|
||||
for param_key, weight in original_weights.items():
|
||||
unet.get_parameter(param_key).copy_(weight)
|
||||
|
Reference in New Issue
Block a user