mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
multiple small stylistic changes requested by reviewers
This commit is contained in:
parent
ce22c0fbaa
commit
6b173cc66f
@ -10,7 +10,7 @@ from typing import List, Optional, Union
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType
|
||||
|
||||
# should match the InvokeAI version when this is first released.
|
||||
CONFIG_FILE_VERSION = "3.2"
|
||||
CONFIG_FILE_VERSION = "3.2.0"
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
|
@ -66,9 +66,8 @@ from .model_records_base import (
|
||||
class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""Implementation of the ModelConfigStore ABC using a SQL database."""
|
||||
|
||||
_conn: sqlite3.Connection
|
||||
_db: SqliteDatabase
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
"""
|
||||
@ -78,16 +77,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
:param lock: threading Lock object
|
||||
"""
|
||||
super().__init__()
|
||||
self._conn = db.conn
|
||||
self._lock = db.lock
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._db = db
|
||||
self._db.conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._db.conn.cursor()
|
||||
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
# Enable foreign keys
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._db.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
self._db.conn.commit()
|
||||
assert (
|
||||
str(self.version) == CONFIG_FILE_VERSION
|
||||
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
|
||||
@ -172,7 +170,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
|
||||
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -197,10 +195,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._db.conn.commit()
|
||||
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._conn.rollback()
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "model_config.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
@ -210,7 +208,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
@ -218,7 +216,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
@property
|
||||
def version(self) -> str:
|
||||
"""Return the version of the database schema."""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT metadata_value FROM model_manager_metadata
|
||||
@ -239,7 +237,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Can raise an UnknownModelException
|
||||
"""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -250,9 +248,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._conn.commit()
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
|
||||
@ -265,7 +263,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
"""
|
||||
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
|
||||
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -281,9 +279,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
)
|
||||
if self._cursor.rowcount == 0:
|
||||
raise UnknownModelException("model not found")
|
||||
self._conn.commit()
|
||||
self._db.conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
return self.get_model(key)
|
||||
@ -296,7 +294,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
|
||||
Exceptions: UnknownModelException
|
||||
"""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config FROM model_config
|
||||
@ -317,7 +315,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
:param key: Unique key for the model to be deleted
|
||||
"""
|
||||
count = 0
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -360,7 +358,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
@ -377,7 +375,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
|
||||
"""Return models with the indicated path."""
|
||||
results = []
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
@ -394,7 +392,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
|
||||
"""Return models with the indicated original_hash."""
|
||||
results = []
|
||||
with self._lock:
|
||||
with self._db.lock:
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
|
@ -16,8 +16,19 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
ModelsValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
class Migrate:
|
||||
"""Migration class."""
|
||||
class MigrateModelYamlToDb:
|
||||
"""
|
||||
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
|
||||
|
||||
The class has one externally useful method, migrate(), which scans the
|
||||
currently models.yaml file and imports all its entries into invokeai.db.
|
||||
|
||||
Use this way:
|
||||
|
||||
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
|
||||
MigrateModelYamlToDb().migrate()
|
||||
|
||||
"""
|
||||
|
||||
config: InvokeAIAppConfig
|
||||
logger: InvokeAILogger
|
||||
@ -28,14 +39,17 @@ class Migrate:
|
||||
self.logger = InvokeAILogger.get_logger()
|
||||
|
||||
def get_db(self) -> ModelRecordServiceSQL:
|
||||
"""Fetch the sqlite3 database for this installation."""
|
||||
db = SqliteDatabase(self.config, self.logger)
|
||||
return ModelRecordServiceSQL(db)
|
||||
|
||||
def get_yaml(self) -> DictConfig:
|
||||
"""Fetch the models.yaml DictConfig for this installation."""
|
||||
yaml_path = self.config.model_conf_path
|
||||
return OmegaConf.load(yaml_path)
|
||||
|
||||
def migrate(self):
|
||||
"""Do the migration from models.yaml to invokeai.db."""
|
||||
db = self.get_db()
|
||||
yaml = self.get_yaml()
|
||||
|
||||
@ -65,7 +79,7 @@ class Migrate:
|
||||
|
||||
|
||||
def main():
|
||||
Migrate().migrate()
|
||||
MigrateModelYamlToDb().migrate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -386,13 +386,10 @@ class Chdir(object):
|
||||
class SilenceWarnings(object):
|
||||
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
|
||||
|
||||
def __init__(self):
|
||||
"""Set up context, save current transformers and diffusers verbosity settings."""
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
"""Set verbosity to error."""
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter("ignore")
|
||||
|
@ -67,16 +67,10 @@ def test_add(store: ModelRecordServiceBase):
|
||||
def test_dup(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", example_config())
|
||||
try:
|
||||
with pytest.raises(DuplicateModelException):
|
||||
store.add_model("key1", config)
|
||||
assert False, "Duplicate model key should have been caught"
|
||||
except DuplicateModelException:
|
||||
assert True
|
||||
try:
|
||||
with pytest.raises(DuplicateModelException):
|
||||
store.add_model("key2", config)
|
||||
assert False, "Duplicate model path should have been caught"
|
||||
except DuplicateModelException:
|
||||
assert True
|
||||
|
||||
|
||||
def test_update(store: ModelRecordServiceBase):
|
||||
@ -90,11 +84,12 @@ def test_update(store: ModelRecordServiceBase):
|
||||
new_config = store.get_model("key1")
|
||||
assert new_config.name == "new name"
|
||||
|
||||
try:
|
||||
|
||||
def test_unknown_key(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
with pytest.raises(UnknownModelException):
|
||||
store.update_model("unknown_key", config)
|
||||
assert False, "expected UnknownModelException"
|
||||
except UnknownModelException:
|
||||
assert True
|
||||
|
||||
|
||||
def test_delete(store: ModelRecordServiceBase):
|
||||
@ -102,20 +97,8 @@ def test_delete(store: ModelRecordServiceBase):
|
||||
store.add_model("key1", config)
|
||||
config = store.get_model("key1")
|
||||
store.del_model("key1")
|
||||
try:
|
||||
with pytest.raises(UnknownModelException):
|
||||
config = store.get_model("key1")
|
||||
assert False, "expected fetch of deleted model to raise exception"
|
||||
except UnknownModelException:
|
||||
assert True
|
||||
|
||||
# a bug in sqlite3 in python 3.9 prevents DEL from returning number of
|
||||
# deleted rows!
|
||||
if sys.version_info.major == 3 and sys.version_info.minor > 9:
|
||||
try:
|
||||
store.del_model("unknown")
|
||||
assert False, "expected delete of unknown model to raise exception"
|
||||
except UnknownModelException:
|
||||
assert True
|
||||
|
||||
|
||||
def test_exists(store: ModelRecordServiceBase):
|
||||
@ -167,3 +150,62 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
|
||||
matches = store.all_models()
|
||||
assert len(matches) == 3
|
||||
|
||||
|
||||
def test_filter_2(store: ModelRecordServiceBase):
|
||||
config1 = DiffusersConfig(
|
||||
path="/tmp/config1",
|
||||
name="config1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = DiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG2HASH",
|
||||
)
|
||||
config3 = DiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
config4 = DiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType("main"),
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
config5 = VaeDiffusersConfig(
|
||||
path="/tmp/config5",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-1"),
|
||||
type=ModelType("vae"),
|
||||
original_hash="CONFIG3HASH",
|
||||
)
|
||||
for c in config1, config2, config3, config4, config5:
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||
|
||||
matches = store.search_by_name(
|
||||
model_type=ModelType("main"),
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_name(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("main"),
|
||||
)
|
||||
assert len(matches) == 2
|
||||
|
||||
matches = store.search_by_name(
|
||||
base_model=BaseModelType("sd-1"),
|
||||
model_type=ModelType("vae"),
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 1
|
Loading…
Reference in New Issue
Block a user