mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add sd_model_loader
node
Loads any pipeline model. Also introduced is `PipelineModelField`, which includes a model name and base model.
This commit is contained in:
parent
b937b7da01
commit
42a59aa147
@ -43,6 +43,117 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
|||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineModelField(BaseModel):
|
||||||
|
"""Pipeline model field"""
|
||||||
|
|
||||||
|
model_name: str = Field(description="Name of the model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
|
class SDModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loading submodels of selected model."""
|
||||||
|
|
||||||
|
type: Literal["sd_model_loader"] = "sd_model_loader"
|
||||||
|
|
||||||
|
model: PipelineModelField = Field(description="The model to load")
|
||||||
|
# TODO: precision?
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["model", "loader"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||||
|
|
||||||
|
base_model = self.model.base_model
|
||||||
|
model_name = self.model.model_name
|
||||||
|
model_type = ModelType.Pipeline
|
||||||
|
|
||||||
|
# TODO: not found exceptions
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.Diffusers,
|
||||||
|
submodel=SDModelType.Tokenizer,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.Diffusers,
|
||||||
|
submodel=SDModelType.TextEncoder,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=self.model_name,
|
||||||
|
model_type=SDModelType.Diffusers,
|
||||||
|
submodel=SDModelType.UNet,
|
||||||
|
):
|
||||||
|
raise Exception(
|
||||||
|
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
return ModelLoaderOutput(
|
||||||
|
unet=UNetField(
|
||||||
|
unet=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Tokenizer,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.TextEncoder,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
vae=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Vae,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
class SD1ModelLoaderInvocation(BaseInvocation):
|
class SD1ModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loading submodels of selected model."""
|
"""Loading submodels of selected model."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user