mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor(mm): add models
table (schema WIP), rename "original_hash" -> "hash"
This commit is contained in:
@ -66,12 +66,12 @@ provides the following fields:
|
|||||||
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
| `model_format` | ModelFormat | The format of the model (e.g. "diffusers"); also used as a Union discriminator |
|
||||||
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
| `base_model` | BaseModelType | The base model that the model is compatible with |
|
||||||
| `path` | str | Location of model on disk |
|
| `path` | str | Location of model on disk |
|
||||||
| `original_hash` | str | Hash of the model when it was first installed |
|
| `hash` | str | Hash of the model |
|
||||||
| `description` | str | Human-readable description of the model (optional) |
|
| `description` | str | Human-readable description of the model (optional) |
|
||||||
| `source` | str | Model's source URL or repo id (optional) |
|
| `source` | str | Model's source URL or repo id (optional) |
|
||||||
|
|
||||||
The `key` is a unique 32-character random ID which was generated at
|
The `key` is a unique 32-character random ID which was generated at
|
||||||
install time. The `original_hash` field stores a hash of the model's
|
install time. The `hash` field stores a hash of the model's
|
||||||
contents at install time obtained by sampling several parts of the
|
contents at install time obtained by sampling several parts of the
|
||||||
model's files using the `imohash` library. Over the course of the
|
model's files using the `imohash` library. Over the course of the
|
||||||
model's lifetime it may be transformed in various ways, such as
|
model's lifetime it may be transformed in various ways, such as
|
||||||
@ -373,7 +373,7 @@ functionality:
|
|||||||
moving it into the InvokeAI root directory under the
|
moving it into the InvokeAI root directory under the
|
||||||
`models` folder (or wherever config parameter `models_dir`
|
`models` folder (or wherever config parameter `models_dir`
|
||||||
specifies).
|
specifies).
|
||||||
|
|
||||||
* Probing of models to determine their type, base type and other key
|
* Probing of models to determine their type, base type and other key
|
||||||
information.
|
information.
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ example_model_config = {
|
|||||||
"format": "checkpoint",
|
"format": "checkpoint",
|
||||||
"config_path": "string",
|
"config_path": "string",
|
||||||
"key": "string",
|
"key": "string",
|
||||||
"original_hash": "string",
|
"hash": "string",
|
||||||
"description": "string",
|
"description": "string",
|
||||||
"source": "string",
|
"source": "string",
|
||||||
"converted_at": 0,
|
"converted_at": 0,
|
||||||
@ -705,7 +705,7 @@ async def convert_model(
|
|||||||
config={
|
config={
|
||||||
"name": original_name,
|
"name": original_name,
|
||||||
"description": model_config.description,
|
"description": model_config.description,
|
||||||
"original_hash": model_config.original_hash,
|
"hash": model_config.hash,
|
||||||
"source": model_config.source,
|
"source": model_config.source,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -101,16 +101,16 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
INSERT INTO model_config (
|
INSERT INTO models (
|
||||||
id,
|
id,
|
||||||
original_hash,
|
hash,
|
||||||
config
|
config
|
||||||
)
|
)
|
||||||
VALUES (?,?,?);
|
VALUES (?,?,?);
|
||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
key,
|
key,
|
||||||
record.original_hash,
|
record.hash,
|
||||||
json_serialized,
|
json_serialized,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -119,9 +119,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
except sqlite3.IntegrityError as e:
|
except sqlite3.IntegrityError as e:
|
||||||
self._db.conn.rollback()
|
self._db.conn.rollback()
|
||||||
if "UNIQUE constraint failed" in str(e):
|
if "UNIQUE constraint failed" in str(e):
|
||||||
if "model_config.path" in str(e):
|
if "models.path" in str(e):
|
||||||
msg = f"A model with path '{record.path}' is already installed"
|
msg = f"A model with path '{record.path}' is already installed"
|
||||||
elif "model_config.name" in str(e):
|
elif "models.name" in str(e):
|
||||||
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
msg = f"A model with name='{record.name}', type='{record.type}', base='{record.base}' is already installed"
|
||||||
else:
|
else:
|
||||||
msg = f"A model with key '{key}' is already installed"
|
msg = f"A model with key '{key}' is already installed"
|
||||||
@ -146,7 +146,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
DELETE FROM model_config
|
DELETE FROM models
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(key,),
|
(key,),
|
||||||
@ -172,7 +172,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
try:
|
try:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
UPDATE model_config
|
UPDATE models
|
||||||
SET
|
SET
|
||||||
config=?
|
config=?
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
@ -199,7 +199,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
SELECT config, strftime('%s',updated_at) FROM models
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(key,),
|
(key,),
|
||||||
@ -220,7 +220,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
select count(*) FROM model_config
|
select count(*) FROM models
|
||||||
WHERE id=?;
|
WHERE id=?;
|
||||||
""",
|
""",
|
||||||
(key,),
|
(key,),
|
||||||
@ -265,7 +265,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
select config, strftime('%s',updated_at) FROM model_config
|
select config, strftime('%s',updated_at) FROM models
|
||||||
{where};
|
{where};
|
||||||
""",
|
""",
|
||||||
tuple(bindings),
|
tuple(bindings),
|
||||||
@ -281,7 +281,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
SELECT config, strftime('%s',updated_at) FROM models
|
||||||
WHERE path=?;
|
WHERE path=?;
|
||||||
""",
|
""",
|
||||||
(str(path),),
|
(str(path),),
|
||||||
@ -292,13 +292,13 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||||
"""Return models with the indicated original_hash."""
|
"""Return models with the indicated hash."""
|
||||||
results = []
|
results = []
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
SELECT config, strftime('%s',updated_at) FROM model_config
|
SELECT config, strftime('%s',updated_at) FROM models
|
||||||
WHERE original_hash=?;
|
WHERE hash=?;
|
||||||
""",
|
""",
|
||||||
(hash,),
|
(hash,),
|
||||||
)
|
)
|
||||||
@ -370,19 +370,19 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
# query1: get the total number of model configs
|
# query1: get the total number of model configs
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
select count(*) from model_config;
|
select count(*) from models;
|
||||||
""",
|
""",
|
||||||
(),
|
(),
|
||||||
)
|
)
|
||||||
total = int(self._cursor.fetchone()[0])
|
total = int(self._cursor.fetchone()[0])
|
||||||
|
|
||||||
# query2: fetch key fields from the join of model_config and model_metadata
|
# query2: fetch key fields from the join of models and model_metadata
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
f"""--sql
|
f"""--sql
|
||||||
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
SELECT a.id as key, a.type, a.base, a.format, a.name,
|
||||||
json_extract(a.config, '$.description') as description,
|
json_extract(a.config, '$.description') as description,
|
||||||
json_extract(b.metadata, '$.tags') as tags
|
json_extract(b.metadata, '$.tags') as tags
|
||||||
FROM model_config AS a
|
FROM models AS a
|
||||||
LEFT JOIN model_metadata AS b on a.id=b.id
|
LEFT JOIN model_metadata AS b on a.id=b.id
|
||||||
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
|
@ -9,6 +9,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_3 import
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_4 import build_migration_4
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_5 import build_migration_5
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import build_migration_6
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +36,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator.register_migration(build_migration_4())
|
migrator.register_migration(build_migration_4())
|
||||||
migrator.register_migration(build_migration_5())
|
migrator.register_migration(build_migration_5())
|
||||||
migrator.register_migration(build_migration_6())
|
migrator.register_migration(build_migration_6())
|
||||||
|
migrator.register_migration(build_migration_7())
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
@ -0,0 +1,81 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
|
|
||||||
|
class Migration7Callback:
|
||||||
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
self._create_models_table(cursor)
|
||||||
|
|
||||||
|
def _create_models_table(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""
|
||||||
|
Adds the timestamp trigger to the model_config table.
|
||||||
|
|
||||||
|
This trigger was inadvertently dropped in earlier migration scripts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tables = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TABLE IF NOT EXISTS models (
|
||||||
|
id TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The next 3 fields are enums in python, unrestricted string here
|
||||||
|
base TEXT GENERATED ALWAYS as (json_extract(config, '$.base')) VIRTUAL NOT NULL,
|
||||||
|
type TEXT GENERATED ALWAYS as (json_extract(config, '$.type')) VIRTUAL NOT NULL,
|
||||||
|
name TEXT GENERATED ALWAYS as (json_extract(config, '$.name')) VIRTUAL NOT NULL,
|
||||||
|
description TEXT GENERATED ALWAYS as (json_extract(config, '$.description')) VIRTUAL NOT NULL,
|
||||||
|
path TEXT GENERATED ALWAYS as (json_extract(config, '$.path')) VIRTUAL NOT NULL,
|
||||||
|
format TEXT GENERATED ALWAYS as (json_extract(config, '$.format')) VIRTUAL NOT NULL,
|
||||||
|
hash TEXT, -- could be null
|
||||||
|
-- Serialized JSON representation of the whole config object,
|
||||||
|
-- which will contain additional fields from subclasses
|
||||||
|
config TEXT NOT NULL,
|
||||||
|
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- Updated via trigger
|
||||||
|
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||||
|
-- unique constraint on combo of name, base and type
|
||||||
|
UNIQUE(name, base, type)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add trigger for `updated_at`.
|
||||||
|
triggers = [
|
||||||
|
"""--sql
|
||||||
|
CREATE TRIGGER IF NOT EXISTS models_updated_at
|
||||||
|
AFTER UPDATE
|
||||||
|
ON models FOR EACH ROW
|
||||||
|
BEGIN
|
||||||
|
UPDATE models SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||||
|
WHERE id = old.id;
|
||||||
|
END;
|
||||||
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add indexes for searchable fields
|
||||||
|
indices = [
|
||||||
|
"CREATE INDEX IF NOT EXISTS base_index ON models(base);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS type_index ON models(type);",
|
||||||
|
"CREATE INDEX IF NOT EXISTS name_index ON models(name);",
|
||||||
|
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON models(path);",
|
||||||
|
]
|
||||||
|
|
||||||
|
for stmt in tables + indices + triggers:
|
||||||
|
cursor.execute(stmt)
|
||||||
|
|
||||||
|
|
||||||
|
def build_migration_7() -> Migration:
|
||||||
|
"""
|
||||||
|
Build the migration from database version 5 to 6.
|
||||||
|
|
||||||
|
This migration does the following:
|
||||||
|
- Adds the model_config_updated_at trigger if it does not exist
|
||||||
|
- Delete all ip_adapter models so that the model prober can find and
|
||||||
|
update with the correct image processor model.
|
||||||
|
"""
|
||||||
|
migration_7 = Migration(
|
||||||
|
from_version=6,
|
||||||
|
to_version=7,
|
||||||
|
callback=Migration7Callback(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return migration_7
|
@ -150,7 +150,7 @@ class MigrateModelYamlToDb1:
|
|||||||
""",
|
""",
|
||||||
(
|
(
|
||||||
key,
|
key,
|
||||||
record.original_hash,
|
record.hash,
|
||||||
json_serialized,
|
json_serialized,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -127,15 +127,13 @@ class ModelConfigBase(BaseModel):
|
|||||||
name: str = Field(description="model name")
|
name: str = Field(description="model name")
|
||||||
base: BaseModelType = Field(description="base model")
|
base: BaseModelType = Field(description="base model")
|
||||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||||
original_hash: Optional[str] = Field(
|
hash: Optional[str] = Field(description="original fasthash of model contents", default=None)
|
||||||
description="original fasthash of model contents", default=None
|
|
||||||
) # this is assigned at install time and will not change
|
|
||||||
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
description: Optional[str] = Field(description="human readable description of the model", default=None)
|
||||||
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||||
schema["required"].extend(["key", "base", "type", "format", "original_hash", "source"])
|
schema["required"].extend(["key", "base", "type", "format", "hash", "source"])
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
use_enum_values=False,
|
use_enum_values=False,
|
||||||
|
@ -161,7 +161,7 @@ class ModelProbe(object):
|
|||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["original_hash"] = fields.get("original_hash") or hash
|
fields["hash"] = fields.get("hash") or hash
|
||||||
|
|
||||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
@ -48,7 +48,7 @@ def example_config() -> TextualInversionFileConfig:
|
|||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.TextualInversion,
|
type=ModelType.TextualInversion,
|
||||||
format=ModelFormat.EmbeddingFile,
|
format=ModelFormat.EmbeddingFile,
|
||||||
original_hash="ABC123",
|
hash="ABC123",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -76,7 +76,7 @@ def test_add(store: ModelRecordServiceBase):
|
|||||||
assert type(config1) == MainCheckpointConfig
|
assert type(config1) == MainCheckpointConfig
|
||||||
assert config1.base == BaseModelType.StableDiffusion1
|
assert config1.base == BaseModelType.StableDiffusion1
|
||||||
assert config1.name == "model1"
|
assert config1.name == "model1"
|
||||||
assert config1.original_hash == "111222333444"
|
assert config1.hash == "111222333444"
|
||||||
|
|
||||||
|
|
||||||
def test_dup(store: ModelRecordServiceBase):
|
def test_dup(store: ModelRecordServiceBase):
|
||||||
@ -140,21 +140,21 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
name="config1",
|
name="config1",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config2 = MainDiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
name="config2",
|
name="config2",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG2HASH",
|
hash="CONFIG2HASH",
|
||||||
)
|
)
|
||||||
config3 = VaeDiffusersConfig(
|
config3 = VaeDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
name="config3",
|
name="config3",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType.Vae,
|
type=ModelType.Vae,
|
||||||
original_hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
for c in config1, config2, config3:
|
for c in config1, config2, config3:
|
||||||
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
|
||||||
@ -170,7 +170,7 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
|
|
||||||
matches = store.search_by_hash("CONFIG1HASH")
|
matches = store.search_by_hash("CONFIG1HASH")
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0].original_hash == "CONFIG1HASH"
|
assert matches[0].hash == "CONFIG1HASH"
|
||||||
|
|
||||||
matches = store.all_models()
|
matches = store.all_models()
|
||||||
assert len(matches) == 3
|
assert len(matches) == 3
|
||||||
@ -182,28 +182,28 @@ def test_unique(store: ModelRecordServiceBase):
|
|||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
original_hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config2 = MainDiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
original_hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config3 = VaeDiffusersConfig(
|
config3 = VaeDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType.Vae,
|
type=ModelType.Vae,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
original_hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config4 = MainDiffusersConfig(
|
config4 = MainDiffusersConfig(
|
||||||
path="/tmp/config4",
|
path="/tmp/config4",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
original_hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
# config1, config2 and config3 are compatible because they have unique combos
|
# config1, config2 and config3 are compatible because they have unique combos
|
||||||
# of name, type and base
|
# of name, type and base
|
||||||
@ -221,35 +221,35 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
name="config1",
|
name="config1",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
)
|
)
|
||||||
config2 = MainDiffusersConfig(
|
config2 = MainDiffusersConfig(
|
||||||
path="/tmp/config2",
|
path="/tmp/config2",
|
||||||
name="config2",
|
name="config2",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG2HASH",
|
hash="CONFIG2HASH",
|
||||||
)
|
)
|
||||||
config3 = MainDiffusersConfig(
|
config3 = MainDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
config4 = MainDiffusersConfig(
|
config4 = MainDiffusersConfig(
|
||||||
path="/tmp/config4",
|
path="/tmp/config4",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType("sdxl"),
|
base=BaseModelType("sdxl"),
|
||||||
type=ModelType.Main,
|
type=ModelType.Main,
|
||||||
original_hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
config5 = VaeDiffusersConfig(
|
config5 = VaeDiffusersConfig(
|
||||||
path="/tmp/config5",
|
path="/tmp/config5",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Vae,
|
type=ModelType.Vae,
|
||||||
original_hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
)
|
)
|
||||||
for c in config1, config2, config3, config4, config5:
|
for c in config1, config2, config3, config4, config5:
|
||||||
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
|
||||||
|
Reference in New Issue
Block a user