mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix merge conflict resolve - restore full/diff layer support
This commit is contained in:
parent
0e9f92b868
commit
f0613bb0ef
@ -391,6 +391,43 @@ class LoKRLayer(LoRALayerBase):
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
|
||||
if len(values.keys()) > 1:
|
||||
_keys = list(values.keys())
|
||||
_keys.remove("diff")
|
||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
_name: str
|
||||
@ -510,10 +547,13 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
||||
return
|
||||
# TODO: ia3/... format
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
@ -536,6 +576,8 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
return state_dict_groupped
|
||||
|
||||
|
||||
# code from
|
||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||
def make_sdxl_unet_conversion_map():
|
||||
unet_conversion_map_layer = []
|
||||
|
||||
@ -620,7 +662,6 @@ def make_sdxl_unet_conversion_map():
|
||||
return unet_conversion_map
|
||||
|
||||
|
||||
# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
|
||||
SDXL_UNET_COMPVIS_MAP = {
|
||||
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||
for sd, hf in make_sdxl_unet_conversion_map()
|
||||
|
Loading…
Reference in New Issue
Block a user