Compare commits

...

5 Commits

3 changed files with 34 additions and 21 deletions

View File

@ -4,6 +4,7 @@ import os
import re
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
@ -280,12 +281,18 @@ class ModelInstallService(ModelInstallServiceBase):
self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
installed = self.scan_directory(self._app_config.root_path / autoimport)
installed: List[str] = []
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
[installed.extend(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
return []
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear()
@ -448,10 +455,10 @@ class ModelInstallService(ModelInstallServiceBase):
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
[installed.update(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str) -> AnyModelConfig:

View File

@ -12,6 +12,8 @@ import hashlib
import os
from pathlib import Path
from typing import Callable, Literal, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
from blake3 import blake3
@ -105,13 +107,14 @@ class ModelHash:
"""
model_component_paths = self._get_file_paths(dir, self._file_filter)
component_hashes: list[str] = []
for component in sorted(model_component_paths):
component_hashes.append(self._hash_file(component))
# Use ThreadPoolExecutor to hash files in parallel
with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
component_hashes = [future.result() for future in as_completed(future_to_component)]
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm
# for the composite hash
# BLAKE3 to hash the hashes
composite_hasher = blake3()
component_hashes.sort()
for h in component_hashes:
composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest()
@ -129,10 +132,12 @@ class ModelHash:
"""
files: list[Path] = []
for root, _dirs, _files in os.walk(model_path):
for file in _files:
if file_filter(file):
files.append(Path(root, file))
entries = [entry for entry in os.scandir(model_path.as_posix()) if not entry.name.startswith(".")]
dirs = [entry for entry in entries if entry.is_dir()]
file_paths = [entry.path for entry in entries if entry.is_file() and file_filter(entry.path)]
files.extend([Path(file) for file in file_paths])
for dir in dirs:
files.extend(ModelHash._get_file_paths(Path(dir.path), file_filter))
return files
@staticmethod
@ -161,13 +166,11 @@ class ModelHash:
"""
def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory."""
"""Hashes a file using a hashlib algorithm."""
hasher = hashlib.new(algorithm)
buffer = bytearray(128 * 1024)
mv = memoryview(buffer)
with open(file_path, "rb", buffering=0) as f:
while n := f.readinto(mv):
hasher.update(mv[:n])
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8 * 1024), b""):
hasher.update(chunk)
return hasher.hexdigest()
return hashlib_hasher

View File

@ -84,6 +84,9 @@ class ProbeBase(object):
class ModelProbe(object):
hasher = ModelHash()
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {},
"checkpoint": {},
@ -157,7 +160,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()