feat(nodes): revise model load API args

This commit is contained in:
psychedelicious 2024-02-29 21:34:25 +11:00
parent 39725e9560
commit 0b0128647b
3 changed files with 15 additions and 15 deletions

View File

@ -93,7 +93,7 @@ class IPAdapterInvocation(BaseInvocation):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_models = context.models.search_by_attrs(
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
assert len(image_encoder_models) == 1
image_encoder_model = CLIPVisionModelField(key=image_encoder_models[0].key)

View File

@ -289,7 +289,7 @@ class ModelsInterface(InvocationContextInterface):
)
def load_by_attrs(
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""
Loads a model by its attributes.
@ -300,10 +300,10 @@ class ModelsInterface(InvocationContextInterface):
:param submodel: For main (pipeline models), the submodel to fetch
"""
return self._services.model_manager.load_model_by_attr(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_name=name,
base_model=base,
model_type=type,
submodel=submodel_type,
context_data=self._data,
)
@ -333,10 +333,10 @@ class ModelsInterface(InvocationContextInterface):
def search_by_attrs(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
name: Optional[str] = None,
base: Optional[BaseModelType] = None,
type: Optional[ModelType] = None,
format: Optional[ModelFormat] = None,
) -> list[AnyModelConfig]:
"""
Searches for models by attributes.
@ -348,10 +348,10 @@ class ModelsInterface(InvocationContextInterface):
"""
return self._services.model_manager.store.search_by_attr(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_format=model_format,
model_name=name,
base_model=base,
model_type=type,
model_format=format,
)

View File

@ -30,7 +30,7 @@ def generate_ti_list(
except UnknownModelException:
try:
loaded_model = context.models.load_by_attrs(
model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion
name=name_or_key, base=base, type=ModelType.TextualInversion
)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)