From 3f5340fa5363b91df3d3d763460a8b12d747c736 Mon Sep 17 00:00:00 2001 From: Mary Hipp Date: Wed, 28 Aug 2024 14:54:36 -0400 Subject: [PATCH] feat(nodes): add submodels as inputs to FLUX main model node instead of hardcoded names --- invokeai/app/invocations/fields.py | 3 ++ invokeai/app/invocations/model.py | 82 +++++++++--------------------- 2 files changed, 27 insertions(+), 58 deletions(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 3a4e2cbddb..03654dd78d 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -45,11 +45,13 @@ class UIType(str, Enum, metaclass=MetaEnum): SDXLRefinerModel = "SDXLRefinerModelField" ONNXModel = "ONNXModelField" VAEModel = "VAEModelField" + FluxVAEModel = "FluxVAEModelField" LoRAModel = "LoRAModelField" ControlNetModel = "ControlNetModelField" IPAdapterModel = "IPAdapterModelField" T2IAdapterModel = "T2IAdapterModelField" T5EncoderModel = "T5EncoderModelField" + CLIPEmbedModel = "CLIPEmbedModelField" SpandrelImageToImageModel = "SpandrelImageToImageModelField" # endregion @@ -128,6 +130,7 @@ class FieldDescriptions: noise = "Noise tensor" clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" t5_encoder = "T5 tokenizer and text encoder" + clip_embed_model = "CLIP Embed loader" unet = "UNet (scheduler, LoRAs)" transformer = "Transformer" vae = "VAE" diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 88874f302a..0427b1677b 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -169,23 +169,35 @@ class FluxModelLoaderInvocation(BaseInvocation): input=Input.Direct, ) - t5_encoder: ModelIdentifierField = InputField( - description=FieldDescriptions.t5_encoder, - ui_type=UIType.T5EncoderModel, + t5_encoder_model: ModelIdentifierField = InputField( + description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder" + ) + + clip_embed_model: ModelIdentifierField = InputField( + description=FieldDescriptions.clip_embed_model, + ui_type=UIType.CLIPEmbedModel, input=Input.Direct, + title="CLIP Embed", + ) + + vae_model: ModelIdentifierField = InputField( + description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE" ) def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: - model_key = self.model.key + for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]: + if not context.models.exists(key): + raise ValueError(f"Unknown model: {key}") + + transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) + vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE}) + + tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + + tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) + t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) - if not context.models.exists(model_key): - raise ValueError(f"Unknown model: {model_key}") - transformer = self._get_model(context, SubModelType.Transformer) - tokenizer = self._get_model(context, SubModelType.Tokenizer) - tokenizer2 = self._get_model(context, SubModelType.Tokenizer2) - clip_encoder = self._get_model(context, SubModelType.TextEncoder) - t5_encoder = self._get_model(context, SubModelType.TextEncoder2) - vae = self._get_model(context, SubModelType.VAE) transformer_config = context.models.get_config(transformer) assert isinstance(transformer_config, CheckpointConfigBase) @@ -197,52 +209,6 @@ class FluxModelLoaderInvocation(BaseInvocation): max_seq_len=max_seq_lengths[transformer_config.config_path], ) - def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField: - match submodel: - case SubModelType.Transformer: - return self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) - case SubModelType.VAE: - return self._pull_model_from_mm( - context, - SubModelType.VAE, - "FLUX.1-schnell_ae", - ModelType.VAE, - BaseModelType.Flux, - ) - case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]: - return self._pull_model_from_mm( - context, - submodel, - "clip-vit-large-patch14", - ModelType.CLIPEmbed, - BaseModelType.Any, - ) - case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]: - return self._pull_model_from_mm( - context, - submodel, - self.t5_encoder.name, - ModelType.T5Encoder, - BaseModelType.Any, - ) - case _: - raise Exception(f"{submodel.value} is not a supported submodule for a flux model") - - def _pull_model_from_mm( - self, - context: InvocationContext, - submodel: SubModelType, - name: str, - type: ModelType, - base: BaseModelType, - ): - if models := context.models.search_by_attrs(name=name, base=base, type=type): - if len(models) != 1: - raise Exception(f"Multiple models detected for selected model with name {name}") - return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel}) - else: - raise ValueError(f"Please install the {base}:{type} model named {name} via starter models") - @invocation( "main_model_loader",