From a8cd3dfc993935f11cd49b5336f769faba53e9c8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 1 Mar 2024 16:13:29 +1100 Subject: [PATCH] refactor(mm): add `models` table (schema WIP), rename "original_hash" -> "hash" --- docs/contributing/MODEL_MANAGER.md | 6 +- invokeai/app/api/routers/model_manager.py | 4 +- .../model_records/model_records_sql.py | 34 ++++---- .../app/services/shared/sqlite/sqlite_util.py | 2 + .../sqlite_migrator/migrations/migration_7.py | 81 +++++++++++++++++++ .../migrations/util/migrate_yaml_config_1.py | 2 +- invokeai/backend/model_manager/config.py | 6 +- invokeai/backend/model_manager/probe.py | 2 +- .../model_records/test_model_records_sql.py | 30 +++---- 9 files changed, 124 insertions(+), 43 deletions(-) create mode 100644 invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 8f30ed70a3..e654a52d2c 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -66,12 +66,12 @@ provides the following fields: | `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator | | `base_model` | BaseModelType | The base model that the model is compatible with | | `path` | str | Location of model on disk | -| `original_hash` | str | Hash of the model when it was first installed | +| `hash` | str | Hash of the model | | `description` | str | Human-readable description of the model (optional) | | `source` | str | Model's source URL or repo id (optional) | The `key` is a unique 32-character random ID which was generated at -install time. The `original_hash` field stores a hash of the model's +install time. The `hash` field stores a hash of the model's contents at install time obtained by sampling several parts of the model's files using the `imohash` library. Over the course of the model's lifetime it may be transformed in various ways, such as @@ -373,7 +373,7 @@ functionality: moving it into the InvokeAI root directory under the `models` folder (or wherever config parameter `models_dir` specifies). - + * Probing of models to determine their type, base type and other key information. diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index ea8a0481b5..5eaab13f21 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -70,7 +70,7 @@ example_model_config = { "format": "checkpoint", "config_path": "string", "key": "string", - "original_hash": "string", + "hash": "string", "description": "string", "source": "string", "converted_at": 0, @@ -705,7 +705,7 @@ async def convert_model( config={ "name": original_name, "description": model_config.description, - "original_hash": model_config.original_hash, + "hash": model_config.hash, "source": model_config.source, }, ) diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 60f0ad86a8..4ab7a6f21a 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -101,16 +101,16 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): try: self._cursor.execute( """--sql - INSERT INTO model_config ( + INSERT INTO models ( id, - original_hash, + hash, config ) VALUES (?,?,?); """, ( key, - record.original_hash, + record.hash, json_serialized, ), ) @@ -119,9 +119,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): except sqlite3.IntegrityError as e: self._db.conn.rollback() if "UNIQUE constraint failed" in str(e): - if "model_config.path" in str(e): + if "models.path" in str(e): msg = f"A model with path '{record.path}' is already installed" - elif "model_config.name" in str(e): + elif "models.name" in str(e): msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed" else: msg = f"A model with key '{key}' is already installed" @@ -146,7 +146,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): try: self._cursor.execute( """--sql - DELETE FROM model_config + DELETE FROM models WHERE id=?; """, (key,), @@ -172,7 +172,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): try: self._cursor.execute( """--sql - UPDATE model_config + UPDATE models SET config=? WHERE id=?; @@ -199,7 +199,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.lock: self._cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM model_config + SELECT config, strftime('%s',updated_at) FROM models WHERE id=?; """, (key,), @@ -220,7 +220,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.lock: self._cursor.execute( """--sql - select count(*) FROM model_config + select count(*) FROM models WHERE id=?; """, (key,), @@ -265,7 +265,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.lock: self._cursor.execute( f"""--sql - select config, strftime('%s',updated_at) FROM model_config + select config, strftime('%s',updated_at) FROM models {where}; """, tuple(bindings), @@ -281,7 +281,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): with self._db.lock: self._cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM model_config + SELECT config, strftime('%s',updated_at) FROM models WHERE path=?; """, (str(path),), @@ -292,13 +292,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): return results def search_by_hash(self, hash: str) -> List[AnyModelConfig]: - """Return models with the indicated original_hash.""" + """Return models with the indicated hash.""" results = [] with self._db.lock: self._cursor.execute( """--sql - SELECT config, strftime('%s',updated_at) FROM model_config - WHERE original_hash=?; + SELECT config, strftime('%s',updated_at) FROM models + WHERE hash=?; """, (hash,), ) @@ -370,19 +370,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): # query1: get the total number of model configs self._cursor.execute( """--sql - select count(*) from model_config; + select count(*) from models; """, (), ) total = int(self._cursor.fetchone()[0]) - # query2: fetch key fields from the join of model_config and model_metadata + # query2: fetch key fields from the join of models and model_metadata self._cursor.execute( f"""--sql SELECT a.id as key, a.type, a.base, a.format, a.name, json_extract(a.config, '$.description') as description, json_extract(b.metadata, '$.tags') as tags - FROM model_config AS a + FROM models AS a LEFT JOIN model_metadata AS b on a.id=b.id ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason LIMIT ? diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 681886eacd..25006f5aba 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_4()) migrator.register_migration(build_migration_5()) migrator.register_migration(build_migration_6()) + migrator.register_migration(build_migration_7()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py new file mode 100644 index 0000000000..eaf457743b --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_7.py @@ -0,0 +1,81 @@ +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration7Callback: + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._create_models_table(cursor) + + def _create_models_table(self, cursor: sqlite3.Cursor) -> None: + """ + Adds the timestamp trigger to the model_config table. + + This trigger was inadvertently dropped in earlier migration scripts. + """ + + tables = [ + """--sql + CREATE TABLE IF NOT EXISTS models ( + id TEXT NOT NULL PRIMARY KEY, + -- The next 3 fields are enums in python, unrestricted string here + base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL, + type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL, + name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL, + description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL NOT NULL, + path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL, + format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL, + hash TEXT, -- could be null + -- Serialized JSON representation of the whole config object, + -- which will contain additional fields from subclasses + config TEXT NOT NULL, + created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- Updated via trigger + updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), + -- unique constraint on combo of name, base and type + UNIQUE(name, base, type) + ); + """ + ] + + # Add trigger for `updated_at`. + triggers = [ + """--sql + CREATE TRIGGER IF NOT EXISTS models_updated_at + AFTER UPDATE + ON models FOR EACH ROW + BEGIN + UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW') + WHERE id = old.id; + END; + """ + ] + + # Add indexes for searchable fields + indices = [ + "CREATE INDEX IF NOT EXISTS base_index ON models(base);", + "CREATE INDEX IF NOT EXISTS type_index ON models(type);", + "CREATE INDEX IF NOT EXISTS name_index ON models(name);", + "CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);", + ] + + for stmt in tables + indices + triggers: + cursor.execute(stmt) + + +def build_migration_7() -> Migration: + """ + Build the migration from database version 5 to 6. + + This migration does the following: + - Adds the model_config_updated_at trigger if it does not exist + - Delete all ip_adapter models so that the model prober can find and + update with the correct image processor model. + """ + migration_7 = Migration( + from_version=6, + to_version=7, + callback=Migration7Callback(), + ) + + return migration_7 diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py b/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py index be4d5f0140..bba609ba35 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/util/migrate_yaml_config_1.py @@ -150,7 +150,7 @@ class MigrateModelYamlToDb1: """, ( key, - record.original_hash, + record.hash, json_serialized, ), ) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index b737c22030..521435aae5 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -127,15 +127,13 @@ class ModelConfigBase(BaseModel): name: str = Field(description="model name") base: BaseModelType = Field(description="base model") key: str = Field(description="unique key for model", default="") - original_hash: Optional[str] = Field( - description="original fasthash of model contents", default=None - ) # this is assigned at install time and will not change + hash: Optional[str] = Field(description="original fasthash of model contents", default=None) description: Optional[str] = Field(description="human readable description of the model", default=None) source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None) @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: - schema["required"].extend(["key", "base", "type", "format", "original_hash", "source"]) + schema["required"].extend(["key", "base", "type", "format", "hash", "source"]) model_config = ConfigDict( use_enum_values=False, diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index e52a47b1e9..5d3c982f7f 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -161,7 +161,7 @@ class ModelProbe(object): fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" ) fields["format"] = fields.get("format") or probe.get_format() - fields["original_hash"] = fields.get("original_hash") or hash + fields["hash"] = fields.get("hash") or hash if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 6bf5e027eb..8263d77d79 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -48,7 +48,7 @@ def example_config() -> TextualInversionFileConfig: base=BaseModelType.StableDiffusion1, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile, - original_hash="ABC123", + hash="ABC123", ) @@ -76,7 +76,7 @@ def test_add(store: ModelRecordServiceBase): assert type(config1) == MainCheckpointConfig assert config1.base == BaseModelType.StableDiffusion1 assert config1.name == "model1" - assert config1.original_hash == "111222333444" + assert config1.hash == "111222333444" def test_dup(store: ModelRecordServiceBase): @@ -140,21 +140,21 @@ def test_filter(store: ModelRecordServiceBase): name="config1", base=BaseModelType.StableDiffusion1, type=ModelType.Main, - original_hash="CONFIG1HASH", + hash="CONFIG1HASH", ) config2 = MainDiffusersConfig( path="/tmp/config2", name="config2", base=BaseModelType.StableDiffusion1, type=ModelType.Main, - original_hash="CONFIG2HASH", + hash="CONFIG2HASH", ) config3 = VaeDiffusersConfig( path="/tmp/config3", name="config3", base=BaseModelType("sd-2"), type=ModelType.Vae, - original_hash="CONFIG3HASH", + hash="CONFIG3HASH", ) for c in config1, config2, config3: store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c) @@ -170,7 +170,7 @@ def test_filter(store: ModelRecordServiceBase): matches = store.search_by_hash("CONFIG1HASH") assert len(matches) == 1 - assert matches[0].original_hash == "CONFIG1HASH" + assert matches[0].hash == "CONFIG1HASH" matches = store.all_models() assert len(matches) == 3 @@ -182,28 +182,28 @@ def test_unique(store: ModelRecordServiceBase): base=BaseModelType.StableDiffusion1, type=ModelType.Main, name="nonuniquename", - original_hash="CONFIG1HASH", + hash="CONFIG1HASH", ) config2 = MainDiffusersConfig( path="/tmp/config2", base=BaseModelType("sd-2"), type=ModelType.Main, name="nonuniquename", - original_hash="CONFIG1HASH", + hash="CONFIG1HASH", ) config3 = VaeDiffusersConfig( path="/tmp/config3", base=BaseModelType("sd-2"), type=ModelType.Vae, name="nonuniquename", - original_hash="CONFIG1HASH", + hash="CONFIG1HASH", ) config4 = MainDiffusersConfig( path="/tmp/config4", base=BaseModelType.StableDiffusion1, type=ModelType.Main, name="nonuniquename", - original_hash="CONFIG1HASH", + hash="CONFIG1HASH", ) # config1, config2 and config3 are compatible because they have unique combos # of name, type and base @@ -221,35 +221,35 @@ def test_filter_2(store: ModelRecordServiceBase): name="config1", base=BaseModelType.StableDiffusion1, type=ModelType.Main, - original_hash="CONFIG1HASH", + hash="CONFIG1HASH", ) config2 = MainDiffusersConfig( path="/tmp/config2", name="config2", base=BaseModelType.StableDiffusion1, type=ModelType.Main, - original_hash="CONFIG2HASH", + hash="CONFIG2HASH", ) config3 = MainDiffusersConfig( path="/tmp/config3", name="dup_name1", base=BaseModelType("sd-2"), type=ModelType.Main, - original_hash="CONFIG3HASH", + hash="CONFIG3HASH", ) config4 = MainDiffusersConfig( path="/tmp/config4", name="dup_name1", base=BaseModelType("sdxl"), type=ModelType.Main, - original_hash="CONFIG3HASH", + hash="CONFIG3HASH", ) config5 = VaeDiffusersConfig( path="/tmp/config5", name="dup_name1", base=BaseModelType.StableDiffusion1, type=ModelType.Vae, - original_hash="CONFIG3HASH", + hash="CONFIG3HASH", ) for c in config1, config2, config3, config4, config5: store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)