feat(nodes): add sd_model_loader node

Loads any pipeline model.

Also introduced is `PipelineModelField`, which includes a model name and base model.
This commit is contained in:
psychedelicious 2023-06-22 17:36:05 +10:00
parent b937b7da01
commit 42a59aa147

View File

@ -43,6 +43,117 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on #fmt: on
class PipelineModelField(BaseModel):
"""Pipeline model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class SDModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd_model_loader"] = "sd_model_loader"
model: PipelineModelField = Field(description="The model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model": "model"
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.Pipeline
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
)
)
class SD1ModelLoaderInvocation(BaseInvocation): class SD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model.""" """Loading submodels of selected model."""