make model key assignment deterministic

- When installing, model keys are now calculated from the model contents.
- .safetensors, .ckpt and other single file models are hashed with sha1
- The contents of diffusers directories are hashed using imohash (faster)

fixup yaml->sql db migration script to assign deterministic key

- this commit also detects and assigns the correct image encoder for
  ip adapter models.
This commit is contained in:
Lincoln Stein 2024-02-24 10:22:22 -05:00 committed by psychedelicious
parent 98a13aa7dc
commit 2b1cb569eb
4 changed files with 41 additions and 16 deletions

View File

@ -7,7 +7,6 @@ import time
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
@ -536,16 +535,16 @@ class ModelInstallService(ModelInstallServiceBase):
setattr(info, key, value) setattr(info, key, value)
return info return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register( def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str: ) -> str:
key = self._create_key() # the model key is either the forced key specified in config,
if config and not config.get("key", None): # or it is the file/directory hash computed by probe
config["key"] = key
info = info or ModelProbe.probe(model_path, config) info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None
assert info.original_hash # always assigned by probe()
info.key = override_key or info.original_hash
model_path = model_path.absolute() model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path): if model_path.is_relative_to(self.app_config.models_path):

View File

@ -3,7 +3,6 @@
import json import json
import sqlite3 import sqlite3
from hashlib import sha1
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -78,14 +77,22 @@ class MigrateModelYamlToDb1:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.") self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type) stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type) stanza["type"] = ModelType(model_type)
stanza["name"] = model_name stanza["name"] = model_name
stanza["original_hash"] = hash stanza["original_hash"] = hash
stanza["current_hash"] = hash stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094 new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}") self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config) self._update_model(key, new_config)
else: else:
self.logger.info(f"Adding model {model_name} with key {model_key}") self.logger.info(f"Adding model {model_name} with key {new_key}")
self._add_model(new_key, new_config) self._add_model(new_key, new_config)
except DuplicateModelException: except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database") self.logger.warning(f"Model {model_name} is already in the database")
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
) )
except sqlite3.IntegrityError as exc: except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@ -28,14 +28,28 @@ class FastModelHash(object):
""" """
model_location = Path(model_location) model_location = Path(model_location)
if model_location.is_file(): if model_location.is_file():
return cls._hash_file(model_location) return cls._hash_file_sha1(model_location)
elif model_location.is_dir(): elif model_location.is_dir():
return cls._hash_dir(model_location) return cls._hash_dir(model_location)
else: else:
raise OSError(f"Not a valid file or directory: {model_location}") raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod @classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str: def _hash_file_sha1(cls, model_location: Union[str, Path]) -> str:
"""
Compute full sha1 hash over a single file and return its hexdigest.
:param model_location: Path to the model file
"""
BLOCK_SIZE = 65536
file_hash = hashlib.sha1()
with open(model_location, "rb") as f:
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
@classmethod
def _hash_file_fast(cls, model_location: Union[str, Path]) -> str:
""" """
Fasthash a single file and return its hexdigest. Fasthash a single file and return its hexdigest.
@ -56,7 +70,7 @@ class FastModelHash(object):
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue continue
path = (Path(root) / file).as_posix() path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path) fast_hash = cls._hash_file_fast(path)
components.update({path: fast_hash}) components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order # hash all the model hashes together, using alphabetic file order

View File

@ -31,7 +31,7 @@ def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Pa
assert len(matches) == 0 assert len(matches) == 0
key = mm2_installer.register_path(embedding_file) key = mm2_installer.register_path(embedding_file)
assert key is not None assert key is not None
assert len(key) == 32 assert len(key) == 40 # length of the sha1 hash
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None: def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None: