make model key assignment deterministic

- When installing, model keys are now calculated from the model contents.
- .safetensors, .ckpt and other single file models are hashed with sha1
- The contents of diffusers directories are hashed using imohash (faster)

fixup yaml->sql db migration script to assign deterministic key

- this commit also detects and assigns the correct image encoder for
  ip adapter models.
This commit is contained in:
Lincoln Stein 2024-02-24 10:22:22 -05:00 committed by psychedelicious
parent d8d7ddf43a
commit a72056e0df
3 changed files with 38 additions and 12 deletions

View File

@ -7,7 +7,6 @@ import time
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union
@ -526,9 +525,6 @@ class ModelInstallService(ModelInstallServiceBase):
setattr(info, key, value)
return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
@ -536,6 +532,10 @@ class ModelInstallService(ModelInstallServiceBase):
# in which case the key field should have been populated by the caller (e.g. in `install_path`).
config["key"] = config.get("key", self._create_key())
info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None
assert info.original_hash # always assigned by probe()
info.key = override_key or info.original_hash
model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path):

View File

@ -3,7 +3,6 @@
import json
import sqlite3
from hashlib import sha1
from logging import Logger
from pathlib import Path
from typing import Optional
@ -78,14 +77,22 @@ class MigrateModelYamlToDb1:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type)
stanza["name"] = model_name
stanza["original_hash"] = hash
stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config)
else:
self.logger.info(f"Adding model {model_name} with key {model_key}")
self.logger.info(f"Adding model {model_name} with key {new_key}")
self._add_model(new_key, new_config)
except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database")
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
)
except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@ -28,14 +28,28 @@ class FastModelHash(object):
"""
model_location = Path(model_location)
if model_location.is_file():
return cls._hash_file(model_location)
return cls._hash_file_sha1(model_location)
elif model_location.is_dir():
return cls._hash_dir(model_location)
else:
raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str:
def _hash_file_sha1(cls, model_location: Union[str, Path]) -> str:
"""
Compute full sha1 hash over a single file and return its hexdigest.
:param model_location: Path to the model file
"""
BLOCK_SIZE = 65536
file_hash = hashlib.sha1()
with open(model_location, "rb") as f:
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
@classmethod
def _hash_file_fast(cls, model_location: Union[str, Path]) -> str:
"""
Fasthash a single file and return its hexdigest.
@ -56,7 +70,7 @@ class FastModelHash(object):
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue
path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path)
fast_hash = cls._hash_file_fast(path)
components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order