allow the model record migrate script to update existing model records

This commit is contained in:
Lincoln Stein
2023-12-11 22:47:19 -05:00
parent 11f4a48144
commit b0cfa58526
3 changed files with 30 additions and 27 deletions

View File

@ -2,6 +2,7 @@
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
from hashlib import sha1
from logging import Logger
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
@ -10,6 +11,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import (
@ -38,9 +40,9 @@ class MigrateModelYamlToDb:
"""
config: InvokeAIAppConfig
logger: InvokeAILogger
logger: Logger
def __init__(self):
def __init__(self) -> None:
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
@ -53,9 +55,11 @@ class MigrateModelYamlToDb:
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path)
omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf,DictConfig)
return omegaconf
def migrate(self):
def migrate(self) -> None:
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db()
yaml = self.get_yaml()
@ -69,6 +73,7 @@ class MigrateModelYamlToDb:
base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path)
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
@ -77,12 +82,20 @@ class MigrateModelYamlToDb:
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_config = ModelsValidator.validate_python(stanza)
self.logger.info(f"Adding model {model_name} with key {model_key}")
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
try:
db.add_model(new_key, new_config)
if original_record := db.search_by_path(stanza.path):
key = original_record[0].key
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
db.update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
db.add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
except UnknownModelException:
self.logger.warning(f"Model at {stanza.path} could not be found in database")
def main():