From 69af099532f0ba0751e5b3a843e5a0ed28aaa342 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 10 Jul 2024 09:41:46 -0400 Subject: [PATCH] Warn on invalid model configs in the DB rather than crashing. --- docs/contributing/MODEL_MANAGER.md | 2 +- invokeai/app/api/dependencies.py | 2 +- .../model_records/model_records_sql.py | 21 +++++++++++++++++-- .../model_records/test_model_records_sql.py | 2 +- .../model_manager/model_manager_fixtures.py | 4 ++-- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 9699db4f1a..52b75d8c39 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -408,7 +408,7 @@ config = get_config() logger = InvokeAILogger.get_logger(config=config) db = SqliteDatabase(config.db_path, logger) -record_store = ModelRecordServiceSQL(db) +record_store = ModelRecordServiceSQL(db, logger) queue = DownloadQueueService() queue.start() diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 27ab030d4c..6e049399db 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -99,7 +99,7 @@ class ApiDependencies: model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_manager = ModelManagerService.build_model_manager( app_config=configuration, - model_record_service=ModelRecordServiceSQL(db=db), + model_record_service=ModelRecordServiceSQL(db=db, logger=logger), download_queue=download_queue_service, events=events, ) diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 2f9829dad4..1d0780efe1 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -40,11 +40,14 @@ Typical usage: """ import json +import logging import sqlite3 from math import ceil from pathlib import Path from typing import List, Optional, Union +import pydantic + from invokeai.app.services.model_records.model_records_base import ( DuplicateModelException, ModelRecordChanges, @@ -67,7 +70,7 @@ from invokeai.backend.model_manager.config import ( class ModelRecordServiceSQL(ModelRecordServiceBase): """Implementation of the ModelConfigStore ABC using a SQL database.""" - def __init__(self, db: SqliteDatabase): + def __init__(self, db: SqliteDatabase, logger: logging.Logger): """ Initialize a new object from preexisting sqlite3 connection and threading lock objects. @@ -76,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): super().__init__() self._db = db self._cursor = db.conn.cursor() + self._logger = logger @property def db(self) -> SqliteDatabase: @@ -291,7 +295,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): tuple(bindings), ) result = self._cursor.fetchall() - results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result] + + # Parse the model configs. + results: list[AnyModelConfig] = [] + for row in result: + try: + model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1]) + except pydantic.ValidationError: + # We catch this error so that the app can still run if there are invalid model configs in the database. + # One reason that an invalid model config might be in the database is if someone had to rollback from a + # newer version of the app that added a new model type. + self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})") + else: + results.append(model_config) + return results def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index d39e95ab3d..e6a89dff06 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -40,7 +40,7 @@ def store( config._root = datadir logger = InvokeAILogger.get_logger(config=config) db = create_mock_sqlite_database(config, logger) - return ModelRecordServiceSQL(db) + return ModelRecordServiceSQL(db, logger) def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig: diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 6fd8c51b54..621b7c65b4 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -110,7 +110,7 @@ def mm2_installer( logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) events = TestEventService() - store = ModelRecordServiceSQL(db) + store = ModelRecordServiceSQL(db, logger) installer = ModelInstallService( app_config=mm2_app_config, @@ -128,7 +128,7 @@ def mm2_installer( def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: logger = InvokeAILogger.get_logger(config=mm2_app_config) db = create_mock_sqlite_database(mm2_app_config, logger) - store = ModelRecordServiceSQL(db) + store = ModelRecordServiceSQL(db, logger) # add five simple config records to the database config1 = VAEDiffusersConfig( key="test_config_1",