mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix minor bugs involving model manager handling of model paths (#6024)
* Fix minor bugs involving model manager handling of model paths - Leave models found in the `autoimport` directory there. Do not move them into the `models` hierarchy. - If model name, type or base is updated and model is in the `models` directory, update its path as appropriate. - On startup during model scanning, if a model's path is a symbolic link, then resolve to an absolute path before deciding it is a new model that must be hashed and registered. (This prevents needless hashing at startup time). * fix issue with dropped suffix --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
parent
4687739319
commit
eb558d72d8
@ -21,10 +21,11 @@ from typing_extensions import Annotated
|
|||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallJob
|
from invokeai.app.services.model_install import ModelInstallJob
|
||||||
from invokeai.app.services.model_records import (
|
from invokeai.app.services.model_records import (
|
||||||
|
DuplicateModelException,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
|
ModelRecordChanges,
|
||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records.model_records_base import DuplicateModelException, ModelRecordChanges
|
|
||||||
from invokeai.app.util.suppress_output import SuppressOutput
|
from invokeai.app.util.suppress_output import SuppressOutput
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -309,8 +310,10 @@ async def update_model_record(
|
|||||||
"""Update a model's config."""
|
"""Update a model's config."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||||
|
installer = ApiDependencies.invoker.services.model_manager.install
|
||||||
try:
|
try:
|
||||||
model_response: AnyModelConfig = record_store.update_model(key, changes=changes)
|
record_store.update_model(key, changes=changes)
|
||||||
|
model_response: AnyModelConfig = installer.sync_model_path(key)
|
||||||
logger.info(f"Updated model: {key}")
|
logger.info(f"Updated model: {key}")
|
||||||
except UnknownModelException as e:
|
except UnknownModelException as e:
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
@ -468,6 +468,19 @@ class ModelInstallServiceBase(ABC):
|
|||||||
def sync_to_config(self) -> None:
|
def sync_to_config(self) -> None:
|
||||||
"""Synchronize models on disk to those in the model record database."""
|
"""Synchronize models on disk to those in the model record database."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def sync_model_path(self, key: str) -> AnyModelConfig:
|
||||||
|
"""
|
||||||
|
Move model into the location indicated by its basetype, type and name.
|
||||||
|
|
||||||
|
Call this after updating a model's attributes in order to move
|
||||||
|
the model's path into the location indicated by its basetype, type and
|
||||||
|
name. Applies only to models whose paths are within the root `models_dir`
|
||||||
|
directory.
|
||||||
|
|
||||||
|
May raise an UnknownModelException.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
||||||
"""
|
"""
|
||||||
|
@ -526,7 +526,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
installed.update(self.scan_directory(models_dir))
|
installed.update(self.scan_directory(models_dir))
|
||||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||||
|
|
||||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
def sync_model_path(self, key: str) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Move model into the location indicated by its basetype, type and name.
|
Move model into the location indicated by its basetype, type and name.
|
||||||
|
|
||||||
@ -538,16 +538,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
May raise an UnknownModelException.
|
May raise an UnknownModelException.
|
||||||
"""
|
"""
|
||||||
model = self.record_store.get_model(key)
|
model = self.record_store.get_model(key)
|
||||||
old_path = Path(model.path)
|
old_path = Path(model.path).resolve()
|
||||||
models_dir = self.app_config.models_path
|
models_dir = self.app_config.models_path.resolve()
|
||||||
|
|
||||||
try:
|
if not old_path.is_relative_to(models_dir):
|
||||||
old_path.relative_to(models_dir)
|
|
||||||
return model
|
return model
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
new_path = models_dir / model.base.value / model.type.value / old_path.name
|
new_path = (models_dir / model.base.value / model.type.value / model.name).with_suffix(old_path.suffix)
|
||||||
|
|
||||||
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
if old_path == new_path or new_path.exists() and old_path == new_path.resolve():
|
||||||
return model
|
return model
|
||||||
@ -559,11 +556,11 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def _scan_register(self, model: Path) -> bool:
|
def _scan_register(self, model: Path) -> bool:
|
||||||
if model in self._cached_model_paths:
|
if model.resolve() in self._cached_model_paths:
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
id = self.register_path(model)
|
id = self.register_path(model)
|
||||||
self._sync_model_path(id) # possibly move it to right place in `models`
|
self.sync_model_path(id) # possibly move it to right place in `models`
|
||||||
self._logger.info(f"Registered {model.name} with id {id}")
|
self._logger.info(f"Registered {model.name} with id {id}")
|
||||||
self._models_installed.add(id)
|
self._models_installed.add(id)
|
||||||
except DuplicateModelException:
|
except DuplicateModelException:
|
||||||
|
@ -6,6 +6,7 @@ from .model_records_base import ( # noqa F401
|
|||||||
ModelRecordServiceBase,
|
ModelRecordServiceBase,
|
||||||
UnknownModelException,
|
UnknownModelException,
|
||||||
ModelSummary,
|
ModelSummary,
|
||||||
|
ModelRecordChanges,
|
||||||
ModelRecordOrderBy,
|
ModelRecordOrderBy,
|
||||||
)
|
)
|
||||||
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
from .model_records_sql import ModelRecordServiceSQL # noqa F401
|
||||||
@ -17,5 +18,6 @@ __all__ = [
|
|||||||
"InvalidModelException",
|
"InvalidModelException",
|
||||||
"UnknownModelException",
|
"UnknownModelException",
|
||||||
"ModelSummary",
|
"ModelSummary",
|
||||||
|
"ModelRecordChanges",
|
||||||
"ModelRecordOrderBy",
|
"ModelRecordOrderBy",
|
||||||
]
|
]
|
||||||
|
@ -20,7 +20,7 @@ from invokeai.app.services.model_install import (
|
|||||||
ModelInstallServiceBase,
|
ModelInstallServiceBase,
|
||||||
URLModelSource,
|
URLModelSource,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.model_records import UnknownModelException
|
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
@ -82,6 +82,18 @@ def test_install(
|
|||||||
assert model_record.source == embedding_file.as_posix()
|
assert model_record.source == embedding_file.as_posix()
|
||||||
|
|
||||||
|
|
||||||
|
def test_rename(
|
||||||
|
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||||
|
) -> None:
|
||||||
|
store = mm2_installer.record_store
|
||||||
|
key = mm2_installer.install_path(embedding_file)
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
|
||||||
|
store.update_model(key, ModelRecordChanges(name="new_name.safetensors", base=BaseModelType("sd-2")))
|
||||||
|
new_model_record = mm2_installer.sync_model_path(key)
|
||||||
|
assert new_model_record.path.endswith("sd-2/embedding/new_name.safetensors")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"fixture_name,size,destination",
|
"fixture_name,size,destination",
|
||||||
[
|
[
|
||||||
|
Loading…
Reference in New Issue
Block a user