add sdxl lora support

This commit is contained in:
Sergey Borisov 2023-07-31 23:18:02 +03:00 committed by Kent Keirsey
parent cfc3a20565
commit 1ac14a1e43
8 changed files with 683 additions and 438 deletions

View File

@ -173,7 +173,7 @@ class CompelInvocation(BaseInvocation):
class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
def run_clip_raw(self, context, clip_field, prompt, get_pooled, lora_prefix):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
@ -210,8 +210,8 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
@ -247,7 +247,7 @@ class SDXLPromptInvocationBase:
return c, c_pooled, None
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
def run_clip_compel(self, context, clip_field, prompt, get_pooled, lora_prefix):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
@ -284,8 +284,8 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
@ -357,11 +357,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False)
c1, c1_pooled, ec1 = self.run_clip_compel(context, self.clip, self.prompt, False, "lora_te1_")
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True)
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.prompt, True, "lora_te2_")
else:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "lora_te2_")
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@ -415,7 +415,8 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True)
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_compel(context, self.clip2, self.style, True, "<NONE>")
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@ -467,11 +468,11 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False)
c1, c1_pooled, ec1 = self.run_clip_raw(context, self.clip, self.prompt, False, "lora_te1_")
if self.style.strip() == "":
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True)
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.prompt, True, "lora_te2_")
else:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "lora_te2_")
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
@ -525,7 +526,8 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True)
# TODO: if there will appear lora for refiner - write proper prefix
c2, c2_pooled, ec2 = self.run_clip_raw(context, self.clip2, self.style, True, "<NONE>")
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)

View File

@ -14,7 +14,7 @@ from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,

View File

