Get CLIPVision model download from HF working.

This commit is contained in:
Ryan Dick 2023-09-14 09:54:10 -04:00
parent 2c1100509f
commit 6d0ea42a94
3 changed files with 13 additions and 5 deletions

View File

@ -418,7 +418,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
image_encoder_model_info = context.services.model_manager.get_model(
# TODO(ryand): Get this model_name from the IPAdapterField.
model_name="ip_adapter_clip_vision",
model_name="ip_adapter_sd_image_encoder",
model_type=ModelType.CLIPVision,
base_model=BaseModelType.Any,
context=context,

View File

@ -318,7 +318,6 @@ class ModelInstall(object):
location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif "unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging)
# TODO(ryand): Add special handling for ip_adapter?
else:
for suffix in ["safetensors", "bin"]:
if f"pytorch_lora_weights.{suffix}" in files:
@ -337,6 +336,11 @@ class ModelInstall(object):
elif f"learned_embeds.{suffix}" in files:
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
break
elif f"model.{suffix}" in files and "config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
if not location:
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {}

View File

@ -54,8 +54,7 @@ class ModelProbe(object):
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"IPAdapterModel": ModelType.IPAdapter,
"CLIPVision": ModelType.CLIPVision,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
}
@classmethod
@ -196,7 +195,12 @@ class ModelProbe(object):
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
class_name = conf["_class_name"]
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type