mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
probe for required encoder for IPAdapters and add to config
This commit is contained in:
parent
db340bc253
commit
8db01ab1b3
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
from builtins import float
|
from builtins import float
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
@ -52,16 +51,6 @@ class IPAdapterField(BaseModel):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
def get_ip_adapter_image_encoder_model_id(model_path: str):
|
|
||||||
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
|
|
||||||
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
|
|
||||||
|
|
||||||
with open(image_encoder_config_file, "r") as f:
|
|
||||||
image_encoder_model = f.readline().strip()
|
|
||||||
|
|
||||||
return image_encoder_model
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("ip_adapter_output")
|
@invocation_output("ip_adapter_output")
|
||||||
class IPAdapterOutput(BaseInvocationOutput):
|
class IPAdapterOutput(BaseInvocationOutput):
|
||||||
# Outputs
|
# Outputs
|
||||||
@ -102,18 +91,7 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key)
|
ip_adapter_info = context.services.model_records.get_model(self.ip_adapter_model.key)
|
||||||
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||||
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
|
|
||||||
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
|
|
||||||
# is currently messy due to differences between how the model info is generated when installing a model from
|
|
||||||
# disk vs. downloading the model.
|
|
||||||
# TODO (LS): Fix the issue above by:
|
|
||||||
# 1. Change IPAdapterConfig definition to include a field for the repo_id of the image encoder model.
|
|
||||||
# 2. Update probe.py to read `image_encoder.txt` and store it in the config.
|
|
||||||
# 3. Change below to get the image encoder from the configuration record.
|
|
||||||
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
|
||||||
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info.path)
|
|
||||||
)
|
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_models = context.services.model_records.search_by_attr(
|
image_encoder_models = context.services.model_records.search_by_attr(
|
||||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||||
|
@ -263,6 +263,7 @@ class IPAdapterConfig(ModelConfigBase):
|
|||||||
"""Model config for IP Adaptor format models."""
|
"""Model config for IP Adaptor format models."""
|
||||||
|
|
||||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||||
|
image_encoder_model_id: str
|
||||||
format: Literal[ModelFormat.InvokeAI]
|
format: Literal[ModelFormat.InvokeAI]
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,6 +78,10 @@ class ProbeBase(object):
|
|||||||
"""Get model scheduler prediction type."""
|
"""Get model scheduler prediction type."""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_image_encoder_model_id(self) -> Optional[str]:
|
||||||
|
"""Get image encoder (IP adapters only)."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ModelProbe(object):
|
class ModelProbe(object):
|
||||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||||
@ -153,6 +157,7 @@ class ModelProbe(object):
|
|||||||
fields["base"] = fields.get("base") or probe.get_base_type()
|
fields["base"] = fields.get("base") or probe.get_base_type()
|
||||||
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
fields["variant"] = fields.get("variant") or probe.get_variant_type()
|
||||||
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
|
||||||
|
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
|
||||||
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
|
||||||
fields["description"] = (
|
fields["description"] = (
|
||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||||
@ -669,6 +674,14 @@ class IPAdapterFolderProbe(FolderProbeBase):
|
|||||||
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
|
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_image_encoder_model_id(self) -> Optional[str]:
|
||||||
|
encoder_id_path = self.model_path / "image_encoder.txt"
|
||||||
|
if not encoder_id_path.exists():
|
||||||
|
return None
|
||||||
|
with open(encoder_id_path, "r") as f:
|
||||||
|
image_encoder_model = f.readline().strip()
|
||||||
|
return image_encoder_model
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
Loading…
Reference in New Issue
Block a user