diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index ada7a06a57..bbe372ff57 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -173,7 +173,7 @@ class CompelInvocation(BaseInvocation): class SDXLPromptInvocationBase: - def run_clip_raw(self, context, clip_field, prompt, get_pooled): + def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix): tokenizer_info = context.services.model_manager.get_model( **clip_field.tokenizer.dict(), context=context, @@ -210,8 +210,8 @@ class SDXLPromptInvocationBase: # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') - with ModelPatcher.apply_lora_text_encoder( - text_encoder_info.context.model, _lora_loader() + with ModelPatcher.apply_lora( + text_encoder_info.context.model, _lora_loader(), lora_prefix ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, @@ -247,7 +247,7 @@ class SDXLPromptInvocationBase: return c, c_pooled, None - def run_clip_compel(self, context, clip_field, prompt, get_pooled): + def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix): tokenizer_info = context.services.model_manager.get_model( **clip_field.tokenizer.dict(), context=context, @@ -284,8 +284,8 @@ class SDXLPromptInvocationBase: # print(traceback.format_exc()) print(f'Warn: trigger: "{trigger}" not found') - with ModelPatcher.apply_lora_text_encoder( - text_encoder_info.context.model, _lora_loader() + with ModelPatcher.apply_lora( + text_encoder_info.context.model, _lora_loader(), lora_prefix ), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( tokenizer, ti_manager, @@ -357,11 +357,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False) + c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_") if self.style.strip() == "": - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True) + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_") else: - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True) + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -415,7 +415,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True) + # TODO: if there will appear lora for refiner - write proper prefix + c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -467,11 +468,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False) + c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_") if self.style.strip() == "": - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True) + c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_") else: - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True) + c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) @@ -525,7 +526,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: - c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True) + # TODO: if there will appear lora for refiner - write proper prefix + c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "") original_size = (self.original_height, self.original_width) crop_coords = (self.crop_top, self.crop_left) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 3edbe86150..6e2e0838bc 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -14,7 +14,7 @@ from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.models import ModelType, SilenceWarnings -from ...backend.model_management.lora import ModelPatcher +from ...backend.model_management import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c19e5c5c9a..d215d500a6 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation): return output +class SDXLLoraLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + # fmt: off + type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output" + + unet: Optional[UNetField] = Field(default=None, description="UNet submodel") + clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") + clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels") + # fmt: on + + +class SDXLLoraLoaderInvocation(BaseInvocation): + """Apply selected lora to unet and text_encoder.""" + + type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader" + + lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") + weight: float = Field(default=0.75, description="With what weight to apply lora") + + unet: Optional[UNetField] = Field(description="UNet model for applying lora") + clip: Optional[ClipField] = Field(description="Clip model for applying lora") + clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora") + + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "SDXL Lora Loader", + "tags": ["lora", "loader"], + "type_hints": {"lora": "lora_model"}, + }, + } + + def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: + if self.lora is None: + raise Exception("No LoRA provided") + + base_model = self.lora.base_model + lora_name = self.lora.model_name + + if not context.services.model_manager.model_exists( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + ): + raise Exception(f"Unkown lora name: {lora_name}!") + + if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras): + raise Exception(f'Lora "{lora_name}" already applied to unet') + + if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras): + raise Exception(f'Lora "{lora_name}" already applied to clip') + + if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras): + raise Exception(f'Lora "{lora_name}" already applied to clip2') + + output = SDXLLoraLoaderOutput() + + if self.unet is not None: + output.unet = copy.deepcopy(self.unet) + output.unet.loras.append( + LoraInfo( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + if self.clip is not None: + output.clip = copy.deepcopy(self.clip) + output.clip.loras.append( + LoraInfo( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + if self.clip2 is not None: + output.clip2 = copy.deepcopy(self.clip2) + output.clip2.loras.append( + LoraInfo( + base_model=base_model, + model_name=lora_name, + model_type=ModelType.Lora, + submodel=None, + weight=self.weight, + ) + ) + + return output + + class VAEModelField(BaseModel): """Vae model field""" diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 7dfceba853..faa6b59782 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union from pydantic import Field, validator -from ...backend.model_management import ModelType, SubModelType +from ...backend.model_management import ModelType, SubModelType, ModelPatcher from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext @@ -293,10 +293,22 @@ class SDXLTextToLatentsInvocation(BaseInvocation): num_inference_steps = self.steps + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"}), + context=context, + ) + yield (lora_info.context.model, lora.weight) + del lora_info + return + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None - with unet_info as unet: + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: + scheduler.set_timesteps(num_inference_steps, device=unet.device) timesteps = scheduler.timesteps diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index cf057f3a89..8e083c1045 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -13,3 +13,4 @@ from .models import ( DuplicateModelException, ) from .model_merge import ModelMerger, MergeInterpolationMethod +from .lora import ModelPatcher diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 4287072a65..56f7a648c9 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -23,421 +23,6 @@ from transformers import CLIPTextModel, CLIPTokenizer # TODO: rename and split this file -class LoRALayerBase: - # rank: Optional[int] - # alpha: Optional[float] - # bias: Optional[torch.Tensor] - # layer_key: str - - # @property - # def scale(self): - # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - - def __init__( - self, - layer_key: str, - values: dict, - ): - if "alpha" in values: - self.alpha = values["alpha"].item() - else: - self.alpha = None - - if "bias_indices" in values and "bias_values" in values and "bias_size" in values: - self.bias = torch.sparse_coo_tensor( - values["bias_indices"], - values["bias_values"], - tuple(values["bias_size"]), - ) - - else: - self.bias = None - - self.rank = None # set in layer implementation - self.layer_key = layer_key - - def forward( - self, - module: torch.nn.Module, - input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure - multiplier: float, - ): - if type(module) == torch.nn.Conv2d: - op = torch.nn.functional.conv2d - extra_args = dict( - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - else: - op = torch.nn.functional.linear - extra_args = {} - - weight = self.get_weight() - - bias = self.bias if self.bias is not None else 0 - scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - return ( - op( - *input_h, - (weight + bias).view(module.weight.shape), - None, - **extra_args, - ) - * multiplier - * scale - ) - - def get_weight(self): - raise NotImplementedError() - - def calc_size(self) -> int: - model_size = 0 - for val in [self.bias]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - if self.bias is not None: - self.bias = self.bias.to(device=device, dtype=dtype) - - -# TODO: find and debug lora/locon with bias -class LoRALayer(LoRALayerBase): - # up: torch.Tensor - # mid: Optional[torch.Tensor] - # down: torch.Tensor - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.up = values["lora_up.weight"] - self.down = values["lora_down.weight"] - if "lora_mid.weight" in values: - self.mid = values["lora_mid.weight"] - else: - self.mid = None - - self.rank = self.down.shape[0] - - def get_weight(self): - if self.mid is not None: - up = self.up.reshape(self.up.shape[0], self.up.shape[1]) - down = self.down.reshape(self.down.shape[0], self.down.shape[1]) - weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) - else: - weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.up, self.mid, self.down]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.up = self.up.to(device=device, dtype=dtype) - self.down = self.down.to(device=device, dtype=dtype) - - if self.mid is not None: - self.mid = self.mid.to(device=device, dtype=dtype) - - -class LoHALayer(LoRALayerBase): - # w1_a: torch.Tensor - # w1_b: torch.Tensor - # w2_a: torch.Tensor - # w2_b: torch.Tensor - # t1: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - self.w1_a = values["hada_w1_a"] - self.w1_b = values["hada_w1_b"] - self.w2_a = values["hada_w2_a"] - self.w2_b = values["hada_w2_b"] - - if "hada_t1" in values: - self.t1 = values["hada_t1"] - else: - self.t1 = None - - if "hada_t2" in values: - self.t2 = values["hada_t2"] - else: - self.t2 = None - - self.rank = self.w1_b.shape[0] - - def get_weight(self): - if self.t1 is None: - weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) - - else: - rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) - rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) - weight = rebuild1 * rebuild2 - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - if self.t1 is not None: - self.t1 = self.t1.to(device=device, dtype=dtype) - - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class LoKRLayer(LoRALayerBase): - # w1: Optional[torch.Tensor] = None - # w1_a: Optional[torch.Tensor] = None - # w1_b: Optional[torch.Tensor] = None - # w2: Optional[torch.Tensor] = None - # w2_a: Optional[torch.Tensor] = None - # w2_b: Optional[torch.Tensor] = None - # t2: Optional[torch.Tensor] = None - - def __init__( - self, - layer_key: str, - values: dict, - ): - super().__init__(layer_key, values) - - if "lokr_w1" in values: - self.w1 = values["lokr_w1"] - self.w1_a = None - self.w1_b = None - else: - self.w1 = None - self.w1_a = values["lokr_w1_a"] - self.w1_b = values["lokr_w1_b"] - - if "lokr_w2" in values: - self.w2 = values["lokr_w2"] - self.w2_a = None - self.w2_b = None - else: - self.w2 = None - self.w2_a = values["lokr_w2_a"] - self.w2_b = values["lokr_w2_b"] - - if "lokr_t2" in values: - self.t2 = values["lokr_t2"] - else: - self.t2 = None - - if "lokr_w1_b" in values: - self.rank = values["lokr_w1_b"].shape[0] - elif "lokr_w2_b" in values: - self.rank = values["lokr_w2_b"].shape[0] - else: - self.rank = None # unscaled - - def get_weight(self): - w1 = self.w1 - if w1 is None: - w1 = self.w1_a @ self.w1_b - - w2 = self.w2 - if w2 is None: - if self.t2 is None: - w2 = self.w2_a @ self.w2_b - else: - w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) - - if len(w2.shape) == 4: - w1 = w1.unsqueeze(2).unsqueeze(2) - w2 = w2.contiguous() - weight = torch.kron(w1, w2) - - return weight - - def calc_size(self) -> int: - model_size = super().calc_size() - for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: - if val is not None: - model_size += val.nelement() * val.element_size() - return model_size - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - super().to(device=device, dtype=dtype) - - if self.w1 is not None: - self.w1 = self.w1.to(device=device, dtype=dtype) - else: - self.w1_a = self.w1_a.to(device=device, dtype=dtype) - self.w1_b = self.w1_b.to(device=device, dtype=dtype) - - if self.w2 is not None: - self.w2 = self.w2.to(device=device, dtype=dtype) - else: - self.w2_a = self.w2_a.to(device=device, dtype=dtype) - self.w2_b = self.w2_b.to(device=device, dtype=dtype) - - if self.t2 is not None: - self.t2 = self.t2.to(device=device, dtype=dtype) - - -class LoRAModel: # (torch.nn.Module): - _name: str - layers: Dict[str, LoRALayer] - _device: torch.device - _dtype: torch.dtype - - def __init__( - self, - name: str, - layers: Dict[str, LoRALayer], - device: torch.device, - dtype: torch.dtype, - ): - self._name = name - self._device = device or torch.cpu - self._dtype = dtype or torch.float32 - self.layers = layers - - @property - def name(self): - return self._name - - @property - def device(self): - return self._device - - @property - def dtype(self): - return self._dtype - - def to( - self, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ) -> LoRAModel: - # TODO: try revert if exception? - for key, layer in self.layers.items(): - layer.to(device=device, dtype=dtype) - self._device = device - self._dtype = dtype - - def calc_size(self) -> int: - model_size = 0 - for _, layer in self.layers.items(): - model_size += layer.calc_size() - return model_size - - @classmethod - def from_checkpoint( - cls, - file_path: Union[str, Path], - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - ): - device = device or torch.device("cpu") - dtype = dtype or torch.float32 - - if isinstance(file_path, str): - file_path = Path(file_path) - - model = cls( - device=device, - dtype=dtype, - name=file_path.stem, # TODO: - layers=dict(), - ) - - if file_path.suffix == ".safetensors": - state_dict = load_file(file_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(file_path, map_location="cpu") - - state_dict = cls._group_state(state_dict) - - for layer_key, values in state_dict.items(): - # lora and locon - if "lora_down.weight" in values: - layer = LoRALayer(layer_key, values) - - # loha - elif "hada_w1_b" in values: - layer = LoHALayer(layer_key, values) - - # lokr - elif "lokr_w1_b" in values or "lokr_w1" in values: - layer = LoKRLayer(layer_key, values) - - else: - # TODO: diff/ia3/... format - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") - return - - # lower memory consumption by removing already parsed layer values - state_dict[layer_key].clear() - - layer.to(device=device, dtype=dtype) - model.layers[layer_key] = layer - - return model - - @staticmethod - def _group_state(state_dict: dict): - state_dict_groupped = dict() - - for key, value in state_dict.items(): - stem, leaf = key.split(".", 1) - if stem not in state_dict_groupped: - state_dict_groupped[stem] = dict() - state_dict_groupped[stem][leaf] = value - - return state_dict_groupped - - """ loras = [ (lora_model1, 0.7), @@ -516,6 +101,27 @@ class ModelPatcher: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModel, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te1_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder2( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModel, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te2_"): + yield + @classmethod @contextmanager def apply_lora( diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index b4c3e48a48..71e1ebc0d4 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -28,8 +28,6 @@ import torch import logging import invokeai.backend.util.logging as logger -from invokeai.app.services.config import get_invokeai_config -from .lora import LoRAModel, TextualInversionModel from .models import BaseModelType, ModelType, SubModelType, ModelBase # Maximum size of the cache, in gigs diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 642f8bbeec..0351bf2652 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -1,7 +1,9 @@ import os import torch from enum import Enum -from typing import Optional, Union, Literal +from typing import Optional, Dict, Union, Literal, Any +from pathlib import Path +from safetensors.torch import load_file from .base import ( ModelBase, ModelConfigBase, @@ -13,9 +15,6 @@ from .base import ( ModelNotFoundException, ) -# TODO: naming -from ..lora import LoRAModel as LoRAModelRaw - class LoRAModelFormat(str, Enum): LyCORIS = "lycoris" @@ -50,6 +49,7 @@ class LoRAModel(ModelBase): model = LoRAModelRaw.from_checkpoint( file_path=self.model_path, dtype=torch_dtype, + base_model=self.base_model, ) self.model_size = model.calc_size() @@ -87,3 +87,532 @@ class LoRAModel(ModelBase): raise NotImplementedError("Diffusers lora not supported") else: return model_path + +class LoRALayerBase: + # rank: Optional[int] + # alpha: Optional[float] + # bias: Optional[torch.Tensor] + # layer_key: str + + # @property + # def scale(self): + # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 + + def __init__( + self, + layer_key: str, + values: dict, + ): + if "alpha" in values: + self.alpha = values["alpha"].item() + else: + self.alpha = None + + if "bias_indices" in values and "bias_values" in values and "bias_size" in values: + self.bias = torch.sparse_coo_tensor( + values["bias_indices"], + values["bias_values"], + tuple(values["bias_size"]), + ) + + else: + self.bias = None + + self.rank = None # set in layer implementation + self.layer_key = layer_key + + def forward( + self, + module: torch.nn.Module, + input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure + multiplier: float, + ): + if type(module) == torch.nn.Conv2d: + op = torch.nn.functional.conv2d + extra_args = dict( + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + else: + op = torch.nn.functional.linear + extra_args = {} + + weight = self.get_weight() + + bias = self.bias if self.bias is not None else 0 + scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 + return ( + op( + *input_h, + (weight + bias).view(module.weight.shape), + None, + **extra_args, + ) + * multiplier + * scale + ) + + def get_weight(self): + raise NotImplementedError() + + def calc_size(self) -> int: + model_size = 0 + for val in [self.bias]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + if self.bias is not None: + self.bias = self.bias.to(device=device, dtype=dtype) + +# TODO: find and debug lora/locon with bias +class LoRALayer(LoRALayerBase): + # up: torch.Tensor + # mid: Optional[torch.Tensor] + # down: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.up = values["lora_up.weight"] + self.down = values["lora_down.weight"] + if "lora_mid.weight" in values: + self.mid = values["lora_mid.weight"] + else: + self.mid = None + + self.rank = self.down.shape[0] + + def get_weight(self): + if self.mid is not None: + up = self.up.reshape(self.up.shape[0], self.up.shape[1]) + down = self.down.reshape(self.down.shape[0], self.down.shape[1]) + weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) + else: + weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.up, self.mid, self.down]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.up = self.up.to(device=device, dtype=dtype) + self.down = self.down.to(device=device, dtype=dtype) + + if self.mid is not None: + self.mid = self.mid.to(device=device, dtype=dtype) + +class LoHALayer(LoRALayerBase): + # w1_a: torch.Tensor + # w1_b: torch.Tensor + # w2_a: torch.Tensor + # w2_b: torch.Tensor + # t1: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.w1_a = values["hada_w1_a"] + self.w1_b = values["hada_w1_b"] + self.w2_a = values["hada_w2_a"] + self.w2_b = values["hada_w2_b"] + + if "hada_t1" in values: + self.t1 = values["hada_t1"] + else: + self.t1 = None + + if "hada_t2" in values: + self.t2 = values["hada_t2"] + else: + self.t2 = None + + self.rank = self.w1_b.shape[0] + + def get_weight(self): + if self.t1 is None: + weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) + + else: + rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) + rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) + weight = rebuild1 * rebuild2 + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + if self.t1 is not None: + self.t1 = self.t1.to(device=device, dtype=dtype) + + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + +class LoKRLayer(LoRALayerBase): + # w1: Optional[torch.Tensor] = None + # w1_a: Optional[torch.Tensor] = None + # w1_b: Optional[torch.Tensor] = None + # w2: Optional[torch.Tensor] = None + # w2_a: Optional[torch.Tensor] = None + # w2_b: Optional[torch.Tensor] = None + # t2: Optional[torch.Tensor] = None + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + if "lokr_w1" in values: + self.w1 = values["lokr_w1"] + self.w1_a = None + self.w1_b = None + else: + self.w1 = None + self.w1_a = values["lokr_w1_a"] + self.w1_b = values["lokr_w1_b"] + + if "lokr_w2" in values: + self.w2 = values["lokr_w2"] + self.w2_a = None + self.w2_b = None + else: + self.w2 = None + self.w2_a = values["lokr_w2_a"] + self.w2_b = values["lokr_w2_b"] + + if "lokr_t2" in values: + self.t2 = values["lokr_t2"] + else: + self.t2 = None + + if "lokr_w1_b" in values: + self.rank = values["lokr_w1_b"].shape[0] + elif "lokr_w2_b" in values: + self.rank = values["lokr_w2_b"].shape[0] + else: + self.rank = None # unscaled + + def get_weight(self): + w1 = self.w1 + if w1 is None: + w1 = self.w1_a @ self.w1_b + + w2 = self.w2 + if w2 is None: + if self.t2 is None: + w2 = self.w2_a @ self.w2_b + else: + w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + weight = torch.kron(w1, w2) + + return weight + + def calc_size(self) -> int: + model_size = super().calc_size() + for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: + if val is not None: + model_size += val.nelement() * val.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + if self.w1 is not None: + self.w1 = self.w1.to(device=device, dtype=dtype) + else: + self.w1_a = self.w1_a.to(device=device, dtype=dtype) + self.w1_b = self.w1_b.to(device=device, dtype=dtype) + + if self.w2 is not None: + self.w2 = self.w2.to(device=device, dtype=dtype) + else: + self.w2_a = self.w2_a.to(device=device, dtype=dtype) + self.w2_b = self.w2_b.to(device=device, dtype=dtype) + + if self.t2 is not None: + self.t2 = self.t2.to(device=device, dtype=dtype) + +# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix +class LoRAModelRaw: # (torch.nn.Module): + _name: str + layers: Dict[str, LoRALayer] + _device: torch.device + _dtype: torch.dtype + + def __init__( + self, + name: str, + layers: Dict[str, LoRALayer], + device: torch.device, + dtype: torch.dtype, + ): + self._name = name + self._device = device or torch.cpu + self._dtype = dtype or torch.float32 + self.layers = layers + + @property + def name(self): + return self._name + + @property + def device(self): + return self._device + + @property + def dtype(self): + return self._dtype + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + # TODO: try revert if exception? + for key, layer in self.layers.items(): + layer.to(device=device, dtype=dtype) + self._device = device + self._dtype = dtype + + def calc_size(self) -> int: + model_size = 0 + for _, layer in self.layers.items(): + model_size += layer.calc_size() + return model_size + + @classmethod + def _convert_sdxl_compvis_keys(cls, state_dict): + new_state_dict = dict() + for full_key, value in state_dict.items(): + if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): + continue # clip same + + if not full_key.startswith("lora_unet_"): + raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}") + src_key = full_key.replace("lora_unet_", "") + try: + dst_key = None + while "_" in src_key: + if src_key in SDXL_UNET_COMPVIS_MAP: + dst_key = SDXL_UNET_COMPVIS_MAP[src_key] + break + src_key = "_".join(src_key.split('_')[:-1]) + + if dst_key is None: + raise Exception(f"Unknown sdxl lora key - {full_key}") + new_key = full_key.replace(src_key, dst_key) + except: + print(SDXL_UNET_COMPVIS_MAP) + raise + new_state_dict[new_key] = value + return new_state_dict + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + base_model: Optional[BaseModelType] = None, + ): + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + if isinstance(file_path, str): + file_path = Path(file_path) + + model = cls( + device=device, + dtype=dtype, + name=file_path.stem, # TODO: + layers=dict(), + ) + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + state_dict = cls._group_state(state_dict) + + if base_model == BaseModelType.StableDiffusionXL: + state_dict = cls._convert_sdxl_compvis_keys(state_dict) + + for layer_key, values in state_dict.items(): + # lora and locon + if "lora_down.weight" in values: + layer = LoRALayer(layer_key, values) + + # loha + elif "hada_w1_b" in values: + layer = LoHALayer(layer_key, values) + + # lokr + elif "lokr_w1_b" in values or "lokr_w1" in values: + layer = LoKRLayer(layer_key, values) + + else: + # TODO: diff/ia3/... format + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") + return + + # lower memory consumption by removing already parsed layer values + state_dict[layer_key].clear() + + layer.to(device=device, dtype=dtype) + model.layers[layer_key] = layer + + return model + + @staticmethod + def _group_state(state_dict: dict): + state_dict_groupped = dict() + + for key, value in state_dict.items(): + stem, leaf = key.split(".", 1) + if stem not in state_dict_groupped: + state_dict_groupped[stem] = dict() + state_dict_groupped[stem][leaf] = value + + return state_dict_groupped + + +def make_sdxl_unet_conversion_map(): + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + +#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} +SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}