allow the model record migrate script to update existing model records

This commit is contained in:
Lincoln Stein 2023-12-11 22:47:19 -05:00
parent 11f4a48144
commit b0cfa58526
3 changed files with 30 additions and 27 deletions

View File

@ -85,7 +85,7 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
def stop(self) -> None:
def stop(self, *args, **kwargs) -> None:
"""Stop the install thread; after this the object can be deleted and garbage collected."""
self._install_queue.put(STOP_JOB)

View File

@ -96,10 +96,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL,
type TEXT NOT NULL,
name TEXT NOT NULL,
path TEXT NOT NULL,
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
@ -175,21 +176,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""--sql
INSERT INTO model_config (
id,
base,
type,
name,
path,
original_hash,
config
)
VALUES (?,?,?,?,?,?,?);
VALUES (?,?,?);
""",
(
key,
record.base,
record.type,
record.name,
record.path,
record.original_hash,
json_serialized,
),
@ -269,14 +262,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute(
"""--sql
UPDATE model_config
SET base=?,
type=?,
name=?,
path=?,
SET
config=?
WHERE id=?;
""",
(record.base, record.type, record.name, record.path, json_serialized, key),
(json_serialized, key),
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
@ -374,7 +364,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
WHERE path=?;
""",
(str(path),),
)

View File

@ -2,6 +2,7 @@
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
from hashlib import sha1
from logging import Logger
from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter
@ -10,6 +11,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordServiceSQL,
UnknownModelException,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import (
@ -38,9 +40,9 @@ class MigrateModelYamlToDb:
"""
config: InvokeAIAppConfig
logger: InvokeAILogger
logger: Logger
def __init__(self):
def __init__(self) -> None:
self.config = InvokeAIAppConfig.get_config()
self.config.parse_args()
self.logger = InvokeAILogger.get_logger()
@ -53,9 +55,11 @@ class MigrateModelYamlToDb:
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path)
omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf,DictConfig)
return omegaconf
def migrate(self):
def migrate(self) -> None:
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db()
yaml = self.get_yaml()
@ -69,6 +73,7 @@ class MigrateModelYamlToDb:
base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path)
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
@ -77,12 +82,20 @@ class MigrateModelYamlToDb:
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_config = ModelsValidator.validate_python(stanza)
self.logger.info(f"Adding model {model_name} with key {model_key}")
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
try:
db.add_model(new_key, new_config)
if original_record := db.search_by_path(stanza.path):
key = original_record[0].key
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
db.update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
db.add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
except UnknownModelException:
self.logger.warning(f"Model at {stanza.path} could not be found in database")
def main():