mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
"""
|
|
Fast hashing of diffusers and checkpoint-style models.
|
|
|
|
Usage:
|
|
from invokeai.backend.model_managre.model_hash import FastModelHash
|
|
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
|
'a8e693a126ea5b831c96064dc569956f'
|
|
"""
|
|
|
|
import hashlib
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Dict, Union
|
|
|
|
from imohash import hashfile
|
|
|
|
|
|
class FastModelHash(object):
|
|
"""FastModelHash obect provides one public class method, hash()."""
|
|
|
|
@classmethod
|
|
def hash(cls, model_location: Union[str, Path]) -> str:
|
|
"""
|
|
Return hexdigest string for model located at model_location.
|
|
|
|
:param model_location: Path to the model
|
|
"""
|
|
model_location = Path(model_location)
|
|
if model_location.is_file():
|
|
return cls._hash_file(model_location)
|
|
elif model_location.is_dir():
|
|
return cls._hash_dir(model_location)
|
|
else:
|
|
raise OSError(f"Not a valid file or directory: {model_location}")
|
|
|
|
@classmethod
|
|
def _hash_file(cls, model_location: Union[str, Path]) -> str:
|
|
"""
|
|
Fasthash a single file and return its hexdigest.
|
|
|
|
:param model_location: Path to the model file
|
|
"""
|
|
# we return md5 hash of the filehash to make it shorter
|
|
# cryptographic security not needed here
|
|
return hashlib.md5(hashfile(model_location)).hexdigest()
|
|
|
|
@classmethod
|
|
def _hash_dir(cls, model_location: Union[str, Path]) -> str:
|
|
components: Dict[str, str] = {}
|
|
|
|
for root, _dirs, files in os.walk(model_location):
|
|
for file in files:
|
|
# only tally tensor files because diffusers config files change slightly
|
|
# depending on how the model was downloaded/converted.
|
|
if not file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
|
continue
|
|
path = (Path(root) / file).as_posix()
|
|
fast_hash = cls._hash_file(path)
|
|
components.update({path: fast_hash})
|
|
|
|
# hash all the model hashes together, using alphabetic file order
|
|
md5 = hashlib.md5()
|
|
for _path, fast_hash in sorted(components.items()):
|
|
md5.update(fast_hash.encode("utf-8"))
|
|
return md5.hexdigest()
|