mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(mm): yaml migration fixup
- If the metadata yaml has an invalid version, exist the app. If we don't, the app will crawl the models dir and add models to the db without having first parsed `models.yaml`. This should not happen often, as the vast majority of users are on v3.0.0 models.yaml files. - Fix off-by-one error with models count (need to pop the `__metadata__` stanza - After a successful migration, rename `models.yaml` to `models.yaml.bak` to prevent the migration logic from re-running on subsequent app startups.
This commit is contained in:
parent
67163c2224
commit
d4686b7f64
@ -1,7 +1,6 @@
|
|||||||
"""Model installation class."""
|
"""Model installation class."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -297,38 +296,40 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
except OSError:
|
except OSError:
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(db_models) == 0 and len(yaml.items()) != 0:
|
yaml_metadata = yaml.pop("__metadata__")
|
||||||
|
yaml_version = yaml_metadata.get("version")
|
||||||
|
|
||||||
|
if yaml_version != "3.0.0":
|
||||||
|
raise ValueError(
|
||||||
|
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
|
||||||
|
)
|
||||||
|
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
f"Starting one-time migration of {len(yaml.items())} models in yaml file to DB. This may take a few minutes."
|
f"Starting one-time migration of {len(yaml.items())} models from `models.yaml` to database. This may take a few minutes."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(db_models) == 0 and len(yaml.items()) != 0:
|
||||||
for model_key, stanza in yaml.items():
|
for model_key, stanza in yaml.items():
|
||||||
if model_key == "__metadata__":
|
|
||||||
try:
|
|
||||||
assert (
|
|
||||||
stanza["version"] == "3.0.0"
|
|
||||||
), f"This script works on version 3.0.0 yaml files, but your configuration points to a {stanza['version']} version"
|
|
||||||
except AssertionError:
|
|
||||||
self._logger.warn(
|
|
||||||
f"Skipping entry with path {stanza.get('path', '')} with outdated metadata version"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
continue
|
|
||||||
|
|
||||||
_, _, model_name = str(model_key).split("/")
|
_, _, model_name = str(model_key).split("/")
|
||||||
model_path = pathlib.Path(stanza["path"])
|
model_path = Path(stanza["path"])
|
||||||
if not model_path.is_absolute():
|
if not model_path.is_absolute():
|
||||||
model_path = self._app_config.models_path / model_path
|
model_path = self._app_config.models_path / model_path
|
||||||
model_path = model_path.resolve()
|
model_path = model_path.resolve()
|
||||||
description = stanza.get("description", "")
|
|
||||||
config_path = stanza.get("config", "")
|
config: dict[str, Any] = {}
|
||||||
model_info = {"name": model_name, "description": description, "config_path": config_path}
|
config["name"] = model_name
|
||||||
|
config["description"] = stanza.get("description")
|
||||||
|
config["config_path"] = stanza.get("config")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
id = self.register_path(model_path=model_path, config=model_info)
|
id = self.register_path(model_path=model_path, config=config)
|
||||||
self._logger.info(f"Registered {model_name} with id {id}")
|
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.warning(f"Model at {model_path} could not be loaded into database: {e}")
|
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
||||||
|
|
||||||
|
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
||||||
|
yaml_path = self._app_config.model_conf_path
|
||||||
|
yaml_path.rename(yaml_path.with_suffix(".bak"))
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||||
|
Loading…
Reference in New Issue
Block a user