Warn on invalid model configs in the DB rather than crashing.

This commit is contained in:
Ryan Dick 2024-07-10 09:41:46 -04:00
parent 5795617f86
commit 69af099532
5 changed files with 24 additions and 7 deletions

View File

@ -408,7 +408,7 @@ config = get_config()
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger) db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db) record_store = ModelRecordServiceSQL(db, logger)
queue = DownloadQueueService() queue = DownloadQueueService()
queue.start() queue.start()

View File

@ -99,7 +99,7 @@ class ApiDependencies:
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db), model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
download_queue=download_queue_service, download_queue=download_queue_service,
events=events, events=events,
) )

View File

@ -40,11 +40,14 @@ Typical usage:
""" """
import json import json
import logging
import sqlite3 import sqlite3
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
import pydantic
from invokeai.app.services.model_records.model_records_base import ( from invokeai.app.services.model_records.model_records_base import (
DuplicateModelException, DuplicateModelException,
ModelRecordChanges, ModelRecordChanges,
@ -67,7 +70,7 @@ from invokeai.backend.model_manager.config import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase): def __init__(self, db: SqliteDatabase, logger: logging.Logger):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -76,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__() super().__init__()
self._db = db self._db = db
self._cursor = db.conn.cursor() self._cursor = db.conn.cursor()
self._logger = logger
@property @property
def db(self) -> SqliteDatabase: def db(self) -> SqliteDatabase:
@ -291,7 +295,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
tuple(bindings), tuple(bindings),
) )
result = self._cursor.fetchall() result = self._cursor.fetchall()
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
# Parse the model configs.
results: list[AnyModelConfig] = []
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
else:
results.append(model_config)
return results return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:

View File

@ -40,7 +40,7 @@ def store(
config._root = datadir config._root = datadir
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger) db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db) return ModelRecordServiceSQL(db, logger)
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig: def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:

View File

@ -110,7 +110,7 @@ def mm2_installer(
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
events = TestEventService() events = TestEventService()
store = ModelRecordServiceSQL(db) store = ModelRecordServiceSQL(db, logger)
installer = ModelInstallService( installer = ModelInstallService(
app_config=mm2_app_config, app_config=mm2_app_config,
@ -128,7 +128,7 @@ def mm2_installer(
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config) logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db) store = ModelRecordServiceSQL(db, logger)
# add five simple config records to the database # add five simple config records to the database
config1 = VAEDiffusersConfig( config1 = VAEDiffusersConfig(
key="test_config_1", key="test_config_1",