mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): add source_type
to model configs
This commit is contained in:
@ -120,6 +120,15 @@ class ModelRepoVariant(str, Enum):
|
||||
FLAX = "flax"
|
||||
|
||||
|
||||
class ModelSourceType(str, Enum):
|
||||
"""Model source type."""
|
||||
|
||||
Path = "path"
|
||||
Url = "url"
|
||||
HFRepoID = "hf_repo_id"
|
||||
CivitAI = "civitai"
|
||||
|
||||
|
||||
class ModelConfigBase(BaseModel):
|
||||
"""Base class for model configuration information."""
|
||||
|
||||
@ -128,7 +137,9 @@ class ModelConfigBase(BaseModel):
|
||||
base: BaseModelType = Field(description="base model")
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
hash: Optional[str] = Field(description="original fasthash of model contents", default=None)
|
||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||
description: Optional[str] = Field(description="Model description", default=None)
|
||||
source: str = Field(description="The source of the model (e.g. path, URL, HF Repo ID)")
|
||||
source_type: ModelSourceType = Field(description="The type of source")
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
|
@ -147,7 +147,6 @@ class ModelProbe(object):
|
||||
if not probe_class:
|
||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||
|
||||
hash = ModelHash().hash(model_path)
|
||||
probe = probe_class(model_path)
|
||||
|
||||
fields["path"] = model_path.as_posix()
|
||||
@ -161,13 +160,16 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or hash
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae] and fields["format"] == ModelFormat.Checkpoint:
|
||||
if (
|
||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae]
|
||||
and fields["format"] is ModelFormat.Checkpoint
|
||||
):
|
||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
||||
model_path,
|
||||
model_type=fields["type"],
|
||||
|
Reference in New Issue
Block a user