mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Optimize weights handling
This commit is contained in:
@ -7,6 +7,7 @@ import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
@ -67,10 +68,7 @@ class ExtensionsManager:
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
original_weights: Dict[str, torch.Tensor] = {}
|
||||
if cached_weights:
|
||||
original_weights.update(cached_weights)
|
||||
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
@ -80,5 +78,5 @@ class ExtensionsManager:
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key, weight in original_weights.items():
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
unet.get_parameter(param_key).copy_(weight)
|
||||
|
Reference in New Issue
Block a user