@ -262,6 +262,103 @@ class LoraLoaderInvocation(BaseInvocation):
return output
class SDXLLoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output"
unet: Optional[UNetField] = Field(default=None, description="UNet submodel")
clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels")
# fmt: on
class SDXLLoraLoaderInvocation(BaseInvocation):
"""Apply selected lora to unet and text_encoder."""
type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader"
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
clip: Optional[ClipField] = Field(description="Clip model for applying lora")
clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "SDXL Lora Loader",
"tags": ["lora", "loader"],
"type_hints": {"lora": "lora_model"},
},
}
def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput:
if self.lora is None:
raise Exception("No LoRA provided")
base_model = self.lora.base_model
lora_name = self.lora.model_name
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip')
if self.clip2 is not None and any(lora.model_name == lora_name for lora in self.clip2.loras):
raise Exception(f'Lora "{lora_name}" already applied to clip2')
output = SDXLLoraLoaderOutput()
if self.unet is not None:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip is not None:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
if self.clip2 is not None:
output.clip2 = copy.deepcopy(self.clip2)
output.clip2.loras.append(
LoraInfo(
base_model=base_model,
model_name=lora_name,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
)
return output
class VAEModelField(BaseModel):
"""Vae model field"""

View File

@ -5,7 +5,7 @@ from typing import List, Literal, Optional, Union
from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType
from ...backend.model_management import ModelType, SubModelType, ModelPatcher
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
@ -293,10 +293,22 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
num_inference_steps = self.steps
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps

View File

@ -13,3 +13,4 @@ from .models import (
DuplicateModelException,
)
from .model_merge import ModelMerger, MergeInterpolationMethod
from .lora import ModelPatcher

View File

@ -23,421 +23,6 @@ from transformers import CLIPTextModel, CLIPTokenizer
# TODO: rename and split this file
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoRAModel: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> LoRAModel:
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
else:
# TODO: diff/ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
return
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
"""
loras = [
(lora_model1, 0.7),
@ -516,6 +101,27 @@ class ModelPatcher:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@classmethod
@contextmanager
def apply_lora(

View File

@ -28,8 +28,6 @@ import torch
import logging
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config
from .lora import LoRAModel, TextualInversionModel
from .models import BaseModelType, ModelType, SubModelType, ModelBase
# Maximum size of the cache, in gigs

View File

@ -1,7 +1,9 @@
import os
import torch
from enum import Enum
from typing import Optional, Union, Literal
from typing import Optional, Dict, Union, Literal, Any
from pathlib import Path
from safetensors.torch import load_file
from .base import (
ModelBase,
ModelConfigBase,
@ -13,9 +15,6 @@ from .base import (
ModelNotFoundException,
)
# TODO: naming
from ..lora import LoRAModel as LoRAModelRaw
class LoRAModelFormat(str, Enum):
LyCORIS = "lycoris"
@ -50,6 +49,7 @@ class LoRAModel(ModelBase):
model = LoRAModelRaw.from_checkpoint(
file_path=self.model_path,
dtype=torch_dtype,
base_model=self.base_model,
)
self.model_size = model.calc_size()
@ -87,3 +87,532 @@ class LoRAModel(ModelBase):
raise NotImplementedError("Diffusers lora not supported")
else:
return model_path
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
# bias: Optional[torch.Tensor]
# layer_key: str
# @property
# def scale(self):
# return self.alpha / self.rank if (self.alpha and self.rank) else 1.0
def __init__(
self,
layer_key: str,
values: dict,
):
if "alpha" in values:
self.alpha = values["alpha"].item()
else:
self.alpha = None
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
self.bias = torch.sparse_coo_tensor(
values["bias_indices"],
values["bias_values"],
tuple(values["bias_size"]),
)
else:
self.bias = None
self.rank = None # set in layer implementation
self.layer_key = layer_key
def forward(
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
# mid: Optional[torch.Tensor]
# down: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.up = values["lora_up.weight"]
self.down = values["lora_down.weight"]
if "lora_mid.weight" in values:
self.mid = values["lora_mid.weight"]
else:
self.mid = None
self.rank = self.down.shape[0]
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
else:
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
# w2_a: torch.Tensor
# w2_b: torch.Tensor
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
self.w2_a = values["hada_w2_a"]
self.w2_b = values["hada_w2_b"]
if "hada_t1" in values:
self.t1 = values["hada_t1"]
else:
self.t1 = None
if "hada_t2" in values:
self.t2 = values["hada_t2"]
else:
self.t2 = None
self.rank = self.w1_b.shape[0]
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
else:
rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a)
rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a)
weight = rebuild1 * rebuild2
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
# w1_b: Optional[torch.Tensor] = None
# w2: Optional[torch.Tensor] = None
# w2_a: Optional[torch.Tensor] = None
# w2_b: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
self.w1_a = None
self.w1_b = None
else:
self.w1 = None
self.w1_a = values["lokr_w1_a"]
self.w1_b = values["lokr_w1_b"]
if "lokr_w2" in values:
self.w2 = values["lokr_w2"]
self.w2_a = None
self.w2_b = None
else:
self.w2 = None
self.w2_a = values["lokr_w2_a"]
self.w2_b = values["lokr_w2_b"]
if "lokr_t2" in values:
self.t2 = values["lokr_t2"]
else:
self.t2 = None
if "lokr_w1_b" in values:
self.rank = values["lokr_w1_b"].shape[0]
elif "lokr_w2_b" in values:
self.rank = values["lokr_w2_b"].shape[0]
else:
self.rank = None # unscaled
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
w2 = self.w2
if w2 is None:
if self.t2 is None:
w2 = self.w2_a @ self.w2_b
else:
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2)
return weight
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
if val is not None:
model_size += val.nelement() * val.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
if self.w1 is not None:
self.w1 = self.w1.to(device=device, dtype=dtype)
else:
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
else:
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
for _, layer in self.layers.items():
model_size += layer.calc_size()
return model_size
@classmethod
def _convert_sdxl_compvis_keys(cls, state_dict):
new_state_dict = dict()
for full_key, value in state_dict.items():
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
continue # clip same
if not full_key.startswith("lora_unet_"):
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
src_key = full_key.replace("lora_unet_", "")
try:
dst_key = None
while "_" in src_key:
if src_key in SDXL_UNET_COMPVIS_MAP:
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
break
src_key = "_".join(src_key.split('_')[:-1])
if dst_key is None:
raise Exception(f"Unknown sdxl lora key - {full_key}")
new_key = full_key.replace(src_key, dst_key)
except:
print(SDXL_UNET_COMPVIS_MAP)
raise
new_state_dict[new_key] = value
return new_state_dict
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
base_model: Optional[BaseModelType] = None,
):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
if isinstance(file_path, str):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_compvis_keys(state_dict)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
else:
# TODO: diff/ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
return
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
model.layers[layer_key] = layer
return model
@staticmethod
def _group_state(state_dict: dict):
state_dict_groupped = dict()
for key, value in state_dict.items():
stem, leaf = key.split(".", 1)
if stem not in state_dict_groupped:
state_dict_groupped[stem] = dict()
state_dict_groupped[stem][leaf] = value
return state_dict_groupped
def make_sdxl_unet_conversion_map():
unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
# if i > 0: commentout for sdxl
# no attention layers in up_blocks.0
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
# no upsample in up_blocks.3
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
hf_mid_atn_prefix = "mid_block.attentions.0."
sd_mid_atn_prefix = "middle_block.1."
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f"mid_block.resnets.{j}."
sd_mid_res_prefix = f"middle_block.{2*j}."
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
("in_layers.0.", "norm1."),
("in_layers.2.", "conv1."),
("out_layers.0.", "norm2."),
("out_layers.3.", "conv2."),
("emb_layers.1.", "time_emb_proj."),
("skip_connection.", "conv_shortcut."),
]
unet_conversion_map = []
for sd, hf in unet_conversion_map_layer:
if "resnets" in hf:
for sd_res, hf_res in unet_conversion_map_resnet:
unet_conversion_map.append((sd + sd_res, hf + hf_res))
else:
unet_conversion_map.append((sd, hf))
for j in range(2):
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
sd_time_embed_prefix = f"time_embed.{j*2}."
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
for j in range(2):
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
sd_label_embed_prefix = f"label_emb.0.{j*2}."
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
unet_conversion_map.append(("out.0.", "conv_norm_out."))
unet_conversion_map.append(("out.2.", "conv_out."))
return unet_conversion_map
#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}