refactor(mm): add models table (schema WIP), rename "original_hash" -> "hash"

This commit is contained in:
psychedelicious
2024-03-01 16:13:29 +11:00
parent 0cce582f2f
commit a8cd3dfc99
9 changed files with 124 additions and 43 deletions

View File

@ -48,7 +48,7 @@ def example_config() -> TextualInversionFileConfig:
base=BaseModelType.StableDiffusion1,
type=ModelType.TextualInversion,
format=ModelFormat.EmbeddingFile,
original_hash="ABC123",
hash="ABC123",
)
@ -76,7 +76,7 @@ def test_add(store: ModelRecordServiceBase):
assert type(config1) == MainCheckpointConfig
assert config1.base == BaseModelType.StableDiffusion1
assert config1.name == "model1"
assert config1.original_hash == "111222333444"
assert config1.hash == "111222333444"
def test_dup(store: ModelRecordServiceBase):
@ -140,21 +140,21 @@ def test_filter(store: ModelRecordServiceBase):
name="config1",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG2HASH",
hash="CONFIG2HASH",
)
config3 = VaeDiffusersConfig(
path="/tmp/config3",
name="config3",
base=BaseModelType("sd-2"),
type=ModelType.Vae,
original_hash="CONFIG3HASH",
hash="CONFIG3HASH",
)
for c in config1, config2, config3:
store.add_model(sha256(c.name.encode("utf-8")).hexdigest(), c)
@ -170,7 +170,7 @@ def test_filter(store: ModelRecordServiceBase):
matches = store.search_by_hash("CONFIG1HASH")
assert len(matches) == 1
assert matches[0].original_hash == "CONFIG1HASH"
assert matches[0].hash == "CONFIG1HASH"
matches = store.all_models()
assert len(matches) == 3
@ -182,28 +182,28 @@ def test_unique(store: ModelRecordServiceBase):
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
base=BaseModelType("sd-2"),
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config3 = VaeDiffusersConfig(
path="/tmp/config3",
base=BaseModelType("sd-2"),
type=ModelType.Vae,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
name="nonuniquename",
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
# config1, config2 and config3 are compatible because they have unique combos
# of name, type and base
@ -221,35 +221,35 @@ def test_filter_2(store: ModelRecordServiceBase):
name="config1",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG1HASH",
hash="CONFIG1HASH",
)
config2 = MainDiffusersConfig(
path="/tmp/config2",
name="config2",
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
original_hash="CONFIG2HASH",
hash="CONFIG2HASH",
)
config3 = MainDiffusersConfig(
path="/tmp/config3",
name="dup_name1",
base=BaseModelType("sd-2"),
type=ModelType.Main,
original_hash="CONFIG3HASH",
hash="CONFIG3HASH",
)
config4 = MainDiffusersConfig(
path="/tmp/config4",
name="dup_name1",
base=BaseModelType("sdxl"),
type=ModelType.Main,
original_hash="CONFIG3HASH",
hash="CONFIG3HASH",
)
config5 = VaeDiffusersConfig(
path="/tmp/config5",
name="dup_name1",
base=BaseModelType.StableDiffusion1,
type=ModelType.Vae,
original_hash="CONFIG3HASH",
hash="CONFIG3HASH",
)
for c in config1, config2, config3, config4, config5:
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)