mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for LyCORIS IA3 format
This commit is contained in:
parent
2ef6a8995b
commit
56023bc725
@ -143,7 +143,7 @@ class ModelPatcher:
|
|||||||
# with torch.autocast(device_type="cpu"):
|
# with torch.autocast(device_type="cpu"):
|
||||||
layer.to(dtype=torch.float32)
|
layer.to(dtype=torch.float32)
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
layer_weight = layer.get_weight() * lora_weight * layer_scale
|
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
|
||||||
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
if module.weight.shape != layer_weight.shape:
|
||||||
# TODO: debug on lycoris
|
# TODO: debug on lycoris
|
||||||
@ -361,7 +361,8 @@ class ONNXModelPatcher:
|
|||||||
|
|
||||||
layer.to(dtype=torch.float32)
|
layer.to(dtype=torch.float32)
|
||||||
layer_key = layer_key.replace(prefix, "")
|
layer_key = layer_key.replace(prefix, "")
|
||||||
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
# TODO: rewrite to pass original tensor weight(required by ia3)
|
||||||
|
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
|
||||||
if layer_key is blended_loras:
|
if layer_key is blended_loras:
|
||||||
blended_loras[layer_key] += layer_weight
|
blended_loras[layer_key] += layer_weight
|
||||||
else:
|
else:
|
||||||
|
@ -122,41 +122,7 @@ class LoRALayerBase:
|
|||||||
self.rank = None # set in layer implementation
|
self.rank = None # set in layer implementation
|
||||||
self.layer_key = layer_key
|
self.layer_key = layer_key
|
||||||
|
|
||||||
def forward(
|
def get_weight(self, orig_weight: torch.Tensor):
|
||||||
self,
|
|
||||||
module: torch.nn.Module,
|
|
||||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
|
||||||
multiplier: float,
|
|
||||||
):
|
|
||||||
if type(module) == torch.nn.Conv2d:
|
|
||||||
op = torch.nn.functional.conv2d
|
|
||||||
extra_args = dict(
|
|
||||||
stride=module.stride,
|
|
||||||
padding=module.padding,
|
|
||||||
dilation=module.dilation,
|
|
||||||
groups=module.groups,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
op = torch.nn.functional.linear
|
|
||||||
extra_args = {}
|
|
||||||
|
|
||||||
weight = self.get_weight()
|
|
||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
|
||||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
|
||||||
return (
|
|
||||||
op(
|
|
||||||
*input_h,
|
|
||||||
(weight + bias).view(module.weight.shape),
|
|
||||||
None,
|
|
||||||
**extra_args,
|
|
||||||
)
|
|
||||||
* multiplier
|
|
||||||
* scale
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_weight(self):
|
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -197,7 +163,7 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self, orig_weight: torch.Tensor):
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
@ -260,7 +226,7 @@ class LoHALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self, orig_weight: torch.Tensor):
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
@ -342,7 +308,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self, orig_weight: torch.Tensor):
|
||||||
w1 = self.w1
|
w1 = self.w1
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
w1 = self.w1_a @ self.w1_b
|
w1 = self.w1_a @ self.w1_b
|
||||||
@ -410,7 +376,7 @@ class FullLayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self, orig_weight: torch.Tensor):
|
||||||
return self.weight
|
return self.weight
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -427,6 +393,44 @@ class FullLayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
class IA3Layer(LoRALayerBase):
|
||||||
|
# weight: torch.Tensor
|
||||||
|
# on_input: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["weight"]
|
||||||
|
self.on_input = values["on_input"]
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor):
|
||||||
|
weight = self.weight
|
||||||
|
if not self.on_input:
|
||||||
|
weight = weight.reshape(-1, 1)
|
||||||
|
return orig_weight * weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||||
class LoRAModelRaw: # (torch.nn.Module):
|
class LoRAModelRaw: # (torch.nn.Module):
|
||||||
@ -547,11 +551,15 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||||
layer = LoKRLayer(layer_key, values)
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
|
# diff
|
||||||
elif "diff" in values:
|
elif "diff" in values:
|
||||||
layer = FullLayer(layer_key, values)
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
|
# ia3
|
||||||
|
elif "weight" in values and "on_input" in values:
|
||||||
|
layer = IA3Layer(layer_key, values)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# TODO: ia3/... format
|
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
raise Exception("Unknown lora format!")
|
raise Exception("Unknown lora format!")
|
||||||
|
|
||||||
|
@ -12,37 +12,43 @@ def lora_token_vector_length(checkpoint: dict) -> int:
|
|||||||
def _get_shape_1(key, tensor, checkpoint):
|
def _get_shape_1(key, tensor, checkpoint):
|
||||||
lora_token_vector_length = None
|
lora_token_vector_length = None
|
||||||
|
|
||||||
|
if "." not in key:
|
||||||
|
return lora_token_vector_length # wrong key format
|
||||||
|
model_key, lora_key = key.split(".", 1)
|
||||||
|
|
||||||
# check lora/locon
|
# check lora/locon
|
||||||
if ".lora_down.weight" in key:
|
if lora_key == "lora_down.weight":
|
||||||
lora_token_vector_length = tensor.shape[1]
|
lora_token_vector_length = tensor.shape[1]
|
||||||
|
|
||||||
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
|
||||||
elif ".hada_w1_b" in key or ".hada_w2_b" in key:
|
elif lora_key in ["hada_w1_b", "hada_w2_b"]:
|
||||||
lora_token_vector_length = tensor.shape[1]
|
lora_token_vector_length = tensor.shape[1]
|
||||||
|
|
||||||
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
|
||||||
elif ".lokr_" in key:
|
elif "lokr_" in lora_key:
|
||||||
_lokr_key = key.split(".")[0]
|
if model_key + ".lokr_w1" in checkpoint:
|
||||||
|
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
|
||||||
if _lokr_key + ".lokr_w1" in checkpoint:
|
elif model_key + "lokr_w1_b" in checkpoint:
|
||||||
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"]
|
_lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
|
||||||
elif _lokr_key + "lokr_w1_b" in checkpoint:
|
|
||||||
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"]
|
|
||||||
else:
|
else:
|
||||||
return lora_token_vector_length # unknown format
|
return lora_token_vector_length # unknown format
|
||||||
|
|
||||||
if _lokr_key + ".lokr_w2" in checkpoint:
|
if model_key + ".lokr_w2" in checkpoint:
|
||||||
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"]
|
_lokr_w2 = checkpoint[model_key + ".lokr_w2"]
|
||||||
elif _lokr_key + "lokr_w2_b" in checkpoint:
|
elif model_key + "lokr_w2_b" in checkpoint:
|
||||||
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"]
|
_lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
|
||||||
else:
|
else:
|
||||||
return lora_token_vector_length # unknown format
|
return lora_token_vector_length # unknown format
|
||||||
|
|
||||||
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
|
||||||
|
|
||||||
elif ".diff" in key:
|
elif lora_key == "diff":
|
||||||
lora_token_vector_length = tensor.shape[1]
|
lora_token_vector_length = tensor.shape[1]
|
||||||
|
|
||||||
|
# ia3 can be detected only by shape[0] in text encoder
|
||||||
|
elif lora_key == "weight" and "lora_unet_" not in model_key:
|
||||||
|
lora_token_vector_length = tensor.shape[0]
|
||||||
|
|
||||||
return lora_token_vector_length
|
return lora_token_vector_length
|
||||||
|
|
||||||
lora_token_vector_length = None
|
lora_token_vector_length = None
|
||||||
|
Loading…
Reference in New Issue
Block a user