Fix minor bugs involving model manager handling of model paths (#6024)

* Fix minor bugs involving model manager handling of model paths

- Leave models found in the `autoimport` directory there. Do not move them
  into the `models` hierarchy.
- If model name, type or base is updated and model is in the `models` directory,
  update its path as appropriate.
- On startup during model scanning, if a model's path is a symbolic link, then resolve
  to an absolute path before deciding it is a new model that must be hashed and
  registered. (This prevents needless hashing at startup time).

* fix issue with dropped suffix

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein 2024-03-21 21:14:45 -04:00 committed by GitHub
parent 4687739319
commit eb558d72d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 13 deletions

View File

@ -21,10 +21,11 @@ from typing_extensions import Annotated
from invokeai.app.services.model_install import ModelInstallJob from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException, InvalidModelException,
ModelRecordChanges,
UnknownModelException, UnknownModelException,
) )
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
from invokeai.app.util.suppress_output import SuppressOutput from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
@ -309,8 +310,10 @@ async def update_model_record(
"""Update a model's config.""" """Update a model's config."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
record_store = ApiDependencies.invoker.services.model_manager.store record_store = ApiDependencies.invoker.services.model_manager.store
installer = ApiDependencies.invoker.services.model_manager.install
try: try:
model_response: AnyModelConfig = record_store.update_model(key, changes=changes) record_store.update_model(key, changes=changes)
model_response: AnyModelConfig = installer.sync_model_path(key)
logger.info(f"Updated model: {key}") logger.info(f"Updated model: {key}")
except UnknownModelException as e: except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))

View File

@ -468,6 +468,19 @@ class ModelInstallServiceBase(ABC):
def sync_to_config(self) -> None: def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the model record database.""" """Synchronize models on disk to those in the model record database."""
@abstractmethod
def sync_model_path(self, key: str) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.
Call this after updating a model's attributes in order to move
the model's path into the location indicated by its basetype, type and
name. Applies only to models whose paths are within the root `models_dir`
directory.
May raise an UnknownModelException.
"""
@abstractmethod @abstractmethod
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
""" """

View File

@ -526,7 +526,7 @@ class ModelInstallService(ModelInstallServiceBase):
installed.update(self.scan_directory(models_dir)) installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered") self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str) -> AnyModelConfig: def sync_model_path(self, key: str) -> AnyModelConfig:
""" """
Move model into the location indicated by its basetype, type and name. Move model into the location indicated by its basetype, type and name.
@ -538,16 +538,13 @@ class ModelInstallService(ModelInstallServiceBase):
May raise an UnknownModelException. May raise an UnknownModelException.
""" """
model = self.record_store.get_model(key) model = self.record_store.get_model(key)
old_path = Path(model.path) old_path = Path(model.path).resolve()
models_dir = self.app_config.models_path models_dir = self.app_config.models_path.resolve()
try: if not old_path.is_relative_to(models_dir):
old_path.relative_to(models_dir)
return model return model
except ValueError:
pass
new_path = models_dir / model.base.value / model.type.value / old_path.name new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
if old_path == new_path or new_path.exists() and old_path == new_path.resolve(): if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
return model return model
@ -559,11 +556,11 @@ class ModelInstallService(ModelInstallServiceBase):
return model return model
def _scan_register(self, model: Path) -> bool: def _scan_register(self, model: Path) -> bool:
if model in self._cached_model_paths: if model.resolve() in self._cached_model_paths:
return True return True
try: try:
id = self.register_path(model) id = self.register_path(model)
self._sync_model_path(id) # possibly move it to right place in `models` self.sync_model_path(id) # possibly move it to right place in `models`
self._logger.info(f"Registered {model.name} with id {id}") self._logger.info(f"Registered {model.name} with id {id}")
self._models_installed.add(id) self._models_installed.add(id)
except DuplicateModelException: except DuplicateModelException:

View File

@ -6,6 +6,7 @@ from .model_records_base import ( # noqa F401
ModelRecordServiceBase, ModelRecordServiceBase,
UnknownModelException, UnknownModelException,
ModelSummary, ModelSummary,
ModelRecordChanges,
ModelRecordOrderBy, ModelRecordOrderBy,
) )
from .model_records_sql import ModelRecordServiceSQL # noqa F401 from .model_records_sql import ModelRecordServiceSQL # noqa F401
@ -17,5 +18,6 @@ __all__ = [
"InvalidModelException", "InvalidModelException",
"UnknownModelException", "UnknownModelException",
"ModelSummary", "ModelSummary",
"ModelRecordChanges",
"ModelRecordOrderBy", "ModelRecordOrderBy",
] ]

View File

@ -20,7 +20,7 @@ from invokeai.app.services.model_install import (
ModelInstallServiceBase, ModelInstallServiceBase,
URLModelSource, URLModelSource,
) )
from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
@ -82,6 +82,18 @@ def test_install(
assert model_record.source == embedding_file.as_posix() assert model_record.source == embedding_file.as_posix()
def test_rename(
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
) -> None:
store = mm2_installer.record_store
key = mm2_installer.install_path(embedding_file)
model_record = store.get_model(key)
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
store.update_model(key, ModelRecordChanges(name="new_name.safetensors", base=BaseModelType("sd-2")))
new_model_record = mm2_installer.sync_model_path(key)
assert new_model_record.path.endswith("sd-2/embedding/new_name.safetensors")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fixture_name,size,destination", "fixture_name,size,destination",
[ [