model probe detects sdxl lora models

This commit is contained in:
Lincoln Stein
2023-08-03 10:26:52 -04:00
committed by Kent Keirsey
parent 1ac14a1e43
commit 1d5d187ba1
5 changed files with 39 additions and 14 deletions

View File

@ -88,6 +88,7 @@ class LoRAModel(ModelBase):
else:
return model_path
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
@ -173,6 +174,7 @@ class LoRALayerBase:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase):
# up: torch.Tensor
@ -225,6 +227,7 @@ class LoRALayer(LoRALayerBase):
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor
# w1_b: torch.Tensor
@ -292,6 +295,7 @@ class LoHALayer(LoRALayerBase):
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None
@ -386,6 +390,7 @@ class LoKRLayer(LoRALayerBase):
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
@ -439,7 +444,7 @@ class LoRAModelRaw: # (torch.nn.Module):
new_state_dict = dict()
for full_key, value in state_dict.items():
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
continue # clip same
continue # clip same
if not full_key.startswith("lora_unet_"):
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
@ -450,7 +455,7 @@ class LoRAModelRaw: # (torch.nn.Module):
if src_key in SDXL_UNET_COMPVIS_MAP:
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
break
src_key = "_".join(src_key.split('_')[:-1])
src_key = "_".join(src_key.split("_")[:-1])
if dst_key is None:
raise Exception(f"Unknown sdxl lora key - {full_key}")
@ -614,5 +619,9 @@ def make_sdxl_unet_conversion_map():
return unet_conversion_map
#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
SDXL_UNET_COMPVIS_MAP = {
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
for sd, hf in make_sdxl_unet_conversion_map()
}