probe for required encoder for IPAdapters and add to config

This commit is contained in:
Lincoln Stein
2024-02-09 20:46:47 -05:00
committed by Brandon Rising
parent dbd2f8dc5f
commit 8f1b7355df
3 changed files with 15 additions and 23 deletions

View File

@ -263,6 +263,7 @@ class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]

View File

@ -78,6 +78,10 @@ class ProbeBase(object):
"""Get model scheduler prediction type."""
return None
def get_image_encoder_model_id(self) -> Optional[str]:
"""Get image encoder (IP adapters only)."""
return None
class ModelProbe(object):
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
@ -153,6 +157,7 @@ class ModelProbe(object):
fields["base"] = fields.get("base") or probe.get_base_type()
fields["variant"] = fields.get("variant") or probe.get_variant_type()
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
fields["description"] = (
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
@ -669,6 +674,14 @@ class IPAdapterFolderProbe(FolderProbeBase):
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
def get_image_encoder_model_id(self) -> Optional[str]:
encoder_id_path = self.model_path / "image_encoder.txt"
if not encoder_id_path.exists():
return None
with open(encoder_id_path, "r") as f:
image_encoder_model = f.readline().strip()
return image_encoder_model
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: