mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Optimize weights handling
This commit is contained in:
parent
1fd9631f2d
commit
86f705bf48
@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
"""
|
||||
loras = [
|
||||
@ -123,9 +124,7 @@ class ModelPatcher:
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||
"""
|
||||
original_weights: Dict[str, torch.Tensor] = {}
|
||||
if cached_weights:
|
||||
original_weights.update(cached_weights)
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
for lora_model, lora_weight in loras:
|
||||
LoRAExt.patch_model(
|
||||
@ -141,7 +140,7 @@ class ModelPatcher:
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key, weight in original_weights.items():
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
model.get_parameter(param_key).copy_(weight)
|
||||
|
||||
@classmethod
|
||||
|
@ -4,12 +4,12 @@ from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -56,17 +56,17 @@ class ExtensionBase:
|
||||
yield None
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
|
||||
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
|
||||
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
|
||||
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
|
||||
context manager.
|
||||
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
|
||||
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
|
||||
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
|
||||
by this context manager.
|
||||
|
||||
Args:
|
||||
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
||||
original_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
|
||||
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
|
||||
access original weight value.
|
||||
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
|
||||
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
|
||||
can access original weights values.
|
||||
"""
|
||||
yield
|
||||
|
@ -1,15 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class FreeUExt(ExtensionBase):
|
||||
@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase):
|
||||
self._freeu_config = freeu_config
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
unet.enable_freeu(
|
||||
b1=self._freeu_config.b1,
|
||||
b2=self._freeu_config.b2,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
@ -13,6 +13,7 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
|
||||
class LoRAExt(ExtensionBase):
|
||||
@ -28,7 +29,7 @@ class LoRAExt(ExtensionBase):
|
||||
self._weight = weight
|
||||
|
||||
@contextmanager
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||
lora_model = self._node_context.models.load(self._model_id).model
|
||||
self.patch_model(
|
||||
model=unet,
|
||||
@ -49,7 +50,7 @@ class LoRAExt(ExtensionBase):
|
||||
prefix: str,
|
||||
lora: LoRAModelRaw,
|
||||
lora_weight: float,
|
||||
original_weights: Dict[str, torch.Tensor],
|
||||
original_weights: OriginalWeightsStorage,
|
||||
):
|
||||
"""
|
||||
Apply one or more LoRAs to a model.
|
||||
@ -57,9 +58,12 @@ class LoRAExt(ExtensionBase):
|
||||
:param lora: LoRA model to patch in.
|
||||
:param lora_weight: LoRA patch weight.
|
||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||
:param original_weights: Dict of original weights, filled by weights which lora patches, used for unpatching.
|
||||
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||
"""
|
||||
|
||||
if lora_weight == 0:
|
||||
return
|
||||
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
@ -95,8 +99,7 @@ class LoRAExt(ExtensionBase):
|
||||
module_param = module.get_parameter(param_name)
|
||||
|
||||
# save original weight
|
||||
if param_key not in original_weights:
|
||||
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
|
||||
original_weights.save(param_key, module_param)
|
||||
|
||||
if module_param.shape != lora_param_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
|
@ -7,6 +7,7 @@ import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||
@ -67,10 +68,7 @@ class ExtensionsManager:
|
||||
if self._is_canceled and self._is_canceled():
|
||||
raise CanceledException
|
||||
|
||||
original_weights: Dict[str, torch.Tensor] = {}
|
||||
if cached_weights:
|
||||
original_weights.update(cached_weights)
|
||||
|
||||
original_weights = OriginalWeightsStorage(cached_weights)
|
||||
try:
|
||||
with ExitStack() as exit_stack:
|
||||
for ext in self._extensions:
|
||||
@ -80,5 +78,5 @@ class ExtensionsManager:
|
||||
|
||||
finally:
|
||||
with torch.no_grad():
|
||||
for param_key, weight in original_weights.items():
|
||||
for param_key, weight in original_weights.get_changed_weights():
|
||||
unet.get_parameter(param_key).copy_(weight)
|
||||
|
35
invokeai/backend/util/original_weights_storage.py
Normal file
35
invokeai/backend/util/original_weights_storage.py
Normal file
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
class OriginalWeightsStorage:
|
||||
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||
self._weights = {}
|
||||
self._changed_weights = set()
|
||||
if cached_weights:
|
||||
self._weights.update(cached_weights)
|
||||
|
||||
def save(self, key: str, weight: torch.Tensor, copy: bool = True):
|
||||
self._changed_weights.add(key)
|
||||
if key in self._weights:
|
||||
return
|
||||
|
||||
self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy)
|
||||
|
||||
def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]:
|
||||
weight = self._weights.get(key, None)
|
||||
if weight is not None and copy:
|
||||
weight = weight.clone()
|
||||
return weight
|
||||
|
||||
def contains(self, key: str) -> bool:
|
||||
return key in self._weights
|
||||
|
||||
def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||
for key in self._changed_weights:
|
||||
yield key, self._weights[key]
|
Loading…
Reference in New Issue
Block a user