mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(config): edge cases in models.yaml migration
When running the configurator, the `legacy_models_conf_path` was stripped when saving the config file. Then the migration logic didn't fire correctly, and the custom models.yaml paths weren't migrated into the db. - Rework the logic to migrate this path by adding it to the config object as a normal field that is not excluded from serialization. - Rearrange the models.yaml migration logic to remove the legacy path after migrating, then write the config file. This way, the legacy path doesn't stick around. - Move the schema version into the config object. - Back up the config file before attempting migration. - Add tests to cover this edge case
This commit is contained in:
parent
1ed1c1fb24
commit
e76cc71e81
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
@ -46,12 +47,6 @@ class URLRegexTokenPair(BaseModel):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
class ConfigMeta(BaseModel):
|
|
||||||
"""Metadata for the config file. This is not stored in the config object."""
|
|
||||||
|
|
||||||
schema_version: int = CONFIG_SCHEMA_VERSION
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIAppConfig(BaseSettings):
|
class InvokeAIAppConfig(BaseSettings):
|
||||||
"""Invoke's global app configuration.
|
"""Invoke's global app configuration.
|
||||||
|
|
||||||
@ -109,6 +104,10 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
|
||||||
|
# INTERNAL
|
||||||
|
schema_version: int = Field(default=CONFIG_SCHEMA_VERSION, description="Schema version of the config file. This is not a user-configurable setting.")
|
||||||
|
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="Path to the legacy models.yaml file. This is not a user-configurable setting.")
|
||||||
|
|
||||||
# WEB
|
# WEB
|
||||||
host: str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.")
|
host: str = Field(default="127.0.0.1", description="IP address to bind to. Use `0.0.0.0` to serve to your local network.")
|
||||||
port: int = Field(default=9090, description="Port to bind to.")
|
port: int = Field(default=9090, description="Port to bind to.")
|
||||||
@ -175,11 +174,6 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.")
|
hashing_algorithm: HASHING_ALGORITHMS = Field(default="blake3", description="Model hashing algorthim for model installs. 'blake3' is best for SSDs. 'blake3_single' is best for spinning disk HDDs. 'random' disables hashing, instead assigning a UUID to models. Useful when using a memory db to reduce model installation time, or if you don't care about storing stable hashes for models. Alternatively, any other hashlib algorithm is accepted, though these are not nearly as performant as blake3.")
|
||||||
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
|
remote_api_tokens: Optional[list[URLRegexTokenPair]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.")
|
||||||
|
|
||||||
# HIDDEN FIELDS
|
|
||||||
# v4 (MM2) doesn't use `models.yaml` files, but users were able to set paths in the v3 config. When we migrate a
|
|
||||||
# v3 config, we need to save the path to the models.yaml. This is only used during migration.
|
|
||||||
legacy_models_yaml_path: Optional[Path] = Field(default=None, description="The `conf_path` setting from a v3 `invokeai.yaml` file. Only present this app session migrated a config file, and it had `conf_test` on it.", exclude=True)
|
|
||||||
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True)
|
model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True)
|
||||||
@ -217,8 +211,20 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
dest_path: Path to write the config to.
|
dest_path: Path to write the config to.
|
||||||
"""
|
"""
|
||||||
with open(dest_path, "w") as file:
|
with open(dest_path, "w") as file:
|
||||||
meta_dict = {"meta": ConfigMeta().model_dump()}
|
# Meta fields should be written in a separate stanza
|
||||||
config_dict = self.model_dump(mode="json", exclude_unset=True, exclude_defaults=True)
|
meta_dict = self.model_dump(mode="json", include={"schema_version"})
|
||||||
|
# Only include the legacy_models_yaml_path if it's set
|
||||||
|
if self.legacy_models_yaml_path:
|
||||||
|
meta_dict.update(self.model_dump(mode="json", include={"legacy_models_yaml_path"}))
|
||||||
|
|
||||||
|
# User settings
|
||||||
|
config_dict = self.model_dump(
|
||||||
|
mode="json",
|
||||||
|
exclude_unset=True,
|
||||||
|
exclude_defaults=True,
|
||||||
|
exclude={"schema_version", "legacy_models_yaml_path"},
|
||||||
|
)
|
||||||
|
|
||||||
file.write("# Internal metadata - do not edit:\n")
|
file.write("# Internal metadata - do not edit:\n")
|
||||||
file.write(yaml.dump(meta_dict, sort_keys=False))
|
file.write(yaml.dump(meta_dict, sort_keys=False))
|
||||||
file.write("\n")
|
file.write("\n")
|
||||||
@ -370,11 +376,12 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
|
|
||||||
if "InvokeAI" in loaded_config_dict:
|
if "InvokeAI" in loaded_config_dict:
|
||||||
# This is a v3 config file, attempt to migrate it
|
# This is a v3 config file, attempt to migrate it
|
||||||
|
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||||
try:
|
try:
|
||||||
config = migrate_v3_config_dict(loaded_config_dict)
|
config = migrate_v3_config_dict(loaded_config_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||||
config_path.rename(config_path.with_suffix(".yaml.bak"))
|
|
||||||
# By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set
|
# By excluding defaults, we ensure that the new config file only contains the settings that were explicitly set
|
||||||
config.write_file(config_path)
|
config.write_file(config_path)
|
||||||
return config
|
return config
|
||||||
@ -382,11 +389,11 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
# Attempt to load as a v4 config file
|
# Attempt to load as a v4 config file
|
||||||
try:
|
try:
|
||||||
# Meta is not included in the model fields, so we need to validate it separately
|
# Meta is not included in the model fields, so we need to validate it separately
|
||||||
config_meta = ConfigMeta.model_validate(loaded_config_dict.pop("meta"))
|
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||||
assert (
|
assert (
|
||||||
config_meta.schema_version == CONFIG_SCHEMA_VERSION
|
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config_meta.schema_version}"
|
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||||
return InvokeAIAppConfig.model_validate(loaded_config_dict)
|
return config
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||||
|
|
||||||
|
@ -292,45 +292,46 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
|
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not legacy_models_yaml_path.exists():
|
if legacy_models_yaml_path.exists():
|
||||||
# No yaml to migrate
|
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
|
||||||
return
|
|
||||||
|
|
||||||
legacy_models_yaml = yaml.safe_load(legacy_models_yaml_path.read_text())
|
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
||||||
|
yaml_version = yaml_metadata.get("version")
|
||||||
|
|
||||||
yaml_metadata = legacy_models_yaml.pop("__metadata__")
|
if yaml_version != "3.0.0":
|
||||||
yaml_version = yaml_metadata.get("version")
|
raise ValueError(
|
||||||
|
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
|
||||||
|
)
|
||||||
|
|
||||||
if yaml_version != "3.0.0":
|
self._logger.info(
|
||||||
raise ValueError(
|
f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes."
|
||||||
f"Attempted migration of unsupported `models.yaml` v{yaml_version}. Only v3.0.0 is supported. Exiting."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._logger.info(
|
if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
|
||||||
f"Starting one-time migration of {len(legacy_models_yaml.items())} models from {str(legacy_models_yaml_path)}. This may take a few minutes."
|
for model_key, stanza in legacy_models_yaml.items():
|
||||||
)
|
_, _, model_name = str(model_key).split("/")
|
||||||
|
model_path = Path(stanza["path"])
|
||||||
|
if not model_path.is_absolute():
|
||||||
|
model_path = self._app_config.models_path / model_path
|
||||||
|
model_path = model_path.resolve()
|
||||||
|
|
||||||
if len(db_models) == 0 and len(legacy_models_yaml.items()) != 0:
|
config: dict[str, Any] = {}
|
||||||
for model_key, stanza in legacy_models_yaml.items():
|
config["name"] = model_name
|
||||||
_, _, model_name = str(model_key).split("/")
|
config["description"] = stanza.get("description")
|
||||||
model_path = Path(stanza["path"])
|
config["config_path"] = stanza.get("config")
|
||||||
if not model_path.is_absolute():
|
|
||||||
model_path = self._app_config.models_path / model_path
|
|
||||||
model_path = model_path.resolve()
|
|
||||||
|
|
||||||
config: dict[str, Any] = {}
|
try:
|
||||||
config["name"] = model_name
|
id = self.register_path(model_path=model_path, config=config)
|
||||||
config["description"] = stanza.get("description")
|
self._logger.info(f"Migrated {model_name} with id {id}")
|
||||||
config["config_path"] = stanza.get("config")
|
except Exception as e:
|
||||||
|
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
||||||
|
|
||||||
try:
|
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
||||||
id = self.register_path(model_path=model_path, config=config)
|
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||||
self._logger.info(f"Migrated {model_name} with id {id}")
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
|
||||||
|
|
||||||
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
# Remove `legacy_models_yaml_path` from the config file - we are done with it either way
|
||||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
self._app_config.legacy_models_yaml_path = None
|
||||||
|
self._app_config.write_file(self._app_config.init_file_path)
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path).resolve() for x in self.record_store.all_models()}
|
||||||
|
@ -34,6 +34,7 @@ from transformers import AutoFeatureExtractor
|
|||||||
|
|
||||||
import invokeai.configs as model_configs
|
import invokeai.configs as model_configs
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||||
from invokeai.backend.model_manager import ModelType
|
from invokeai.backend.model_manager import ModelType
|
||||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||||
@ -63,8 +64,7 @@ def get_literal_fields(field: str) -> Tuple[Any]:
|
|||||||
|
|
||||||
# --------------------------globals-----------------------
|
# --------------------------globals-----------------------
|
||||||
|
|
||||||
# Start from a fresh config object - we will read the user's config from file later, and update it with their choices
|
config = get_config()
|
||||||
config = InvokeAIAppConfig()
|
|
||||||
|
|
||||||
PRECISION_CHOICES = get_literal_fields("precision")
|
PRECISION_CHOICES = get_literal_fields("precision")
|
||||||
DEVICE_CHOICES = get_literal_fields("device")
|
DEVICE_CHOICES = get_literal_fields("device")
|
||||||
|
@ -3,6 +3,8 @@ from typing import Literal, get_args, get_type_hints
|
|||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||||
|
|
||||||
|
_excluded = {"schema_version", "legacy_models_yaml_path"}
|
||||||
|
|
||||||
|
|
||||||
def generate_config_docstrings() -> str:
|
def generate_config_docstrings() -> str:
|
||||||
"""Helper function for mkdocs. Generates a docstring for the InvokeAIAppConfig class.
|
"""Helper function for mkdocs. Generates a docstring for the InvokeAIAppConfig class.
|
||||||
@ -20,7 +22,7 @@ def generate_config_docstrings() -> str:
|
|||||||
type_hints = get_type_hints(InvokeAIAppConfig)
|
type_hints = get_type_hints(InvokeAIAppConfig)
|
||||||
|
|
||||||
for k, v in InvokeAIAppConfig.model_fields.items():
|
for k, v in InvokeAIAppConfig.model_fields.items():
|
||||||
if v.exclude:
|
if v.exclude or k in _excluded:
|
||||||
continue
|
continue
|
||||||
field_type = type_hints.get(k)
|
field_type = type_hints.get(k)
|
||||||
extra = ""
|
extra = ""
|
||||||
|
@ -9,16 +9,14 @@ from pydantic import ValidationError
|
|||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
|
||||||
|
|
||||||
v4_config = """
|
v4_config = """
|
||||||
meta:
|
schema_version: 4
|
||||||
schema_version: 4
|
|
||||||
|
|
||||||
host: "192.168.1.1"
|
host: "192.168.1.1"
|
||||||
port: 8080
|
port: 8080
|
||||||
"""
|
"""
|
||||||
|
|
||||||
invalid_v5_config = """
|
invalid_v5_config = """
|
||||||
meta:
|
schema_version: 5
|
||||||
schema_version: 5
|
|
||||||
|
|
||||||
host: "192.168.1.1"
|
host: "192.168.1.1"
|
||||||
port: 8080
|
port: 8080
|
||||||
@ -44,6 +42,12 @@ InvokeAI:
|
|||||||
max_vram_cache_size: 50
|
max_vram_cache_size: 50
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
v3_config_with_bad_values = """
|
||||||
|
InvokeAI:
|
||||||
|
Web Server:
|
||||||
|
port: "ice cream"
|
||||||
|
"""
|
||||||
|
|
||||||
invalid_config = """
|
invalid_config = """
|
||||||
i like turtles
|
i like turtles
|
||||||
"""
|
"""
|
||||||
@ -88,6 +92,29 @@ def test_migrate_v3_config_from_file(tmp_path: Path):
|
|||||||
assert not hasattr(config, "esrgan")
|
assert not hasattr(config, "esrgan")
|
||||||
|
|
||||||
|
|
||||||
|
def test_migrate_v3_backup(tmp_path: Path):
|
||||||
|
"""Test the backup of the config file."""
|
||||||
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
|
temp_config_file.write_text(v3_config)
|
||||||
|
|
||||||
|
load_and_migrate_config(temp_config_file)
|
||||||
|
assert temp_config_file.with_suffix(".yaml.bak").exists()
|
||||||
|
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_failed_migrate_backup(tmp_path: Path):
|
||||||
|
"""Test the failed migration of the config file."""
|
||||||
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
|
temp_config_file.write_text(v3_config_with_bad_values)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
load_and_migrate_config(temp_config_file)
|
||||||
|
assert temp_config_file.with_suffix(".yaml.bak").exists()
|
||||||
|
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config_with_bad_values
|
||||||
|
assert temp_config_file.exists()
|
||||||
|
assert temp_config_file.read_text() == v3_config_with_bad_values
|
||||||
|
|
||||||
|
|
||||||
def test_bails_on_invalid_config(tmp_path: Path):
|
def test_bails_on_invalid_config(tmp_path: Path):
|
||||||
"""Test reading configuration from a file."""
|
"""Test reading configuration from a file."""
|
||||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
|
Loading…
Reference in New Issue
Block a user