Warn on invalid model configs in the DB rather than crashing.

This commit is contained in:
Ryan Dick 2024-07-10 09:41:46 -04:00
parent 5795617f86
commit 69af099532
5 changed files with 24 additions and 7 deletions

View File

@ -408,7 +408,7 @@ config = get_config()
logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db)
record_store = ModelRecordServiceSQL(db, logger)
queue = DownloadQueueService()
queue.start()

View File

@ -99,7 +99,7 @@ class ApiDependencies:
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db),
model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
download_queue=download_queue_service,
events=events,
)

View File

@ -40,11 +40,14 @@ Typical usage:
"""
import json
import logging
import sqlite3
from math import ceil
from pathlib import Path
from typing import List, Optional, Union
import pydantic
from invokeai.app.services.model_records.model_records_base import (
DuplicateModelException,
ModelRecordChanges,
@ -67,7 +70,7 @@ from invokeai.backend.model_manager.config import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase):
def __init__(self, db: SqliteDatabase, logger: logging.Logger):
"""
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -76,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__()
self._db = db
self._cursor = db.conn.cursor()
self._logger = logger
@property
def db(self) -> SqliteDatabase:
@ -291,7 +295,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
tuple(bindings),
)
result = self._cursor.fetchall()
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
# Parse the model configs.
results: list[AnyModelConfig] = []
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
else:
results.append(model_config)
return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:

View File

@ -40,7 +40,7 @@ def store(
config._root = datadir
logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db)
return ModelRecordServiceSQL(db, logger)
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:

View File

@ -110,7 +110,7 @@ def mm2_installer(
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = TestEventService()
store = ModelRecordServiceSQL(db)
store = ModelRecordServiceSQL(db, logger)
installer = ModelInstallService(
app_config=mm2_app_config,
@ -128,7 +128,7 @@ def mm2_installer(
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db)
store = ModelRecordServiceSQL(db, logger)
# add five simple config records to the database
config1 = VAEDiffusersConfig(
key="test_config_1",