feat(mm): make hash.py a script for testing

This commit is contained in:
psychedelicious 2024-02-27 20:54:46 +11:00 committed by Ryan Dick
parent efceee5128
commit 4b073157b8

View File

@ -7,8 +7,12 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5') >>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
'a8e693a126ea5b831c96064dc569956f' 'a8e693a126ea5b831c96064dc569956f'
""" """
import cProfile
import os import os
import pstats
import threading
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Union from typing import Union
from blake3 import blake3 from blake3 import blake3
@ -58,7 +62,7 @@ class FastModelHash(object):
# only tally tensor files because diffusers config files change slightly # only tally tensor files because diffusers config files change slightly
# depending on how the model was downloaded/converted. # depending on how the model was downloaded/converted.
if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")): if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
components.append((Path(root, file).as_posix())) components.append((Path(root, file).resolve().as_posix()))
component_hashes: list[str] = [] component_hashes: list[str] = []
@ -68,3 +72,20 @@ class FastModelHash(object):
component_hashes.append(file_hasher.hexdigest()) component_hashes.append(file_hasher.hexdigest())
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest() return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest()
if __name__ == "__main__":
with TemporaryDirectory() as tempdir:
profile_path = Path(tempdir, "profile_results.pstats").as_posix()
profiler = cProfile.Profile()
profiler.enable()
t = threading.Thread(
target=FastModelHash.hash, args=("/media/rhino/invokeai/models/sd-1/main/stable-diffusion-v1-5-inpainting",)
)
t.start()
t.join()
profiler.disable()
stats = pstats.Stats(profiler).sort_stats(pstats.SortKey.TIME)
stats.dump_stats(profile_path)
os.system(f"snakeviz {profile_path}")