refactor(mm): add models table (schema WIP), rename "original_hash" -> "hash"

This commit is contained in:
psychedelicious 2024-03-01 16:13:29 +11:00
parent 0cce582f2f
commit a8cd3dfc99
9 changed files with 124 additions and 43 deletions

View File

@ -66,12 +66,12 @@ provides the following fields:
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
| `base_model` | BaseModelType | The base model that the model is compatible with |
| `path` | str | Location of model on disk |
| `original_hash` | str | Hash of the model when it was first installed |
| `hash` | str | Hash of the model |
| `description` | str | Human-readable description of the model (optional) |
| `source` | str | Model's source URL or repo id (optional) |
The `key` is a unique 32-character random ID which was generated at
install time. The `original_hash` field stores a hash of the model's
install time. The `hash` field stores a hash of the model's
contents at install time obtained by sampling several parts of the
model's files using the `imohash` library. Over the course of the
model's lifetime it may be transformed in various ways, such as

View File

@ -70,7 +70,7 @@ example_model_config = {
"format": "checkpoint",
"config_path": "string",
"key": "string",
"original_hash": "string",
"hash": "string",
"description": "string",
"source": "string",
"converted_at": 0,
@ -705,7 +705,7 @@ async def convert_model(
config={
"name": original_name,
"description": model_config.description,
"original_hash": model_config.original_hash,
"hash": model_config.hash,
"source": model_config.source,
},
)

View File

@ -101,16 +101,16 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
try:
self._cursor.execute(
"""--sql
INSERT INTO model_config (
INSERT INTO models (
id,
original_hash,
hash,
config
)
VALUES (?,?,?);
""",
(
key,
record.original_hash,
record.hash,
json_serialized,
),
)
@ -119,9 +119,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "model_config.path" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{record.path}' is already installed"
elif "model_config.name" in str(e):
elif "models.name" in str(e):
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
else:
msg = f"A model with key '{key}' is already installed"
@ -146,7 +146,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
try:
self._cursor.execute(
"""--sql
DELETE FROM model_config
DELETE FROM models
WHERE id=?;
""",
(key,),
@ -172,7 +172,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
try:
self._cursor.execute(
"""--sql
UPDATE model_config
UPDATE models
SET
config=?
WHERE id=?;
@ -199,7 +199,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM model_config
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
@ -220,7 +220,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
select count(*) FROM models
WHERE id=?;
""",
(key,),
@ -265,7 +265,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
f"""--sql
select config, strftime('%s',updated_at) FROM model_config
select config, strftime('%s',updated_at) FROM models
{where};
""",
tuple(bindings),
@ -281,7 +281,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM model_config
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
@ -292,13 +292,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated original_hash."""
"""Return models with the indicated hash."""
results = []
with self._db.lock:
self._cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM model_config
WHERE original_hash=?;
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
@ -370,19 +370,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
# query1: get the total number of model configs
self._cursor.execute(
"""--sql
select count(*) from model_config;
select count(*) from models;
""",
(),
)
total = int(self._cursor.fetchone()[0])
# query2: fetch key fields from the join of model_config and model_metadata
# query2: fetch key fields from the join of models and model_metadata
self._cursor.execute(
f"""--sql
SELECT a.id as key, a.type, a.base, a.format, a.name,
json_extract(a.config, '$.description') as description,
json_extract(b.metadata, '$.tags') as tags
FROM model_config AS a
FROM models AS a
LEFT JOIN model_metadata AS b on a.id=b.id
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?

View File

@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_4())
migrator.register_migration(build_migration_5())
migrator.register_migration(build_migration_6())
migrator.register_migration(build_migration_7())
migrator.run_migrations()
return db

View File

@ -0,0 +1,81 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration7Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._create_models_table(cursor)
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
"""
Adds the timestamp trigger to the model_config table.
This trigger was inadvertently dropped in earlier migration scripts.
"""
tables = [
"""--sql
CREATE TABLE IF NOT EXISTS models (
id TEXT NOT NULL PRIMARY KEY,
-- The next 3 fields are enums in python, unrestricted string here
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL NOT NULL,
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
hash TEXT, -- could be null
-- Serialized JSON representation of the whole config object,
-- which will contain additional fields from subclasses
config TEXT NOT NULL,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- unique constraint on combo of name, base and type
UNIQUE(name, base, type)
);
"""
]
# Add trigger for `updated_at`.
triggers = [
"""--sql
CREATE TRIGGER IF NOT EXISTS models_updated_at
AFTER UPDATE
ON models FOR EACH ROW
BEGIN
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE id = old.id;
END;
"""
]
# Add indexes for searchable fields
indices = [
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
]
for stmt in tables + indices + triggers:
cursor.execute(stmt)
def build_migration_7() -> Migration:
"""
Build the migration from database version 5 to 6.
This migration does the following:
- Adds the model_config_updated_at trigger if it does not exist
- Delete all ip_adapter models so that the model prober can find and
update with the correct image processor model.
"""
migration_7 = Migration(
from_version=6,
to_version=7,
callback=Migration7Callback(),
)
return migration_7

View File

@ -150,7 +150,7 @@ class MigrateModelYamlToDb1:
""",
(
key,
record.original_hash,
record.hash,
json_serialized,
),
)

View File

@ -127,15 +127,13 @@ class ModelConfigBase(BaseModel):
name: str = Field(description="model name")
base: BaseModelType = Field(description="base model")
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
) # this is assigned at install time and will not change
hash: Optional[str] = Field(description="original fasthash of model contents", default=None)
description: Optional[str] = Field(description="human readable description of the model", default=None)
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
schema["required"].extend(["key", "base", "type", "format", "original_hash", "source"])
schema["required"].extend(["key", "base", "type", "format", "hash", "source"])
model_config = ConfigDict(
use_enum_values=False,

View File

@ -161,7 +161,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["original_hash"] = fields.get("original_hash") or hash
fields["hash"] = fields.get("hash") or hash
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()

View File

@ -48,7 +48,7 @@ def example_config() -> TextualInversionFileConfig:
base=BaseModelType.StableDiffusion1,
type=ModelType.TextualInversion,
format=ModelFormat.EmbeddingFile,
original_hash="ABC123",
hash="ABC123",
)
@ -76,7 +76,7 @@ def test_add(store: ModelRecordServiceBase):
assert type(config1) == MainCheckpointConfig
assert config1.base == BaseModelType.StableDiffusion1
assert config1.name == "model1"
assert config1.original_hash == "111222333444"
assert config1.hash == "111222333444"
def test_dup(store: ModelRecordServiceBase):
@ -140,21 +140,21 @@ def test_filter(store: ModelRecordServiceBase):
name="config1",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG2HASH",
hash="CONFIG2HASH",
)
config3 = VaeDiffusersConfig(
path="/tmp/config3",
name="config3",
base=BaseModelType("sd-2"),
type=ModelType.Vae,
original_hash="CONFIG3HASH",
hash="CONFIG3HASH",
)
for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
@ -170,7 +170,7 @@ def test_filter(store: ModelRecordServiceBase):
matches = store.search_by_hash("CONFIG1HASH")
assert len(matches) == 1
assert matches[0].original_hash == "CONFIG1HASH"
assert matches[0].hash == "CONFIG1HASH"
matches = store.all_models()
assert len(matches) == 3
@ -182,28 +182,28 @@ def test_unique(store: ModelRecordServiceBase):
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
base=BaseModelType("sd-2"),
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config3 = VaeDiffusersConfig(
path="/tmp/config3",
base=BaseModelType("sd-2"),
type=ModelType.Vae,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
# config1, config2 and config3 are compatible because they have unique combos
# of name, type and base
@ -221,35 +221,35 @@ def test_filter_2(store: ModelRecordServiceBase):
name="config1",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG2HASH",
hash="CONFIG2HASH",
)
config3 = MainDiffusersConfig(
path="/tmp/config3",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType.Main,
original_hash="CONFIG3HASH",
hash="CONFIG3HASH",
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sdxl"),
type=ModelType.Main,
original_hash="CONFIG3HASH",
hash="CONFIG3HASH",
)
config5 = VaeDiffusersConfig(
path="/tmp/config5",
name="dup_name1",
base=BaseModelType.StableDiffusion1,
type=ModelType.Vae,
original_hash="CONFIG3HASH",
hash="CONFIG3HASH",
)
for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)