mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): add source_type
to model configs
This commit is contained in:
parent
4471ea8ad1
commit
9378e47a06
@ -18,6 +18,7 @@ from invokeai.app.services.events.events_base import EventServiceBase
|
|||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||||
|
from invokeai.backend.model_manager.config import ModelSourceType
|
||||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||||
|
|
||||||
from ..model_metadata import ModelMetadataStoreBase
|
from ..model_metadata import ModelMetadataStoreBase
|
||||||
@ -151,6 +152,13 @@ ModelSource = Annotated[
|
|||||||
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
Union[LocalModelSource, HFModelSource, CivitaiModelSource, URLModelSource], Field(discriminator="type")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||||
|
URLModelSource: ModelSourceType.Url,
|
||||||
|
HFModelSource: ModelSourceType.HFRepoID,
|
||||||
|
CivitaiModelSource: ModelSourceType.CivitAI,
|
||||||
|
LocalModelSource: ModelSourceType.Path,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallJob(BaseModel):
|
class ModelInstallJob(BaseModel):
|
||||||
"""Object that tracks the current status of an install request."""
|
"""Object that tracks the current status of an install request."""
|
||||||
|
@ -27,6 +27,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
CheckpointConfigBase,
|
CheckpointConfigBase,
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
|
ModelSourceType,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata import (
|
from invokeai.backend.model_manager.metadata import (
|
||||||
@ -42,6 +43,7 @@ from invokeai.backend.util import Chdir, InvokeAILogger
|
|||||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
CivitaiModelSource,
|
CivitaiModelSource,
|
||||||
HFModelSource,
|
HFModelSource,
|
||||||
InstallStatus,
|
InstallStatus,
|
||||||
@ -140,6 +142,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
config = config or {}
|
config = config or {}
|
||||||
if not config.get("source"):
|
if not config.get("source"):
|
||||||
config["source"] = model_path.resolve().as_posix()
|
config["source"] = model_path.resolve().as_posix()
|
||||||
|
config["source_type"] = ModelSourceType.Path
|
||||||
return self._register(model_path, config)
|
return self._register(model_path, config)
|
||||||
|
|
||||||
def install_path(
|
def install_path(
|
||||||
@ -153,7 +156,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
config["source"] = model_path.resolve().as_posix()
|
config["source"] = model_path.resolve().as_posix()
|
||||||
config["key"] = config.get("key", uuid_string())
|
config["key"] = config.get("key", uuid_string())
|
||||||
|
|
||||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config)
|
||||||
|
|
||||||
if preferred_name := config.get("name"):
|
if preferred_name := config.get("name"):
|
||||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||||
@ -375,6 +378,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.bytes = job.total_bytes
|
job.bytes = job.total_bytes
|
||||||
self._signal_job_running(job)
|
self._signal_job_running(job)
|
||||||
job.config_in["source"] = str(job.source)
|
job.config_in["source"] = str(job.source)
|
||||||
|
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
|
||||||
if job.inplace:
|
if job.inplace:
|
||||||
key = self.register_path(job.local_path, job.config_in)
|
key = self.register_path(job.local_path, job.config_in)
|
||||||
else:
|
else:
|
||||||
@ -521,13 +525,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
move(old_path, new_path)
|
move(old_path, new_path)
|
||||||
return new_path
|
return new_path
|
||||||
|
|
||||||
def _probe_model(self, model_path: Path, config: Optional[Dict[str, Any]] = None) -> AnyModelConfig:
|
|
||||||
info: AnyModelConfig = ModelProbe.probe(Path(model_path))
|
|
||||||
if config: # used to override probe fields
|
|
||||||
for key, value in config.items():
|
|
||||||
setattr(info, key, value)
|
|
||||||
return info
|
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -538,8 +535,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
info = info or ModelProbe.probe(model_path, config)
|
info = info or ModelProbe.probe(model_path, config)
|
||||||
override_key: Optional[str] = config.get("key") if config else None
|
override_key: Optional[str] = config.get("key") if config else None
|
||||||
|
|
||||||
assert info.original_hash # always assigned by probe()
|
assert info.hash # always assigned by probe()
|
||||||
info.key = override_key or info.original_hash
|
info.key = override_key or info.hash
|
||||||
|
|
||||||
model_path = model_path.absolute()
|
model_path = model_path.absolute()
|
||||||
if model_path.is_relative_to(self.app_config.models_path):
|
if model_path.is_relative_to(self.app_config.models_path):
|
||||||
@ -573,7 +570,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
source=source,
|
source=source,
|
||||||
config_in=config or {},
|
config_in=config or {},
|
||||||
local_path=Path(source.path),
|
local_path=Path(source.path),
|
||||||
inplace=source.inplace,
|
inplace=source.inplace or False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_from_civitai(self, source: CivitaiModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
@ -630,7 +627,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _import_remote_model(
|
def _import_remote_model(
|
||||||
self,
|
self,
|
||||||
source: ModelSource,
|
source: HFModelSource | CivitaiModelSource | URLModelSource,
|
||||||
remote_files: List[RemoteModelFile],
|
remote_files: List[RemoteModelFile],
|
||||||
metadata: Optional[AnyModelRepoMetadata],
|
metadata: Optional[AnyModelRepoMetadata],
|
||||||
config: Optional[Dict[str, Any]],
|
config: Optional[Dict[str, Any]],
|
||||||
@ -658,7 +655,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# In the event that there is a subfolder specified in the source,
|
# In the event that there is a subfolder specified in the source,
|
||||||
# we need to remove it from the destination path in order to avoid
|
# we need to remove it from the destination path in order to avoid
|
||||||
# creating unwanted subfolders
|
# creating unwanted subfolders
|
||||||
if hasattr(source, "subfolder") and source.subfolder:
|
if isinstance(source, HFModelSource) and source.subfolder:
|
||||||
root = Path(remote_files[0].path.parts[0])
|
root = Path(remote_files[0].path.parts[0])
|
||||||
subfolder = root / source.subfolder
|
subfolder = root / source.subfolder
|
||||||
else:
|
else:
|
||||||
|
@ -120,6 +120,15 @@ class ModelRepoVariant(str, Enum):
|
|||||||
FLAX = "flax"
|
FLAX = "flax"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSourceType(str, Enum):
|
||||||
|
"""Model source type."""
|
||||||
|
|
||||||
|
Path = "path"
|
||||||
|
Url = "url"
|
||||||
|
HFRepoID = "hf_repo_id"
|
||||||
|
CivitAI = "civitai"
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigBase(BaseModel):
|
class ModelConfigBase(BaseModel):
|
||||||
"""Base class for model configuration information."""
|
"""Base class for model configuration information."""
|
||||||
|
|
||||||
@ -128,7 +137,9 @@ class ModelConfigBase(BaseModel):
|
|||||||
base: BaseModelType = Field(description="base model")
|
base: BaseModelType = Field(description="base model")
|
||||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||||
hash: Optional[str] = Field(description="original fasthash of model contents", default=None)
|
hash: Optional[str] = Field(description="original fasthash of model contents", default=None)
|
||||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
description: Optional[str] = Field(description="Model description", default=None)
|
||||||
|
source: str = Field(description="The source of the model (e.g. path, URL, HF Repo ID)")
|
||||||
|
source_type: ModelSourceType = Field(description="The type of source")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
|
@ -147,7 +147,6 @@ class ModelProbe(object):
|
|||||||
if not probe_class:
|
if not probe_class:
|
||||||
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}")
|
||||||
|
|
||||||
hash = ModelHash().hash(model_path)
|
|
||||||
probe = probe_class(model_path)
|
probe = probe_class(model_path)
|
||||||
|
|
||||||
fields["path"] = model_path.as_posix()
|
fields["path"] = model_path.as_posix()
|
||||||
@ -161,13 +160,16 @@ class ModelProbe(object):
|
|||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or hash
|
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||||
|
|
||||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
|
||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae] and fields["format"] == ModelFormat.Checkpoint:
|
if (
|
||||||
|
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae]
|
||||||
|
and fields["format"] is ModelFormat.Checkpoint
|
||||||
|
):
|
||||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
fields["config_path"] = cls._get_checkpoint_config_path(
|
||||||
model_path,
|
model_path,
|
||||||
model_type=fields["type"],
|
model_type=fields["type"],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user