multiple small stylistic changes requested by reviewers

This commit is contained in:
Lincoln Stein 2023-11-08 16:45:26 -05:00
parent ce22c0fbaa
commit 6b173cc66f
5 changed files with 110 additions and 59 deletions

View File

@ -10,7 +10,7 @@ from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType
# should match the InvokeAI version when this is first released. # should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2" CONFIG_FILE_VERSION = "3.2.0"
class DuplicateModelException(Exception): class DuplicateModelException(Exception):

View File

@ -66,9 +66,8 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase): class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database.""" """Implementation of the ModelConfigStore ABC using a SQL database."""
_conn: sqlite3.Connection _db: SqliteDatabase
_cursor: sqlite3.Cursor _cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, db: SqliteDatabase): def __init__(self, db: SqliteDatabase):
""" """
@ -78,16 +77,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param lock: threading Lock object :param lock: threading Lock object
""" """
super().__init__() super().__init__()
self._conn = db.conn self._db = db
self._lock = db.lock self._db.conn.row_factory = sqlite3.Row
self._conn.row_factory = sqlite3.Row self._cursor = self._db.conn.cursor()
self._cursor = self._conn.cursor()
with self._lock: with self._db.lock:
# Enable foreign keys # Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;") self._db.conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables() self._create_tables()
self._conn.commit() self._db.conn.commit()
assert ( assert (
str(self.version) == CONFIG_FILE_VERSION str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}" ), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
@ -172,7 +170,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""" """
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect. record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string. json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
with self._lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -197,10 +195,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
json_serialized, json_serialized,
), ),
) )
self._conn.commit() self._db.conn.commit()
except sqlite3.IntegrityError as e: except sqlite3.IntegrityError as e:
self._conn.rollback() self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e): if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e): if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed" msg = f"A model with path '{record.path}' is already installed"
@ -210,7 +208,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
else: else:
raise e raise e
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._db.conn.rollback()
raise e raise e
return self.get_model(key) return self.get_model(key)
@ -218,7 +216,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
@property @property
def version(self) -> str: def version(self) -> str:
"""Return the version of the database schema.""" """Return the version of the database schema."""
with self._lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT metadata_value FROM model_manager_metadata SELECT metadata_value FROM model_manager_metadata
@ -239,7 +237,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException Can raise an UnknownModelException
""" """
with self._lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -250,9 +248,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
) )
if self._cursor.rowcount == 0: if self._cursor.rowcount == 0:
raise UnknownModelException("model not found") raise UnknownModelException("model not found")
self._conn.commit() self._db.conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._db.conn.rollback()
raise e raise e
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase: def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
@ -265,7 +263,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""" """
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string. json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
with self._lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -281,9 +279,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
) )
if self._cursor.rowcount == 0: if self._cursor.rowcount == 0:
raise UnknownModelException("model not found") raise UnknownModelException("model not found")
self._conn.commit() self._db.conn.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._db.conn.rollback()
raise e raise e
return self.get_model(key) return self.get_model(key)
@ -296,7 +294,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException Exceptions: UnknownModelException
""" """
with self._lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
SELECT config FROM model_config SELECT config FROM model_config
@ -317,7 +315,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted :param key: Unique key for the model to be deleted
""" """
count = 0 count = 0
with self._lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -360,7 +358,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
where_clause.append("type=?") where_clause.append("type=?")
bindings.append(model_type) bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
f"""--sql f"""--sql
@ -377,7 +375,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]: def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
"""Return models with the indicated path.""" """Return models with the indicated path."""
results = [] results = []
with self._lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
@ -394,7 +392,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_hash(self, hash: str) -> List[ModelConfigBase]: def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
"""Return models with the indicated original_hash.""" """Return models with the indicated original_hash."""
results = [] results = []
with self._lock: with self._db.lock:
try: try:
self._cursor.execute( self._cursor.execute(
"""--sql """--sql

View File

@ -16,8 +16,19 @@ from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig) ModelsValidator = TypeAdapter(AnyModelConfig)
class Migrate: class MigrateModelYamlToDb:
"""Migration class.""" """
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
The class has one externally useful method, migrate(), which scans the
currently models.yaml file and imports all its entries into invokeai.db.
Use this way:
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
MigrateModelYamlToDb().migrate()
"""
config: InvokeAIAppConfig config: InvokeAIAppConfig
logger: InvokeAILogger logger: InvokeAILogger
@ -28,14 +39,17 @@ class Migrate:
self.logger = InvokeAILogger.get_logger() self.logger = InvokeAILogger.get_logger()
def get_db(self) -> ModelRecordServiceSQL: def get_db(self) -> ModelRecordServiceSQL:
"""Fetch the sqlite3 database for this installation."""
db = SqliteDatabase(self.config, self.logger) db = SqliteDatabase(self.config, self.logger)
return ModelRecordServiceSQL(db) return ModelRecordServiceSQL(db)
def get_yaml(self) -> DictConfig: def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path) return OmegaConf.load(yaml_path)
def migrate(self): def migrate(self):
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db() db = self.get_db()
yaml = self.get_yaml() yaml = self.get_yaml()
@ -65,7 +79,7 @@ class Migrate:
def main(): def main():
Migrate().migrate() MigrateModelYamlToDb().migrate()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -386,13 +386,10 @@ class Chdir(object):
class SilenceWarnings(object): class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages.""" """Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __init__(self):
"""Set up context, save current transformers and diffusers verbosity settings."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self): def __enter__(self):
"""Set verbosity to error.""" """Set verbosity to error."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error() transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error() diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore") warnings.simplefilter("ignore")

View File

@ -67,16 +67,10 @@ def test_add(store: ModelRecordServiceBase):
def test_dup(store: ModelRecordServiceBase): def test_dup(store: ModelRecordServiceBase):
config = example_config() config = example_config()
store.add_model("key1", example_config()) store.add_model("key1", example_config())
try: with pytest.raises(DuplicateModelException):
store.add_model("key1", config) store.add_model("key1", config)
assert False, "Duplicate model key should have been caught" with pytest.raises(DuplicateModelException):
except DuplicateModelException:
assert True
try:
store.add_model("key2", config) store.add_model("key2", config)
assert False, "Duplicate model path should have been caught"
except DuplicateModelException:
assert True
def test_update(store: ModelRecordServiceBase): def test_update(store: ModelRecordServiceBase):
@ -90,11 +84,12 @@ def test_update(store: ModelRecordServiceBase):
new_config = store.get_model("key1") new_config = store.get_model("key1")
assert new_config.name == "new name" assert new_config.name == "new name"
try:
def test_unknown_key(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
with pytest.raises(UnknownModelException):
store.update_model("unknown_key", config) store.update_model("unknown_key", config)
assert False, "expected UnknownModelException"
except UnknownModelException:
assert True
def test_delete(store: ModelRecordServiceBase): def test_delete(store: ModelRecordServiceBase):
@ -102,20 +97,8 @@ def test_delete(store: ModelRecordServiceBase):
store.add_model("key1", config) store.add_model("key1", config)
config = store.get_model("key1") config = store.get_model("key1")
store.del_model("key1") store.del_model("key1")
try: with pytest.raises(UnknownModelException):
config = store.get_model("key1") config = store.get_model("key1")
assert False, "expected fetch of deleted model to raise exception"
except UnknownModelException:
assert True
# a bug in sqlite3 in python 3.9 prevents DEL from returning number of
# deleted rows!
if sys.version_info.major == 3 and sys.version_info.minor > 9:
try:
store.del_model("unknown")
assert False, "expected delete of unknown model to raise exception"
except UnknownModelException:
assert True
def test_exists(store: ModelRecordServiceBase): def test_exists(store: ModelRecordServiceBase):
@ -167,3 +150,62 @@ def test_filter(store: ModelRecordServiceBase):
matches = store.all_models() matches = store.all_models()
assert len(matches) == 3 assert len(matches) == 3
def test_filter_2(store: ModelRecordServiceBase):
config1 = DiffusersConfig(
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG1HASH",
)
config2 = DiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG2HASH",
)
config3 = DiffusersConfig(
path="/tmp/config3",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType("main"),
original_hash="CONFIG3HASH",
)
config4 = DiffusersConfig(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType("main"),
original_hash="CONFIG3HASH",
)
config5 = VaeDiffusersConfig(
path="/tmp/config5",
name="dup_name1",
base=BaseModelType("sd-1"),
type=ModelType("vae"),
original_hash="CONFIG3HASH",
)
for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
matches = store.search_by_name(
model_type=ModelType("main"),
model_name="dup_name1",
)
assert len(matches) == 2
matches = store.search_by_name(
base_model=BaseModelType("sd-1"),
model_type=ModelType("main"),
)
assert len(matches) == 2
matches = store.search_by_name(
base_model=BaseModelType("sd-1"),
model_type=ModelType("vae"),
model_name="dup_name1",
)
assert len(matches) == 1