feat(nodes): update invocation context for mm2, update nodes model usage

This commit is contained in:
psychedelicious
2024-02-15 20:43:41 +11:00
parent 88d6de4101
commit 539570cc7a
9 changed files with 141 additions and 147 deletions

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from PIL.Image import Image
@ -12,8 +13,9 @@ from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.model_manager import LoadedModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -259,45 +261,95 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
def exists(self, key: str) -> bool:
"""
Checks if a model exists.
:param model_name: The name of the model to check.
:param base_model: The base model of the model to check.
:param model_type: The type of the model to check.
:param key: The key of the model.
"""
return self._services.model_manager.model_exists(model_name, base_model, model_type)
return self._services.model_manager.store.exists(key)
def load(
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
) -> LoadedModelInfo:
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Loads a model.
:param model_name: The name of the model to get.
:param base_model: The base model of the model to get.
:param model_type: The type of the model to get.
:param submodel: The submodel of the model to get.
:param key: The key of the model.
:param submodel_type: The submodel of the model to get.
:returns: An object representing the loaded model.
"""
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
return self._services.model_manager.get_model(
model_name, base_model, model_type, submodel, context_data=self._context_data
return self._services.model_manager.load.load_model_by_key(
key=key, submodel_type=submodel_type, context_data=self._context_data
)
def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
def load_by_attrs(
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
) -> LoadedModel:
"""
Loads a model by its attributes.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
"""
return self._services.model_manager.load.load_model_by_attr(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
context_data=self._context_data,
)
def get_config(self, key: str) -> AnyModelConfig:
"""
Gets a model's info, an dict-like object.
:param model_name: The name of the model to get.
:param base_model: The base model of the model to get.
:param model_type: The type of the model to get.
:param key: The key of the model.
"""
return self._services.model_manager.model_info(model_name, base_model, model_type)
return self._services.model_manager.store.get_model(key=key)
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
"""
Gets a model's metadata, if it has any.
:param key: The key of the model.
"""
return self._services.model_manager.store.get_metadata(key=key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""
Searches for models by path.
:param path: The path to search for.
"""
return self._services.model_manager.store.search_by_path(path)
def search_by_attrs(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
"""
Searches for models by attributes.
:param model_name: Name of to be fetched.
:param base_model: Base model
:param model_type: Type of the model
:param submodel: For main (pipeline models), the submodel to fetch
"""
return self._services.model_manager.store.search_by_attr(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_format=model_format,
)
class ConfigInterface(InvocationContextInterface):