Add support for LyCORIS IA3 format (#4234)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [ ] Yes
- [x] No

      
## Have you updated all relevant documentation?
- [ ] Yes
- [x] No


## Description
Add support for LyCORIS IA3 format

## Related Tickets & Documents
- Closes #4229 

## Added/updated tests?

- [ ] Yes
- [x] No
This commit is contained in:
StAlKeR7779 2023-08-11 03:30:35 +03:00 committed by GitHub
commit 591838a84b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 56 deletions

View File

@ -143,7 +143,7 @@ class ModelPatcher:
# with torch.autocast(device_type="cpu"): # with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32) layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris # TODO: debug on lycoris
@ -361,7 +361,8 @@ class ONNXModelPatcher:
layer.to(dtype=torch.float32) layer.to(dtype=torch.float32)
layer_key = layer_key.replace(prefix, "") layer_key = layer_key.replace(prefix, "")
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight # TODO: rewrite to pass original tensor weight(required by ia3)
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
if layer_key is blended_loras: if layer_key is blended_loras:
blended_loras[layer_key] += layer_weight blended_loras[layer_key] += layer_weight
else: else:

View File

@ -122,41 +122,7 @@ class LoRALayerBase:
self.rank = None # set in layer implementation self.rank = None # set in layer implementation
self.layer_key = layer_key self.layer_key = layer_key
def forward( def get_weight(self, orig_weight: torch.Tensor):
self,
module: torch.nn.Module,
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
multiplier: float,
):
if type(module) == torch.nn.Conv2d:
op = torch.nn.functional.conv2d
extra_args = dict(
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
)
else:
op = torch.nn.functional.linear
extra_args = {}
weight = self.get_weight()
bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
return (
op(
*input_h,
(weight + bias).view(module.weight.shape),
None,
**extra_args,
)
* multiplier
* scale
)
def get_weight(self):
raise NotImplementedError() raise NotImplementedError()
def calc_size(self) -> int: def calc_size(self) -> int:
@ -197,7 +163,7 @@ class LoRALayer(LoRALayerBase):
self.rank = self.down.shape[0] self.rank = self.down.shape[0]
def get_weight(self): def get_weight(self, orig_weight: torch.Tensor):
if self.mid is not None: if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1]) up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@ -260,7 +226,7 @@ class LoHALayer(LoRALayerBase):
self.rank = self.w1_b.shape[0] self.rank = self.w1_b.shape[0]
def get_weight(self): def get_weight(self, orig_weight: torch.Tensor):
if self.t1 is None: if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -342,7 +308,7 @@ class LoKRLayer(LoRALayerBase):
else: else:
self.rank = None # unscaled self.rank = None # unscaled
def get_weight(self): def get_weight(self, orig_weight: torch.Tensor):
w1 = self.w1 w1 = self.w1
if w1 is None: if w1 is None:
w1 = self.w1_a @ self.w1_b w1 = self.w1_a @ self.w1_b
@ -410,7 +376,7 @@ class FullLayer(LoRALayerBase):
self.rank = None # unscaled self.rank = None # unscaled
def get_weight(self): def get_weight(self, orig_weight: torch.Tensor):
return self.weight return self.weight
def calc_size(self) -> int: def calc_size(self) -> int:
@ -428,6 +394,45 @@ class FullLayer(LoRALayerBase):
self.weight = self.weight.to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype)
class IA3Layer(LoRALayerBase):
# weight: torch.Tensor
# on_input: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["weight"]
self.on_input = values["on_input"]
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
return orig_weight * weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module): class LoRAModelRaw: # (torch.nn.Module):
_name: str _name: str
@ -547,11 +552,15 @@ class LoRAModelRaw: # (torch.nn.Module):
elif "lokr_w1_b" in values or "lokr_w1" in values: elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values) layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values: elif "diff" in values:
layer = FullLayer(layer_key, values) layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)
else: else:
# TODO: ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!") raise Exception("Unknown lora format!")

View File

@ -12,37 +12,43 @@ def lora_token_vector_length(checkpoint: dict) -> int:
def _get_shape_1(key, tensor, checkpoint): def _get_shape_1(key, tensor, checkpoint):
lora_token_vector_length = None lora_token_vector_length = None
if "." not in key:
return lora_token_vector_length # wrong key format
model_key, lora_key = key.split(".", 1)
# check lora/locon # check lora/locon
if ".lora_down.weight" in key: if lora_key == "lora_down.weight":
lora_token_vector_length = tensor.shape[1] lora_token_vector_length = tensor.shape[1]
# check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes)
elif ".hada_w1_b" in key or ".hada_w2_b" in key: elif lora_key in ["hada_w1_b", "hada_w2_b"]:
lora_token_vector_length = tensor.shape[1] lora_token_vector_length = tensor.shape[1]
# check lokr (don't worry about lokr_t2 as it used only in 4d shapes) # check lokr (don't worry about lokr_t2 as it used only in 4d shapes)
elif ".lokr_" in key: elif "lokr_" in lora_key:
_lokr_key = key.split(".")[0] if model_key + ".lokr_w1" in checkpoint:
_lokr_w1 = checkpoint[model_key + ".lokr_w1"]
if _lokr_key + ".lokr_w1" in checkpoint: elif model_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"] _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"]
elif _lokr_key + "lokr_w1_b" in checkpoint:
_lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"]
else: else:
return lora_token_vector_length # unknown format return lora_token_vector_length # unknown format
if _lokr_key + ".lokr_w2" in checkpoint: if model_key + ".lokr_w2" in checkpoint:
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"] _lokr_w2 = checkpoint[model_key + ".lokr_w2"]
elif _lokr_key + "lokr_w2_b" in checkpoint: elif model_key + "lokr_w2_b" in checkpoint:
_lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"] _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"]
else: else:
return lora_token_vector_length # unknown format return lora_token_vector_length # unknown format
lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1]
elif ".diff" in key: elif lora_key == "diff":
lora_token_vector_length = tensor.shape[1] lora_token_vector_length = tensor.shape[1]
# ia3 can be detected only by shape[0] in text encoder
elif lora_key == "weight" and "lora_unet_" not in model_key:
lora_token_vector_length = tensor.shape[0]
return lora_token_vector_length return lora_token_vector_length
lora_token_vector_length = None lora_token_vector_length = None