multiple small stylistic changes requested by reviewers

This commit is contained in:
Lincoln Stein 2023-11-08 16:45:26 -05:00
parent ce22c0fbaa
commit 6b173cc66f
5 changed files with 110 additions and 59 deletions

View File

@ -10,7 +10,7 @@ from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType
# should match the InvokeAI version when this is first released.
CONFIG_FILE_VERSION = "3.2"
CONFIG_FILE_VERSION = "3.2.0"
class DuplicateModelException(Exception):

View File

@ -66,9 +66,8 @@ from .model_records_base import (
class ModelRecordServiceSQL(ModelRecordServiceBase):
"""Implementation of the ModelConfigStore ABC using a SQL database."""
_conn: sqlite3.Connection
_db: SqliteDatabase
_cursor: sqlite3.Cursor
_lock: threading.Lock
def __init__(self, db: SqliteDatabase):
"""
@ -78,16 +77,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param lock: threading Lock object
"""
super().__init__()
self._conn = db.conn
self._lock = db.lock
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._db = db
self._db.conn.row_factory = sqlite3.Row
self._cursor = self._db.conn.cursor()
with self._lock:
with self._db.lock:
# Enable foreign keys
self._conn.execute("PRAGMA foreign_keys = ON;")
self._db.conn.execute("PRAGMA foreign_keys = ON;")
self._create_tables()
self._conn.commit()
self._db.conn.commit()
assert (
str(self.version) == CONFIG_FILE_VERSION
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
@ -172,7 +170,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
with self._lock:
with self._db.lock:
try:
self._cursor.execute(
"""--sql
@ -197,10 +195,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
json_serialized,
),
)
self._conn.commit()
self._db.conn.commit()
except sqlite3.IntegrityError as e:
self._conn.rollback()
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
@ -210,7 +208,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
else:
raise e
except sqlite3.Error as e:
self._conn.rollback()
self._db.conn.rollback()
raise e
return self.get_model(key)
@ -218,7 +216,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
@property
def version(self) -> str:
"""Return the version of the database schema."""
with self._lock:
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT metadata_value FROM model_manager_metadata
@ -239,7 +237,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
with self._lock:
with self._db.lock:
try:
self._cursor.execute(
"""--sql
@ -250,9 +248,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._conn.commit()
self._db.conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
self._db.conn.rollback()
raise e
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
@ -265,7 +263,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
with self._lock:
with self._db.lock:
try:
self._cursor.execute(
"""--sql
@ -281,9 +279,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if self._cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._conn.commit()
self._db.conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
self._db.conn.rollback()
raise e
return self.get_model(key)
@ -296,7 +294,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
with self._lock:
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
@ -317,7 +315,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
count = 0
with self._lock:
with self._db.lock:
try:
self._cursor.execute(
"""--sql
@ -360,7 +358,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
where_clause.append("type=?")
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._lock:
with self._db.lock:
try:
self._cursor.execute(
f"""--sql
@ -377,7 +375,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
"""Return models with the indicated path."""
results = []
with self._lock:
with self._db.lock:
try:
self._cursor.execute(
"""--sql
@ -394,7 +392,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
"""Return models with the indicated original_hash."""
results = []
with self._lock:
with self._db.lock:
try:
self._cursor.execute(
"""--sql

View File

@ -16,8 +16,19 @@ from invokeai.backend.util.logging import InvokeAILogger
ModelsValidator = TypeAdapter(AnyModelConfig)
class Migrate:
"""Migration class."""
class MigrateModelYamlToDb:
"""
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
The class has one externally useful method, migrate(), which scans the
currently models.yaml file and imports all its entries into invokeai.db.
Use this way:
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
MigrateModelYamlToDb().migrate()
"""
config: InvokeAIAppConfig
logger: InvokeAILogger
@ -28,14 +39,17 @@ class Migrate:
self.logger = InvokeAILogger.get_logger()
def get_db(self) -> ModelRecordServiceSQL:
"""Fetch the sqlite3 database for this installation."""
db = SqliteDatabase(self.config, self.logger)
return ModelRecordServiceSQL(db)
def get_yaml(self) -> DictConfig:
"""Fetch the models.yaml DictConfig for this installation."""
yaml_path = self.config.model_conf_path
return OmegaConf.load(yaml_path)
def migrate(self):
"""Do the migration from models.yaml to invokeai.db."""
db = self.get_db()
yaml = self.get_yaml()
@ -65,7 +79,7 @@ class Migrate:
def main():
Migrate().migrate()
MigrateModelYamlToDb().migrate()
if __name__ == "__main__":

View File

@ -386,13 +386,10 @@ class Chdir(object):
class SilenceWarnings(object):
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
def __init__(self):
"""Set up context, save current transformers and diffusers verbosity settings."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
"""Set verbosity to error."""
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")

View File

@ -67,16 +67,10 @@ def test_add(store: ModelRecordServiceBase):
def test_dup(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", example_config())
try:
with pytest.raises(DuplicateModelException):
store.add_model("key1", config)
assert False, "Duplicate model key should have been caught"
except DuplicateModelException:
assert True
try:
with pytest.raises(DuplicateModelException):
store.add_model("key2", config)
assert False, "Duplicate model path should have been caught"
except DuplicateModelException:
assert True
def test_update(store: ModelRecordServiceBase):
@ -90,11 +84,12 @@ def test_update(store: ModelRecordServiceBase):
new_config = store.get_model("key1")
assert new_config.name == "new name"
try:
def test_unknown_key(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
with pytest.raises(UnknownModelException):
store.update_model("unknown_key", config)
assert False, "expected UnknownModelException"
except UnknownModelException:
assert True
def test_delete(store: ModelRecordServiceBase):
@ -102,20 +97,8 @@ def test_delete(store: ModelRecordServiceBase):
store.add_model("key1", config)
config = store.get_model("key1")
store.del_model("key1")
try:
with pytest.raises(UnknownModelException):
config = store.get_model("key1")
assert False, "expected fetch of deleted model to raise exception"
except UnknownModelException:
assert True
# a bug in sqlite3 in python 3.9 prevents DEL from returning number of
# deleted rows!
if sys.version_info.major == 3 and sys.version_info.minor > 9:
try:
store.del_model("unknown")
assert False, "expected delete of unknown model to raise exception"
except UnknownModelException:
assert True
def test_exists(store: ModelRecordServiceBase):
@ -167,3 +150,62 @@ def test_filter(store: ModelRecordServiceBase):
matches = store.all_models()
assert len(matches) == 3
def test_filter_2(store: ModelRecordServiceBase):
config1 = DiffusersConfig(
path="/tmp/config1",
name="config1",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG1HASH",
)
config2 = DiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType("sd-1"),
type=ModelType("main"),
original_hash="CONFIG2HASH",
)
config3 = DiffusersConfig(
path="/tmp/config3",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType("main"),
original_hash="CONFIG3HASH",
)
config4 = DiffusersConfig(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType("main"),
original_hash="CONFIG3HASH",
)
config5 = VaeDiffusersConfig(
path="/tmp/config5",
name="dup_name1",
base=BaseModelType("sd-1"),
type=ModelType("vae"),
original_hash="CONFIG3HASH",
)
for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
matches = store.search_by_name(
model_type=ModelType("main"),
model_name="dup_name1",
)
assert len(matches) == 2
matches = store.search_by_name(
base_model=BaseModelType("sd-1"),
model_type=ModelType("main"),
)
assert len(matches) == 2
matches = store.search_by_name(
base_model=BaseModelType("sd-1"),
model_type=ModelType("vae"),
model_name="dup_name1",
)
assert len(matches) == 1