allow the model record migrate script to update existing model records

This commit is contained in:
Lincoln Stein 2023-12-11 22:47:19 -05:00
parent 11f4a48144
commit b0cfa58526
3 changed files with 30 additions and 27 deletions

View File

@ -85,7 +85,7 @@ class ModelInstallService(ModelInstallServiceBase):
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102 def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus return self._event_bus
def stop(self) -> None: def stop(self, *args, **kwargs) -> None:
"""Stop the install thread; after this the object can be deleted and garbage collected.""" """Stop the install thread; after this the object can be deleted and garbage collected."""
self._install_queue.put(STOP_JOB) self._install_queue.put(STOP_JOB)

View File

@ -96,10 +96,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
CREATE TABLE IF NOT EXISTS model_config ( CREATE TABLE IF NOT EXISTS model_config (
id TEXT NOT NULL PRIMARY KEY, id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here -- The next 3 fields are enums in python, unrestricted string here
base TEXT NOT NULL, base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT NOT NULL, type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT NOT NULL, name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
path TEXT NOT NULL, path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
original_hash TEXT, -- could be null original_hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object, -- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses -- which will contain additional fields from subclasses
@ -175,21 +176,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""--sql """--sql
INSERT INTO model_config ( INSERT INTO model_config (
id, id,
base,
type,
name,
path,
original_hash, original_hash,
config config
) )
VALUES (?,?,?,?,?,?,?); VALUES (?,?,?);
""", """,
( (
key, key,
record.base,
record.type,
record.name,
record.path,
record.original_hash, record.original_hash,
json_serialized, json_serialized,
), ),
@ -269,14 +262,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
UPDATE model_config UPDATE model_config
SET base=?, SET
type=?,
name=?,
path=?,
config=? config=?
WHERE id=?; WHERE id=?;
""", """,
(record.base, record.type, record.name, record.path, json_serialized, key), (json_serialized, key),
) )
if self._cursor.rowcount == 0: if self._cursor.rowcount == 0:
raise UnknownModelException("model not found") raise UnknownModelException("model not found")
@ -374,7 +364,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config FROM model_config SELECT config FROM model_config
WHERE model_path=?; WHERE path=?;
""", """,
(str(path),), (str(path),),
) )

View File

@ -2,6 +2,7 @@
"""Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format.""" """Migrate from the InvokeAI v2 models.yaml format to the v3 sqlite format."""
from hashlib import sha1 from hashlib import sha1
from logging import Logger
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pydantic import TypeAdapter from pydantic import TypeAdapter
@ -10,6 +11,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ( from invokeai.app.services.model_records import (
DuplicateModelException, DuplicateModelException,
ModelRecordServiceSQL, ModelRecordServiceSQL,
UnknownModelException,
) )
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -38,9 +40,9 @@ class MigrateModelYamlToDb:
""" """
config: InvokeAIAppConfig config: InvokeAIAppConfig
logger: InvokeAILogger logger: Logger
def __init__(self): def __init__(self) -> None:
self.config = InvokeAIAppConfig.get_config() self.config = InvokeAIAppConfig.get_config()
self.config.parse_args() self.config.parse_args()
self.logger = InvokeAILogger.get_logger() self.logger = InvokeAILogger.get_logger()
@ -53,9 +55,11 @@ class MigrateModelYamlToDb:
def get_yaml(self) -> DictConfig: def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation.""" """Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path) omegaconf = OmegaConf.load(yaml_path)
assert isinstance(omegaconf,DictConfig)
return omegaconf
def migrate(self): def migrate(self) -> None:
"""Do the migration from models.yaml to invokeai.db.""" """Do the migration from models.yaml to invokeai.db."""
db = self.get_db() db = self.get_db()
yaml = self.get_yaml() yaml = self.get_yaml()
@ -69,6 +73,7 @@ class MigrateModelYamlToDb:
base_type, model_type, model_name = str(model_key).split("/") base_type, model_type, model_name = str(model_key).split("/")
hash = FastModelHash.hash(self.config.models_path / stanza.path) hash = FastModelHash.hash(self.config.models_path / stanza.path)
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest() new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type) stanza["base"] = BaseModelType(base_type)
@ -77,12 +82,20 @@ class MigrateModelYamlToDb:
stanza["original_hash"] = hash stanza["original_hash"] = hash
stanza["current_hash"] = hash stanza["current_hash"] = hash
new_config = ModelsValidator.validate_python(stanza) new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
self.logger.info(f"Adding model {model_name} with key {model_key}")
try: try:
if original_record := db.search_by_path(stanza.path):
key = original_record[0].key
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
db.update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
db.add_model(new_key, new_config) db.add_model(new_key, new_config)
except DuplicateModelException: except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database") self.logger.warning(f"Model {model_name} is already in the database")
except UnknownModelException:
self.logger.warning(f"Model at {stanza.path} could not be found in database")
def main(): def main():