mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix minor bugs involving model manager handling of model paths (#6024)
* Fix minor bugs involving model manager handling of model paths - Leave models found in the `autoimport` directory there. Do not move them into the `models` hierarchy. - If model name, type or base is updated and model is in the `models` directory, update its path as appropriate. - On startup during model scanning, if a model's path is a symbolic link, then resolve to an absolute path before deciding it is a new model that must be hashed and registered. (This prevents needless hashing at startup time). * fix issue with dropped suffix --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
parent
4687739319
commit
eb558d72d8
@ -21,10 +21,11 @@ from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelRecordChanges,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
||||
from invokeai.app.util.suppress_output import SuppressOutput
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
@ -309,8 +310,10 @@ async def update_model_record(
|
||||
"""Update a model's config."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
try:
|
||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
||||
record_store.update_model(key, changes=changes)
|
||||
model_response: AnyModelConfig = installer.sync_model_path(key)
|
||||
logger.info(f"Updated model: {key}")
|
||||
except UnknownModelException as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
@ -468,6 +468,19 @@ class ModelInstallServiceBase(ABC):
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the model record database."""
|
||||
|
||||
@abstractmethod
|
||||
def sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
Call this after updating a model's attributes in order to move
|
||||
the model's path into the location indicated by its basetype, type and
|
||||
name. Applies only to models whose paths are within the root `models_dir`
|
||||
directory.
|
||||
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
||||
"""
|
||||
|
@ -526,7 +526,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
def sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Move model into the location indicated by its basetype, type and name.
|
||||
|
||||
@ -538,16 +538,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
May raise an UnknownModelException.
|
||||
"""
|
||||
model = self.record_store.get_model(key)
|
||||
old_path = Path(model.path)
|
||||
models_dir = self.app_config.models_path
|
||||
old_path = Path(model.path).resolve()
|
||||
models_dir = self.app_config.models_path.resolve()
|
||||
|
||||
try:
|
||||
old_path.relative_to(models_dir)
|
||||
if not old_path.is_relative_to(models_dir):
|
||||
return model
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
||||
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
|
||||
|
||||
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
||||
return model
|
||||
@ -559,11 +556,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return model
|
||||
|
||||
def _scan_register(self, model: Path) -> bool:
|
||||
if model in self._cached_model_paths:
|
||||
if model.resolve() in self._cached_model_paths:
|
||||
return True
|
||||
try:
|
||||
id = self.register_path(model)
|
||||
self._sync_model_path(id) # possibly move it to right place in `models`
|
||||
self.sync_model_path(id) # possibly move it to right place in `models`
|
||||
self._logger.info(f"Registered {model.name} with id {id}")
|
||||
self._models_installed.add(id)
|
||||
except DuplicateModelException:
|
||||
|
@ -6,6 +6,7 @@ from .model_records_base import ( # noqa F401
|
||||
ModelRecordServiceBase,
|
||||
UnknownModelException,
|
||||
ModelSummary,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
)
|
||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
||||
@ -17,5 +18,6 @@ __all__ = [
|
||||
"InvalidModelException",
|
||||
"UnknownModelException",
|
||||
"ModelSummary",
|
||||
"ModelRecordChanges",
|
||||
"ModelRecordOrderBy",
|
||||
]
|
||||
|
@ -20,7 +20,7 @@ from invokeai.app.services.model_install import (
|
||||
ModelInstallServiceBase,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
@ -82,6 +82,18 @@ def test_install(
|
||||
assert model_record.source == embedding_file.as_posix()
|
||||
|
||||
|
||||
def test_rename(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
|
||||
store.update_model(key, ModelRecordChanges(name="new_name.safetensors", base=BaseModelType("sd-2")))
|
||||
new_model_record = mm2_installer.sync_model_path(key)
|
||||
assert new_model_record.path.endswith("sd-2/embedding/new_name.safetensors")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fixture_name,size,destination",
|
||||
[
|
||||
|
Loading…
Reference in New Issue
Block a user