Model Manager Refactor: Install remote models and store their tags and other metadata (#5361)

* add basic functionality for model metadata fetching from hf and civitai

* add storage

* start unit tests

* add unit tests and documentation

* add missing dependency for pytests

* remove redundant fetch; add modified/published dates; updated docs

* add code to select diffusers files based on the variant type

* implement Civitai installs

* make huggingface parallel downloading work

* add unit tests for model installation manager

- Fixed race condition on selection of download destination path
- Add fixtures common to several model_manager_2 unit tests
- Added dummy model files for testing diffusers and safetensors downloading/probing
- Refactored code for selecting proper variant from list of huggingface repo files
- Regrouped ordering of methods in model_install_default.py

* improve Civitai model downloading

- Provide a better error message when Civitai requires an access token (doesn't give a 403 forbidden, but redirects
  to the HTML of an authorization page -- arrgh)
- Handle case of Civitai providing a primary download link plus additional links for VAEs, config files, etc

* add routes for retrieving metadata and tags

* code tidying and documentation

* fix ruff errors

* add file needed to maintain test root diretory in repo for unit tests

* fix self->cls in classmethod

* add pydantic plugin for mypy

* use TestSession instead of requests.Session to prevent any internet activity

improve logging

fix error message formatting

fix logging again

fix forward vs reverse slash issue in Windows install tests

* Several fixes of problems detected during PR review:

- Implement cancel_model_install_job and get_model_install_job routes
  to allow for better control of model download and install.
- Fix thread deadlock that occurred after cancelling an install.
- Remove unneeded pytest_plugins section from tests/conftest.py
- Remove unused _in_terminal_state() from model_install_default.
- Remove outdated documentation from several spots.
- Add workaround for Civitai API results which don't return correct
  URL for the default model.

* fix docs and tests to match get_job_by_source() rather than get_job()

* Update invokeai/backend/model_manager/metadata/fetch/huggingface.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* Call CivitaiMetadata.model_validate_json() directly

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* Second round of revisions suggested by @ryanjdick:

- Fix type mismatch in `list_all_metadata()` route.
- Do not have a default value for the model install job id
- Remove static class variable declarations from non Pydantic classes
- Change `id` field to `model_id` for the sqlite3 `model_tags` table.
- Changed AFTER DELETE triggers to ON DELETE CASCADE for the metadata and tags tables.
- Made the `id` field of the `model_metadata` table into a primary key to achieve uniqueness.

* Code cleanup suggested in PR review:

- Narrowed the declaration of the `parts` attribute of the download progress event
- Removed auto-conversion of str to Url in Url-containing sources
- Fixed handling of `InvalidModelConfigException`
- Made unknown sources raise `NotImplementedError` rather than `Exception`
- Improved status reporting on cached HuggingFace access tokens

* Multiple fixes:

- `job.total_size` returns a valid size for locally installed models
- new route `list_models` returns a paged summary of model, name,
  description, tags and other essential info
- fix a few type errors

* consolidated all invokeai root pytest fixtures into a single location

* Update invokeai/backend/model_manager/metadata/metadata_store.py

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

* Small tweaks in response to review comments:

- Remove flake8 configuration from pyproject.toml
- Use `id` rather than `modelId` for huggingface `ModelInfo` object
- Use `last_modified` rather than `LastModified` for huggingface `ModelInfo` object
- Add `sha256` field to file metadata downloaded from huggingface
- Add `Invoker` argument to the model installer `start()` and `stop()` routines
  (but made it optional in order to facilitate use of the service outside the API)
- Removed redundant `PRAGMA foreign_keys` from metadata store initialization code.

* Additional tweaks and minor bug fixes

- Fix calculation of aggregate diffusers model size to only count the
  size of files, not files + directories (which gives different unit test
  results on different filesystems).
- Refactor _get_metadata() and _get_download_urls() to have distinct code paths
  for Civitai, HuggingFace and URL sources.
- Forward the `inplace` flag from the source to the job and added unit test for this.
- Attach cached model metadata to the job rather than to the model install service.

* fix unit test that was breaking on windows due to CR/LF changing size of test json files

* fix ruff formatting

* a few last minor fixes before merging:

- Turn job `error` and `error_type` into properties derived from the exception.
- Add TODO comment about the reason for handling temporary directory destruction
  manually rather than using tempfile.tmpdir().

* add unit tests for reporting HTTP download errors

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Lincoln Stein
2024-01-14 14:54:53 -05:00
committed by GitHub
parent 426a7b900f
commit 4536e4a8b6
67 changed files with 4056 additions and 652 deletions

View File

@ -5,11 +5,10 @@ from pathlib import Path
from typing import Any, Dict, List
import pytest
import requests
from pydantic import BaseModel
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
@ -19,8 +18,8 @@ TestAdapter.__test__ = False
@pytest.fixture
def session() -> requests.sessions.Session:
sess = requests.Session()
def session() -> Session:
sess = TestSession()
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
@ -160,7 +159,7 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.stop()
def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, capsys) -> None:
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue = DownloadQueueService(
requests_session=session,
)
@ -191,7 +190,7 @@ def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, ca
queue.stop()
def test_cancel(tmp_path: Path, session: requests.sessions.Session) -> None:
def test_cancel(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)

