make model key assignment deterministic

- When installing, model keys are now calculated from the model contents.
- .safetensors, .ckpt and other single file models are hashed with sha1
- The contents of diffusers directories are hashed using imohash (faster)

fixup yaml->sql db migration script to assign deterministic key

- this commit also detects and assigns the correct image encoder for
  ip adapter models.
This commit is contained in:
Lincoln Stein 2024-02-24 10:22:22 -05:00 committed by Ryan Dick
parent d16b9a25bb
commit 64f8535ef5
3 changed files with 38 additions and 12 deletions

View File

@ -7,7 +7,6 @@ import time
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from random import randbytes
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Dict, List, Optional, Set, Union
@ -526,9 +525,6 @@ class ModelInstallService(ModelInstallServiceBase):
setattr(info, key, value) setattr(info, key, value)
return info return info
def _create_key(self) -> str:
return sha256(randbytes(100)).hexdigest()[0:32]
def _register( def _register(
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str: ) -> str:
@ -536,6 +532,10 @@ class ModelInstallService(ModelInstallServiceBase):
# in which case the key field should have been populated by the caller (e.g. in `install_path`). # in which case the key field should have been populated by the caller (e.g. in `install_path`).
config["key"] = config.get("key", self._create_key()) config["key"] = config.get("key", self._create_key())
info = info or ModelProbe.probe(model_path, config) info = info or ModelProbe.probe(model_path, config)
override_key: Optional[str] = config.get("key") if config else None
assert info.original_hash # always assigned by probe()
info.key = override_key or info.original_hash
model_path = model_path.absolute() model_path = model_path.absolute()
if model_path.is_relative_to(self.app_config.models_path): if model_path.is_relative_to(self.app_config.models_path):

View File

@ -3,7 +3,6 @@
import json import json
import sqlite3 import sqlite3
from hashlib import sha1
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -78,14 +77,22 @@ class MigrateModelYamlToDb1:
self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.") self.logger.warning(f"The model at {stanza.path} is not a valid file or directory. Skipping migration.")
continue continue
assert isinstance(model_key, str)
new_key = sha1(model_key.encode("utf-8")).hexdigest()
stanza["base"] = BaseModelType(base_type) stanza["base"] = BaseModelType(base_type)
stanza["type"] = ModelType(model_type) stanza["type"] = ModelType(model_type)
stanza["name"] = model_name stanza["name"] = model_name
stanza["original_hash"] = hash stanza["original_hash"] = hash
stanza["current_hash"] = hash stanza["current_hash"] = hash
new_key = hash # deterministic key assignment
# special case for ip adapters, which need the new `image_encoder_model_id` field
if stanza["type"] == ModelType.IPAdapter:
try:
stanza["image_encoder_model_id"] = self._get_image_encoder_model_id(
self.config.models_path / stanza.path
)
except OSError:
self.logger.warning(f"Could not determine image encoder for {stanza.path}. Skipping.")
continue
new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094 new_config: AnyModelConfig = ModelsValidator.validate_python(stanza) # type: ignore # see https://github.com/pydantic/pydantic/discussions/7094
@ -95,7 +102,7 @@ class MigrateModelYamlToDb1:
self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}") self.logger.info(f"Updating model {model_name} with information from models.yaml using key {key}")
self._update_model(key, new_config) self._update_model(key, new_config)
else: else:
self.logger.info(f"Adding model {model_name} with key {model_key}") self.logger.info(f"Adding model {model_name} with key {new_key}")
self._add_model(new_key, new_config) self._add_model(new_key, new_config)
except DuplicateModelException: except DuplicateModelException:
self.logger.warning(f"Model {model_name} is already in the database") self.logger.warning(f"Model {model_name} is already in the database")
@ -149,3 +156,8 @@ class MigrateModelYamlToDb1:
) )
except sqlite3.IntegrityError as exc: except sqlite3.IntegrityError as exc:
raise DuplicateModelException(f"{record.name}: model is already in database") from exc raise DuplicateModelException(f"{record.name}: model is already in database") from exc
def _get_image_encoder_model_id(self, model_path: Path) -> str:
with open(model_path / "image_encoder.txt") as f:
encoder = f.read()
return encoder.strip()

View File

@ -28,14 +28,28 @@ class FastModelHash(object):
""" """
model_location = Path(model_location) model_location = Path(model_location)
if model_location.is_file(): if model_location.is_file():
return cls._hash_file(model_location) return cls._hash_file_sha1(model_location)
elif model_location.is_dir(): elif model_location.is_dir():
return cls._hash_dir(model_location) return cls._hash_dir(model_location)
else: else:
raise OSError(f"Not a valid file or directory: {model_location}") raise OSError(f"Not a valid file or directory: {model_location}")
@classmethod @classmethod
def _hash_file(cls, model_location: Union[str, Path]) -> str: def _hash_file_sha1(cls, model_location: Union[str, Path]) -> str:
"""
Compute full sha1 hash over a single file and return its hexdigest.
:param model_location: Path to the model file
"""
BLOCK_SIZE = 65536
file_hash = hashlib.sha1()
with open(model_location, "rb") as f:
data = f.read(BLOCK_SIZE)
file_hash.update(data)
return file_hash.hexdigest()
@classmethod
def _hash_file_fast(cls, model_location: Union[str, Path]) -> str:
""" """
Fasthash a single file and return its hexdigest. Fasthash a single file and return its hexdigest.
@ -56,7 +70,7 @@ class FastModelHash(object):
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
continue continue
path = (Path(root) / file).as_posix() path = (Path(root) / file).as_posix()
fast_hash = cls._hash_file(path) fast_hash = cls._hash_file_fast(path)
components.update({path: fast_hash}) components.update({path: fast_hash})
# hash all the model hashes together, using alphabetic file order # hash all the model hashes together, using alphabetic file order