mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Let users pick CLIP Vision model for Checkpoint IP Adapters
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
from builtins import float
|
||||
from typing import List, Union
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
@ -49,12 +49,15 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||
|
||||
|
||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
# Inputs
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).", ui_order=1)
|
||||
ip_adapter_model: ModelIdentifierField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
@ -62,7 +65,9 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ui_order=-1,
|
||||
ui_type=UIType.IPAdapterModel,
|
||||
)
|
||||
|
||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
|
||||
description="CLIP Vision model to use", default="ViT-H", ui_order=2
|
||||
)
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, description="The weight given to the IP-Adapter", title="Weight"
|
||||
)
|
||||
@ -89,12 +94,12 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
assert isinstance(ip_adapter_info, (IPAdapterDiffusersConfig, IPAdapterCheckpointConfig))
|
||||
|
||||
image_encoder_model_id = (
|
||||
ip_adapter_info.image_encoder_model_id
|
||||
if isinstance(ip_adapter_info, IPAdapterDiffusersConfig)
|
||||
else "ip_adapter_sd_image_encoder"
|
||||
)
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
if isinstance(ip_adapter_info, IPAdapterDiffusersConfig):
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
else:
|
||||
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
|
||||
|
||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||
|
||||
return IPAdapterOutput(
|
||||
@ -109,19 +114,25 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
|
||||
found = False
|
||||
while not found:
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
|
||||
if not len(image_encoder_models) > 0:
|
||||
context.logger.warning(
|
||||
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
|
||||
Downloading and installing now. This may take a while."
|
||||
)
|
||||
|
||||
installer = context._services.model_manager.install
|
||||
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
|
||||
)
|
||||
found = len(image_encoder_models) > 0
|
||||
if not found:
|
||||
context.logger.warning(
|
||||
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
|
||||
)
|
||||
context.logger.warning("Downloading and installing now. This may take a while.")
|
||||
installer = context._services.model_manager.install
|
||||
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
|
||||
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
|
||||
assert len(image_encoder_models) == 1
|
||||
|
||||
if len(image_encoder_models) == 0:
|
||||
context.logger.error("Error while fetching CLIP Vision Image Encoder")
|
||||
assert len(image_encoder_models) == 1
|
||||
|
||||
return image_encoder_models[0]
|
||||
|
Reference in New Issue
Block a user