View File

@ -2,11 +2,12 @@
Test the model installer
"""
import platform
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel, ValidationError
from pydantic import ValidationError
from pydantic.networks import Url
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
@ -14,104 +15,50 @@ from invokeai.app.services.model_install import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallService,
ModelInstallServiceBase,
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelType
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, ModelFormat, ModelType
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
OS = platform.uname().system
@pytest.fixture
def test_file(datadir: Path) -> Path:
return datadir / "test_embedding.safetensors"
@pytest.fixture
def app_config(datadir: Path) -> InvokeAIAppConfig:
return InvokeAIAppConfig(
root=datadir / "root",
models_dir=datadir / "root/models",
)
@pytest.fixture
def store(
app_config: InvokeAIAppConfig,
) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=app_config)
db = create_mock_sqlite_database(app_config, logger)
store: ModelRecordServiceBase = ModelRecordServiceSQL(db)
return store
@pytest.fixture
def installer(app_config: InvokeAIAppConfig, store: ModelRecordServiceBase) -> ModelInstallServiceBase:
installer = ModelInstallService(
app_config=app_config,
record_store=store,
event_bus=DummyEventService(),
)
installer.start()
return installer
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
def test_registration(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
def test_registration(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = installer.register_path(test_file)
key = mm2_installer.register_path(embedding_file)
assert key is not None
assert len(key) == 32
def test_registration_meta(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
key = installer.register_path(test_file)
def test_registration_meta(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store
key = mm2_installer.register_path(embedding_file)
model_record = store.get_model(key)
assert model_record is not None
assert model_record.name == "test_embedding"
assert model_record.type == ModelType.TextualInversion
assert Path(model_record.path) == test_file
assert Path(model_record.path) == embedding_file
assert model_record.base == BaseModelType("sd-1")
assert model_record.description is not None
assert model_record.source is not None
assert Path(model_record.source) == test_file
assert Path(model_record.source) == embedding_file
def test_registration_meta_override_fail(installer: ModelInstallServiceBase, test_file: Path) -> None:
def test_registration_meta_override_fail(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
key = None
with pytest.raises(ValidationError):
key = installer.register_path(test_file, {"name": "banana_sushi", "type": ModelType("lora")})
key = mm2_installer.register_path(embedding_file, {"name": "banana_sushi", "type": ModelType("lora")})
assert key is None
def test_registration_meta_override_succeed(installer: ModelInstallServiceBase, test_file: Path) -> None:
store = installer.record_store
key = installer.register_path(
test_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
def test_registration_meta_override_succeed(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store
key = mm2_installer.register_path(
embedding_file, {"name": "banana_sushi", "source": "fake/repo_id", "current_hash": "New Hash"}
)
model_record = store.get_model(key)
assert model_record.name == "banana_sushi"
@ -119,40 +66,59 @@ def test_registration_meta_override_succeed(installer: ModelInstallServiceBase,
assert model_record.current_hash == "New Hash"
def test_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
store = installer.record_store
key = installer.install_path(test_file)
def test_install(
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
) -> None:
store = mm2_installer.record_store
key = mm2_installer.install_path(embedding_file)
model_record = store.get_model(key)
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert model_record.source == test_file.as_posix()
assert model_record.source == embedding_file.as_posix()
def test_background_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig) -> None:
@pytest.mark.parametrize(
"fixture_name,size,destination",
[
("embedding_file", 15440, "sd-1/embedding/test_embedding.safetensors"),
("diffusers_dir", 8241 if OS == "Windows" else 7907, "sdxl/main/test-diffusers-main"), # EOL chars
],
)
def test_background_install(
mm2_installer: ModelInstallServiceBase,
fixture_name: str,
size: int,
destination: str,
mm2_app_config: InvokeAIAppConfig,
request: pytest.FixtureRequest,
) -> None:
"""Note: may want to break this down into several smaller unit tests."""
path = test_file
path: Path = request.getfixturevalue(fixture_name)
description = "Test of metadata assignment"
source = LocalModelSource(path=path, inplace=False)
job = installer.import_model(source, config={"description": description})
job = mm2_installer.import_model(source, config={"description": description})
assert job is not None
assert isinstance(job, ModelInstallJob)
# See if job is registered properly
assert job in installer.get_job(source)
assert job in mm2_installer.get_job_by_source(source)
# test that the job object tracked installation correctly
jobs = installer.wait_for_installs()
jobs = mm2_installer.wait_for_installs()
assert len(jobs) > 0
my_job = [x for x in jobs if x.source == source]
assert len(my_job) == 1
assert my_job[0].status == InstallStatus.COMPLETED
assert job == my_job[0]
assert job.status == InstallStatus.COMPLETED
assert job.total_bytes == size
# test that the expected events were issued
bus = installer.event_bus
assert bus is not None # sigh - ruff is a stickler for type checking
assert isinstance(bus, DummyEventService)
bus = mm2_installer.event_bus
assert bus
assert hasattr(bus, "events")
assert len(bus.events) == 2
event_names = [x.event_name for x in bus.events]
assert "model_install_started" in event_names
assert "model_install_running" in event_names
assert "model_install_completed" in event_names
assert Path(bus.events[0].payload["source"]) == source
assert Path(bus.events[1].payload["source"]) == source
@ -160,41 +126,134 @@ def test_background_install(installer: ModelInstallServiceBase, test_file: Path,
assert key is not None
# see if the thing actually got installed at the expected location
model_record = installer.record_store.get_model(key)
model_record = mm2_installer.record_store.get_model(key)
assert model_record is not None
assert model_record.path == "sd-1/embedding/test_embedding.safetensors"
assert Path(app_config.models_dir / model_record.path).exists()
assert model_record.path == destination
assert Path(mm2_app_config.models_dir / model_record.path).exists()
# see if metadata was properly passed through
assert model_record.description == description
# see if job filtering works
assert mm2_installer.get_job_by_source(source)[0] == job
# see if prune works properly
installer.prune_jobs()
assert not installer.get_job(source)
mm2_installer.prune_jobs()
assert not mm2_installer.get_job_by_source(source)
def test_delete_install(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store
key = installer.install_path(test_file)
def test_not_inplace_install(
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
) -> None:
source = LocalModelSource(path=embedding_file, inplace=False)
job = mm2_installer.import_model(source)
mm2_installer.wait_for_installs()
assert job is not None
assert job.config_out is not None
assert Path(job.config_out.path) != embedding_file
assert Path(mm2_app_config.models_dir / job.config_out.path).exists()
def test_inplace_install(
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
) -> None:
source = LocalModelSource(path=embedding_file, inplace=True)
job = mm2_installer.import_model(source)
mm2_installer.wait_for_installs()
assert job is not None
assert job.config_out is not None
assert Path(job.config_out.path) == embedding_file
def test_delete_install(
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
) -> None:
store = mm2_installer.record_store
key = mm2_installer.install_path(embedding_file)
model_record = store.get_model(key)
assert Path(app_config.models_dir / model_record.path).exists()
assert test_file.exists() # original should still be there after installation
installer.delete(key)
assert Path(mm2_app_config.models_dir / model_record.path).exists()
assert embedding_file.exists() # original should still be there after installation
mm2_installer.delete(key)
assert not Path(
app_config.models_dir / model_record.path
mm2_app_config.models_dir / model_record.path
).exists() # after deletion, installed copy should not exist
assert test_file.exists() # but original should still be there
assert embedding_file.exists() # but original should still be there
with pytest.raises(UnknownModelException):
store.get_model(key)
def test_delete_register(installer: ModelInstallServiceBase, test_file: Path, app_config: InvokeAIAppConfig):
store = installer.record_store
key = installer.register_path(test_file)
def test_delete_register(
mm2_installer: ModelInstallServiceBase, embedding_file: Path, mm2_app_config: InvokeAIAppConfig
) -> None:
store = mm2_installer.record_store
key = mm2_installer.register_path(embedding_file)
model_record = store.get_model(key)
assert Path(app_config.models_dir / model_record.path).exists()
assert test_file.exists() # original should still be there after installation
installer.delete(key)
assert Path(app_config.models_dir / model_record.path).exists()
assert Path(mm2_app_config.models_dir / model_record.path).exists()
assert embedding_file.exists() # original should still be there after installation
mm2_installer.delete(key)
assert Path(mm2_app_config.models_dir / model_record.path).exists()
with pytest.raises(UnknownModelException):
store.get_model(key)
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
bus = mm2_installer.event_bus
store = mm2_installer.record_store
assert store is not None
assert bus is not None
assert hasattr(bus, "events") # the dummy event service has this
job = mm2_installer.import_model(source)
assert job.source == source
job_list = mm2_installer.wait_for_installs(timeout=10)
assert len(job_list) == 1
assert job.complete
assert job.config_out
key = job.config_out.key
model_record = store.get_model(key)
assert Path(mm2_app_config.models_dir / model_record.path).exists()
assert len(bus.events) == 3
event_names = [x.event_name for x in bus.events]
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus = mm2_installer.event_bus
store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase)
assert store is not None
job = mm2_installer.import_model(source)
job_list = mm2_installer.wait_for_installs(timeout=10)
assert len(job_list) == 1
assert job.complete
assert job.config_out
key = job.config_out.key
model_record = store.get_model(key)
assert Path(mm2_app_config.models_dir / model_record.path).exists()
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {"model_install_downloading", "model_install_running", "model_install_completed"}
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
job = mm2_installer.import_model(source)
mm2_installer.wait_for_installs(timeout=10)
assert job.status == InstallStatus.ERROR
assert job.errored
assert job.error_type == "HTTPError"
assert job.error
assert "NOT FOUND" in job.error
assert "Traceback" in job.error

View File

@ -1 +0,0 @@
This directory is used by pytest-datadir.

View File

@ -1 +0,0 @@
Dummy file to establish git path.

View File

@ -10,6 +10,7 @@ import pytest
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import (
DuplicateModelException,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelRecordServiceSQL,
UnknownModelException,
@ -22,14 +23,16 @@ from invokeai.backend.model_manager.config import (
TextualInversionConfig,
VaeDiffusersConfig,
)
from invokeai.backend.model_manager.metadata import BaseMetadata
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
def store(
datadir: Any,
) -> ModelRecordServiceBase:
) -> ModelRecordServiceSQL:
config = InvokeAIAppConfig(root=datadir)
logger = InvokeAILogger.get_logger(config=config)
db = create_mock_sqlite_database(config, logger)
@ -268,3 +271,50 @@ def test_filter_2(store: ModelRecordServiceBase):
model_name="dup_name1",
)
assert len(matches) == 1
def test_summary(mm2_record_store: ModelRecordServiceSQL) -> None:
# The fixture provides us with five configs.
for x in range(1, 5):
key = f"test_config_{x}"
name = f"name_{x}"
author = f"author_{x}"
tags = {f"tag{y}" for y in range(1, x)}
mm2_record_store.metadata_store.add_metadata(
model_key=key, metadata=BaseMetadata(name=name, author=author, tags=tags)
)
# sanity check that the tags sent in all right
assert mm2_record_store.get_metadata("test_config_3").tags == {"tag1", "tag2"}
assert mm2_record_store.get_metadata("test_config_4").tags == {"tag1", "tag2", "tag3"}
# get summary
summary1 = mm2_record_store.list_models(page=0, per_page=100)
assert summary1.page == 0
assert summary1.pages == 1
assert summary1.per_page == 100
assert summary1.total == 5
assert len(summary1.items) == 5
assert summary1.items[0].name == "test5" # lora / sd-1 / diffusers / test5
# find test_config_3
config3 = [x for x in summary1.items if x.key == "test_config_3"][0]
assert config3.description == "This is test 3"
assert config3.tags == {"tag1", "tag2"}
# find test_config_5
config5 = [x for x in summary1.items if x.key == "test_config_5"][0]
assert config5.tags == set()
assert config5.description == ""
# test paging
summary2 = mm2_record_store.list_models(page=1, per_page=2)
assert summary2.page == 1
assert summary2.per_page == 2
assert summary2.pages == 3
assert summary1.items[2].name == summary2.items[0].name
# test sorting
summary = mm2_record_store.list_models(page=0, per_page=100, order_by=ModelRecordOrderBy.Name)
print(summary.items)
assert summary.items[0].name == "model1"
assert summary.items[-1].name == "test5"

View File

@ -0,0 +1 @@
This is an empty invokeai root that is used as a template for model manager tests.

View File

@ -0,0 +1 @@
This is a template empty invokeai root directory used to test model management.

View File

@ -0,0 +1 @@
This is a template empty invokeai root directory used to test model management.

View File

@ -0,0 +1,34 @@
{
"_class_name": "StableDiffusionXLPipeline",
"_diffusers_version": "0.23.0",
"_name_or_path": "stabilityai/sdxl-turbo",
"force_zeros_for_empty_prompt": true,
"scheduler": [
"diffusers",
"EulerAncestralDiscreteScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"text_encoder_2": [
"transformers",
"CLIPTextModelWithProjection"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"tokenizer_2": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}

View File

@ -0,0 +1,17 @@
{
"_class_name": "EulerAncestralDiscreteScheduler",
"_diffusers_version": "0.23.0",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"interpolation_type": "linear",
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"sample_max_value": 1.0,
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"timestep_spacing": "trailing",
"trained_betas": null
}

View File

@ -0,0 +1,25 @@
{
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/text_encoder",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float16",
"transformers_version": "4.35.0",
"vocab_size": 49408
}

View File

@ -0,0 +1,25 @@
{
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/text_encoder_2",
"architectures": [
"CLIPTextModelWithProjection"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": 1280,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 5120,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 77,
"model_type": "clip_text_model",
"num_attention_heads": 20,
"num_hidden_layers": 32,
"pad_token_id": 1,
"projection_dim": 1280,
"torch_dtype": "float16",
"transformers_version": "4.35.0",
"vocab_size": 49408
}

View File

@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

View File

@ -0,0 +1,30 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "<|endoftext|>",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

View File

@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "!",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

View File

@ -0,0 +1,38 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "!",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
},
"49406": {
"content": "<|startoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"49407": {
"content": "<|endoftext|>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<|startoftext|>",
"clean_up_tokenization_spaces": true,
"do_lower_case": true,
"eos_token": "<|endoftext|>",
"errors": "replace",
"model_max_length": 77,
"pad_token": "!",
"tokenizer_class": "CLIPTokenizer",
"unk_token": "<|endoftext|>"
}

View File

@ -0,0 +1,73 @@
{
"_class_name": "UNet2DConditionModel",
"_diffusers_version": "0.23.0",
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/unet",
"act_fn": "silu",
"addition_embed_type": "text_time",
"addition_embed_type_num_heads": 64,
"addition_time_embed_dim": 256,
"attention_head_dim": [
5,
10,
20
],
"attention_type": "default",
"block_out_channels": [
320,
640,
1280
],
"center_input_sample": false,
"class_embed_type": null,
"class_embeddings_concat": false,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"cross_attention_dim": 2048,
"cross_attention_norm": null,
"down_block_types": [
"DownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D"
],
"downsample_padding": 1,
"dropout": 0.0,
"dual_cross_attention": false,
"encoder_hid_dim": null,
"encoder_hid_dim_type": null,
"flip_sin_to_cos": true,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 2,
"mid_block_only_cross_attention": null,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_attention_heads": null,
"num_class_embeds": null,
"only_cross_attention": false,
"out_channels": 4,
"projection_class_embeddings_input_dim": 2816,
"resnet_out_scale_factor": 1.0,
"resnet_skip_time_act": false,
"resnet_time_scale_shift": "default",
"reverse_transformer_layers_per_block": null,
"sample_size": 64,
"time_cond_proj_dim": null,
"time_embedding_act_fn": null,
"time_embedding_dim": null,
"time_embedding_type": "positional",
"timestep_post_act": null,
"transformer_layers_per_block": [
1,
2,
10
],
"up_block_types": [
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"UpBlock2D"
],
"upcast_attention": null,
"use_linear_projection": true
}

View File

@ -0,0 +1,32 @@
{
"_class_name": "AutoencoderKL",
"_diffusers_version": "0.23.0",
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/vae",
"act_fn": "silu",
"block_out_channels": [
128,
256,
512,
512
],
"down_block_types": [
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
"force_upcast": true,
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 2,
"norm_num_groups": 32,
"out_channels": 3,
"sample_size": 1024,
"scaling_factor": 0.13025,
"up_block_types": [
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
]
}

View File

@ -0,0 +1,265 @@
# Fixtures to support testing of the model_manager v2 installer, metadata and record store
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
RepoHFMetadata1_nofp16,
RepoHFModelJson1,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
@pytest.fixture
def mm2_root_dir(tmp_path_factory) -> Path:
root_template = Path(__file__).resolve().parent / "data" / "invokeai_root"
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
shutil.copytree(root_template, temp_dir)
return temp_dir
@pytest.fixture
def mm2_model_files(tmp_path_factory) -> Path:
root_template = Path(__file__).resolve().parent / "data" / "test_files"
temp_dir: Path = tmp_path_factory.mktemp("data") / "test_files"
shutil.copytree(root_template, temp_dir)
return temp_dir
@pytest.fixture
def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors"
@pytest.fixture
def diffusers_dir(mm2_model_files: Path) -> Path:
return mm2_model_files / "test-diffusers-main"
@pytest.fixture
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
app_config = InvokeAIAppConfig(
root=mm2_root_dir,
models_dir=mm2_root_dir / "models",
)
return app_config
@pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db)
# add five simple config records to the database
raw1 = {
"path": "/tmp/foo1",
"format": ModelFormat("diffusers"),
"name": "test2",
"base": BaseModelType("sd-2"),
"type": ModelType("vae"),
"original_hash": "111222333444",
"source": "stabilityai/sdxl-vae",
}
raw2 = {
"path": "/tmp/foo2.ckpt",
"name": "model1",
"format": ModelFormat("checkpoint"),
"base": BaseModelType("sd-1"),
"type": "main",
"config": "/tmp/foo.yaml",
"variant": "normal",
"original_hash": "111222333444",
"source": "https://civitai.com/models/206883/split",
}
raw3 = {
"path": "/tmp/foo3",
"format": ModelFormat("diffusers"),
"name": "test3",
"base": BaseModelType("sdxl"),
"type": ModelType("main"),
"original_hash": "111222333444",
"source": "author3/model3",
"description": "This is test 3",
}
raw4 = {
"path": "/tmp/foo4",
"format": ModelFormat("diffusers"),
"name": "test4",
"base": BaseModelType("sdxl"),
"type": ModelType("lora"),
"original_hash": "111222333444",
"source": "author4/model4",
}
raw5 = {
"path": "/tmp/foo5",
"format": ModelFormat("diffusers"),
"name": "test5",
"base": BaseModelType("sd-1"),
"type": ModelType("lora"),
"original_hash": "111222333444",
"source": "author4/model5",
}
store.add_model("test_config_1", raw1)
store.add_model("test_config_2", raw2)
store.add_model("test_config_3", raw3)
store.add_model("test_config_4", raw4)
store.add_model("test_config_5", raw5)
return store
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
db = mm2_record_store._db # to ensure we are sharing the same database
return ModelMetadataStore(db)
@pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
"""This fixtures defines a series of mock URLs for testing download and installation."""
sess = TestSession()
sess.mount(
"https://test.com/missing_model.safetensors",
TestAdapter(
b"missing",
status=404,
),
)
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
TestAdapter(
RepoHFMetadata1,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo-nofp16",
TestAdapter(
RepoHFMetadata1_nofp16,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1_nofp16)},
),
)
sess.mount(
"https://civitai.com/api/v1/model-versions/242807",
TestAdapter(
RepoCivitaiVersionMetadata1,
headers={
"Content-Length": len(RepoCivitaiVersionMetadata1),
},
),
)
sess.mount(
"https://civitai.com/api/v1/models/215485",
TestAdapter(
RepoCivitaiModelMetadata1,
headers={
"Content-Length": len(RepoCivitaiModelMetadata1),
},
),
)
sess.mount(
"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/model_index.json",
TestAdapter(
RepoHFModelJson1,
headers={
"Content-Length": len(RepoHFModelJson1),
},
),
)
with open(embedding_file, "rb") as f:
data = f.read() # file is small - just 15K
sess.mount(
"https://www.test.foo/download/test_embedding.safetensors",
TestAdapter(data, headers={"Content-Type": "application/octet-stream", "Content-Length": len(data)}),
)
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
TestAdapter(
RepoHFMetadata1,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
for root, _, files in os.walk(diffusers_dir):
for name in files:
path = Path(root, name)
url_base = path.relative_to(diffusers_dir).as_posix()
url = f"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/{url_base}"
with open(path, "rb") as f:
data = f.read()
sess.mount(
url,
TestAdapter(
data,
headers={
"Content-Type": "application/json; charset=utf-8",
"Content-Length": len(data),
},
),
)
return sess
@pytest.fixture
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db)
metadata_store = ModelMetadataStore(db)
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=store,
download_queue=download_queue,
metadata_store=metadata_store,
event_bus=events,
session=mm2_session,
)
installer.start()
return installer

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,201 @@
"""
Test model metadata fetching and storage.
"""
import datetime
from pathlib import Path
import pytest
from pydantic.networks import HttpUrl
from requests.sessions import Session
from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.metadata import (
CivitaiMetadata,
CivitaiMetadataFetch,
CommercialUsage,
HuggingFaceMetadata,
HuggingFaceMetadataFetch,
ModelMetadataStore,
UnknownMetadataException,
)
from invokeai.backend.model_manager.util import select_hf_files
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
tags = {"text-to-image", "diffusers"}
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags=tags,
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
assert input_metadata == output_metadata
with pytest.raises(UnknownMetadataException):
mm2_metadata_store.add_metadata("unknown_key", input_metadata)
assert mm2_metadata_store.list_tags() == tags
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
input_metadata.name = "new-name"
mm2_metadata_store.update_metadata("test_config_1", input_metadata)
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
assert output_metadata.name == "new-name"
assert input_metadata == output_metadata
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
metadata1 = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
tags={"text-to-image", "diffusers"},
id="stabilityai/sdxl-vae",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata2 = HuggingFaceMetadata(
name="model2",
author="stabilityai",
tags={"text-to-image", "diffusers", "community-contributed"},
id="author2/model2",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
metadata3 = HuggingFaceMetadata(
name="model3",
author="author3",
tags={"text-to-image", "checkpoint", "community-contributed"},
id="author3/model3",
tag_dict={"license": "other"},
last_modified=datetime.datetime.now(),
)
mm2_metadata_store.add_metadata("test_config_1", metadata1)
mm2_metadata_store.add_metadata("test_config_2", metadata2)
mm2_metadata_store.add_metadata("test_config_3", metadata3)
matches = mm2_metadata_store.search_by_author("stabilityai")
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = mm2_metadata_store.search_by_author("Sherlock Holmes")
assert not matches
matches = mm2_metadata_store.search_by_name("model3")
assert len(matches) == 1
assert "test_config_3" in matches
matches = mm2_metadata_store.search_by_tag({"text-to-image"})
assert len(matches) == 3
matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"})
assert len(matches) == 2
assert "test_config_1" in matches
assert "test_config_2" in matches
matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"})
assert len(matches) == 1
assert "test_config_3" in matches
# does the tag table update correctly?
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert not matches
assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"}
metadata3.tags.add("licensed-for-commercial-use")
mm2_metadata_store.update_metadata("test_config_3", metadata3)
assert mm2_metadata_store.list_tags() == {
"text-to-image",
"diffusers",
"community-contributed",
"checkpoint",
"licensed-for-commercial-use",
}
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
assert len(matches) == 1
def test_metadata_civitai_fetch(mm2_session: Session) -> None:
fetcher = CivitaiMetadataFetch(mm2_session)
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
assert isinstance(metadata, CivitaiMetadata)
assert metadata.id == 215485
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
assert metadata.version_id == 242807
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
def test_metadata_hf_fetch(mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)
assert metadata.author == "test_author" # this is not the same as the original
assert metadata.files
assert metadata.tags == {
"diffusers",
"onnx",
"safetensors",
"text-to-image",
"license:other",
"has_space",
"diffusers:StableDiffusionXLPipeline",
"region:us",
}
def test_metadata_hf_filter(mm2_session: Session) -> None:
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)
files = [x.path for x in metadata.files]
fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files
fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32"))
assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files
onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx"))
assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files
default_files = select_hf_files.filter_files(files)
assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files
openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino"))
print(openvino_files)
assert len(openvino_files) == 0
flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax"))
print(flax_files)
assert not flax_files
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(
HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16")
)
assert isinstance(metadata, HuggingFaceMetadata)
files = [x.path for x in metadata.files]
filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
assert (
Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files
) # confirm that default is returned
assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files
def test_metadata_hf_urls(mm2_session: Session) -> None:
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
assert isinstance(metadata, HuggingFaceMetadata)

View File

@ -0,0 +1,239 @@
from pathlib import Path
from typing import List
import pytest
from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.util.select_hf_files import filter_files
# This is the full list of model paths returned by the HF API for sdxl-base
@pytest.fixture
def sdxl_base_files() -> List[Path]:
return [
Path(x)
for x in [
".gitattributes",
"01.png",
"LICENSE.md",
"README.md",
"comparison.png",
"model_index.json",
"pipeline.png",
"scheduler/scheduler_config.json",
"sd_xl_base_1.0.safetensors",
"sd_xl_base_1.0_0.9vae.safetensors",
"sd_xl_offset_example-lora_1.0.safetensors",
"text_encoder/config.json",
"text_encoder/flax_model.msgpack",
"text_encoder/model.fp16.safetensors",
"text_encoder/model.onnx",
"text_encoder/model.safetensors",
"text_encoder/openvino_model.bin",
"text_encoder/openvino_model.xml",
"text_encoder_2/config.json",
"text_encoder_2/flax_model.msgpack",
"text_encoder_2/model.fp16.safetensors",
"text_encoder_2/model.onnx",
"text_encoder_2/model.onnx_data",
"text_encoder_2/model.safetensors",
"text_encoder_2/openvino_model.bin",
"text_encoder_2/openvino_model.xml",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"tokenizer_2/merges.txt",
"tokenizer_2/special_tokens_map.json",
"tokenizer_2/tokenizer_config.json",
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/diffusion_flax_model.msgpack",
"unet/diffusion_pytorch_model.fp16.safetensors",
"unet/diffusion_pytorch_model.safetensors",
"unet/model.onnx",
"unet/model.onnx_data",
"unet/openvino_model.bin",
"unet/openvino_model.xml",
"vae/config.json",
"vae/diffusion_flax_model.msgpack",
"vae/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.safetensors",
"vae_1_0/config.json",
"vae_1_0/diffusion_pytorch_model.fp16.safetensors",
"vae_1_0/diffusion_pytorch_model.safetensors",
"vae_decoder/config.json",
"vae_decoder/model.onnx",
"vae_decoder/openvino_model.bin",
"vae_decoder/openvino_model.xml",
"vae_encoder/config.json",
"vae_encoder/model.onnx",
"vae_encoder/openvino_model.bin",
"vae_encoder/openvino_model.xml",
]
]
# This are what we expect to get when various diffusers variants are requested
@pytest.mark.parametrize(
"variant,expected_list",
[
(
None,
[
"model_index.json",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.safetensors",
"text_encoder_2/config.json",
"text_encoder_2/model.safetensors",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"tokenizer_2/merges.txt",
"tokenizer_2/special_tokens_map.json",
"tokenizer_2/tokenizer_config.json",
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.safetensors",
"vae_1_0/config.json",
"vae_1_0/diffusion_pytorch_model.safetensors",
],
),
(
ModelRepoVariant.DEFAULT,
[
"model_index.json",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.safetensors",
"text_encoder_2/config.json",
"text_encoder_2/model.safetensors",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"tokenizer_2/merges.txt",
"tokenizer_2/special_tokens_map.json",
"tokenizer_2/tokenizer_config.json",
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.safetensors",
"vae_1_0/config.json",
"vae_1_0/diffusion_pytorch_model.safetensors",
],
),
(
ModelRepoVariant.OPENVINO,
[
"model_index.json",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/openvino_model.bin",
"text_encoder/openvino_model.xml",
"text_encoder_2/config.json",
"text_encoder_2/openvino_model.bin",
"text_encoder_2/openvino_model.xml",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"tokenizer_2/merges.txt",
"tokenizer_2/special_tokens_map.json",
"tokenizer_2/tokenizer_config.json",
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/openvino_model.bin",
"unet/openvino_model.xml",
"vae_decoder/config.json",
"vae_decoder/openvino_model.bin",
"vae_decoder/openvino_model.xml",
"vae_encoder/config.json",
"vae_encoder/openvino_model.bin",
"vae_encoder/openvino_model.xml",
],
),
(
ModelRepoVariant.FP16,
[
"model_index.json",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.fp16.safetensors",
"text_encoder_2/config.json",
"text_encoder_2/model.fp16.safetensors",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"tokenizer_2/merges.txt",
"tokenizer_2/special_tokens_map.json",
"tokenizer_2/tokenizer_config.json",
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.fp16.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.fp16.safetensors",
"vae_1_0/config.json",
"vae_1_0/diffusion_pytorch_model.fp16.safetensors",
],
),
(
ModelRepoVariant.ONNX,
[
"model_index.json",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.onnx",
"text_encoder_2/config.json",
"text_encoder_2/model.onnx",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"tokenizer_2/merges.txt",
"tokenizer_2/special_tokens_map.json",
"tokenizer_2/tokenizer_config.json",
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/model.onnx",
"vae_decoder/config.json",
"vae_decoder/model.onnx",
"vae_encoder/config.json",
"vae_encoder/model.onnx",
],
),
(
ModelRepoVariant.FLAX,
[
"model_index.json",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/flax_model.msgpack",
"text_encoder_2/config.json",
"text_encoder_2/flax_model.msgpack",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"tokenizer_2/merges.txt",
"tokenizer_2/special_tokens_map.json",
"tokenizer_2/tokenizer_config.json",
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/diffusion_flax_model.msgpack",
"vae/config.json",
"vae/diffusion_flax_model.msgpack",
],
),
],
)
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None:
print(f"testing variant {variant}")
filtered_files = filter_files(sdxl_base_files, variant)
assert set(filtered_files) == {Path(x) for x in expected_list}

View File

@ -1,6 +1,7 @@
# conftest.py is a special pytest file. Fixtures defined in this file will be accessible to all tests in this directory
# without needing to explicitly import them. (https://docs.pytest.org/en/6.2.x/fixture.html)
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
from invokeai.backend.util.test_utils import model_installer, torch_device # noqa: F401