Do not crash if there are invalid model configs in the DB (#6593)

## Summary

This PR changes the handling of invalid model configs in the DB to log a
warning rather than crashing the app.

This change is being made in preparation for some upcoming new model
additions. Previously, if a user rolled back from an app version that
added a new model type, the app would not launch until the DB was fixed.
This PR changes this behaviour to allow rollbacks of this type (with
warnings).

**Keep in mind that this change is only helpful to users _rolling back
to a version that has this fix_. I.e. it offers no help in the first
version that includes it.**

## QA Instructions

1. Run the Spandrel model branch, which adds a new model type
https://github.com/invoke-ai/InvokeAI/pull/6556.
2. Add a spandrel model via the model manager.
3. Rollback to main. The app will crash on launch due to the invalid
spandrel model config.
4. Checkout this branch. The app should now run with warnings about the
invalid model config.


## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick 2024-07-11 21:15:51 -04:00 committed by GitHub
commit 2320701929
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 24 additions and 7 deletions

View File

@ -408,7 +408,7 @@ config = get_config()
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = SqliteDatabase(config.db_path, logger) db = SqliteDatabase(config.db_path, logger)
record_store = ModelRecordServiceSQL(db) record_store = ModelRecordServiceSQL(db, logger)
queue = DownloadQueueService() queue = DownloadQueueService()
queue.start() queue.start()

View File

@ -99,7 +99,7 @@ class ApiDependencies:
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db), model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
download_queue=download_queue_service, download_queue=download_queue_service,
events=events, events=events,
) )

View File

@ -40,11 +40,14 @@ Typical usage:
""" """
import json import json
import logging
import sqlite3 import sqlite3
from math import ceil from math import ceil
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
import pydantic
from invokeai.app.services.model_records.model_records_base import ( from invokeai.app.services.model_records.model_records_base import (
DuplicateModelException, DuplicateModelException,
ModelRecordChanges, ModelRecordChanges,
@ -67,7 +70,7 @@ from invokeai.backend.model_manager.config import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
def __init__(self, db: SqliteDatabase): def __init__(self, db: SqliteDatabase, logger: logging.Logger):
""" """
Initialize a new object from preexisting sqlite3 connection and threading lock objects. Initialize a new object from preexisting sqlite3 connection and threading lock objects.
@ -76,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
super().__init__() super().__init__()
self._db = db self._db = db
self._cursor = db.conn.cursor() self._cursor = db.conn.cursor()
self._logger = logger
@property @property
def db(self) -> SqliteDatabase: def db(self) -> SqliteDatabase:
@ -291,7 +295,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
tuple(bindings), tuple(bindings),
) )
result = self._cursor.fetchall() result = self._cursor.fetchall()
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
# Parse the model configs.
results: list[AnyModelConfig] = []
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
else:
results.append(model_config)
return results return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]: def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:

View File

@ -40,7 +40,7 @@ def store(
config._root = datadir config._root = datadir
logger = InvokeAILogger.get_logger(config=config) logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger) db = create_mock_sqlite_database(config, logger)
return ModelRecordServiceSQL(db) return ModelRecordServiceSQL(db, logger)
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig: def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:

View File

@ -110,7 +110,7 @@ def mm2_installer(
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
events = TestEventService() events = TestEventService()
store = ModelRecordServiceSQL(db) store = ModelRecordServiceSQL(db, logger)
installer = ModelInstallService( installer = ModelInstallService(
app_config=mm2_app_config, app_config=mm2_app_config,
@ -128,7 +128,7 @@ def mm2_installer(
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase: def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config) logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db) store = ModelRecordServiceSQL(db, logger)
# add five simple config records to the database # add five simple config records to the database
config1 = VAEDiffusersConfig( config1 = VAEDiffusersConfig(
key="test_config_1", key="test_config_1",