Suggested changes + simplify weights logic in patching

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov
2024-07-30 00:34:37 +03:00
parent 8500bac3ca
commit 2227a2357f
6 changed files with 76 additions and 108 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
@ -67,29 +67,18 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled():
raise CanceledException
modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set()
original_weights: Dict[str, torch.Tensor] = {}
if cached_weights:
original_weights.update(cached_weights)
try:
with ExitStack() as exit_stack:
for ext in self._extensions:
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
ext.patch_unet(unet, cached_weights)
)
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
yield None
finally:
with torch.no_grad():
for param_key in modified_cached_weights:
unet.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
for param_key, weight in original_weights.items():
unet.get_parameter(param_key).copy_(weight)