Fix merge conflict resolve - restore full/diff layer support

This commit is contained in:
Sergey Borisov 2023-08-04 19:53:27 +03:00
parent 0e9f92b868
commit f0613bb0ef

View File

@ -391,6 +391,43 @@ class LoKRLayer(LoRALayerBase):
self.t2 = self.t2.to(device=device, dtype=dtype)
class FullLayer(LoRALayerBase):
# weight: torch.Tensor
def __init__(
self,
layer_key: str,
values: dict,
):
super().__init__(layer_key, values)
self.weight = values["diff"]
if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.rank = None # unscaled
def get_weight(self):
return self.weight
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
@ -510,10 +547,13 @@ class LoRAModelRaw: # (torch.nn.Module):
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
elif "diff" in values:
layer = FullLayer(layer_key, values)
else:
# TODO: diff/ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
return
# TODO: ia3/... format
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
@ -536,6 +576,8 @@ class LoRAModelRaw: # (torch.nn.Module):
return state_dict_groupped
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map():
unet_conversion_map_layer = []
@ -620,7 +662,6 @@ def make_sdxl_unet_conversion_map():
return unet_conversion_map
# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
SDXL_UNET_COMPVIS_MAP = {
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
for sd, hf in make_sdxl_unet_conversion_map()