mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merged SDXLModelLoader into ModelLoader invocation
This commit is contained in:
parent
358ced6bab
commit
75c5ce46bc
@ -46,22 +46,10 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
|
||||||
"""SDXL model loader output"""
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
|
||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
|
||||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (2d set)")
|
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class MainModelField(BaseModel):
|
class MainModelField(BaseModel):
|
||||||
"""Main model field"""
|
"""Main model field"""
|
||||||
|
|
||||||
@ -136,79 +124,6 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
return ModelLoaderOutput(
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
skipped_layers=0,
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
submodel=SubModelType.Vae,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
class SDXLMainModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loads an SDXL main model, outputting its submodels."""
|
|
||||||
|
|
||||||
type: Literal["sdxl_main_model_loader"] = "sdxl_main_model_loader"
|
|
||||||
|
|
||||||
model: MainModelField = Field(description="The SDXL model to load")
|
|
||||||
# TODO: precision?
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"title": "SDXL Model Loader",
|
|
||||||
"tags": ["model", "loader", "sdxl"],
|
|
||||||
"type_hints": {"model": "model"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
|
||||||
base_model = self.model.base_model
|
|
||||||
model_name = self.model.model_name
|
|
||||||
model_type = ModelType.Main
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=model_type,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
|
||||||
|
|
||||||
return SDXLModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -267,7 +182,6 @@ class SDXLMainModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LoraLoaderOutput(BaseInvocationOutput):
|
class LoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
|
@ -150,7 +150,7 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
nodeId={nodeId}
|
nodeId={nodeId}
|
||||||
field={field}
|
field={field}
|
||||||
template={template}
|
template={template}
|
||||||
base_models={['sd-1', 'sd-2']}
|
base_models={['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user