Optimize weights handling

This commit is contained in:
Sergey Borisov
2024-07-30 03:39:01 +03:00
parent 1fd9631f2d
commit 86f705bf48
6 changed files with 62 additions and 27 deletions

View File

@ -7,6 +7,7 @@ import torch
from diffusers import UNet2DConditionModel
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
@ -67,10 +68,7 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled():
raise CanceledException
original_weights: Dict[str, torch.Tensor] = {}
if cached_weights:
original_weights.update(cached_weights)
original_weights = OriginalWeightsStorage(cached_weights)
try:
with ExitStack() as exit_stack:
for ext in self._extensions:
@ -80,5 +78,5 @@ class ExtensionsManager:
finally:
with torch.no_grad():
for param_key, weight in original_weights.items():
for param_key, weight in original_weights.get_changed_weights():
unet.get_parameter(param_key).copy_(weight)