feat: add base model recognition for ip adapter safetensor files

This commit is contained in:
blessedcoolant
2024-03-24 01:58:46 +05:30
parent 67afb1763e
commit b1c8266e22
2 changed files with 34 additions and 40 deletions

View File

@ -535,7 +535,18 @@ class IPAdapterCheckpointProbe(CheckpointProbeBase):
for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")):
continue
return BaseModelType.StableDiffusionXL
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
print(cross_attention_dim)
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")