mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): add models
table (schema WIP), rename "original_hash" -> "hash"
This commit is contained in:
parent
0cce582f2f
commit
a8cd3dfc99
@ -66,12 +66,12 @@ provides the following fields:
|
||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||
| `path` | str | Location of model on disk |
|
||||
| `original_hash` | str | Hash of the model when it was first installed |
|
||||
| `hash` | str | Hash of the model |
|
||||
| `description` | str | Human-readable description of the model (optional) |
|
||||
| `source` | str | Model's source URL or repo id (optional) |
|
||||
|
||||
The `key` is a unique 32-character random ID which was generated at
|
||||
install time. The `original_hash` field stores a hash of the model's
|
||||
install time. The `hash` field stores a hash of the model's
|
||||
contents at install time obtained by sampling several parts of the
|
||||
model's files using the `imohash` library. Over the course of the
|
||||
model's lifetime it may be transformed in various ways, such as
|
||||
@ -373,7 +373,7 @@ functionality:
|
||||
moving it into the InvokeAI root directory under the
|
||||
`models` folder (or wherever config parameter `models_dir`
|
||||
specifies).
|
||||
|
||||
|
||||
* Probing of models to determine their type, base type and other key
|
||||
information.
|
||||
|
||||
|
@ -70,7 +70,7 @@ example_model_config = {
|
||||
"format": "checkpoint",
|
||||
"config_path": "string",
|
||||
"key": "string",
|
||||
"original_hash": "string",
|
||||
"hash": "string",
|
||||
"description": "string",
|
||||
"source": "string",
|
||||
"converted_at": 0,
|
||||
@ -705,7 +705,7 @@ async def convert_model(
|
||||
config={
|
||||
"name": original_name,
|
||||
"description": model_config.description,
|
||||
"original_hash": model_config.original_hash,
|
||||
"hash": model_config.hash,
|
||||
"source": model_config.source,
|
||||
},
|
||||
)
|
||||
|
@ -101,16 +101,16 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO model_config (
|
||||
INSERT INTO models (
|
||||
id,
|
||||
original_hash,
|
||||
hash,
|
||||
config
|
||||
)
|
||||
VALUES (?,?,?);
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
record.hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
@ -119,9 +119,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._db.conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
if "model_config.path" in str(e):
|
||||
if "models.path" in str(e):
|
||||
msg = f"A model with path '{record.path}' is already installed"
|
||||
elif "model_config.name" in str(e):
|
||||
elif "models.name" in str(e):
|
||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{key}' is already installed"
|
||||
@ -146,7 +146,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM model_config
|
||||
DELETE FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -172,7 +172,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
try:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE model_config
|
||||
UPDATE models
|
||||
SET
|
||||
config=?
|
||||
WHERE id=?;
|
||||
@ -199,7 +199,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -220,7 +220,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) FROM model_config
|
||||
select count(*) FROM models
|
||||
WHERE id=?;
|
||||
""",
|
||||
(key,),
|
||||
@ -265,7 +265,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
select config, strftime('%s',updated_at) FROM model_config
|
||||
select config, strftime('%s',updated_at) FROM models
|
||||
{where};
|
||||
""",
|
||||
tuple(bindings),
|
||||
@ -281,7 +281,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE path=?;
|
||||
""",
|
||||
(str(path),),
|
||||
@ -292,13 +292,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated original_hash."""
|
||||
"""Return models with the indicated hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
||||
WHERE original_hash=?;
|
||||
SELECT config, strftime('%s',updated_at) FROM models
|
||||
WHERE hash=?;
|
||||
""",
|
||||
(hash,),
|
||||
)
|
||||
@ -370,19 +370,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
# query1: get the total number of model configs
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
select count(*) from model_config;
|
||||
select count(*) from models;
|
||||
""",
|
||||
(),
|
||||
)
|
||||
total = int(self._cursor.fetchone()[0])
|
||||
|
||||
# query2: fetch key fields from the join of model_config and model_metadata
|
||||
# query2: fetch key fields from the join of models and model_metadata
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||
json_extract(a.config, '$.description') as description,
|
||||
json_extract(b.metadata, '$.tags') as tags
|
||||
FROM model_config AS a
|
||||
FROM models AS a
|
||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||
LIMIT ?
|
||||
|
@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||
|
||||
|
||||
@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
||||
migrator.register_migration(build_migration_4())
|
||||
migrator.register_migration(build_migration_5())
|
||||
migrator.register_migration(build_migration_6())
|
||||
migrator.register_migration(build_migration_7())
|
||||
migrator.run_migrations()
|
||||
|
||||
return db
|
||||
|
@ -0,0 +1,81 @@
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
class Migration7Callback:
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
self._create_models_table(cursor)
|
||||
|
||||
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||
"""
|
||||
Adds the timestamp trigger to the model_config table.
|
||||
|
||||
This trigger was inadvertently dropped in earlier migration scripts.
|
||||
"""
|
||||
|
||||
tables = [
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS models (
|
||||
id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The next 3 fields are enums in python, unrestricted string here
|
||||
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL NOT NULL,
|
||||
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||
hash TEXT, -- could be null
|
||||
-- Serialized JSON representation of the whole config object,
|
||||
-- which will contain additional fields from subclasses
|
||||
config TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- unique constraint on combo of name, base and type
|
||||
UNIQUE(name, base, type)
|
||||
);
|
||||
"""
|
||||
]
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
triggers = [
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
||||
AFTER UPDATE
|
||||
ON models FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE id = old.id;
|
||||
END;
|
||||
"""
|
||||
]
|
||||
|
||||
# Add indexes for searchable fields
|
||||
indices = [
|
||||
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
||||
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
||||
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
||||
]
|
||||
|
||||
for stmt in tables + indices + triggers:
|
||||
cursor.execute(stmt)
|
||||
|
||||
|
||||
def build_migration_7() -> Migration:
|
||||
"""
|
||||
Build the migration from database version 5 to 6.
|
||||
|
||||
This migration does the following:
|
||||
- Adds the model_config_updated_at trigger if it does not exist
|
||||
- Delete all ip_adapter models so that the model prober can find and
|
||||
update with the correct image processor model.
|
||||
"""
|
||||
migration_7 = Migration(
|
||||
from_version=6,
|
||||
to_version=7,
|
||||
callback=Migration7Callback(),
|
||||
)
|
||||
|
||||
return migration_7
|
@ -150,7 +150,7 @@ class MigrateModelYamlToDb1:
|
||||
""",
|
||||
(
|
||||
key,
|
||||
record.original_hash,
|
||||
record.hash,
|
||||
json_serialized,
|
||||
),
|
||||
)
|
||||
|
@ -127,15 +127,13 @@ class ModelConfigBase(BaseModel):
|
||||
name: str = Field(description="model name")
|
||||
base: BaseModelType = Field(description="base model")
|
||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||
original_hash: Optional[str] = Field(
|
||||
description="original fasthash of model contents", default=None
|
||||
) # this is assigned at install time and will not change
|
||||
hash: Optional[str] = Field(description="original fasthash of model contents", default=None)
|
||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
schema["required"].extend(["key", "base", "type", "format", "original_hash", "source"])
|
||||
schema["required"].extend(["key", "base", "type", "format", "hash", "source"])
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
|
@ -161,7 +161,7 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["hash"] = fields.get("hash") or hash
|
||||
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
@ -48,7 +48,7 @@ def example_config() -> TextualInversionFileConfig:
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.TextualInversion,
|
||||
format=ModelFormat.EmbeddingFile,
|
||||
original_hash="ABC123",
|
||||
hash="ABC123",
|
||||
)
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ def test_add(store: ModelRecordServiceBase):
|
||||
assert type(config1) == MainCheckpointConfig
|
||||
assert config1.base == BaseModelType.StableDiffusion1
|
||||
assert config1.name == "model1"
|
||||
assert config1.original_hash == "111222333444"
|
||||
assert config1.hash == "111222333444"
|
||||
|
||||
|
||||
def test_dup(store: ModelRecordServiceBase):
|
||||
@ -140,21 +140,21 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
name="config1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG2HASH",
|
||||
hash="CONFIG2HASH",
|
||||
)
|
||||
config3 = VaeDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
name="config3",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType.Vae,
|
||||
original_hash="CONFIG3HASH",
|
||||
hash="CONFIG3HASH",
|
||||
)
|
||||
for c in config1, config2, config3:
|
||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||
@ -170,7 +170,7 @@ def test_filter(store: ModelRecordServiceBase):
|
||||
|
||||
matches = store.search_by_hash("CONFIG1HASH")
|
||||
assert len(matches) == 1
|
||||
assert matches[0].original_hash == "CONFIG1HASH"
|
||||
assert matches[0].hash == "CONFIG1HASH"
|
||||
|
||||
matches = store.all_models()
|
||||
assert len(matches) == 3
|
||||
@ -182,28 +182,28 @@ def test_unique(store: ModelRecordServiceBase):
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
)
|
||||
config3 = VaeDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType.Vae,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
)
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
name="nonuniquename",
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
)
|
||||
# config1, config2 and config3 are compatible because they have unique combos
|
||||
# of name, type and base
|
||||
@ -221,35 +221,35 @@ def test_filter_2(store: ModelRecordServiceBase):
|
||||
name="config1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG1HASH",
|
||||
hash="CONFIG1HASH",
|
||||
)
|
||||
config2 = MainDiffusersConfig(
|
||||
path="/tmp/config2",
|
||||
name="config2",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG2HASH",
|
||||
hash="CONFIG2HASH",
|
||||
)
|
||||
config3 = MainDiffusersConfig(
|
||||
path="/tmp/config3",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sd-2"),
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG3HASH",
|
||||
hash="CONFIG3HASH",
|
||||
)
|
||||
config4 = MainDiffusersConfig(
|
||||
path="/tmp/config4",
|
||||
name="dup_name1",
|
||||
base=BaseModelType("sdxl"),
|
||||
type=ModelType.Main,
|
||||
original_hash="CONFIG3HASH",
|
||||
hash="CONFIG3HASH",
|
||||
)
|
||||
config5 = VaeDiffusersConfig(
|
||||
path="/tmp/config5",
|
||||
name="dup_name1",
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Vae,
|
||||
original_hash="CONFIG3HASH",
|
||||
hash="CONFIG3HASH",
|
||||
)
|
||||
for c in config1, config2, config3, config4, config5:
|
||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||
|
Loading…
Reference in New Issue
Block a user