mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch dangling references to original model manager. - Caught and fixed most dangling references (still checking) - Rename lora, textual_inversion and model_patcher modules - Introduce a RawModel base class to simplfy the Union returned by the model loaders. - Tidy up the model manager 2-related tests. Add useful fixtures, and a finalizer to the queue and installer fixtures that will stop the services and release threads.
This commit is contained in:
committed by
psychedelicious
parent
ba1f8878dd
commit
2ad0752582
@ -5,10 +5,9 @@ from typing import Optional, Union
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.install.model_install_backend import ModelInstall
|
||||
from invokeai.backend.model_management.model_manager import LoadedModelInfo
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -16,31 +15,20 @@ def torch_device():
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_installer():
|
||||
"""A global ModelInstall pytest fixture to be used by many tests."""
|
||||
# HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions
|
||||
# between tests that need to alter the config. For example, some tests change the 'root' directory in the config,
|
||||
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
|
||||
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
|
||||
# a singleton.
|
||||
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
|
||||
|
||||
|
||||
def install_and_load_model(
|
||||
model_installer: ModelInstall,
|
||||
model_manager: ModelManagerServiceBase,
|
||||
model_path_id_or_url: Union[str, Path],
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> LoadedModelInfo:
|
||||
"""Install a model if it is not already installed, then get the LoadedModelInfo for that model.
|
||||
) -> LoadedModel:
|
||||
"""Install a model if it is not already installed, then get the LoadedModel for that model.
|
||||
|
||||
This is intended as a utility function for tests.
|
||||
|
||||
Args:
|
||||
model_installer (ModelInstall): The model installer.
|
||||
mm2_model_manager (ModelManagerServiceBase): The model manager
|
||||
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
|
||||
is not already installed.
|
||||
model_name (str): The model name, forwarded to ModelManager.get_model(...).
|
||||
@ -51,16 +39,23 @@ def install_and_load_model(
|
||||
Returns:
|
||||
LoadedModelInfo
|
||||
"""
|
||||
# If the requested model is already installed, return its LoadedModelInfo.
|
||||
with contextlib.suppress(ModelNotFoundException):
|
||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||
# If the requested model is already installed, return its LoadedModel
|
||||
with contextlib.suppress(UnknownModelException):
|
||||
# TODO: Replace with wrapper call
|
||||
loaded_model: LoadedModel = model_manager.load.load_model_by_attr(
|
||||
model_name=model_name, base_model=base_model, model_type=model_type
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
# Install the requested model.
|
||||
model_installer.heuristic_import(model_path_id_or_url)
|
||||
job = model_manager.install.heuristic_import(model_path_id_or_url)
|
||||
model_manager.install.wait_for_job(job, timeout=10)
|
||||
assert job.complete
|
||||
|
||||
try:
|
||||
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
|
||||
except ModelNotFoundException as e:
|
||||
loaded_model = model_manager.load.load_model_by_config(job.config_out)
|
||||
return loaded_model
|
||||
except UnknownModelException as e:
|
||||
raise Exception(
|
||||
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
|
||||
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
|
||||
|
Reference in New Issue
Block a user