mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model probe detects sdxl lora models
This commit is contained in:
parent
1ac14a1e43
commit
1d5d187ba1
@ -306,9 +306,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||
do_classifier_free_guidance = True
|
||||
cross_attention_kwargs = None
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
||||
unet_info as unet:
|
||||
|
||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
|
@ -101,7 +101,6 @@ class ModelPatcher:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_sdxl_lora_text_encoder(
|
||||
|
@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
|
||||
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
|
||||
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
|
||||
# misclassified as SD-1
|
||||
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
if key in checkpoint and checkpoint[key].shape[0] == 320:
|
||||
return BaseModelType.StableDiffusion2
|
||||
|
||||
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
|
||||
if key in checkpoint:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
|
||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||
|
||||
lora_token_vector_length = (
|
||||
checkpoint[key1].shape[1]
|
||||
if key1 in checkpoint
|
||||
else checkpoint[key2].shape[0]
|
||||
else checkpoint[key2].shape[1]
|
||||
if key2 in checkpoint
|
||||
else 768
|
||||
else checkpoint[key3].shape[0]
|
||||
if key3 in checkpoint
|
||||
else None
|
||||
)
|
||||
|
||||
if lora_token_vector_length == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif lora_token_vector_length == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
else:
|
||||
return None
|
||||
raise InvalidModelException(f"Unknown LoRA type")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
|
@ -88,6 +88,7 @@ class LoRAModel(ModelBase):
|
||||
else:
|
||||
return model_path
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
# alpha: Optional[float]
|
||||
@ -173,6 +174,7 @@ class LoRALayerBase:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
class LoRALayer(LoRALayerBase):
|
||||
# up: torch.Tensor
|
||||
@ -225,6 +227,7 @@ class LoRALayer(LoRALayerBase):
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
# w1_a: torch.Tensor
|
||||
# w1_b: torch.Tensor
|
||||
@ -292,6 +295,7 @@ class LoHALayer(LoRALayerBase):
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
# w1: Optional[torch.Tensor] = None
|
||||
# w1_a: Optional[torch.Tensor] = None
|
||||
@ -386,6 +390,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
_name: str
|
||||
@ -439,7 +444,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
new_state_dict = dict()
|
||||
for full_key, value in state_dict.items():
|
||||
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||
continue # clip same
|
||||
continue # clip same
|
||||
|
||||
if not full_key.startswith("lora_unet_"):
|
||||
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
|
||||
@ -450,7 +455,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
if src_key in SDXL_UNET_COMPVIS_MAP:
|
||||
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
|
||||
break
|
||||
src_key = "_".join(src_key.split('_')[:-1])
|
||||
src_key = "_".join(src_key.split("_")[:-1])
|
||||
|
||||
if dst_key is None:
|
||||
raise Exception(f"Unknown sdxl lora key - {full_key}")
|
||||
@ -614,5 +619,9 @@ def make_sdxl_unet_conversion_map():
|
||||
|
||||
return unet_conversion_map
|
||||
|
||||
#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
|
||||
SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
|
||||
|
||||
# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
|
||||
SDXL_UNET_COMPVIS_MAP = {
|
||||
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||
for sd, hf in make_sdxl_unet_conversion_map()
|
||||
}
|
||||
|
@ -9,8 +9,10 @@ parser = argparse.ArgumentParser(description="Probe model type")
|
||||
parser.add_argument(
|
||||
"model_path",
|
||||
type=Path,
|
||||
nargs="+",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
info = ModelProbe().probe(args.model_path)
|
||||
print(info)
|
||||
for path in args.model_path:
|
||||
info = ModelProbe().probe(path)
|
||||
print(f"{path}: {info}")
|
||||
|
Loading…
Reference in New Issue
Block a user