From 29595ce059aba37ce7d132153451b60aca0d0c6c Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 23 Aug 2024 13:50:01 -0400 Subject: [PATCH] Switch flux to using its own conditioning field --- invokeai/app/invocations/fields.py | 6 +++ invokeai/app/invocations/flux_text_encoder.py | 8 ++-- .../app/invocations/flux_text_to_image.py | 4 +- invokeai/app/invocations/primitives.py | 12 ++++++ .../web/src/features/nodes/types/constants.ts | 1 + .../frontend/web/src/services/api/schema.ts | 42 +++++++++++++++---- 6 files changed, 58 insertions(+), 15 deletions(-) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 6b7d7bef63..3a4e2cbddb 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -236,6 +236,12 @@ class ColorField(BaseModel): return (self.r, self.g, self.b, self.a) +class FluxConditioningField(BaseModel): + """A conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") + + class ConditioningField(BaseModel): """A conditioning tensor primitive value""" diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 80e13c2270..0e7ebd6d69 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -6,7 +6,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokeniz from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField from invokeai.app.invocations.model import CLIPField, T5EncoderField -from invokeai.app.invocations.primitives import ConditioningOutput +from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.conditioner import HFEncoder from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo @@ -38,17 +38,15 @@ class FluxTextEncoderInvocation(BaseInvocation): ) prompt: str = InputField(description="Text prompt to encode.") - # TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not - # compatible with other ConditioningOutputs. @torch.no_grad() - def invoke(self, context: InvocationContext) -> ConditioningOutput: + def invoke(self, context: InvocationContext) -> FluxConditioningOutput: t5_embeddings, clip_embeddings = self._encode_prompt(context) conditioning_data = ConditioningFieldData( conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)] ) conditioning_name = context.conditioning.save(conditioning_data) - return ConditioningOutput.build(conditioning_name) + return FluxConditioningOutput.build(conditioning_name) def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]: # Load CLIP. diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 9504abee3e..b68bb91513 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -4,8 +4,8 @@ from PIL import Image from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.fields import ( - ConditioningField, FieldDescriptions, + FluxConditioningField, Input, InputField, WithBoard, @@ -41,7 +41,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): description=FieldDescriptions.vae, input=Input.Connection, ) - positive_text_conditioning: ConditioningField = InputField( + positive_text_conditioning: FluxConditioningField = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection ) width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 3655554f3b..bb136d62fd 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import ( ConditioningField, DenoiseMaskField, FieldDescriptions, + FluxConditioningField, ImageField, Input, InputField, @@ -414,6 +415,17 @@ class MaskOutput(BaseInvocationOutput): height: int = OutputField(description="The height of the mask in pixels.") +@invocation_output("flux_conditioning_output") +class FluxConditioningOutput(BaseInvocationOutput): + """Base class for nodes that output a single conditioning tensor""" + + conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond) + + @classmethod + def build(cls, conditioning_name: str) -> "FluxConditioningOutput": + return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name)) + + @invocation_output("conditioning_output") class ConditioningOutput(BaseInvocationOutput): """Base class for nodes that output a single conditioning tensor""" diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 19927220f2..100c094c46 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = { CLIPField: 'green.500', ColorField: 'pink.300', ConditioningField: 'cyan.500', + FluxConditioningField: 'cyan.500', ControlField: 'teal.500', ControlNetModelField: 'teal.500', EnumField: 'blue.500', diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 8c3849593a..2b506759bd 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -5720,6 +5720,32 @@ export type components = { */ type: "float_to_int"; }; + /** + * FluxConditioningField + * @description A conditioning tensor primitive value + */ + FluxConditioningField: { + /** + * Conditioning Name + * @description The name of conditioning tensor + */ + conditioning_name: string; + }; + /** + * FluxConditioningOutput + * @description Base class for nodes that output a single conditioning tensor + */ + FluxConditioningOutput: { + /** @description Conditioning tensor */ + conditioning: components["schemas"]["FluxConditioningField"]; + /** + * type + * @default flux_conditioning_output + * @constant + * @enum {string} + */ + type: "flux_conditioning_output"; + }; /** * Flux Main Model * @description Loads a flux base model, outputting its submodels. @@ -5781,7 +5807,7 @@ export type components = { vae: components["schemas"]["VAEField"]; /** * Max Seq Length - * @description VAE + * @description The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer) * @enum {integer} */ max_seq_len: 256 | 512; @@ -5835,11 +5861,11 @@ export type components = { */ t5_max_seq_len?: 256 | 512; /** - * Positive Prompt - * @description Positive prompt for text-to-image generation. + * Prompt + * @description Text prompt to encode. * @default null */ - positive_prompt?: string; + prompt?: string; /** * type * @default flux_text_encoder @@ -5895,7 +5921,7 @@ export type components = { * @description Positive conditioning tensor * @default null */ - positive_text_conditioning?: components["schemas"]["ConditioningField"]; + positive_text_conditioning?: components["schemas"]["FluxConditioningField"]; /** * Width * @description Width of the generated image. @@ -6105,7 +6131,7 @@ export type components = { * @description The results of node executions */ results?: { - [key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; + [key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; }; /** * Errors @@ -8500,7 +8526,7 @@ export type components = { * Result * @description The result of the invocation */ - result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; + result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"]; }; /** * InvocationDenoiseProgressEvent @@ -8675,7 +8701,7 @@ export type components = { float_range: components["schemas"]["FloatCollectionOutput"]; float_to_int: components["schemas"]["IntegerOutput"]; flux_model_loader: components["schemas"]["FluxModelLoaderOutput"]; - flux_text_encoder: components["schemas"]["ConditioningOutput"]; + flux_text_encoder: components["schemas"]["FluxConditioningOutput"]; flux_text_to_image: components["schemas"]["ImageOutput"]; freeu: components["schemas"]["UNetOutput"]; grounding_dino: components["schemas"]["BoundingBoxCollectionOutput"];