From bb3f1b9ca6444c14ede4dcb7c659fdd08150a98f Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 7 Mar 2024 15:21:45 -0500 Subject: [PATCH] Add threading to scan dir calls, cap thread pool in hash function to number of files --- .../model_install/model_install_default.py | 17 ++++++++++++----- invokeai/backend/model_manager/hash.py | 2 +- invokeai/backend/model_manager/probe.py | 5 ++++- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 670da572d2..d8170156eb 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -4,6 +4,7 @@ import os import re import threading import time +from concurrent.futures import ThreadPoolExecutor, as_completed from hashlib import sha256 from pathlib import Path from queue import Empty, Queue @@ -280,12 +281,18 @@ class ModelInstallService(ModelInstallServiceBase): self._scan_models_directory() if autoimport := self._app_config.autoimport_dir: self._logger.info("Scanning autoimport directory for new models") - installed = self.scan_directory(self._app_config.root_path / autoimport) + installed: List[str] = [] + # Use ThreadPoolExecutor to scan dirs in parallel + with ThreadPoolExecutor() as executor: + future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType] + [installed.extend(models.result()) for models in as_completed(future_models)] self._logger.info(f"{len(installed)} new models registered") self._logger.info("Model installer (re)initialized") def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()} + if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0: + return [] callback = self._scan_install if install else self._scan_register search = ModelSearch(on_model_found=callback, config=self._app_config) self._models_installed.clear() @@ -448,10 +455,10 @@ class ModelInstallService(ModelInstallServiceBase): self.unregister(key) self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models") - for cur_base_model in BaseModelType: - for cur_model_type in ModelType: - models_dir = Path(cur_base_model.value, cur_model_type.value) - installed.update(self.scan_directory(models_dir)) + # Use ThreadPoolExecutor to scan dirs in parallel + with ThreadPoolExecutor() as executor: + future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType] + [installed.update(models.result()) for models in as_completed(future_models)] self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered") def _sync_model_path(self, key: str) -> AnyModelConfig: diff --git a/invokeai/backend/model_manager/hash.py b/invokeai/backend/model_manager/hash.py index 89d4d954e6..c379b5ee77 100644 --- a/invokeai/backend/model_manager/hash.py +++ b/invokeai/backend/model_manager/hash.py @@ -108,7 +108,7 @@ class ModelHash: model_component_paths = self._get_file_paths(dir, self._file_filter) # Use ThreadPoolExecutor to hash files in parallel - with ThreadPoolExecutor() as executor: + with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor: future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)} component_hashes = [future.result() for future in as_completed(future_to_component)] diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 75925dcf0b..a9201034e4 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -84,6 +84,9 @@ class ProbeBase(object): class ModelProbe(object): + + hasher = ModelHash() + PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { "diffusers": {}, "checkpoint": {}, @@ -157,7 +160,7 @@ class ModelProbe(object): fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}" ) fields["format"] = fields.get("format") or probe.get_format() - fields["hash"] = fields.get("hash") or ModelHash().hash(model_path) + fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path) if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()