Update _get_hf_load_class to support clipvision models

This commit is contained in:
Brandon Rising 2024-02-14 13:07:11 -05:00 committed by psychedelicious
parent 0b1c2acd61
commit 8e51392910

View File

@ -163,8 +163,12 @@ class ModelLoader(ModelLoaderBase):
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config["_class_name"]
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
class_name = config.get("_class_name", None)
if class_name:
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model":
class_name = config.get("architectures")[0]
return self._hf_definition_to_type(module="transformers", class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e