From 014be0ab67347137d38de9ec01a3384a338f817b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 29 Feb 2024 21:34:25 +1100 Subject: [PATCH] feat(nodes): revise model load API args --- invokeai/app/invocations/ip_adapter.py | 2 +- .../app/services/shared/invocation_context.py | 26 +++++++++---------- invokeai/app/util/ti_utils.py | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 15e254010b..bebdc29b86 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -93,7 +93,7 @@ class IPAdapterInvocation(BaseInvocation): image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_models = context.models.search_by_attrs( - model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision + name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision ) assert len(image_encoder_models) == 1 image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 31064a5e7c..cbbe3216cc 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -289,7 +289,7 @@ class ModelsInterface(InvocationContextInterface): ) def load_by_attrs( - self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None + self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None ) -> LoadedModel: """ Loads a model by its attributes. @@ -300,10 +300,10 @@ class ModelsInterface(InvocationContextInterface): :param submodel: For main (pipeline models), the submodel to fetch """ return self._services.model_manager.load_model_by_attr( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=submodel, + model_name=name, + base_model=base, + model_type=type, + submodel=submodel_type, context_data=self._data, ) @@ -333,10 +333,10 @@ class ModelsInterface(InvocationContextInterface): def search_by_attrs( self, - model_name: Optional[str] = None, - base_model: Optional[BaseModelType] = None, - model_type: Optional[ModelType] = None, - model_format: Optional[ModelFormat] = None, + name: Optional[str] = None, + base: Optional[BaseModelType] = None, + type: Optional[ModelType] = None, + format: Optional[ModelFormat] = None, ) -> list[AnyModelConfig]: """ Searches for models by attributes. @@ -348,10 +348,10 @@ class ModelsInterface(InvocationContextInterface): """ return self._services.model_manager.store.search_by_attr( - model_name=model_name, - base_model=base_model, - model_type=model_type, - model_format=model_format, + model_name=name, + base_model=base, + model_type=type, + model_format=format, ) diff --git a/invokeai/app/util/ti_utils.py b/invokeai/app/util/ti_utils.py index b5c884c9b7..d204a40183 100644 --- a/invokeai/app/util/ti_utils.py +++ b/invokeai/app/util/ti_utils.py @@ -30,7 +30,7 @@ def generate_ti_list( except UnknownModelException: try: loaded_model = context.models.load_by_attrs( - model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion + name=name_or_key, base=base, type=ModelType.TextualInversion ) model = loaded_model.model assert isinstance(model, TextualInversionModelRaw)