bad implementation of diffusers folder download

This commit is contained in:
Lincoln Stein
2024-05-08 21:21:01 -07:00
parent f211c95dbc
commit b48d4a049d
5 changed files with 64 additions and 13 deletions

View File

@ -14,6 +14,7 @@ from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import (
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
@ -21,7 +22,13 @@ from invokeai.app.services.model_install import (
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
BaseModelType,
InvalidModelConfigException,
ModelFormat,
ModelRepoVariant,
ModelType,
)
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
OS = platform.uname().system
@ -247,7 +254,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus = mm2_installer.event_bus
@ -278,6 +285,17 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
}
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None:
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
job = mm2_installer.download_diffusers_model(source, tmp_path)
mm2_installer.wait_for_installs(timeout=5)
print(job.local_path)
assert job.status == InstallStatus.DOWNLOADS_DONE
assert (tmp_path / "sdxl-turbo").exists()
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
job = mm2_installer.import_model(source)
@ -327,7 +345,7 @@ def test_other_error_during_install(
},
],
)
@pytest.mark.timeout(timeout=40, method="thread")
@pytest.mark.timeout(timeout=20, method="thread")
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
assert "name" in model_params and "type" in model_params