model probe detects sdxl lora models

This commit is contained in:
Lincoln Stein 2023-08-03 10:26:52 -04:00 committed by Kent Keirsey
parent 1ac14a1e43
commit 1d5d187ba1
5 changed files with 39 additions and 14 deletions

View File

@ -306,9 +306,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
do_classifier_free_guidance = True do_classifier_free_guidance = True
cross_attention_kwargs = None cross_attention_kwargs = None
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device) scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps

View File

@ -101,7 +101,6 @@ class ModelPatcher:
with cls.apply_lora(text_encoder, loras, "lora_te_"): with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield yield
@classmethod @classmethod
@contextmanager @contextmanager
def apply_sdxl_lora_text_encoder( def apply_sdxl_lora_text_encoder(

View File

@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint checkpoint = self.checkpoint
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
# misclassified as SD-1
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
if key in checkpoint and checkpoint[key].shape[0] == 320:
return BaseModelType.StableDiffusion2
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
if key in checkpoint:
return BaseModelType.StableDiffusionXL
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
lora_token_vector_length = ( lora_token_vector_length = (
checkpoint[key1].shape[1] checkpoint[key1].shape[1]
if key1 in checkpoint if key1 in checkpoint
else checkpoint[key2].shape[0] else checkpoint[key2].shape[1]
if key2 in checkpoint if key2 in checkpoint
else 768 else checkpoint[key3].shape[0]
if key3 in checkpoint
else None
) )
if lora_token_vector_length == 768: if lora_token_vector_length == 768:
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
elif lora_token_vector_length == 1024: elif lora_token_vector_length == 1024:
return BaseModelType.StableDiffusion2 return BaseModelType.StableDiffusion2
else: else:
return None raise InvalidModelException(f"Unknown LoRA type")
class TextualInversionCheckpointProbe(CheckpointProbeBase): class TextualInversionCheckpointProbe(CheckpointProbeBase):

View File

@ -88,6 +88,7 @@ class LoRAModel(ModelBase):
else: else:
return model_path return model_path
class LoRALayerBase: class LoRALayerBase:
# rank: Optional[int] # rank: Optional[int]
# alpha: Optional[float] # alpha: Optional[float]
@ -173,6 +174,7 @@ class LoRALayerBase:
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype) self.bias = self.bias.to(device=device, dtype=dtype)
# TODO: find and debug lora/locon with bias # TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase): class LoRALayer(LoRALayerBase):
# up: torch.Tensor # up: torch.Tensor
@ -225,6 +227,7 @@ class LoRALayer(LoRALayerBase):
if self.mid is not None: if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype) self.mid = self.mid.to(device=device, dtype=dtype)
class LoHALayer(LoRALayerBase): class LoHALayer(LoRALayerBase):
# w1_a: torch.Tensor # w1_a: torch.Tensor
# w1_b: torch.Tensor # w1_b: torch.Tensor
@ -292,6 +295,7 @@ class LoHALayer(LoRALayerBase):
if self.t2 is not None: if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype) self.t2 = self.t2.to(device=device, dtype=dtype)
class LoKRLayer(LoRALayerBase): class LoKRLayer(LoRALayerBase):
# w1: Optional[torch.Tensor] = None # w1: Optional[torch.Tensor] = None
# w1_a: Optional[torch.Tensor] = None # w1_a: Optional[torch.Tensor] = None
@ -386,6 +390,7 @@ class LoKRLayer(LoRALayerBase):
if self.t2 is not None: if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype) self.t2 = self.t2.to(device=device, dtype=dtype)
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module): class LoRAModelRaw: # (torch.nn.Module):
_name: str _name: str
@ -439,7 +444,7 @@ class LoRAModelRaw: # (torch.nn.Module):
new_state_dict = dict() new_state_dict = dict()
for full_key, value in state_dict.items(): for full_key, value in state_dict.items():
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
continue # clip same continue # clip same
if not full_key.startswith("lora_unet_"): if not full_key.startswith("lora_unet_"):
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}") raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
@ -450,7 +455,7 @@ class LoRAModelRaw: # (torch.nn.Module):
if src_key in SDXL_UNET_COMPVIS_MAP: if src_key in SDXL_UNET_COMPVIS_MAP:
dst_key = SDXL_UNET_COMPVIS_MAP[src_key] dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
break break
src_key = "_".join(src_key.split('_')[:-1]) src_key = "_".join(src_key.split("_")[:-1])
if dst_key is None: if dst_key is None:
raise Exception(f"Unknown sdxl lora key - {full_key}") raise Exception(f"Unknown sdxl lora key - {full_key}")
@ -614,5 +619,9 @@ def make_sdxl_unet_conversion_map():
return unet_conversion_map return unet_conversion_map
#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} # _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
SDXL_UNET_COMPVIS_MAP = {
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
for sd, hf in make_sdxl_unet_conversion_map()
}

View File

@ -9,8 +9,10 @@ parser = argparse.ArgumentParser(description="Probe model type")
parser.add_argument( parser.add_argument(
"model_path", "model_path",
type=Path, type=Path,
nargs="+",
) )
args = parser.parse_args() args = parser.parse_args()
info = ModelProbe().probe(args.model_path) for path in args.model_path:
print(info) info = ModelProbe().probe(path)
print(f"{path}: {info}")