mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model probe detects sdxl lora models
This commit is contained in:
parent
1ac14a1e43
commit
1d5d187ba1
@ -306,9 +306,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
|
|||||||
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
|
||||||
do_classifier_free_guidance = True
|
do_classifier_free_guidance = True
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
|
with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
|
||||||
unet_info as unet:
|
|
||||||
|
|
||||||
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
scheduler.set_timesteps(num_inference_steps, device=unet.device)
|
||||||
timesteps = scheduler.timesteps
|
timesteps = scheduler.timesteps
|
||||||
|
|
||||||
|
@ -101,7 +101,6 @@ class ModelPatcher:
|
|||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_sdxl_lora_text_encoder(
|
def apply_sdxl_lora_text_encoder(
|
||||||
|
@ -315,21 +315,38 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
|||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
checkpoint = self.checkpoint
|
checkpoint = self.checkpoint
|
||||||
|
|
||||||
|
# SD-2 models are very hard to probe. These probes are brittle and likely to fail in the future
|
||||||
|
# There are also some "SD-2 LoRAs" that have identical keys and shapes to SD-1 and will be
|
||||||
|
# misclassified as SD-1
|
||||||
|
key = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
|
if key in checkpoint and checkpoint[key].shape[0] == 320:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
|
||||||
|
key = "lora_unet_output_blocks_5_1_transformer_blocks_1_ff_net_2.lora_up.weight"
|
||||||
|
if key in checkpoint:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
|
||||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||||
|
key3 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||||
|
|
||||||
lora_token_vector_length = (
|
lora_token_vector_length = (
|
||||||
checkpoint[key1].shape[1]
|
checkpoint[key1].shape[1]
|
||||||
if key1 in checkpoint
|
if key1 in checkpoint
|
||||||
else checkpoint[key2].shape[0]
|
else checkpoint[key2].shape[1]
|
||||||
if key2 in checkpoint
|
if key2 in checkpoint
|
||||||
else 768
|
else checkpoint[key3].shape[0]
|
||||||
|
if key3 in checkpoint
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if lora_token_vector_length == 768:
|
if lora_token_vector_length == 768:
|
||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
elif lora_token_vector_length == 1024:
|
elif lora_token_vector_length == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
else:
|
else:
|
||||||
return None
|
raise InvalidModelException(f"Unknown LoRA type")
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||||
|
@ -88,6 +88,7 @@ class LoRAModel(ModelBase):
|
|||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBase:
|
class LoRALayerBase:
|
||||||
# rank: Optional[int]
|
# rank: Optional[int]
|
||||||
# alpha: Optional[float]
|
# alpha: Optional[float]
|
||||||
@ -173,6 +174,7 @@ class LoRALayerBase:
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
class LoRALayer(LoRALayerBase):
|
class LoRALayer(LoRALayerBase):
|
||||||
# up: torch.Tensor
|
# up: torch.Tensor
|
||||||
@ -225,6 +227,7 @@ class LoRALayer(LoRALayerBase):
|
|||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
class LoHALayer(LoRALayerBase):
|
||||||
# w1_a: torch.Tensor
|
# w1_a: torch.Tensor
|
||||||
# w1_b: torch.Tensor
|
# w1_b: torch.Tensor
|
||||||
@ -292,6 +295,7 @@ class LoHALayer(LoRALayerBase):
|
|||||||
if self.t2 is not None:
|
if self.t2 is not None:
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
class LoKRLayer(LoRALayerBase):
|
||||||
# w1: Optional[torch.Tensor] = None
|
# w1: Optional[torch.Tensor] = None
|
||||||
# w1_a: Optional[torch.Tensor] = None
|
# w1_a: Optional[torch.Tensor] = None
|
||||||
@ -386,6 +390,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
if self.t2 is not None:
|
if self.t2 is not None:
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||||
class LoRAModelRaw: # (torch.nn.Module):
|
class LoRAModelRaw: # (torch.nn.Module):
|
||||||
_name: str
|
_name: str
|
||||||
@ -439,7 +444,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
new_state_dict = dict()
|
new_state_dict = dict()
|
||||||
for full_key, value in state_dict.items():
|
for full_key, value in state_dict.items():
|
||||||
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
|
||||||
continue # clip same
|
continue # clip same
|
||||||
|
|
||||||
if not full_key.startswith("lora_unet_"):
|
if not full_key.startswith("lora_unet_"):
|
||||||
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
|
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}")
|
||||||
@ -450,7 +455,7 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
if src_key in SDXL_UNET_COMPVIS_MAP:
|
if src_key in SDXL_UNET_COMPVIS_MAP:
|
||||||
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
|
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
|
||||||
break
|
break
|
||||||
src_key = "_".join(src_key.split('_')[:-1])
|
src_key = "_".join(src_key.split("_")[:-1])
|
||||||
|
|
||||||
if dst_key is None:
|
if dst_key is None:
|
||||||
raise Exception(f"Unknown sdxl lora key - {full_key}")
|
raise Exception(f"Unknown sdxl lora key - {full_key}")
|
||||||
@ -614,5 +619,9 @@ def make_sdxl_unet_conversion_map():
|
|||||||
|
|
||||||
return unet_conversion_map
|
return unet_conversion_map
|
||||||
|
|
||||||
#_sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
|
|
||||||
SDXL_UNET_COMPVIS_MAP = {f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
|
# _sdxl_conversion_map = {f"lora_unet_{sd}".rstrip(".").replace(".", "_"): f"lora_unet_{hf}".rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()}
|
||||||
|
SDXL_UNET_COMPVIS_MAP = {
|
||||||
|
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_")
|
||||||
|
for sd, hf in make_sdxl_unet_conversion_map()
|
||||||
|
}
|
||||||
|
@ -9,8 +9,10 @@ parser = argparse.ArgumentParser(description="Probe model type")
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"model_path",
|
"model_path",
|
||||||
type=Path,
|
type=Path,
|
||||||
|
nargs="+",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
info = ModelProbe().probe(args.model_path)
|
for path in args.model_path:
|
||||||
print(info)
|
info = ModelProbe().probe(path)
|
||||||
|
print(f"{path}: {info}")
|
||||||
|
Loading…
Reference in New Issue
Block a user