Handle loras in modular denoise

This commit is contained in:
Sergey Borisov
2024-07-24 05:07:29 +03:00
parent 7c975f0d00
commit ab0bfa709a
4 changed files with 227 additions and 4 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
import torch
from diffusers import UNet2DConditionModel
@ -67,9 +67,31 @@ class ExtensionsManager:
if self._is_canceled and self._is_canceled():
raise CanceledException
# TODO: create weight patch logic in PR with extension which uses it
with ExitStack() as exit_stack:
modified_weights: Dict[str, torch.Tensor] = {}
modified_cached_weights: Set[str] = set()
exit_stack = ExitStack()
try:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
if res is None:
continue
ext_modified_cached_weights, ext_modified_weights = res
modified_cached_weights.update(ext_modified_cached_weights)
# store only first returned weight for each key, because
# next extension which changes it, will work with already modified weight
for param_key, weight in ext_modified_weights.items():
if param_key in modified_weights:
continue
modified_weights[param_key] = weight
yield None
finally:
exit_stack.close()
with torch.no_grad():
for param_key in modified_cached_weights:
unet.get_parameter(param_key).copy_(cached_weights[param_key])
for param_key, weight in modified_weights.items():
unet.get_parameter(param_key).copy_(weight)