Compare commits

...

5 Commits

3 changed files with 34 additions and 21 deletions

View File

@ -4,6 +4,7 @@ import os
import re import re
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
@ -280,12 +281,18 @@ class ModelInstallService(ModelInstallServiceBase):
self._scan_models_directory() self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir: if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models") self._logger.info("Scanning autoimport directory for new models")
installed = self.scan_directory(self._app_config.root_path / autoimport) installed: List[str] = []
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
[installed.extend(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered") self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized") self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
return []
callback = self._scan_install if install else self._scan_register callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback, config=self._app_config) search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear() self._models_installed.clear()
@ -448,10 +455,10 @@ class ModelInstallService(ModelInstallServiceBase):
self.unregister(key) self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models") self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
for cur_base_model in BaseModelType: # Use ThreadPoolExecutor to scan dirs in parallel
for cur_model_type in ModelType: with ThreadPoolExecutor() as executor:
models_dir = Path(cur_base_model.value, cur_model_type.value) future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
installed.update(self.scan_directory(models_dir)) [installed.update(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered") self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str) -> AnyModelConfig: def _sync_model_path(self, key: str) -> AnyModelConfig:

View File

@ -12,6 +12,8 @@ import hashlib
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Optional, Union from typing import Callable, Literal, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
from blake3 import blake3 from blake3 import blake3
@ -105,13 +107,14 @@ class ModelHash:
""" """
model_component_paths = self._get_file_paths(dir, self._file_filter) model_component_paths = self._get_file_paths(dir, self._file_filter)
component_hashes: list[str] = [] # Use ThreadPoolExecutor to hash files in parallel
for component in sorted(model_component_paths): with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
component_hashes.append(self._hash_file(component)) future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
component_hashes = [future.result() for future in as_completed(future_to_component)]
# BLAKE3 is cryptographically secure. We may as well fall back on a secure algorithm # BLAKE3 to hash the hashes
# for the composite hash
composite_hasher = blake3() composite_hasher = blake3()
component_hashes.sort()
for h in component_hashes: for h in component_hashes:
composite_hasher.update(h.encode("utf-8")) composite_hasher.update(h.encode("utf-8"))
return composite_hasher.hexdigest() return composite_hasher.hexdigest()
@ -129,10 +132,12 @@ class ModelHash:
""" """
files: list[Path] = [] files: list[Path] = []
for root, _dirs, _files in os.walk(model_path): entries = [entry for entry in os.scandir(model_path.as_posix()) if not entry.name.startswith(".")]
for file in _files: dirs = [entry for entry in entries if entry.is_dir()]
if file_filter(file): file_paths = [entry.path for entry in entries if entry.is_file() and file_filter(entry.path)]
files.append(Path(root, file)) files.extend([Path(file) for file in file_paths])
for dir in dirs:
files.extend(ModelHash._get_file_paths(Path(dir.path), file_filter))
return files return files
@staticmethod @staticmethod
@ -161,13 +166,11 @@ class ModelHash:
""" """
def hashlib_hasher(file_path: Path) -> str: def hashlib_hasher(file_path: Path) -> str:
"""Hashes a file using a hashlib algorithm. Uses `memoryview` to avoid reading the entire file into memory.""" """Hashes a file using a hashlib algorithm."""
hasher = hashlib.new(algorithm) hasher = hashlib.new(algorithm)
buffer = bytearray(128 * 1024) with open(file_path, "rb") as f:
mv = memoryview(buffer) for chunk in iter(lambda: f.read(8 * 1024), b""):
with open(file_path, "rb", buffering=0) as f: hasher.update(chunk)
while n := f.readinto(mv):
hasher.update(mv[:n])
return hasher.hexdigest() return hasher.hexdigest()
return hashlib_hasher return hashlib_hasher

View File

@ -84,6 +84,9 @@ class ProbeBase(object):
class ModelProbe(object): class ModelProbe(object):
hasher = ModelHash()
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {}, "diffusers": {},
"checkpoint": {}, "checkpoint": {},
@ -157,7 +160,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
) )
fields["format"] = fields.get("format") or probe.get_format() fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path) fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()