mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make ModelSearch pydantic
This commit is contained in:
@ -16,3 +16,4 @@ from .config import ( # noqa F401
|
||||
from .install import ModelInstall # noqa F401
|
||||
from .probe import ModelProbe, InvalidModelException # noqa F401
|
||||
from .storage import DuplicateModelException # noqa F401
|
||||
from .search import ModelSearch
|
||||
|
@ -254,10 +254,10 @@ class ModelInstall(ModelInstallBase):
|
||||
self.unregister(id)
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
search = ModelSearch()
|
||||
search.model_found = self._scan_install if install else self._scan_register
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
self._installed = set()
|
||||
search.search([scan_dir])
|
||||
search.search(scan_dir)
|
||||
return list(self._installed)
|
||||
|
||||
def garbage_collect(self) -> List[str]: # noqa D102
|
||||
|
@ -362,7 +362,7 @@ class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
elif token_vector_length == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"Unknown LoRA type: {self.model}")
|
||||
raise InvalidModelException(f"Unsupported LoRA type: {self.model}")
|
||||
|
||||
|
||||
class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
|
@ -1,97 +1,108 @@
|
||||
# Copyright 2023, Lincoln D. Stein and the InvokeAI Team
|
||||
"""
|
||||
Abstract base class for recursive directory search for models.
|
||||
Abstract base class and implementation for recursive directory search for models.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
from invokeai.backend.model_manager import ModelSearch, ModelProbe
|
||||
|
||||
def find_main_models(model: Path) -> bool:
|
||||
info = ModelProbe.probe(model)
|
||||
if info.model_type == 'main' and info.base_type == 'sd-1':
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
search = ModelSearch(on_model_found=report_it)
|
||||
found = search.search('/tmp/models')
|
||||
print(found) # list of matching model paths
|
||||
print(search.stats) # search stats
|
||||
```
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Optional, Callable, Union, types
|
||||
from typing import Set, Optional, Callable, Union, types
|
||||
from pathlib import Path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
default_logger = InvokeAILogger.getLogger()
|
||||
|
||||
|
||||
class ModelSearchBase(ABC):
|
||||
"""Hierarchical directory model search class"""
|
||||
class SearchStats(BaseModel):
|
||||
items_scanned: int = 0
|
||||
models_found: int = 0
|
||||
models_filtered: int = 0
|
||||
|
||||
def __init__(self, logger: types.ModuleType = logger):
|
||||
"""
|
||||
Initialize a recursive model directory search.
|
||||
:param directories: List of directory Paths to recurse through
|
||||
:param logger: Logger to use
|
||||
"""
|
||||
self.logger = logger
|
||||
self._items_scanned = 0
|
||||
self._models_found = 0
|
||||
self._scanned_dirs = set()
|
||||
self._scanned_paths = set()
|
||||
self._pruned_paths = set()
|
||||
|
||||
class ModelSearchBase(ABC, BaseModel):
|
||||
"""
|
||||
Abstract directory traversal model search class
|
||||
|
||||
Usage:
|
||||
search = ModelSearchBase(
|
||||
on_search_started = search_started_callback,
|
||||
on_search_completed = search_completed_callback,
|
||||
on_model_found = model_found_callback,
|
||||
)
|
||||
models_found = search.search('/path/to/directory')
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
|
||||
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||
logger : InvokeAILogger = Field(default=default_logger, description="InvokeAILogger instance.") # noqa E221
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
underscore_attrs_are_private = True
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@abstractmethod
|
||||
def on_search_started(self):
|
||||
def search_started(self):
|
||||
"""
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the root search directory to the Callable `on_search_started`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_model_found(self, model: Path):
|
||||
def model_found(self, model: Path):
|
||||
"""
|
||||
Process a found model. Raise an exception if something goes wrong.
|
||||
Called when a model is found during search.
|
||||
|
||||
:param model: Model to process - could be a directory or checkpoint.
|
||||
|
||||
Passes the model's Path to the Callable `on_model_found`.
|
||||
This Callable receives the path to the model and returns a boolean
|
||||
to indicate whether the model should be returned in the search
|
||||
results.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_search_completed(self):
|
||||
def search_completed(self):
|
||||
"""
|
||||
Perform some activity when the scan is completed. May use instance
|
||||
variables, items_scanned and models_found
|
||||
Called before the scan starts.
|
||||
|
||||
Passes the Set of found model Paths to the Callable `on_search_completed`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def search(self, directories: List[Union[Path, str]]):
|
||||
self.on_search_started()
|
||||
for dir in directories:
|
||||
self.walk_directory(dir)
|
||||
self.on_search_completed()
|
||||
@abstractmethod
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
"""
|
||||
Recursively search for models in `directory` and return a set of model paths.
|
||||
|
||||
def walk_directory(self, path: Union[Path, str]):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
self._items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path in self._scanned_paths or path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
|
||||
]
|
||||
):
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
|
||||
Callables will be invoked during the search.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelSearch(ModelSearchBase):
|
||||
@ -104,35 +115,73 @@ class ModelSearch(ModelSearchBase):
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
_model_set: Set[Path]
|
||||
search_started: Callable[[Path], None]
|
||||
search_completed: Callable[[Set[Path]], None]
|
||||
model_found: Callable[[Path], bool]
|
||||
_directory: Path = Field(default=None)
|
||||
_models_found: Set[Path] = Field(default=None)
|
||||
_scanned_dirs: Set[Path] = Field(default=None)
|
||||
_pruned_paths: Set[Path] = Field(default=None)
|
||||
|
||||
def __init__(self, logger: types.ModuleType = logger):
|
||||
super().__init__(logger)
|
||||
self._model_set = set()
|
||||
self.search_started = None
|
||||
self.search_completed = None
|
||||
self.model_found = None
|
||||
def search_started(self):
|
||||
self._models_found = set()
|
||||
self._scanned_dirs = set()
|
||||
self._pruned_paths = set()
|
||||
if self.on_search_started:
|
||||
self.on_search_started(self._directory)
|
||||
|
||||
def on_search_started(self):
|
||||
self._model_set = set()
|
||||
if self.search_started:
|
||||
self.search_started()
|
||||
|
||||
def on_model_found(self, model: Path):
|
||||
if not self.model_found:
|
||||
self._model_set.add(model)
|
||||
def model_found(self, model: Path):
|
||||
self.stats.models_found += 1
|
||||
if not self.on_model_found:
|
||||
self.stats.models_filtered += 1
|
||||
self._models_found.add(model)
|
||||
return
|
||||
if self.model_found(model):
|
||||
self._model_set.add(model)
|
||||
if self.on_model_found(model):
|
||||
self.stats.models_filtered += 1
|
||||
self._models_found.add(model)
|
||||
|
||||
def on_search_completed(self):
|
||||
if self.search_completed:
|
||||
self.search_completed(self._model_set)
|
||||
def search_completed(self):
|
||||
if self.on_search_completed:
|
||||
self.on_search_completed(self._models_found)
|
||||
|
||||
def list_models(self, directories: List[Union[Path, str]]) -> List[Path]:
|
||||
"""Return list of models found"""
|
||||
self.search(directories)
|
||||
return list(self._model_set)
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
self._directory = directory
|
||||
self.stats = SearchStats() # zero out
|
||||
self.search_started() # This will initialize _models_found to empty
|
||||
self._walk_directory(directory)
|
||||
self.search_completed()
|
||||
return self._models_found
|
||||
|
||||
def _walk_directory(self, path: Union[Path, str]):
|
||||
for root, dirs, files in os.walk(path, followlinks=True):
|
||||
# don't descend into directories that start with a "."
|
||||
# to avoid the Mac .DS_STORE issue.
|
||||
if str(Path(root).name).startswith("."):
|
||||
self._pruned_paths.add(root)
|
||||
if any([Path(root).is_relative_to(x) for x in self._pruned_paths]):
|
||||
continue
|
||||
|
||||
self.stats.items_scanned += len(dirs) + len(files)
|
||||
for d in dirs:
|
||||
path = Path(root) / d
|
||||
if path.parent in self._scanned_dirs:
|
||||
self._scanned_dirs.add(path)
|
||||
continue
|
||||
if any(
|
||||
[
|
||||
(path / x).exists()
|
||||
for x in ["config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"]
|
||||
]
|
||||
):
|
||||
self._scanned_dirs.add(path)
|
||||
try:
|
||||
self.model_found(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
if path.parent in self._scanned_dirs:
|
||||
continue
|
||||
if path.suffix in {".ckpt", ".bin", ".pth", ".safetensors", ".pt"}:
|
||||
try:
|
||||
self.model_found(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
|
Reference in New Issue
Block a user