mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
195 lines
6.7 KiB
Python
195 lines
6.7 KiB
Python
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
|
"""
|
|
Abstract base class and implementation for recursive directory search for models.
|
|
|
|
Example usage:
|
|
```
|
|
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
|
|
|
def find_main_models(model: Path) -> bool:
|
|
info = ModelProbe.probe(model)
|
|
if info.model_type == 'main' and info.base_type == 'sd-1':
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
search = ModelSearch(on_model_found=report_it)
|
|
found = search.search('/tmp/models')
|
|
print(found) # list of matching model paths
|
|
print(search.stats) # search stats
|
|
```
|
|
"""
|
|
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from typing import Callable, Optional, Set, Union
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
|
|
default_logger = InvokeAILogger.get_logger()
|
|
|
|
|
|
class SearchStats(BaseModel):
|
|
items_scanned: int = 0
|
|
models_found: int = 0
|
|
models_filtered: int = 0
|
|
|
|
|
|
class ModelSearchBase(ABC, BaseModel):
|
|
"""
|
|
Abstract directory traversal model search class
|
|
|
|
Usage:
|
|
search = ModelSearchBase(
|
|
on_search_started = search_started_callback,
|
|
on_search_completed = search_completed_callback,
|
|
on_model_found = model_found_callback,
|
|
)
|
|
models_found = search.search('/path/to/directory')
|
|
"""
|
|
|
|
# fmt: off
|
|
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
|
|
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
|
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
|
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
|
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
|
# fmt: on
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
@abstractmethod
|
|
def search_started(self) -> None:
|
|
"""
|
|
Called before the scan starts.
|
|
|
|
Passes the root search directory to the Callable `on_search_started`.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def model_found(self, model: Path) -> None:
|
|
"""
|
|
Called when a model is found during search.
|
|
|
|
:param model: Model to process - could be a directory or checkpoint.
|
|
|
|
Passes the model's Path to the Callable `on_model_found`.
|
|
This Callable receives the path to the model and returns a boolean
|
|
to indicate whether the model should be returned in the search
|
|
results.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def search_completed(self) -> None:
|
|
"""
|
|
Called before the scan starts.
|
|
|
|
Passes the Set of found model Paths to the Callable `on_search_completed`.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
|
"""
|
|
Recursively search for models in `directory` and return a set of model paths.
|
|
|
|
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
|
|
Callables will be invoked during the search.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ModelSearch(ModelSearchBase):
|
|
"""
|
|
Implementation of ModelSearch with callbacks.
|
|
Usage:
|
|
search = ModelSearch()
|
|
search.model_found = lambda path : 'anime' in path.as_posix()
|
|
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
|
# returns all models that have 'anime' in the path
|
|
"""
|
|
|
|
models_found: Set[Path] = Field(default=None)
|
|
scanned_dirs: Set[Path] = Field(default=None)
|
|
pruned_paths: Set[Path] = Field(default=None)
|
|
|
|
def search_started(self) -> None:
|
|
self.models_found = set()
|
|
self.scanned_dirs = set()
|
|
self.pruned_paths = set()
|
|
if self.on_search_started:
|
|
self.on_search_started(self._directory)
|
|
|
|
def model_found(self, model: Path) -> None:
|
|
self.stats.models_found += 1
|
|
if not self.on_model_found:
|
|
self.stats.models_filtered += 1
|
|
self.models_found.add(model)
|
|
return
|
|
if self.on_model_found(model):
|
|
self.stats.models_filtered += 1
|
|
self.models_found.add(model)
|
|
|
|
def search_completed(self) -> None:
|
|
if self.on_search_completed:
|
|
self.on_search_completed(self._models_found)
|
|
|
|
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
|
self._directory = Path(directory)
|
|
self.stats = SearchStats() # zero out
|
|
self.search_started() # This will initialize _models_found to empty
|
|
self._walk_directory(directory)
|
|
self.search_completed()
|
|
return self.models_found
|
|
|
|
def _walk_directory(self, path: Union[Path, str]) -> None:
|
|
for root, dirs, files in os.walk(path, followlinks=True):
|
|
# don't descend into directories that start with a "."
|
|
# to avoid the Mac .DS_STORE issue.
|
|
if str(Path(root).name).startswith("."):
|
|
self.pruned_paths.add(Path(root))
|
|
if any(Path(root).is_relative_to(x) for x in self.pruned_paths):
|
|
continue
|
|
|
|
self.stats.items_scanned += len(dirs) + len(files)
|
|
for d in dirs:
|
|
path = Path(root) / d
|
|
if path.parent in self.scanned_dirs:
|
|
self.scanned_dirs.add(path)
|
|
continue
|
|
if any(
|
|
(path / x).exists()
|
|
for x in [
|
|
"config.json",
|
|
"model_index.json",
|
|
"learned_embeds.bin",
|
|
"pytorch_lora_weights.bin",
|
|
"image_encoder.txt",
|
|
]
|
|
):
|
|
self.scanned_dirs.add(path)
|
|
try:
|
|
self.model_found(path)
|
|
except KeyboardInterrupt:
|
|
raise
|
|
except Exception as e:
|
|
self.logger.warning(str(e))
|
|
|
|
for f in files:
|
|
path = Path(root) / f
|
|
if path.parent in self.scanned_dirs:
|
|
continue
|
|
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
|
try:
|
|
self.model_found(path)
|
|
except KeyboardInterrupt:
|
|
raise
|
|
except Exception as e:
|
|
self.logger.warning(str(e))
|