mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add threading to scan dir calls, cap thread pool in hash function to number of files
This commit is contained in:
parent
119d26e102
commit
bb3f1b9ca6
@ -4,6 +4,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
@ -280,12 +281,18 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._scan_models_directory()
|
self._scan_models_directory()
|
||||||
if autoimport := self._app_config.autoimport_dir:
|
if autoimport := self._app_config.autoimport_dir:
|
||||||
self._logger.info("Scanning autoimport directory for new models")
|
self._logger.info("Scanning autoimport directory for new models")
|
||||||
installed = self.scan_directory(self._app_config.root_path / autoimport)
|
installed: List[str] = []
|
||||||
|
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
future_models = [executor.submit(self.scan_directory, self._app_config.root_path / autoimport / cur_model_type.value) for cur_model_type in ModelType]
|
||||||
|
[installed.extend(models.result()) for models in as_completed(future_models)]
|
||||||
self._logger.info(f"{len(installed)} new models registered")
|
self._logger.info(f"{len(installed)} new models registered")
|
||||||
self._logger.info("Model installer (re)initialized")
|
self._logger.info("Model installer (re)initialized")
|
||||||
|
|
||||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||||
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
self._cached_model_paths = {Path(x.path).absolute() for x in self.record_store.all_models()}
|
||||||
|
if len([entry for entry in os.scandir(scan_dir) if not entry.name.startswith(".")]) == 0:
|
||||||
|
return []
|
||||||
callback = self._scan_install if install else self._scan_register
|
callback = self._scan_install if install else self._scan_register
|
||||||
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
search = ModelSearch(on_model_found=callback, config=self._app_config)
|
||||||
self._models_installed.clear()
|
self._models_installed.clear()
|
||||||
@ -448,10 +455,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self.unregister(key)
|
self.unregister(key)
|
||||||
|
|
||||||
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
self._logger.info(f"Scanning {self._app_config.models_path} for new and orphaned models")
|
||||||
for cur_base_model in BaseModelType:
|
# Use ThreadPoolExecutor to scan dirs in parallel
|
||||||
for cur_model_type in ModelType:
|
with ThreadPoolExecutor() as executor:
|
||||||
models_dir = Path(cur_base_model.value, cur_model_type.value)
|
future_models = [executor.submit(self.scan_directory, Path(cur_base_model.value, cur_model_type.value)) for cur_base_model in BaseModelType for cur_model_type in ModelType]
|
||||||
installed.update(self.scan_directory(models_dir))
|
[installed.update(models.result()) for models in as_completed(future_models)]
|
||||||
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")
|
||||||
|
|
||||||
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
def _sync_model_path(self, key: str) -> AnyModelConfig:
|
||||||
|
@ -108,7 +108,7 @@ class ModelHash:
|
|||||||
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
model_component_paths = self._get_file_paths(dir, self._file_filter)
|
||||||
|
|
||||||
# Use ThreadPoolExecutor to hash files in parallel
|
# Use ThreadPoolExecutor to hash files in parallel
|
||||||
with ThreadPoolExecutor() as executor:
|
with ThreadPoolExecutor(min(((os.cpu_count() or 1) + 4), len(model_component_paths))) as executor:
|
||||||
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
|
future_to_component = {executor.submit(self._hash_file, component): component for component in sorted(model_component_paths)}
|
||||||
component_hashes = [future.result() for future in as_completed(future_to_component)]
|
component_hashes = [future.result() for future in as_completed(future_to_component)]
|
||||||
|
|
||||||
|
@ -84,6 +84,9 @@ class ProbeBase(object):
|
|||||||
|
|
||||||
|
|
||||||
class ModelProbe(object):
|
class ModelProbe(object):
|
||||||
|
|
||||||
|
hasher = ModelHash()
|
||||||
|
|
||||||
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = {
|
||||||
"diffusers": {},
|
"diffusers": {},
|
||||||
"checkpoint": {},
|
"checkpoint": {},
|
||||||
@ -157,7 +160,7 @@ class ModelProbe(object):
|
|||||||
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
fields.get("description") or f"{fields['base'].value} {fields['type'].value} model {fields['name']}"
|
||||||
)
|
)
|
||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
fields["hash"] = fields.get("hash") or cls.hasher.hash(model_path)
|
||||||
|
|
||||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
Loading…
Reference in New Issue
Block a user