mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
afd9ae7712
These were added as a hold-me-over for the nodes API changes, no longer needed. A followup commit will fix the nodes API to not rely on these.
64 lines
2.5 KiB
Python
64 lines
2.5 KiB
Python
import contextlib
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
|
from invokeai.app.services.model_records import UnknownModelException
|
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def torch_device():
|
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
def install_and_load_model(
|
|
model_manager: ModelManagerServiceBase,
|
|
model_path_id_or_url: Union[str, Path],
|
|
model_name: str,
|
|
base_model: BaseModelType,
|
|
model_type: ModelType,
|
|
submodel_type: Optional[SubModelType] = None,
|
|
) -> LoadedModel:
|
|
"""Install a model if it is not already installed, then get the LoadedModel for that model.
|
|
|
|
This is intended as a utility function for tests.
|
|
|
|
Args:
|
|
mm2_model_manager (ModelManagerServiceBase): The model manager
|
|
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
|
|
is not already installed.
|
|
model_name (str): The model name, forwarded to ModelManager.get_model(...).
|
|
base_model (BaseModelType): The base model, forwarded to ModelManager.get_model(...).
|
|
model_type (ModelType): The model type, forwarded to ModelManager.get_model(...).
|
|
submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...).
|
|
|
|
Returns:
|
|
LoadedModelInfo
|
|
"""
|
|
# If the requested model is already installed, return its LoadedModel
|
|
with contextlib.suppress(UnknownModelException):
|
|
# TODO: Replace with wrapper call
|
|
configs = model_manager.store.search_by_attr(
|
|
model_name=model_name, base_model=base_model, model_type=model_type
|
|
)
|
|
loaded_model: LoadedModel = model_manager.load.load_model(configs[0])
|
|
return loaded_model
|
|
|
|
# Install the requested model.
|
|
job = model_manager.install.heuristic_import(model_path_id_or_url)
|
|
model_manager.install.wait_for_job(job, timeout=10)
|
|
assert job.complete
|
|
|
|
try:
|
|
loaded_model = model_manager.load.load_model(job.config_out)
|
|
return loaded_model
|
|
except UnknownModelException as e:
|
|
raise Exception(
|
|
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
|
|
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
|
|
)
|