Add threading to scan dir calls, cap thread pool in hash function to number of files

This commit is contained in:
Brandon Rising 2024-03-07 15:21:45 -05:00
parent 119d26e102
commit bb3f1b9ca6
3 changed files with 17 additions and 7 deletions

View File

@ -4,6 +4,7 @@ import os
import re import re
import threading import threading
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
@ -280,12 +281,18 @@ class ModelInstallService(ModelInstallServiceBase):
self._scan_models_directory() self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir: if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models") self._logger.info("Scanning autoimport directory for new models")
installed = self.scan_directory(self._app_config.root_path / autoimport) installed: List[str] = []
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
[installed.extend(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered") self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized") self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()} self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
return []
callback = self._scan_install if install else self._scan_register callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback, config=self._app_config) search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear() self._models_installed.clear()
@ -448,10 +455,10 @@ class ModelInstallService(ModelInstallServiceBase):
self.unregister(key) self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models") self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
for cur_base_model in BaseModelType: # Use ThreadPoolExecutor to scan dirs in parallel
for cur_model_type in ModelType: with ThreadPoolExecutor() as executor:
models_dir = Path(cur_base_model.value, cur_model_type.value) future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
installed.update(self.scan_directory(models_dir)) [installed.update(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered") self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str) -> AnyModelConfig: def _sync_model_path(self, key: str) -> AnyModelConfig:

View File

@ -108,7 +108,7 @@ class ModelHash:
model_component_paths = self._get_file_paths(dir, self._file_filter) model_component_paths = self._get_file_paths(dir, self._file_filter)
# Use ThreadPoolExecutor to hash files in parallel # Use ThreadPoolExecutor to hash files in parallel
with ThreadPoolExecutor() as executor: with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)} future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
component_hashes = [future.result() for future in as_completed(future_to_component)] component_hashes = [future.result() for future in as_completed(future_to_component)]

View File

@ -84,6 +84,9 @@ class ProbeBase(object):
class ModelProbe(object): class ModelProbe(object):
hasher = ModelHash()
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {}, "diffusers": {},
"checkpoint": {}, "checkpoint": {},
@ -157,7 +160,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
) )
fields["format"] = fields.get("format") or probe.get_format() fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path) fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()