make ModelSearch pydantic

This commit is contained in:
Lincoln Stein
2023-08-24 13:37:49 -04:00
parent 93cef55964
commit 97f2e778ee
4 changed files with 145 additions and 95 deletions

View File

@ -16,3 +16,4 @@ from .config import ( # noqa F401
from .install import ModelInstall # noqa F401
from .probe import ModelProbe, InvalidModelException # noqa F401
from .storage import DuplicateModelException # noqa F401
from .search import ModelSearch

View File

@ -254,10 +254,10 @@ class ModelInstall(ModelInstallBase):
self.unregister(id)
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
search = ModelSearch()
search.model_found = self._scan_install if install else self._scan_register
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
self._installed = set()
search.search([scan_dir])
search.search(scan_dir)
return list(self._installed)
def garbage_collect(self) -> List[str]: # noqa D102

View File

@ -362,7 +362,7 @@ class LoRACheckpointProbe(CheckpointProbeBase):
elif token_vector_length == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"Unknown LoRA type: {self.model}")
raise InvalidModelException(f"Unsupported LoRA type: {self.model}")
class TextualInversionCheckpointProbe(CheckpointProbeBase):

View File

@ -1,97 +1,108 @@
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
"""
Abstract base class for recursive directory search for models.
Abstract base class and implementation for recursive directory search for models.
Example usage:
```
from invokeai.backend.model_manager import ModelSearch, ModelProbe
def find_main_models(model: Path) -> bool:
info = ModelProbe.probe(model)
if info.model_type == 'main' and info.base_type == 'sd-1':
return True
else:
return False
search = ModelSearch(on_model_found=report_it)
found = search.search('/tmp/models')
print(found) # list of matching model paths
print(search.stats) # search stats
```
"""
import os
from abc import ABC, abstractmethod
from typing import List, Set, Optional, Callable, Union, types
from typing import Set, Optional, Callable, Union, types
from pathlib import Path
import invokeai.backend.util.logging as logger
from invokeai.backend.util.logging import InvokeAILogger
from pydantic import Field, BaseModel
default_logger = InvokeAILogger.getLogger()
class ModelSearchBase(ABC):
"""Hierarchical directory model search class"""
class SearchStats(BaseModel):
items_scanned: int = 0
models_found: int = 0
models_filtered: int = 0
def __init__(self, logger: types.ModuleType = logger):
"""
Initialize a recursive model directory search.
:param directories: List of directory Paths to recurse through
:param logger: Logger to use
"""
self.logger = logger
self._items_scanned = 0
self._models_found = 0
self._scanned_dirs = set()
self._scanned_paths = set()
self._pruned_paths = set()
class ModelSearchBase(ABC, BaseModel):
"""
Abstract directory traversal model search class
Usage:
search = ModelSearchBase(
on_search_started = search_started_callback,
on_search_completed = search_completed_callback,
on_model_found = model_found_callback,
)
models_found = search.search('/path/to/directory')
"""
# fmt: off
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : InvokeAILogger = Field(default=default_logger, description="InvokeAILogger instance.") # noqa E221
# fmt: on
class Config:
underscore_attrs_are_private = True
arbitrary_types_allowed = True
@abstractmethod
def on_search_started(self):
def search_started(self):
"""
Called before the scan starts.
Passes the root search directory to the Callable `on_search_started`.
"""
pass
@abstractmethod
def on_model_found(self, model: Path):
def model_found(self, model: Path):
"""
Process a found model. Raise an exception if something goes wrong.
Called when a model is found during search.
:param model: Model to process - could be a directory or checkpoint.
Passes the model's Path to the Callable `on_model_found`.
This Callable receives the path to the model and returns a boolean
to indicate whether the model should be returned in the search
results.
"""
pass
@abstractmethod
def on_search_completed(self):
def search_completed(self):
"""
Perform some activity when the scan is completed. May use instance
variables, items_scanned and models_found
Called before the scan starts.
Passes the Set of found model Paths to the Callable `on_search_completed`.
"""
pass
def search(self, directories: List[Union[Path, str]]):
self.on_search_started()
for dir in directories:
self.walk_directory(dir)
self.on_search_completed()
@abstractmethod
def search(self, directory: Union[Path, str]) -> Set[Path]:
"""
Recursively search for models in `directory` and return a set of model paths.
def walk_directory(self, path: Union[Path, str]):
for root, dirs, files in os.walk(path, followlinks=True):
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self._items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path in self._scanned_paths or path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
]
):
try:
self.on_model_found(path)
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(str(e))
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
Callables will be invoked during the search.
"""
pass
class ModelSearch(ModelSearchBase):
@ -104,35 +115,73 @@ class ModelSearch(ModelSearchBase):
# returns all models that have 'anime' in the path
"""
_model_set: Set[Path]
search_started: Callable[[Path], None]
search_completed: Callable[[Set[Path]], None]
model_found: Callable[[Path], bool]
_directory: Path = Field(default=None)
_models_found: Set[Path] = Field(default=None)
_scanned_dirs: Set[Path] = Field(default=None)
_pruned_paths: Set[Path] = Field(default=None)
def __init__(self, logger: types.ModuleType = logger):
super().__init__(logger)
self._model_set = set()
self.search_started = None
self.search_completed = None
self.model_found = None
def search_started(self):
self._models_found = set()
self._scanned_dirs = set()
self._pruned_paths = set()
if self.on_search_started:
self.on_search_started(self._directory)
def on_search_started(self):
self._model_set = set()
if self.search_started:
self.search_started()
def on_model_found(self, model: Path):
if not self.model_found:
self._model_set.add(model)
def model_found(self, model: Path):
self.stats.models_found += 1
if not self.on_model_found:
self.stats.models_filtered += 1
self._models_found.add(model)
return
if self.model_found(model):
self._model_set.add(model)
if self.on_model_found(model):
self.stats.models_filtered += 1
self._models_found.add(model)
def on_search_completed(self):
if self.search_completed:
self.search_completed(self._model_set)
def search_completed(self):
if self.on_search_completed:
self.on_search_completed(self._models_found)
def list_models(self, directories: List[Union[Path, str]]) -> List[Path]:
"""Return list of models found"""
self.search(directories)
return list(self._model_set)
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = directory
self.stats = SearchStats() # zero out
self.search_started() # This will initialize _models_found to empty
self._walk_directory(directory)
self.search_completed()
return self._models_found
def _walk_directory(self, path: Union[Path, str]):
for root, dirs, files in os.walk(path, followlinks=True):
# don't descend into directories that start with a "."
# to avoid the Mac .DS_STORE issue.
if str(Path(root).name).startswith("."):
self._pruned_paths.add(root)
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
continue
self.stats.items_scanned += len(dirs) + len(files)
for d in dirs:
path = Path(root) / d
if path.parent in self._scanned_dirs:
self._scanned_dirs.add(path)
continue
if any(
[
(path / x).exists()
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
]
):
self._scanned_dirs.add(path)
try:
self.model_found(path)
except Exception as e:
self.logger.warning(str(e))
for f in files:
path = Path(root) / f
if path.parent in self._scanned_dirs:
continue
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
try:
self.model_found(path)
except Exception as e:
self.logger.warning(str(e))