mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): make hash.py a script for testing
This commit is contained in:
parent
efceee5128
commit
4b073157b8
@ -7,8 +7,12 @@ from invokeai.backend.model_managre.model_hash import FastModelHash
|
|||||||
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
>>> FastModelHash.hash('/home/models/stable-diffusion-v1.5')
|
||||||
'a8e693a126ea5b831c96064dc569956f'
|
'a8e693a126ea5b831c96064dc569956f'
|
||||||
"""
|
"""
|
||||||
|
import cProfile
|
||||||
import os
|
import os
|
||||||
|
import pstats
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from blake3 import blake3
|
from blake3 import blake3
|
||||||
@ -58,7 +62,7 @@ class FastModelHash(object):
|
|||||||
# only tally tensor files because diffusers config files change slightly
|
# only tally tensor files because diffusers config files change slightly
|
||||||
# depending on how the model was downloaded/converted.
|
# depending on how the model was downloaded/converted.
|
||||||
if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
if file.endswith((".ckpt", ".safetensors", ".bin", ".pt", ".pth")):
|
||||||
components.append((Path(root, file).as_posix()))
|
components.append((Path(root, file).resolve().as_posix()))
|
||||||
|
|
||||||
component_hashes: list[str] = []
|
component_hashes: list[str] = []
|
||||||
|
|
||||||
@ -68,3 +72,20 @@ class FastModelHash(object):
|
|||||||
component_hashes.append(file_hasher.hexdigest())
|
component_hashes.append(file_hasher.hexdigest())
|
||||||
|
|
||||||
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest()
|
return blake3(b"".join([bytes.fromhex(h) for h in component_hashes])).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with TemporaryDirectory() as tempdir:
|
||||||
|
profile_path = Path(tempdir, "profile_results.pstats").as_posix()
|
||||||
|
profiler = cProfile.Profile()
|
||||||
|
profiler.enable()
|
||||||
|
t = threading.Thread(
|
||||||
|
target=FastModelHash.hash, args=("/media/rhino/invokeai/models/sd-1/main/stable-diffusion-v1-5-inpainting",)
|
||||||
|
)
|
||||||
|
t.start()
|
||||||
|
t.join()
|
||||||
|
profiler.disable()
|
||||||
|
stats = pstats.Stats(profiler).sort_stats(pstats.SortKey.TIME)
|
||||||
|
stats.dump_stats(profile_path)
|
||||||
|
|
||||||
|
os.system(f"snakeviz {profile_path}")
|
||||||
|
Loading…
Reference in New Issue
Block a user