mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Optimize weights handling
This commit is contained in:
parent
1fd9631f2d
commit
86f705bf48
@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_
|
|||||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
|
||||||
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
"""
|
"""
|
||||||
loras = [
|
loras = [
|
||||||
@ -123,9 +124,7 @@ class ModelPatcher:
|
|||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
"""
|
"""
|
||||||
original_weights: Dict[str, torch.Tensor] = {}
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
if cached_weights:
|
|
||||||
original_weights.update(cached_weights)
|
|
||||||
try:
|
try:
|
||||||
for lora_model, lora_weight in loras:
|
for lora_model, lora_weight in loras:
|
||||||
LoRAExt.patch_model(
|
LoRAExt.patch_model(
|
||||||
@ -141,7 +140,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for param_key, weight in original_weights.items():
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
model.get_parameter(param_key).copy_(weight)
|
model.get_parameter(param_key).copy_(weight)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -4,12 +4,12 @@ from contextlib import contextmanager
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List
|
from typing import TYPE_CHECKING, Callable, Dict, List
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -56,17 +56,17 @@ class ExtensionBase:
|
|||||||
yield None
|
yield None
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
|
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
|
||||||
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
|
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
|
||||||
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
|
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
|
||||||
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
|
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
|
||||||
context manager.
|
by this context manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
unet (UNet2DConditionModel): The UNet model on execution device to patch.
|
||||||
original_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
|
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
|
||||||
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
|
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
|
||||||
access original weight value.
|
can access original weights values.
|
||||||
"""
|
"""
|
||||||
yield
|
yield
|
||||||
|
@ -1,15 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
class FreeUExt(ExtensionBase):
|
class FreeUExt(ExtensionBase):
|
||||||
@ -21,7 +21,7 @@ class FreeUExt(ExtensionBase):
|
|||||||
self._freeu_config = freeu_config
|
self._freeu_config = freeu_config
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
unet.enable_freeu(
|
unet.enable_freeu(
|
||||||
b1=self._freeu_config.b1,
|
b1=self._freeu_config.b1,
|
||||||
b2=self._freeu_config.b2,
|
b2=self._freeu_config.b2,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Dict, Tuple
|
from typing import TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
@ -13,6 +13,7 @@ if TYPE_CHECKING:
|
|||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
|
|
||||||
class LoRAExt(ExtensionBase):
|
class LoRAExt(ExtensionBase):
|
||||||
@ -28,7 +29,7 @@ class LoRAExt(ExtensionBase):
|
|||||||
self._weight = weight
|
self._weight = weight
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
|
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
|
||||||
lora_model = self._node_context.models.load(self._model_id).model
|
lora_model = self._node_context.models.load(self._model_id).model
|
||||||
self.patch_model(
|
self.patch_model(
|
||||||
model=unet,
|
model=unet,
|
||||||
@ -49,7 +50,7 @@ class LoRAExt(ExtensionBase):
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
lora: LoRAModelRaw,
|
lora: LoRAModelRaw,
|
||||||
lora_weight: float,
|
lora_weight: float,
|
||||||
original_weights: Dict[str, torch.Tensor],
|
original_weights: OriginalWeightsStorage,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Apply one or more LoRAs to a model.
|
Apply one or more LoRAs to a model.
|
||||||
@ -57,9 +58,12 @@ class LoRAExt(ExtensionBase):
|
|||||||
:param lora: LoRA model to patch in.
|
:param lora: LoRA model to patch in.
|
||||||
:param lora_weight: LoRA patch weight.
|
:param lora_weight: LoRA patch weight.
|
||||||
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
:param original_weights: Dict of original weights, filled by weights which lora patches, used for unpatching.
|
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if lora_weight == 0:
|
||||||
|
return
|
||||||
|
|
||||||
# assert lora.device.type == "cpu"
|
# assert lora.device.type == "cpu"
|
||||||
for layer_key, layer in lora.layers.items():
|
for layer_key, layer in lora.layers.items():
|
||||||
if not layer_key.startswith(prefix):
|
if not layer_key.startswith(prefix):
|
||||||
@ -95,8 +99,7 @@ class LoRAExt(ExtensionBase):
|
|||||||
module_param = module.get_parameter(param_name)
|
module_param = module.get_parameter(param_name)
|
||||||
|
|
||||||
# save original weight
|
# save original weight
|
||||||
if param_key not in original_weights:
|
original_weights.save(param_key, module_param)
|
||||||
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
|
|
||||||
|
|
||||||
if module_param.shape != lora_param_weight.shape:
|
if module_param.shape != lora_param_weight.shape:
|
||||||
# TODO: debug on lycoris
|
# TODO: debug on lycoris
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
|
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
|
||||||
@ -67,10 +68,7 @@ class ExtensionsManager:
|
|||||||
if self._is_canceled and self._is_canceled():
|
if self._is_canceled and self._is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
original_weights: Dict[str, torch.Tensor] = {}
|
original_weights = OriginalWeightsStorage(cached_weights)
|
||||||
if cached_weights:
|
|
||||||
original_weights.update(cached_weights)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with ExitStack() as exit_stack:
|
with ExitStack() as exit_stack:
|
||||||
for ext in self._extensions:
|
for ext in self._extensions:
|
||||||
@ -80,5 +78,5 @@ class ExtensionsManager:
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for param_key, weight in original_weights.items():
|
for param_key, weight in original_weights.get_changed_weights():
|
||||||
unet.get_parameter(param_key).copy_(weight)
|
unet.get_parameter(param_key).copy_(weight)
|
||||||
|
35
invokeai/backend/util/original_weights_storage.py
Normal file
35
invokeai/backend/util/original_weights_storage.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, Iterator, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
|
class OriginalWeightsStorage:
|
||||||
|
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
|
||||||
|
self._weights = {}
|
||||||
|
self._changed_weights = set()
|
||||||
|
if cached_weights:
|
||||||
|
self._weights.update(cached_weights)
|
||||||
|
|
||||||
|
def save(self, key: str, weight: torch.Tensor, copy: bool = True):
|
||||||
|
self._changed_weights.add(key)
|
||||||
|
if key in self._weights:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy)
|
||||||
|
|
||||||
|
def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]:
|
||||||
|
weight = self._weights.get(key, None)
|
||||||
|
if weight is not None and copy:
|
||||||
|
weight = weight.clone()
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def contains(self, key: str) -> bool:
|
||||||
|
return key in self._weights
|
||||||
|
|
||||||
|
def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]:
|
||||||
|
for key in self._changed_weights:
|
||||||
|
yield key, self._weights[key]
|
Loading…
Reference in New Issue
Block a user