feat(nodes): revise model load API args

This commit is contained in:
psychedelicious 2024-02-29 21:34:25 +11:00 committed by Kent Keirsey
parent e5d9f33f7b
commit 014be0ab67
3 changed files with 15 additions and 15 deletions

View File

@ -93,7 +93,7 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs( image_encoder_models = context.models.search_by_attrs(
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
) )
assert len(image_encoder_models) == 1 assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key) image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)

View File

@ -289,7 +289,7 @@ class ModelsInterface(InvocationContextInterface):
) )
def load_by_attrs( def load_by_attrs(
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel: ) -> LoadedModel:
""" """
Loads a model by its attributes. Loads a model by its attributes.
@ -300,10 +300,10 @@ class ModelsInterface(InvocationContextInterface):
:param submodel: For main (pipeline models), the submodel to fetch :param submodel: For main (pipeline models), the submodel to fetch
""" """
return self._services.model_manager.load_model_by_attr( return self._services.model_manager.load_model_by_attr(
model_name=model_name, model_name=name,
base_model=base_model, base_model=base,
model_type=model_type, model_type=type,
submodel=submodel, submodel=submodel_type,
context_data=self._data, context_data=self._data,
) )
@ -333,10 +333,10 @@ class ModelsInterface(InvocationContextInterface):
def search_by_attrs( def search_by_attrs(
self, self,
model_name: Optional[str] = None, name: Optional[str] = None,
base_model: Optional[BaseModelType] = None, base: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None, type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None, format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]: ) -> list[AnyModelConfig]:
""" """
Searches for models by attributes. Searches for models by attributes.
@ -348,10 +348,10 @@ class ModelsInterface(InvocationContextInterface):
""" """
return self._services.model_manager.store.search_by_attr( return self._services.model_manager.store.search_by_attr(
model_name=model_name, model_name=name,
base_model=base_model, base_model=base,
model_type=model_type, model_type=type,
model_format=model_format, model_format=format,
) )

View File

@ -30,7 +30,7 @@ def generate_ti_list(
except UnknownModelException: except UnknownModelException:
try: try:
loaded_model = context.models.load_by_attrs( loaded_model = context.models.load_by_attrs(
model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion name=name_or_key, base=base, type=ModelType.TextualInversion
) )
model = loaded_model.model model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw) assert isinstance(model, TextualInversionModelRaw)