mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(mm): ModelSearch cleanup
- No need for it to by a pydantic model. Just a class now. - Remove ABC, it made it hard to understand what was going on as attributes were spread across the ABC and implementation. Also, there is no other implementation. - Add tests
This commit is contained in:
parent
631e789195
commit
bd5b43c00d
@ -21,104 +21,69 @@ Example usage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from dataclasses import dataclass
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Set, Union
|
from typing import Callable, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
default_logger: Logger = InvokeAILogger.get_logger()
|
default_logger: Logger = InvokeAILogger.get_logger()
|
||||||
|
|
||||||
|
|
||||||
class SearchStats(BaseModel):
|
@dataclass
|
||||||
items_scanned: int = 0
|
class SearchStats:
|
||||||
models_found: int = 0
|
"""Statistics about the search.
|
||||||
models_filtered: int = 0
|
|
||||||
|
|
||||||
|
Attributes:
|
||||||
class ModelSearchBase(ABC, BaseModel):
|
items_scanned: number of items scanned
|
||||||
|
models_found: number of models found
|
||||||
|
models_filtered: number of models that passed the filter
|
||||||
"""
|
"""
|
||||||
Abstract directory traversal model search class
|
|
||||||
|
items_scanned = 0
|
||||||
|
models_found = 0
|
||||||
|
models_filtered = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSearch:
|
||||||
|
"""Searches a directory tree for models, using a callback to filter the results.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
search = ModelSearchBase(
|
search = ModelSearch()
|
||||||
on_search_started = search_started_callback,
|
search.model_found = lambda path : 'anime' in path.as_posix()
|
||||||
on_search_completed = search_completed_callback,
|
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
||||||
on_model_found = model_found_callback,
|
# returns all models that have 'anime' in the path
|
||||||
)
|
|
||||||
models_found = search.search('/path/to/directory')
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# fmt: off
|
def __init__(
|
||||||
on_search_started : Optional[Callable[[Path], None]] = Field(default=None, description="Called just before the search starts.") # noqa E221
|
self,
|
||||||
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
stats: Optional[SearchStats] = None,
|
||||||
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
logger: Optional[Logger] = None,
|
||||||
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
on_search_started: Optional[Callable[[Path], None]] = None,
|
||||||
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
on_model_found: Optional[Callable[[Path], bool]] = None,
|
||||||
# fmt: on
|
on_search_completed: Optional[Callable[[Set[Path]], None]] = None,
|
||||||
|
config: Optional[InvokeAIAppConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create a new ModelSearch object.
|
||||||
|
|
||||||
class Config:
|
Args:
|
||||||
arbitrary_types_allowed = True
|
stats: SearchStats object to hold statistics about the search
|
||||||
|
logger: Logger object to use for logging
|
||||||
@abstractmethod
|
on_search_started: callback to be invoked when the search starts
|
||||||
def search_started(self) -> None:
|
on_model_found: callback to be invoked when a model is found. The callback should return True if the model
|
||||||
|
should be included in the results.
|
||||||
|
on_search_completed: callback to be invoked when the search is completed
|
||||||
|
config: configuration object
|
||||||
"""
|
"""
|
||||||
Called before the scan starts.
|
self.stats = stats or SearchStats()
|
||||||
|
self.logger = logger or default_logger
|
||||||
Passes the root search directory to the Callable `on_search_started`.
|
self.on_search_started = on_search_started
|
||||||
"""
|
self.on_model_found = on_model_found
|
||||||
pass
|
self.on_search_completed = on_search_completed
|
||||||
|
self.models_found: set[Path] = set()
|
||||||
@abstractmethod
|
self.config = config or InvokeAIAppConfig.get_config()
|
||||||
def model_found(self, model: Path) -> None:
|
|
||||||
"""
|
|
||||||
Called when a model is found during search.
|
|
||||||
|
|
||||||
:param model: Model to process - could be a directory or checkpoint.
|
|
||||||
|
|
||||||
Passes the model's Path to the Callable `on_model_found`.
|
|
||||||
This Callable receives the path to the model and returns a boolean
|
|
||||||
to indicate whether the model should be returned in the search
|
|
||||||
results.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search_completed(self) -> None:
|
|
||||||
"""
|
|
||||||
Called before the scan starts.
|
|
||||||
|
|
||||||
Passes the Set of found model Paths to the Callable `on_search_completed`.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
|
||||||
"""
|
|
||||||
Recursively search for models in `directory` and return a set of model paths.
|
|
||||||
|
|
||||||
If provided, the `on_search_started`, `on_model_found` and `on_search_completed`
|
|
||||||
Callables will be invoked during the search.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ModelSearch(ModelSearchBase):
|
|
||||||
"""
|
|
||||||
Implementation of ModelSearch with callbacks.
|
|
||||||
Usage:
|
|
||||||
search = ModelSearch()
|
|
||||||
search.model_found = lambda path : 'anime' in path.as_posix()
|
|
||||||
found = search.list_models(['/tmp/models1','/tmp/models2'])
|
|
||||||
# returns all models that have 'anime' in the path
|
|
||||||
"""
|
|
||||||
|
|
||||||
models_found: Set[Path] = Field(default_factory=set)
|
|
||||||
config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
|
||||||
|
|
||||||
def search_started(self) -> None:
|
def search_started(self) -> None:
|
||||||
self.models_found = set()
|
self.models_found = set()
|
||||||
|
142
tests/test_model_search.py
Normal file
142
tests/test_model_search.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_search(tmp_path: Path) -> tuple[ModelSearch, Path]:
|
||||||
|
search = ModelSearch()
|
||||||
|
return search, tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_search_on_search_started(model_search: tuple[ModelSearch, Path]):
|
||||||
|
search, tmp_path = model_search
|
||||||
|
on_search_started_called_with: Path | None = None
|
||||||
|
|
||||||
|
def on_search_started_callback(path: Path) -> None:
|
||||||
|
nonlocal on_search_started_called_with
|
||||||
|
on_search_started_called_with = path
|
||||||
|
|
||||||
|
search.on_search_started = on_search_started_callback
|
||||||
|
search.search(tmp_path)
|
||||||
|
|
||||||
|
assert on_search_started_called_with == tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_search_on_completed(model_search: tuple[ModelSearch, Path]):
|
||||||
|
search, tmp_path = model_search
|
||||||
|
on_search_completed_called_with: set[Path] | None = None
|
||||||
|
file1 = tmp_path / "file1.ckpt"
|
||||||
|
with open(file1, "w") as f:
|
||||||
|
f.write("")
|
||||||
|
|
||||||
|
def on_search_completed_callback(models: set[Path]) -> None:
|
||||||
|
nonlocal on_search_completed_called_with
|
||||||
|
on_search_completed_called_with = models
|
||||||
|
|
||||||
|
search.on_search_completed = on_search_completed_callback
|
||||||
|
expected = {file1}
|
||||||
|
found = search.search(tmp_path)
|
||||||
|
|
||||||
|
assert found == expected
|
||||||
|
assert on_search_completed_called_with == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_search_handles_files(model_search: tuple[ModelSearch, Path]):
|
||||||
|
search, tmp_path = model_search
|
||||||
|
on_model_found_called_with: set[Path] = set()
|
||||||
|
|
||||||
|
file1 = tmp_path / "file1.ckpt"
|
||||||
|
file2 = tmp_path / "file2.ckpt"
|
||||||
|
file3 = tmp_path / "subfolder" / "file3.ckpt"
|
||||||
|
file4 = tmp_path / "subfolder" / "subfolder" / "file4.ckpt"
|
||||||
|
file5 = tmp_path / "not_a_model_file.txt"
|
||||||
|
|
||||||
|
file4.parent.mkdir(parents=True)
|
||||||
|
for file in [file1, file2, file3, file4, file5]:
|
||||||
|
with open(file, "w") as f:
|
||||||
|
f.write("")
|
||||||
|
|
||||||
|
def on_model_found_callback(path: Path) -> bool:
|
||||||
|
on_model_found_called_with.add(path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
search.on_model_found = on_model_found_callback
|
||||||
|
|
||||||
|
expected = {file1, file2, file3, file4}
|
||||||
|
found = search.search(tmp_path)
|
||||||
|
|
||||||
|
assert on_model_found_called_with == expected
|
||||||
|
assert found == expected
|
||||||
|
assert search.stats.models_found == 4
|
||||||
|
assert search.stats.models_filtered == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_search_filters_by_on_model_found(model_search: tuple[ModelSearch, Path]):
|
||||||
|
search, tmp_path = model_search
|
||||||
|
on_model_found_called_with: set[Path] = set()
|
||||||
|
|
||||||
|
file1 = tmp_path / "file1.ckpt"
|
||||||
|
file2 = tmp_path / "file2.ckpt" # explicitly ignored
|
||||||
|
|
||||||
|
for file in [file1, file2]:
|
||||||
|
with open(file, "w") as f:
|
||||||
|
f.write("")
|
||||||
|
|
||||||
|
def on_model_found_callback(path: Path) -> bool:
|
||||||
|
if path == file2:
|
||||||
|
return False
|
||||||
|
on_model_found_called_with.add(path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
search.on_model_found = on_model_found_callback
|
||||||
|
|
||||||
|
expected = {file1}
|
||||||
|
found = search.search(tmp_path)
|
||||||
|
|
||||||
|
assert on_model_found_called_with == expected
|
||||||
|
assert found == expected
|
||||||
|
assert search.stats.models_filtered == 1
|
||||||
|
assert search.stats.models_found == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_search_handles_diffusers_model_dirs(model_search: tuple[ModelSearch, Path]):
|
||||||
|
search, tmp_path = model_search
|
||||||
|
on_model_found_called_with: set[Path] = set()
|
||||||
|
|
||||||
|
diffusers_dir = tmp_path / "diffusers_dir"
|
||||||
|
diffusers_dir_entry_point = diffusers_dir / "model_index.json"
|
||||||
|
diffusers_dir.mkdir()
|
||||||
|
with open(diffusers_dir_entry_point, "w") as f:
|
||||||
|
f.write("")
|
||||||
|
|
||||||
|
nested_diffusers_dir = tmp_path / "subfolder" / "nested_diffusers_dir"
|
||||||
|
nested_diffusers_dir_entry_point = nested_diffusers_dir / "model_index.json"
|
||||||
|
nested_diffusers_dir_ignore_me_file = nested_diffusers_dir / "ignore_me.ckpt" # totally skipped
|
||||||
|
nested_diffusers_dir.mkdir(parents=True)
|
||||||
|
with open(nested_diffusers_dir_entry_point, "w") as f:
|
||||||
|
f.write("")
|
||||||
|
with open(nested_diffusers_dir_ignore_me_file, "w") as f:
|
||||||
|
f.write("")
|
||||||
|
|
||||||
|
not_a_diffusers_dir = tmp_path / "not_a_diffusers_dir"
|
||||||
|
not_a_diffusers_dir_entry_point = not_a_diffusers_dir / "not_model_index.json"
|
||||||
|
not_a_diffusers_dir.mkdir()
|
||||||
|
with open(not_a_diffusers_dir_entry_point, "w") as f:
|
||||||
|
f.write("")
|
||||||
|
|
||||||
|
def on_model_found_callback(path: Path) -> bool:
|
||||||
|
on_model_found_called_with.add(path)
|
||||||
|
return True
|
||||||
|
|
||||||
|
search.on_model_found = on_model_found_callback
|
||||||
|
|
||||||
|
expected = {diffusers_dir, nested_diffusers_dir}
|
||||||
|
found = search.search(tmp_path)
|
||||||
|
|
||||||
|
assert found == expected
|
||||||
|
assert on_model_found_called_with == expected
|
||||||
|
assert search.stats.models_found == 2
|
||||||
|
assert search.stats.models_filtered == 2
|
Loading…
Reference in New Issue
Block a user