mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Warn on invalid model configs in the DB rather than crashing.
This commit is contained in:
parent
5795617f86
commit
69af099532
@ -408,7 +408,7 @@ config = get_config()
|
|||||||
|
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
db = SqliteDatabase(config.db_path, logger)
|
db = SqliteDatabase(config.db_path, logger)
|
||||||
record_store = ModelRecordServiceSQL(db)
|
record_store = ModelRecordServiceSQL(db, logger)
|
||||||
queue = DownloadQueueService()
|
queue = DownloadQueueService()
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
|
@ -99,7 +99,7 @@ class ApiDependencies:
|
|||||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||||
model_manager = ModelManagerService.build_model_manager(
|
model_manager = ModelManagerService.build_model_manager(
|
||||||
app_config=configuration,
|
app_config=configuration,
|
||||||
model_record_service=ModelRecordServiceSQL(db=db),
|
model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
|
||||||
download_queue=download_queue_service,
|
download_queue=download_queue_service,
|
||||||
events=events,
|
events=events,
|
||||||
)
|
)
|
||||||
|
@ -40,11 +40,14 @@ Typical usage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
|
||||||
from invokeai.app.services.model_records.model_records_base import (
|
from invokeai.app.services.model_records.model_records_base import (
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
ModelRecordChanges,
|
ModelRecordChanges,
|
||||||
@ -67,7 +70,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||||
|
|
||||||
def __init__(self, db: SqliteDatabase):
|
def __init__(self, db: SqliteDatabase, logger: logging.Logger):
|
||||||
"""
|
"""
|
||||||
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
Initialize a new object from preexisting sqlite3 connection and threading lock objects.
|
||||||
|
|
||||||
@ -76,6 +79,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._db = db
|
self._db = db
|
||||||
self._cursor = db.conn.cursor()
|
self._cursor = db.conn.cursor()
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db(self) -> SqliteDatabase:
|
def db(self) -> SqliteDatabase:
|
||||||
@ -291,7 +295,20 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
tuple(bindings),
|
tuple(bindings),
|
||||||
)
|
)
|
||||||
result = self._cursor.fetchall()
|
result = self._cursor.fetchall()
|
||||||
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in result]
|
|
||||||
|
# Parse the model configs.
|
||||||
|
results: list[AnyModelConfig] = []
|
||||||
|
for row in result:
|
||||||
|
try:
|
||||||
|
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
|
||||||
|
except pydantic.ValidationError:
|
||||||
|
# We catch this error so that the app can still run if there are invalid model configs in the database.
|
||||||
|
# One reason that an invalid model config might be in the database is if someone had to rollback from a
|
||||||
|
# newer version of the app that added a new model type.
|
||||||
|
self._logger.warning(f"Found an invalid model config in the database. Ignoring this model. ({row[0]})")
|
||||||
|
else:
|
||||||
|
results.append(model_config)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||||
|
@ -40,7 +40,7 @@ def store(
|
|||||||
config._root = datadir
|
config._root = datadir
|
||||||
logger = InvokeAILogger.get_logger(config=config)
|
logger = InvokeAILogger.get_logger(config=config)
|
||||||
db = create_mock_sqlite_database(config, logger)
|
db = create_mock_sqlite_database(config, logger)
|
||||||
return ModelRecordServiceSQL(db)
|
return ModelRecordServiceSQL(db, logger)
|
||||||
|
|
||||||
|
|
||||||
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:
|
def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig:
|
||||||
|
@ -110,7 +110,7 @@ def mm2_installer(
|
|||||||
logger = InvokeAILogger.get_logger()
|
logger = InvokeAILogger.get_logger()
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
events = TestEventService()
|
events = TestEventService()
|
||||||
store = ModelRecordServiceSQL(db)
|
store = ModelRecordServiceSQL(db, logger)
|
||||||
|
|
||||||
installer = ModelInstallService(
|
installer = ModelInstallService(
|
||||||
app_config=mm2_app_config,
|
app_config=mm2_app_config,
|
||||||
@ -128,7 +128,7 @@ def mm2_installer(
|
|||||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
store = ModelRecordServiceSQL(db)
|
store = ModelRecordServiceSQL(db, logger)
|
||||||
# add five simple config records to the database
|
# add five simple config records to the database
|
||||||
config1 = VAEDiffusersConfig(
|
config1 = VAEDiffusersConfig(
|
||||||
key="test_config_1",
|
key="test_config_1",
|
||||||
|
Loading…
Reference in New Issue
Block a user