mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add submodels as inputs to FLUX main model node instead of hardcoded names
This commit is contained in:
parent
f2a1a39b33
commit
3f5340fa53
@ -45,11 +45,13 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
|||||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||||
ONNXModel = "ONNXModelField"
|
ONNXModel = "ONNXModelField"
|
||||||
VAEModel = "VAEModelField"
|
VAEModel = "VAEModelField"
|
||||||
|
FluxVAEModel = "FluxVAEModelField"
|
||||||
LoRAModel = "LoRAModelField"
|
LoRAModel = "LoRAModelField"
|
||||||
ControlNetModel = "ControlNetModelField"
|
ControlNetModel = "ControlNetModelField"
|
||||||
IPAdapterModel = "IPAdapterModelField"
|
IPAdapterModel = "IPAdapterModelField"
|
||||||
T2IAdapterModel = "T2IAdapterModelField"
|
T2IAdapterModel = "T2IAdapterModelField"
|
||||||
T5EncoderModel = "T5EncoderModelField"
|
T5EncoderModel = "T5EncoderModelField"
|
||||||
|
CLIPEmbedModel = "CLIPEmbedModelField"
|
||||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
@ -128,6 +130,7 @@ class FieldDescriptions:
|
|||||||
noise = "Noise tensor"
|
noise = "Noise tensor"
|
||||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||||
t5_encoder = "T5 tokenizer and text encoder"
|
t5_encoder = "T5 tokenizer and text encoder"
|
||||||
|
clip_embed_model = "CLIP Embed loader"
|
||||||
unet = "UNet (scheduler, LoRAs)"
|
unet = "UNet (scheduler, LoRAs)"
|
||||||
transformer = "Transformer"
|
transformer = "Transformer"
|
||||||
vae = "VAE"
|
vae = "VAE"
|
||||||
|
@ -169,23 +169,35 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
input=Input.Direct,
|
input=Input.Direct,
|
||||||
)
|
)
|
||||||
|
|
||||||
t5_encoder: ModelIdentifierField = InputField(
|
t5_encoder_model: ModelIdentifierField = InputField(
|
||||||
description=FieldDescriptions.t5_encoder,
|
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
|
||||||
ui_type=UIType.T5EncoderModel,
|
)
|
||||||
|
|
||||||
|
clip_embed_model: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.clip_embed_model,
|
||||||
|
ui_type=UIType.CLIPEmbedModel,
|
||||||
input=Input.Direct,
|
input=Input.Direct,
|
||||||
|
title="CLIP Embed",
|
||||||
|
)
|
||||||
|
|
||||||
|
vae_model: ModelIdentifierField = InputField(
|
||||||
|
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
|
||||||
model_key = self.model.key
|
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
|
||||||
|
if not context.models.exists(key):
|
||||||
|
raise ValueError(f"Unknown model: {key}")
|
||||||
|
|
||||||
|
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||||
|
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||||
|
|
||||||
|
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||||
|
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||||
|
|
||||||
|
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||||
|
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||||
|
|
||||||
if not context.models.exists(model_key):
|
|
||||||
raise ValueError(f"Unknown model: {model_key}")
|
|
||||||
transformer = self._get_model(context, SubModelType.Transformer)
|
|
||||||
tokenizer = self._get_model(context, SubModelType.Tokenizer)
|
|
||||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
|
||||||
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
|
||||||
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
|
||||||
vae = self._get_model(context, SubModelType.VAE)
|
|
||||||
transformer_config = context.models.get_config(transformer)
|
transformer_config = context.models.get_config(transformer)
|
||||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||||
|
|
||||||
@ -197,52 +209,6 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
|
||||||
match submodel:
|
|
||||||
case SubModelType.Transformer:
|
|
||||||
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
|
||||||
case SubModelType.VAE:
|
|
||||||
return self._pull_model_from_mm(
|
|
||||||
context,
|
|
||||||
SubModelType.VAE,
|
|
||||||
"FLUX.1-schnell_ae",
|
|
||||||
ModelType.VAE,
|
|
||||||
BaseModelType.Flux,
|
|
||||||
)
|
|
||||||
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
|
||||||
return self._pull_model_from_mm(
|
|
||||||
context,
|
|
||||||
submodel,
|
|
||||||
"clip-vit-large-patch14",
|
|
||||||
ModelType.CLIPEmbed,
|
|
||||||
BaseModelType.Any,
|
|
||||||
)
|
|
||||||
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
|
||||||
return self._pull_model_from_mm(
|
|
||||||
context,
|
|
||||||
submodel,
|
|
||||||
self.t5_encoder.name,
|
|
||||||
ModelType.T5Encoder,
|
|
||||||
BaseModelType.Any,
|
|
||||||
)
|
|
||||||
case _:
|
|
||||||
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
|
||||||
|
|
||||||
def _pull_model_from_mm(
|
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
|
||||||
submodel: SubModelType,
|
|
||||||
name: str,
|
|
||||||
type: ModelType,
|
|
||||||
base: BaseModelType,
|
|
||||||
):
|
|
||||||
if models := context.models.search_by_attrs(name=name, base=base, type=type):
|
|
||||||
if len(models) != 1:
|
|
||||||
raise Exception(f"Multiple models detected for selected model with name {name}")
|
|
||||||
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Please install the {base}:{type} model named {name} via starter models")
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
"main_model_loader",
|
"main_model_loader",
|
||||||
|
Loading…
Reference in New Issue
Block a user