mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Model Manager Refactor: Install remote models and store their tags and other metadata (#5361)
* add basic functionality for model metadata fetching from hf and civitai * add storage * start unit tests * add unit tests and documentation * add missing dependency for pytests * remove redundant fetch; add modified/published dates; updated docs * add code to select diffusers files based on the variant type * implement Civitai installs * make huggingface parallel downloading work * add unit tests for model installation manager - Fixed race condition on selection of download destination path - Add fixtures common to several model_manager_2 unit tests - Added dummy model files for testing diffusers and safetensors downloading/probing - Refactored code for selecting proper variant from list of huggingface repo files - Regrouped ordering of methods in model_install_default.py * improve Civitai model downloading - Provide a better error message when Civitai requires an access token (doesn't give a 403 forbidden, but redirects to the HTML of an authorization page -- arrgh) - Handle case of Civitai providing a primary download link plus additional links for VAEs, config files, etc * add routes for retrieving metadata and tags * code tidying and documentation * fix ruff errors * add file needed to maintain test root diretory in repo for unit tests * fix self->cls in classmethod * add pydantic plugin for mypy * use TestSession instead of requests.Session to prevent any internet activity improve logging fix error message formatting fix logging again fix forward vs reverse slash issue in Windows install tests * Several fixes of problems detected during PR review: - Implement cancel_model_install_job and get_model_install_job routes to allow for better control of model download and install. - Fix thread deadlock that occurred after cancelling an install. - Remove unneeded pytest_plugins section from tests/conftest.py - Remove unused _in_terminal_state() from model_install_default. - Remove outdated documentation from several spots. - Add workaround for Civitai API results which don't return correct URL for the default model. * fix docs and tests to match get_job_by_source() rather than get_job() * Update invokeai/backend/model_manager/metadata/fetch/huggingface.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Call CivitaiMetadata.model_validate_json() directly Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Second round of revisions suggested by @ryanjdick: - Fix type mismatch in `list_all_metadata()` route. - Do not have a default value for the model install job id - Remove static class variable declarations from non Pydantic classes - Change `id` field to `model_id` for the sqlite3 `model_tags` table. - Changed AFTER DELETE triggers to ON DELETE CASCADE for the metadata and tags tables. - Made the `id` field of the `model_metadata` table into a primary key to achieve uniqueness. * Code cleanup suggested in PR review: - Narrowed the declaration of the `parts` attribute of the download progress event - Removed auto-conversion of str to Url in Url-containing sources - Fixed handling of `InvalidModelConfigException` - Made unknown sources raise `NotImplementedError` rather than `Exception` - Improved status reporting on cached HuggingFace access tokens * Multiple fixes: - `job.total_size` returns a valid size for locally installed models - new route `list_models` returns a paged summary of model, name, description, tags and other essential info - fix a few type errors * consolidated all invokeai root pytest fixtures into a single location * Update invokeai/backend/model_manager/metadata/metadata_store.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * Small tweaks in response to review comments: - Remove flake8 configuration from pyproject.toml - Use `id` rather than `modelId` for huggingface `ModelInfo` object - Use `last_modified` rather than `LastModified` for huggingface `ModelInfo` object - Add `sha256` field to file metadata downloaded from huggingface - Add `Invoker` argument to the model installer `start()` and `stop()` routines (but made it optional in order to facilitate use of the service outside the API) - Removed redundant `PRAGMA foreign_keys` from metadata store initialization code. * Additional tweaks and minor bug fixes - Fix calculation of aggregate diffusers model size to only count the size of files, not files + directories (which gives different unit test results on different filesystems). - Refactor _get_metadata() and _get_download_urls() to have distinct code paths for Civitai, HuggingFace and URL sources. - Forward the `inplace` flag from the source to the job and added unit test for this. - Attach cached model metadata to the job rather than to the model install service. * fix unit test that was breaking on windows due to CR/LF changing size of test json files * fix ruff formatting * a few last minor fixes before merging: - Turn job `error` and `error_type` into properties derived from the exception. - Add TODO comment about the reason for handling temporary directory destruction manually rather than using tempfile.tmpdir(). * add unit tests for reporting HTTP download errors --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
@ -5,11 +5,10 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@ -19,8 +18,8 @@ TestAdapter.__test__ = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> requests.sessions.Session:
|
||||
sess = requests.Session()
|
||||
def session() -> Session:
|
||||
sess = TestSession()
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
@ -160,7 +159,7 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, capsys) -> None:
|
||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
)
|
||||
@ -191,7 +190,7 @@ def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, ca
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_cancel(tmp_path: Path, session: requests.sessions.Session) -> None:
|
||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
event_bus = DummyEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
|
@ -2,11 +2,12 @@
|
||||
Test the model installer
|
||||
"""
|
||||
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import ValidationError
|
||||
from pydantic.networks import Url
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
@ -14,104 +15,50 @@ from invokeai.app.services.model_install import (
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
|
||||
OS = platform.uname().system
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_file(datadir: Path) -> Path:
|
||||
return datadir / "test_embedding.safetensors"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_config(datadir: Path) -> InvokeAIAppConfig:
|
||||
return InvokeAIAppConfig(
|
||||
root=datadir / "root",
|
||||
models_dir=datadir / "root/models",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(
|
||||
app_config: InvokeAIAppConfig,
|
||||
) -> ModelRecordServiceBase:
|
||||
logger = InvokeAILogger.get_logger(config=app_config)
|
||||
db = create_mock_sqlite_database(app_config, logger)
|
||||
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
|
||||
installer = ModelInstallService(
|
||||
app_config=app_config,
|
||||
record_store=store,
|
||||
event_bus=DummyEventService(),
|
||||
)
|
||||
installer.start()
|
||||
return installer
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
|
||||
|
||||
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
store = installer.record_store
|
||||
def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
store = mm2_installer.record_store
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = installer.register_path(test_file)
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
assert key is not None
|
||||
assert len(key) == 32
|
||||
|
||||
|
||||
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
store = installer.record_store
|
||||
key = installer.register_path(test_file)
|
||||
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record is not None
|
||||
assert model_record.name == "test_embedding"
|
||||
assert model_record.type == ModelType.TextualInversion
|
||||
assert Path(model_record.path) == test_file
|
||||
assert Path(model_record.path) == embedding_file
|
||||
assert model_record.base == BaseModelType("sd-1")
|
||||
assert model_record.description is not None
|
||||
assert model_record.source is not None
|
||||
assert Path(model_record.source) == test_file
|
||||
assert Path(model_record.source) == embedding_file
|
||||
|
||||
|
||||
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
key = None
|
||||
with pytest.raises(ValidationError):
|
||||
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
|
||||
assert key is None
|
||||
|
||||
|
||||
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
|
||||
store = installer.record_store
|
||||
key = installer.register_path(
|
||||
test_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
|
||||
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.register_path(
|
||||
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
|
||||
)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.name == "banana_sushi"
|
||||
@ -119,40 +66,59 @@ def test_registration_meta_override_succeed(installer: ModelInstallServiceBase,
|
||||
assert model_record.current_hash == "New Hash"
|
||||
|
||||
|
||||
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||
store = installer.record_store
|
||||
key = installer.install_path(test_file)
|
||||
def test_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||
assert model_record.source == test_file.as_posix()
|
||||
assert model_record.source == embedding_file.as_posix()
|
||||
|
||||
|
||||
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"fixture_name,size,destination",
|
||||
[
|
||||
("embedding_file", 15440, "sd-1/embedding/test_embedding.safetensors"),
|
||||
("diffusers_dir", 8241 if OS == "Windows" else 7907, "sdxl/main/test-diffusers-main"), # EOL chars
|
||||
],
|
||||
)
|
||||
def test_background_install(
|
||||
mm2_installer: ModelInstallServiceBase,
|
||||
fixture_name: str,
|
||||
size: int,
|
||||
destination: str,
|
||||
mm2_app_config: InvokeAIAppConfig,
|
||||
request: pytest.FixtureRequest,
|
||||
) -> None:
|
||||
"""Note: may want to break this down into several smaller unit tests."""
|
||||
path = test_file
|
||||
path: Path = request.getfixturevalue(fixture_name)
|
||||
description = "Test of metadata assignment"
|
||||
source = LocalModelSource(path=path, inplace=False)
|
||||
job = installer.import_model(source, config={"description": description})
|
||||
job = mm2_installer.import_model(source, config={"description": description})
|
||||
assert job is not None
|
||||
assert isinstance(job, ModelInstallJob)
|
||||
|
||||
# See if job is registered properly
|
||||
assert job in installer.get_job(source)
|
||||
assert job in mm2_installer.get_job_by_source(source)
|
||||
|
||||
# test that the job object tracked installation correctly
|
||||
jobs = installer.wait_for_installs()
|
||||
jobs = mm2_installer.wait_for_installs()
|
||||
assert len(jobs) > 0
|
||||
my_job = [x for x in jobs if x.source == source]
|
||||
assert len(my_job) == 1
|
||||
assert my_job[0].status == InstallStatus.COMPLETED
|
||||
assert job == my_job[0]
|
||||
assert job.status == InstallStatus.COMPLETED
|
||||
assert job.total_bytes == size
|
||||
|
||||
# test that the expected events were issued
|
||||
bus = installer.event_bus
|
||||
assert bus is not None # sigh - ruff is a stickler for type checking
|
||||
assert isinstance(bus, DummyEventService)
|
||||
bus = mm2_installer.event_bus
|
||||
assert bus
|
||||
assert hasattr(bus, "events")
|
||||
|
||||
assert len(bus.events) == 2
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert "model_install_started" in event_names
|
||||
assert "model_install_running" in event_names
|
||||
assert "model_install_completed" in event_names
|
||||
assert Path(bus.events[0].payload["source"]) == source
|
||||
assert Path(bus.events[1].payload["source"]) == source
|
||||
@ -160,41 +126,134 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
|
||||
assert key is not None
|
||||
|
||||
# see if the thing actually got installed at the expected location
|
||||
model_record = installer.record_store.get_model(key)
|
||||
model_record = mm2_installer.record_store.get_model(key)
|
||||
assert model_record is not None
|
||||
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
|
||||
assert Path(app_config.models_dir / model_record.path).exists()
|
||||
assert model_record.path == destination
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
|
||||
# see if metadata was properly passed through
|
||||
assert model_record.description == description
|
||||
|
||||
# see if job filtering works
|
||||
assert mm2_installer.get_job_by_source(source)[0] == job
|
||||
|
||||
# see if prune works properly
|
||||
installer.prune_jobs()
|
||||
assert not installer.get_job(source)
|
||||
mm2_installer.prune_jobs()
|
||||
assert not mm2_installer.get_job_by_source(source)
|
||||
|
||||
|
||||
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||
store = installer.record_store
|
||||
key = installer.install_path(test_file)
|
||||
def test_not_inplace_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
source = LocalModelSource(path=embedding_file, inplace=False)
|
||||
job = mm2_installer.import_model(source)
|
||||
mm2_installer.wait_for_installs()
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) != embedding_file
|
||||
assert Path(mm2_app_config.models_dir / job.config_out.path).exists()
|
||||
|
||||
|
||||
def test_inplace_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
source = LocalModelSource(path=embedding_file, inplace=True)
|
||||
job = mm2_installer.import_model(source)
|
||||
mm2_installer.wait_for_installs()
|
||||
assert job is not None
|
||||
assert job.config_out is not None
|
||||
assert Path(job.config_out.path) == embedding_file
|
||||
|
||||
|
||||
def test_delete_install(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert Path(app_config.models_dir / model_record.path).exists()
|
||||
assert test_file.exists() # original should still be there after installation
|
||||
installer.delete(key)
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
assert embedding_file.exists() # original should still be there after installation
|
||||
mm2_installer.delete(key)
|
||||
assert not Path(
|
||||
app_config.models_dir / model_record.path
|
||||
mm2_app_config.models_dir / model_record.path
|
||||
).exists() # after deletion, installed copy should not exist
|
||||
assert test_file.exists() # but original should still be there
|
||||
assert embedding_file.exists() # but original should still be there
|
||||
with pytest.raises(UnknownModelException):
|
||||
store.get_model(key)
|
||||
|
||||
|
||||
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
|
||||
store = installer.record_store
|
||||
key = installer.register_path(test_file)
|
||||
def test_delete_register(
|
||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
|
||||
) -> None:
|
||||
store = mm2_installer.record_store
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert Path(app_config.models_dir / model_record.path).exists()
|
||||
assert test_file.exists() # original should still be there after installation
|
||||
installer.delete(key)
|
||||
assert Path(app_config.models_dir / model_record.path).exists()
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
assert embedding_file.exists() # original should still be there after installation
|
||||
mm2_installer.delete(key)
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
with pytest.raises(UnknownModelException):
|
||||
store.get_model(key)
|
||||
|
||||
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert store is not None
|
||||
assert bus is not None
|
||||
assert hasattr(bus, "events") # the dummy event service has this
|
||||
|
||||
job = mm2_installer.import_model(source)
|
||||
assert job.source == source
|
||||
job_list = mm2_installer.wait_for_installs(timeout=10)
|
||||
assert len(job_list) == 1
|
||||
assert job.complete
|
||||
assert job.config_out
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 3
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
|
||||
|
||||
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert isinstance(bus, EventServiceBase)
|
||||
assert store is not None
|
||||
|
||||
job = mm2_installer.import_model(source)
|
||||
job_list = mm2_installer.wait_for_installs(timeout=10)
|
||||
assert len(job_list) == 1
|
||||
assert job.complete
|
||||
assert job.config_out
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert Path(mm2_app_config.models_dir / model_record.path).exists()
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert len(bus.events) >= 3
|
||||
event_names = {x.event_name for x in bus.events}
|
||||
assert event_names == {"model_install_downloading", "model_install_running", "model_install_completed"}
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||
job = mm2_installer.import_model(source)
|
||||
mm2_installer.wait_for_installs(timeout=10)
|
||||
assert job.status == InstallStatus.ERROR
|
||||
assert job.errored
|
||||
assert job.error_type == "HTTPError"
|
||||
assert job.error
|
||||
assert "NOT FOUND" in job.error
|
||||
assert "Traceback" in job.error
|
||||
|
@ -1 +0,0 @@
|
||||
This directory is used by pytest-datadir.
|
@ -1 +0,0 @@
|
||||
Dummy file to establish git path.
|
@ -10,6 +10,7 @@ import pytest
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelRecordServiceSQL,
|
||||
UnknownModelException,
|
||||
@ -22,14 +23,16 @@ from invokeai.backend.model_manager.config import (
|
||||
TextualInversionConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import BaseMetadata
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(
|
||||
datadir: Any,
|
||||
) -> ModelRecordServiceBase:
|
||||
) -> ModelRecordServiceSQL:
|
||||
config = InvokeAIAppConfig(root=datadir)
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = create_mock_sqlite_database(config, logger)
|
||||
@ -268,3 +271,50 @@ def test_filter_2(store: ModelRecordServiceBase):
|
||||
model_name="dup_name1",
|
||||
)
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_summary(mm2_record_store: ModelRecordServiceSQL) -> None:
|
||||
# The fixture provides us with five configs.
|
||||
for x in range(1, 5):
|
||||
key = f"test_config_{x}"
|
||||
name = f"name_{x}"
|
||||
author = f"author_{x}"
|
||||
tags = {f"tag{y}" for y in range(1, x)}
|
||||
mm2_record_store.metadata_store.add_metadata(
|
||||
model_key=key, metadata=BaseMetadata(name=name, author=author, tags=tags)
|
||||
)
|
||||
# sanity check that the tags sent in all right
|
||||
assert mm2_record_store.get_metadata("test_config_3").tags == {"tag1", "tag2"}
|
||||
assert mm2_record_store.get_metadata("test_config_4").tags == {"tag1", "tag2", "tag3"}
|
||||
|
||||
# get summary
|
||||
summary1 = mm2_record_store.list_models(page=0, per_page=100)
|
||||
assert summary1.page == 0
|
||||
assert summary1.pages == 1
|
||||
assert summary1.per_page == 100
|
||||
assert summary1.total == 5
|
||||
assert len(summary1.items) == 5
|
||||
assert summary1.items[0].name == "test5" # lora / sd-1 / diffusers / test5
|
||||
|
||||
# find test_config_3
|
||||
config3 = [x for x in summary1.items if x.key == "test_config_3"][0]
|
||||
assert config3.description == "This is test 3"
|
||||
assert config3.tags == {"tag1", "tag2"}
|
||||
|
||||
# find test_config_5
|
||||
config5 = [x for x in summary1.items if x.key == "test_config_5"][0]
|
||||
assert config5.tags == set()
|
||||
assert config5.description == ""
|
||||
|
||||
# test paging
|
||||
summary2 = mm2_record_store.list_models(page=1, per_page=2)
|
||||
assert summary2.page == 1
|
||||
assert summary2.per_page == 2
|
||||
assert summary2.pages == 3
|
||||
assert summary1.items[2].name == summary2.items[0].name
|
||||
|
||||
# test sorting
|
||||
summary = mm2_record_store.list_models(page=0, per_page=100, order_by=ModelRecordOrderBy.Name)
|
||||
print(summary.items)
|
||||
assert summary.items[0].name == "model1"
|
||||
assert summary.items[-1].name == "test5"
|
||||
|
1
tests/backend/model_manager_2/data/invokeai_root/README
Normal file
1
tests/backend/model_manager_2/data/invokeai_root/README
Normal file
@ -0,0 +1 @@
|
||||
This is an empty invokeai root that is used as a template for model manager tests.
|
@ -0,0 +1 @@
|
||||
This is a template empty invokeai root directory used to test model management.
|
@ -0,0 +1 @@
|
||||
This is a template empty invokeai root directory used to test model management.
|
@ -0,0 +1,34 @@
|
||||
{
|
||||
"_class_name": "StableDiffusionXLPipeline",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"_name_or_path": "stabilityai/sdxl-turbo",
|
||||
"force_zeros_for_empty_prompt": true,
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"EulerAncestralDiscreteScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"text_encoder_2": [
|
||||
"transformers",
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"tokenizer_2": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
{
|
||||
"_class_name": "EulerAncestralDiscreteScheduler",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"interpolation_type": "linear",
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "trailing",
|
||||
"trained_betas": null
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/text_encoder",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.35.0",
|
||||
"vocab_size": 49408
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/text_encoder_2",
|
||||
"architectures": [
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 5120,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 1280,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.35.0",
|
||||
"vocab_size": 49408
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "!",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
{
|
||||
"_class_name": "UNet2DConditionModel",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/unet",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": "text_time",
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": 256,
|
||||
"attention_head_dim": [
|
||||
5,
|
||||
10,
|
||||
20
|
||||
],
|
||||
"attention_type": "default",
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280
|
||||
],
|
||||
"center_input_sample": false,
|
||||
"class_embed_type": null,
|
||||
"class_embeddings_concat": false,
|
||||
"conv_in_kernel": 3,
|
||||
"conv_out_kernel": 3,
|
||||
"cross_attention_dim": 2048,
|
||||
"cross_attention_norm": null,
|
||||
"down_block_types": [
|
||||
"DownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"dropout": 0.0,
|
||||
"dual_cross_attention": false,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_only_cross_attention": null,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"out_channels": 4,
|
||||
"projection_class_embeddings_input_dim": 2816,
|
||||
"resnet_out_scale_factor": 1.0,
|
||||
"resnet_skip_time_act": false,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"reverse_transformer_layers_per_block": null,
|
||||
"sample_size": 64,
|
||||
"time_cond_proj_dim": null,
|
||||
"time_embedding_act_fn": null,
|
||||
"time_embedding_dim": null,
|
||||
"time_embedding_type": "positional",
|
||||
"timestep_post_act": null,
|
||||
"transformer_layers_per_block": [
|
||||
1,
|
||||
2,
|
||||
10
|
||||
],
|
||||
"up_block_types": [
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"UpBlock2D"
|
||||
],
|
||||
"upcast_attention": null,
|
||||
"use_linear_projection": true
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
{
|
||||
"_class_name": "AutoencoderKL",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/vae",
|
||||
"act_fn": "silu",
|
||||
"block_out_channels": [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D"
|
||||
],
|
||||
"force_upcast": true,
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 3,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 0.13025,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D"
|
||||
]
|
||||
}
|
265
tests/backend/model_manager_2/model_manager_2_fixtures.py
Normal file
265
tests/backend/model_manager_2/model_manager_2_fixtures.py
Normal file
@ -0,0 +1,265 @@
|
||||
# Fixtures to support testing of the model_manager v2 installer, metadata and record store
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
||||
RepoCivitaiModelMetadata1,
|
||||
RepoCivitaiVersionMetadata1,
|
||||
RepoHFMetadata1,
|
||||
RepoHFMetadata1_nofp16,
|
||||
RepoHFModelJson1,
|
||||
)
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
|
||||
|
||||
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
|
||||
@pytest.fixture
|
||||
def mm2_root_dir(tmp_path_factory) -> Path:
|
||||
root_template = Path(__file__).resolve().parent / "data" / "invokeai_root"
|
||||
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
|
||||
shutil.copytree(root_template, temp_dir)
|
||||
return temp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_model_files(tmp_path_factory) -> Path:
|
||||
root_template = Path(__file__).resolve().parent / "data" / "test_files"
|
||||
temp_dir: Path = tmp_path_factory.mktemp("data") / "test_files"
|
||||
shutil.copytree(root_template, temp_dir)
|
||||
return temp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_file(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test_embedding.safetensors"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test-diffusers-main"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
app_config = InvokeAIAppConfig(
|
||||
root=mm2_root_dir,
|
||||
models_dir=mm2_root_dir / "models",
|
||||
)
|
||||
return app_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
store = ModelRecordServiceSQL(db)
|
||||
# add five simple config records to the database
|
||||
raw1 = {
|
||||
"path": "/tmp/foo1",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test2",
|
||||
"base": BaseModelType("sd-2"),
|
||||
"type": ModelType("vae"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "stabilityai/sdxl-vae",
|
||||
}
|
||||
raw2 = {
|
||||
"path": "/tmp/foo2.ckpt",
|
||||
"name": "model1",
|
||||
"format": ModelFormat("checkpoint"),
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": "main",
|
||||
"config": "/tmp/foo.yaml",
|
||||
"variant": "normal",
|
||||
"original_hash": "111222333444",
|
||||
"source": "https://civitai.com/models/206883/split",
|
||||
}
|
||||
raw3 = {
|
||||
"path": "/tmp/foo3",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test3",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("main"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author3/model3",
|
||||
"description": "This is test 3",
|
||||
}
|
||||
raw4 = {
|
||||
"path": "/tmp/foo4",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test4",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("lora"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author4/model4",
|
||||
}
|
||||
raw5 = {
|
||||
"path": "/tmp/foo5",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test5",
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": ModelType("lora"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author4/model5",
|
||||
}
|
||||
store.add_model("test_config_1", raw1)
|
||||
store.add_model("test_config_2", raw2)
|
||||
store.add_model("test_config_3", raw3)
|
||||
store.add_model("test_config_4", raw4)
|
||||
store.add_model("test_config_5", raw5)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
||||
db = mm2_record_store._db # to ensure we are sharing the same database
|
||||
return ModelMetadataStore(db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
"""This fixtures defines a series of mock URLs for testing download and installation."""
|
||||
sess = TestSession()
|
||||
sess.mount(
|
||||
"https://test.com/missing_model.safetensors",
|
||||
TestAdapter(
|
||||
b"missing",
|
||||
status=404,
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo-nofp16",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1_nofp16,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1_nofp16)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/model-versions/242807",
|
||||
TestAdapter(
|
||||
RepoCivitaiVersionMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiVersionMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/models/215485",
|
||||
TestAdapter(
|
||||
RepoCivitaiModelMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiModelMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/model_index.json",
|
||||
TestAdapter(
|
||||
RepoHFModelJson1,
|
||||
headers={
|
||||
"Content-Length": len(RepoHFModelJson1),
|
||||
},
|
||||
),
|
||||
)
|
||||
with open(embedding_file, "rb") as f:
|
||||
data = f.read() # file is small - just 15K
|
||||
sess.mount(
|
||||
"https://www.test.foo/download/test_embedding.safetensors",
|
||||
TestAdapter(data, headers={"Content-Type": "application/octet-stream", "Content-Length": len(data)}),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
for root, _, files in os.walk(diffusers_dir):
|
||||
for name in files:
|
||||
path = Path(root, name)
|
||||
url_base = path.relative_to(diffusers_dir).as_posix()
|
||||
url = f"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/{url_base}"
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
sess.mount(
|
||||
url,
|
||||
TestAdapter(
|
||||
data,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Content-Length": len(data),
|
||||
},
|
||||
),
|
||||
)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db)
|
||||
metadata_store = ModelMetadataStore(db)
|
||||
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=download_queue,
|
||||
metadata_store=metadata_store,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
return installer
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,201 @@
|
||||
"""
|
||||
Test model metadata fetching and storage.
|
||||
"""
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import HttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
CivitaiMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from invokeai.backend.model_manager.util import select_hf_files
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
tags = {"text-to-image", "diffusers"}
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags=tags,
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
||||
assert input_metadata == output_metadata
|
||||
with pytest.raises(UnknownMetadataException):
|
||||
mm2_metadata_store.add_metadata("unknown_key", input_metadata)
|
||||
assert mm2_metadata_store.list_tags() == tags
|
||||
|
||||
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
input_metadata.name = "new-name"
|
||||
mm2_metadata_store.update_metadata("test_config_1", input_metadata)
|
||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
||||
assert output_metadata.name == "new-name"
|
||||
assert input_metadata == output_metadata
|
||||
|
||||
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
metadata1 = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata2 = HuggingFaceMetadata(
|
||||
name="model2",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers", "community-contributed"},
|
||||
id="author2/model2",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata3 = HuggingFaceMetadata(
|
||||
name="model3",
|
||||
author="author3",
|
||||
tags={"text-to-image", "checkpoint", "community-contributed"},
|
||||
id="author3/model3",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", metadata1)
|
||||
mm2_metadata_store.add_metadata("test_config_2", metadata2)
|
||||
mm2_metadata_store.add_metadata("test_config_3", metadata3)
|
||||
|
||||
matches = mm2_metadata_store.search_by_author("stabilityai")
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
matches = mm2_metadata_store.search_by_author("Sherlock Holmes")
|
||||
assert not matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_name("model3")
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"text-to-image"})
|
||||
assert len(matches) == 3
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"})
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"})
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
# does the tag table update correctly?
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert not matches
|
||||
assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"}
|
||||
metadata3.tags.add("licensed-for-commercial-use")
|
||||
mm2_metadata_store.update_metadata("test_config_3", metadata3)
|
||||
assert mm2_metadata_store.list_tags() == {
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"community-contributed",
|
||||
"checkpoint",
|
||||
"licensed-for-commercial-use",
|
||||
}
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
||||
fetcher = CivitaiMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
|
||||
assert isinstance(metadata, CivitaiMetadata)
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
||||
|
||||
def test_metadata_hf_fetch(mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
assert metadata.author == "test_author" # this is not the same as the original
|
||||
assert metadata.files
|
||||
assert metadata.tags == {
|
||||
"diffusers",
|
||||
"onnx",
|
||||
"safetensors",
|
||||
"text-to-image",
|
||||
"license:other",
|
||||
"has_space",
|
||||
"diffusers:StableDiffusionXLPipeline",
|
||||
"region:us",
|
||||
}
|
||||
|
||||
|
||||
def test_metadata_hf_filter(mm2_session: Session) -> None:
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
files = [x.path for x in metadata.files]
|
||||
fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files
|
||||
|
||||
fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files
|
||||
|
||||
onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files
|
||||
|
||||
default_files = select_hf_files.filter_files(files)
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files
|
||||
|
||||
openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino"))
|
||||
print(openvino_files)
|
||||
assert len(openvino_files) == 0
|
||||
|
||||
flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax"))
|
||||
print(flax_files)
|
||||
assert not flax_files
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(
|
||||
HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16")
|
||||
)
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
files = [x.path for x in metadata.files]
|
||||
filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
||||
assert (
|
||||
Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files
|
||||
) # confirm that default is returned
|
||||
assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files
|
||||
|
||||
|
||||
def test_metadata_hf_urls(mm2_session: Session) -> None:
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
239
tests/backend/model_manager_2/util/test_hf_model_select.py
Normal file
239
tests/backend/model_manager_2/util/test_hf_model_select.py
Normal file
@ -0,0 +1,239 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.util.select_hf_files import filter_files
|
||||
|
||||
|
||||
# This is the full list of model paths returned by the HF API for sdxl-base
|
||||
@pytest.fixture
|
||||
def sdxl_base_files() -> List[Path]:
|
||||
return [
|
||||
Path(x)
|
||||
for x in [
|
||||
".gitattributes",
|
||||
"01.png",
|
||||
"LICENSE.md",
|
||||
"README.md",
|
||||
"comparison.png",
|
||||
"model_index.json",
|
||||
"pipeline.png",
|
||||
"scheduler/scheduler_config.json",
|
||||
"sd_xl_base_1.0.safetensors",
|
||||
"sd_xl_base_1.0_0.9vae.safetensors",
|
||||
"sd_xl_offset_example-lora_1.0.safetensors",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/flax_model.msgpack",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder/model.onnx",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder/openvino_model.bin",
|
||||
"text_encoder/openvino_model.xml",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/flax_model.msgpack",
|
||||
"text_encoder_2/model.fp16.safetensors",
|
||||
"text_encoder_2/model.onnx",
|
||||
"text_encoder_2/model.onnx_data",
|
||||
"text_encoder_2/model.safetensors",
|
||||
"text_encoder_2/openvino_model.bin",
|
||||
"text_encoder_2/openvino_model.xml",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_flax_model.msgpack",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"unet/model.onnx",
|
||||
"unet/model.onnx_data",
|
||||
"unet/openvino_model.bin",
|
||||
"unet/openvino_model.xml",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_flax_model.msgpack",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae_1_0/diffusion_pytorch_model.safetensors",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/model.onnx",
|
||||
"vae_decoder/openvino_model.bin",
|
||||
"vae_decoder/openvino_model.xml",
|
||||
"vae_encoder/config.json",
|
||||
"vae_encoder/model.onnx",
|
||||
"vae_encoder/openvino_model.bin",
|
||||
"vae_encoder/openvino_model.xml",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
# This are what we expect to get when various diffusers variants are requested
|
||||
@pytest.mark.parametrize(
|
||||
"variant,expected_list",
|
||||
[
|
||||
(
|
||||
None,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.DEFAULT,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.OPENVINO,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/openvino_model.bin",
|
||||
"text_encoder/openvino_model.xml",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/openvino_model.bin",
|
||||
"text_encoder_2/openvino_model.xml",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/openvino_model.bin",
|
||||
"unet/openvino_model.xml",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/openvino_model.bin",
|
||||
"vae_decoder/openvino_model.xml",
|
||||
"vae_encoder/config.json",
|
||||
"vae_encoder/openvino_model.bin",
|
||||
"vae_encoder/openvino_model.xml",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FP16,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.fp16.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.fp16.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.ONNX,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.onnx",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.onnx",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/model.onnx",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/model.onnx",
|
||||
"vae_encoder/config.json",
|
||||
"vae_encoder/model.onnx",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FLAX,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/flax_model.msgpack",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/flax_model.msgpack",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_flax_model.msgpack",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_flax_model.msgpack",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None:
|
||||
print(f"testing variant {variant}")
|
||||
filtered_files = filter_files(sdxl_base_files, variant)
|
||||
assert set(filtered_files) == {Path(x) for x in expected_list}
|
@ -1,6 +1,7 @@
|
||||
# conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory
|
||||
# without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html)
|
||||
|
||||
|
||||
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
|
||||
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
|
||||
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401
|
||||
|
Reference in New Issue
Block a user