install missing clip_vision encoders if required by an ip adapter (#5982)

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein 2024-03-17 22:19:53 -04:00 committed by GitHub
parent 7ea9cac9a3
commit 7387b0bdc9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,7 @@ from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType, IPAdapterConfig, ModelType from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
class IPAdapterField(BaseModel): class IPAdapterField(BaseModel):
@ -89,17 +89,32 @@ class IPAdapterInvocation(BaseInvocation):
assert isinstance(ip_adapter_info, IPAdapterConfig) assert isinstance(ip_adapter_info, IPAdapterConfig)
image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs( image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
return IPAdapterOutput( return IPAdapterOutput(
ip_adapter=IPAdapterField( ip_adapter=IPAdapterField(
image=self.image, image=self.image,
ip_adapter_model=self.ip_adapter_model, ip_adapter_model=self.ip_adapter_model,
image_encoder_model=ModelIdentifierField.from_config(image_encoder_models[0]), image_encoder_model=ModelIdentifierField.from_config(image_encoder_model),
weight=self.weight, weight=self.weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
), ),
) )
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
found = False
while not found:
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
found = len(image_encoder_models) > 0
if not found:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed."
)
context.logger.warning("Downloading and installing now. This may take a while.")
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
assert len(image_encoder_models) == 1
return image_encoder_models[0]