Add threading to scan dir calls, cap thread pool in hash function to number of files

This commit is contained in:
Brandon Rising 2024-03-07 15:21:45 -05:00
parent 119d26e102
commit bb3f1b9ca6
3 changed files with 17 additions and 7 deletions

View File

@ -4,6 +4,7 @@ import os
import re
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from hashlib import sha256
from pathlib import Path
from queue import Empty, Queue
@ -280,12 +281,18 @@ class ModelInstallService(ModelInstallServiceBase):
self._scan_models_directory()
if autoimport := self._app_config.autoimport_dir:
self._logger.info("Scanning autoimport directory for new models")
installed = self.scan_directory(self._app_config.root_path / autoimport)
installed: List[str] = []
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
[installed.extend(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered")
self._logger.info("Model installer (re)initialized")
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
return []
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear()
@ -448,10 +455,10 @@ class ModelInstallService(ModelInstallServiceBase):
self.unregister(key)
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
for cur_base_model in BaseModelType:
for cur_model_type in ModelType:
models_dir = Path(cur_base_model.value, cur_model_type.value)
installed.update(self.scan_directory(models_dir))
# Use ThreadPoolExecutor to scan dirs in parallel
with ThreadPoolExecutor() as executor:
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
[installed.update(models.result()) for models in as_completed(future_models)]
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
def _sync_model_path(self, key: str) -> AnyModelConfig:

View File

@ -108,7 +108,7 @@ class ModelHash:
model_component_paths = self._get_file_paths(dir, self._file_filter)
# Use ThreadPoolExecutor to hash files in parallel
with ThreadPoolExecutor() as executor:
with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
component_hashes = [future.result() for future in as_completed(future_to_component)]

View File

@ -84,6 +84,9 @@ class ProbeBase(object):
class ModelProbe(object):
hasher = ModelHash()
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
"diffusers": {},
"checkpoint": {},
@ -157,7 +160,7 @@ class ModelProbe(object):
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
)
fields["format"] = fields.get("format") or probe.get_format()
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()