mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add threading to scan dir calls, cap thread pool in hash function to number of files
This commit is contained in:
parent
119d26e102
commit
bb3f1b9ca6
@ -4,6 +4,7 @@ import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
@ -280,12 +281,18 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._scan_models_directory()
|
||||
if autoimport := self._app_config.autoimport_dir:
|
||||
self._logger.info("Scanning autoimport directory for new models")
|
||||
installed = self.scan_directory(self._app_config.root_path / autoimport)
|
||||
installed: List[str] = []
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
|
||||
[installed.extend(models.result()) for models in as_completed(future_models)]
|
||||
self._logger.info(f"{len(installed)} new models registered")
|
||||
self._logger.info("Model installer (re)initialized")
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
|
||||
return []
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
||||
self._models_installed.clear()
|
||||
@ -448,10 +455,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self.unregister(key)
|
||||
|
||||
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||
for cur_base_model in BaseModelType:
|
||||
for cur_model_type in ModelType:
|
||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
||||
installed.update(self.scan_directory(models_dir))
|
||||
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
|
||||
[installed.update(models.result()) for models in as_completed(future_models)]
|
||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||
|
||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
||||
|
@ -108,7 +108,7 @@ class ModelHash:
|
||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||
|
||||
# Use ThreadPoolExecutor to hash files in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
|
||||
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
|
||||
component_hashes = [future.result() for future in as_completed(future_to_component)]
|
||||
|
||||
|
@ -84,6 +84,9 @@ class ProbeBase(object):
|
||||
|
||||
|
||||
class ModelProbe(object):
|
||||
|
||||
hasher = ModelHash()
|
||||
|
||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||
"diffusers": {},
|
||||
"checkpoint": {},
|
||||
@ -157,7 +160,7 @@ class ModelProbe(object):
|
||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||
)
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
|
||||
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
Loading…
Reference in New Issue
Block a user