From 1d5d187ba10f4d2dab3d2bb1213e41cae37d701c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 3 Aug 2023 10:26:52 -0400 Subject: [PATCH] model probe detects sdxl lora models --- invokeai/app/invocations/sdxl.py | 4 +-- invokeai/backend/model_management/lora.py | 1 - .../backend/model_management/model_probe.py | 25 ++++++++++++++++--- .../backend/model_management/models/lora.py | 17 ++++++++++--- scripts/probe-model.py | 6 +++-- 5 files changed, 39 insertions(+), 14 deletions(-) diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index faa6b59782..aaa616a378 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -306,9 +306,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation): unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context) do_classifier_free_guidance = True cross_attention_kwargs = None - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ - unet_info as unet: - + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet: scheduler.set_timesteps(num_inference_steps, device=unet.device) timesteps = scheduler.timesteps diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 56f7a648c9..0a0ab3d629 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -101,7 +101,6 @@ class ModelPatcher: with cls.apply_lora(text_encoder, loras, "lora_te_"): yield - @classmethod @contextmanager def apply_sdxl_lora_text_encoder( diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index c3964d760c..21462cf6e6 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: checkpoint = self.checkpoint + + # SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future + # There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be + # misclassified as SD-1 + key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" + if key in checkpoint and checkpoint[key].shape[0] == 320: + return BaseModelType.StableDiffusion2 + + key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight" + if key in checkpoint: + return BaseModelType.StableDiffusionXL + key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight" - key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" + key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a" + lora_token_vector_length = ( checkpoint[key1].shape[1] if key1 in checkpoint - else checkpoint[key2].shape[0] + else checkpoint[key2].shape[1] if key2 in checkpoint - else 768 + else checkpoint[key3].shape[0] + if key3 in checkpoint + else None ) + if lora_token_vector_length == 768: return BaseModelType.StableDiffusion1 elif lora_token_vector_length == 1024: return BaseModelType.StableDiffusion2 else: - return None + raise InvalidModelException(f"Unknown LoRA type") class TextualInversionCheckpointProbe(CheckpointProbeBase): diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 0351bf2652..0870e78469 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -88,6 +88,7 @@ class LoRAModel(ModelBase): else: return model_path + class LoRALayerBase: # rank: Optional[int] # alpha: Optional[float] @@ -173,6 +174,7 @@ class LoRALayerBase: if self.bias is not None: self.bias = self.bias.to(device=device, dtype=dtype) + # TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): # up: torch.Tensor @@ -225,6 +227,7 @@ class LoRALayer(LoRALayerBase): if self.mid is not None: self.mid = self.mid.to(device=device, dtype=dtype) + class LoHALayer(LoRALayerBase): # w1_a: torch.Tensor # w1_b: torch.Tensor @@ -292,6 +295,7 @@ class LoHALayer(LoRALayerBase): if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) + class LoKRLayer(LoRALayerBase): # w1: Optional[torch.Tensor] = None # w1_a: Optional[torch.Tensor] = None @@ -386,6 +390,7 @@ class LoKRLayer(LoRALayerBase): if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) + # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): _name: str @@ -439,7 +444,7 @@ class LoRAModelRaw: # (torch.nn.Module): new_state_dict = dict() for full_key, value in state_dict.items(): if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): - continue # clip same + continue # clip same if not full_key.startswith("lora_unet_"): raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}") @@ -450,7 +455,7 @@ class LoRAModelRaw: # (torch.nn.Module): if src_key in SDXL_UNET_COMPVIS_MAP: dst_key = SDXL_UNET_COMPVIS_MAP[src_key] break - src_key = "_".join(src_key.split('_')[:-1]) + src_key = "_".join(src_key.split("_")[:-1]) if dst_key is None: raise Exception(f"Unknown sdxl lora key - {full_key}") @@ -614,5 +619,9 @@ def make_sdxl_unet_conversion_map(): return unet_conversion_map -#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} -SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} + +# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()} +SDXL_UNET_COMPVIS_MAP = { + f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") + for sd, hf in make_sdxl_unet_conversion_map() +} diff --git a/scripts/probe-model.py b/scripts/probe-model.py index 7281dafc3f..4cf2c50263 100755 --- a/scripts/probe-model.py +++ b/scripts/probe-model.py @@ -9,8 +9,10 @@ parser = argparse.ArgumentParser(description="Probe model type") parser.add_argument( "model_path", type=Path, + nargs="+", ) args = parser.parse_args() -info = ModelProbe().probe(args.model_path) -print(info) +for path in args.model_path: + info = ModelProbe().probe(path) + print(f"{path}: {info}")