mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Switch ModelSearch from os.walk to os.scandir
This commit is contained in:
parent
8c6860a2c5
commit
a8c3efd98a
@ -25,6 +25,7 @@ from abc import ABC, abstractmethod
|
|||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Set, Union
|
from typing import Callable, Optional, Set, Union
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@ -117,13 +118,10 @@ class ModelSearch(ModelSearchBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
models_found: Set[Path] = Field(default_factory=set)
|
models_found: Set[Path] = Field(default_factory=set)
|
||||||
scanned_dirs: Set[Path] = Field(default_factory=set)
|
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||||
pruned_paths: Set[Path] = Field(default_factory=set)
|
|
||||||
|
|
||||||
def search_started(self) -> None:
|
def search_started(self) -> None:
|
||||||
self.models_found = set()
|
self.models_found = set()
|
||||||
self.scanned_dirs = set()
|
|
||||||
self.pruned_paths = set()
|
|
||||||
if self.on_search_started:
|
if self.on_search_started:
|
||||||
self.on_search_started(self._directory)
|
self.on_search_started(self._directory)
|
||||||
|
|
||||||
@ -139,53 +137,51 @@ class ModelSearch(ModelSearchBase):
|
|||||||
|
|
||||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||||
self._directory = Path(directory)
|
self._directory = Path(directory)
|
||||||
|
if not self._directory.is_absolute():
|
||||||
|
self._directory = self.config.models_path / self._directory
|
||||||
self.stats = SearchStats() # zero out
|
self.stats = SearchStats() # zero out
|
||||||
self.search_started() # This will initialize _models_found to empty
|
self.search_started() # This will initialize _models_found to empty
|
||||||
self._walk_directory(directory)
|
self._walk_directory(directory)
|
||||||
self.search_completed()
|
self.search_completed()
|
||||||
return self.models_found
|
return self.models_found
|
||||||
|
|
||||||
def _walk_directory(self, path: Union[Path, str]) -> None:
|
def _walk_directory(self, path: Union[Path, str], max_depth: int = 20) -> None:
|
||||||
for root, dirs, files in os.walk(path, followlinks=True):
|
absolute_path = Path(path)
|
||||||
# don't descend into directories that start with a "."
|
if len(absolute_path.parts) - len(self._directory.parts) > max_depth \
|
||||||
# to avoid the Mac .DS_STORE issue.
|
or not absolute_path.exists() \
|
||||||
if str(Path(root).name).startswith("."):
|
or absolute_path.parent in self.models_found:
|
||||||
self.pruned_paths.add(Path(root))
|
return
|
||||||
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
|
entries = os.scandir(absolute_path)
|
||||||
continue
|
entries = [entry for entry in entries if not entry.name.startswith(".")]
|
||||||
|
dirs = [entry for entry in entries if entry.is_dir()]
|
||||||
|
file_names = [entry.name for entry in entries if entry.is_file()]
|
||||||
|
if any(
|
||||||
|
x in file_names
|
||||||
|
for x in [
|
||||||
|
"config.json",
|
||||||
|
"model_index.json",
|
||||||
|
"learned_embeds.bin",
|
||||||
|
"pytorch_lora_weights.bin",
|
||||||
|
"image_encoder.txt",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
self.model_found(absolute_path)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(str(e))
|
||||||
|
finally:
|
||||||
|
return
|
||||||
|
|
||||||
self.stats.items_scanned += len(dirs) + len(files)
|
for n in file_names:
|
||||||
for d in dirs:
|
if any([n.endswith(suffix) for suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}]):
|
||||||
path = Path(root) / d
|
try:
|
||||||
if path.parent in self.scanned_dirs:
|
self.model_found(absolute_path / n)
|
||||||
self.scanned_dirs.add(path)
|
except KeyboardInterrupt:
|
||||||
continue
|
raise
|
||||||
if any(
|
except Exception as e:
|
||||||
(path / x).exists()
|
self.logger.warning(str(e))
|
||||||
for x in [
|
|
||||||
"config.json",
|
|
||||||
"model_index.json",
|
|
||||||
"learned_embeds.bin",
|
|
||||||
"pytorch_lora_weights.bin",
|
|
||||||
"image_encoder.txt",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
self.scanned_dirs.add(path)
|
|
||||||
try:
|
|
||||||
self.model_found(path)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(str(e))
|
|
||||||
|
|
||||||
for f in files:
|
for d in dirs:
|
||||||
path = Path(root) / f
|
self._walk_directory(absolute_path / d)
|
||||||
if path.parent in self.scanned_dirs:
|
|
||||||
continue
|
|
||||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
|
||||||
try:
|
|
||||||
self.model_found(path)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(str(e))
|
|
||||||
|
Loading…
Reference in New Issue
Block a user