mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Switch flux to using its own conditioning field
This commit is contained in:
parent
1047584b3e
commit
5063be92bf
@ -236,6 +236,12 @@ class ColorField(BaseModel):
|
|||||||
return (self.r, self.g, self.b, self.a)
|
return (self.r, self.g, self.b, self.a)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxConditioningField(BaseModel):
|
||||||
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
|
|
||||||
|
|
||||||
class ConditioningField(BaseModel):
|
class ConditioningField(BaseModel):
|
||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokeniz
|
|||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||||
@ -38,17 +38,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
prompt: str = InputField(description="Text prompt to encode.")
|
prompt: str = InputField(description="Text prompt to encode.")
|
||||||
|
|
||||||
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
|
|
||||||
# compatible with other ConditioningOutputs.
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return FluxConditioningOutput.build(conditioning_name)
|
||||||
|
|
||||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Load CLIP.
|
# Load CLIP.
|
||||||
|
@ -4,8 +4,8 @@ from PIL import Image
|
|||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
ConditioningField,
|
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
|
FluxConditioningField,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
WithBoard,
|
WithBoard,
|
||||||
@ -41,7 +41,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
description=FieldDescriptions.vae,
|
description=FieldDescriptions.vae,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
positive_text_conditioning: ConditioningField = InputField(
|
positive_text_conditioning: FluxConditioningField = InputField(
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||||
)
|
)
|
||||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||||
|
@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
ConditioningField,
|
ConditioningField,
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
|
FluxConditioningField,
|
||||||
ImageField,
|
ImageField,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
@ -414,6 +415,17 @@ class MaskOutput(BaseInvocationOutput):
|
|||||||
height: int = OutputField(description="The height of the mask in pixels.")
|
height: int = OutputField(description="The height of the mask in pixels.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("flux_conditioning_output")
|
||||||
|
class FluxConditioningOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single conditioning tensor"""
|
||||||
|
|
||||||
|
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
|
||||||
|
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_output")
|
@invocation_output("conditioning_output")
|
||||||
class ConditioningOutput(BaseInvocationOutput):
|
class ConditioningOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single conditioning tensor"""
|
"""Base class for nodes that output a single conditioning tensor"""
|
||||||
|
@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
|||||||
CLIPField: 'green.500',
|
CLIPField: 'green.500',
|
||||||
ColorField: 'pink.300',
|
ColorField: 'pink.300',
|
||||||
ConditioningField: 'cyan.500',
|
ConditioningField: 'cyan.500',
|
||||||
|
FluxConditioningField: 'cyan.500',
|
||||||
ControlField: 'teal.500',
|
ControlField: 'teal.500',
|
||||||
ControlNetModelField: 'teal.500',
|
ControlNetModelField: 'teal.500',
|
||||||
EnumField: 'blue.500',
|
EnumField: 'blue.500',
|
||||||
|
@ -5720,6 +5720,32 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
type: "float_to_int";
|
type: "float_to_int";
|
||||||
};
|
};
|
||||||
|
/**
|
||||||
|
* FluxConditioningField
|
||||||
|
* @description A conditioning tensor primitive value
|
||||||
|
*/
|
||||||
|
FluxConditioningField: {
|
||||||
|
/**
|
||||||
|
* Conditioning Name
|
||||||
|
* @description The name of conditioning tensor
|
||||||
|
*/
|
||||||
|
conditioning_name: string;
|
||||||
|
};
|
||||||
|
/**
|
||||||
|
* FluxConditioningOutput
|
||||||
|
* @description Base class for nodes that output a single conditioning tensor
|
||||||
|
*/
|
||||||
|
FluxConditioningOutput: {
|
||||||
|
/** @description Conditioning tensor */
|
||||||
|
conditioning: components["schemas"]["FluxConditioningField"];
|
||||||
|
/**
|
||||||
|
* type
|
||||||
|
* @default flux_conditioning_output
|
||||||
|
* @constant
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
type: "flux_conditioning_output";
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* Flux Main Model
|
* Flux Main Model
|
||||||
* @description Loads a flux base model, outputting its submodels.
|
* @description Loads a flux base model, outputting its submodels.
|
||||||
@ -5781,7 +5807,7 @@ export type components = {
|
|||||||
vae: components["schemas"]["VAEField"];
|
vae: components["schemas"]["VAEField"];
|
||||||
/**
|
/**
|
||||||
* Max Seq Length
|
* Max Seq Length
|
||||||
* @description VAE
|
* @description The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)
|
||||||
* @enum {integer}
|
* @enum {integer}
|
||||||
*/
|
*/
|
||||||
max_seq_len: 256 | 512;
|
max_seq_len: 256 | 512;
|
||||||
@ -5835,11 +5861,11 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
t5_max_seq_len?: 256 | 512;
|
t5_max_seq_len?: 256 | 512;
|
||||||
/**
|
/**
|
||||||
* Positive Prompt
|
* Prompt
|
||||||
* @description Positive prompt for text-to-image generation.
|
* @description Text prompt to encode.
|
||||||
* @default null
|
* @default null
|
||||||
*/
|
*/
|
||||||
positive_prompt?: string;
|
prompt?: string;
|
||||||
/**
|
/**
|
||||||
* type
|
* type
|
||||||
* @default flux_text_encoder
|
* @default flux_text_encoder
|
||||||
@ -5895,7 +5921,7 @@ export type components = {
|
|||||||
* @description Positive conditioning tensor
|
* @description Positive conditioning tensor
|
||||||
* @default null
|
* @default null
|
||||||
*/
|
*/
|
||||||
positive_text_conditioning?: components["schemas"]["ConditioningField"];
|
positive_text_conditioning?: components["schemas"]["FluxConditioningField"];
|
||||||
/**
|
/**
|
||||||
* Width
|
* Width
|
||||||
* @description Width of the generated image.
|
* @description Width of the generated image.
|
||||||
@ -6105,7 +6131,7 @@ export type components = {
|
|||||||
* @description The results of node executions
|
* @description The results of node executions
|
||||||
*/
|
*/
|
||||||
results?: {
|
results?: {
|
||||||
[key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
|
[key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* Errors
|
* Errors
|
||||||
@ -8500,7 +8526,7 @@ export type components = {
|
|||||||
* Result
|
* Result
|
||||||
* @description The result of the invocation
|
* @description The result of the invocation
|
||||||
*/
|
*/
|
||||||
result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
|
result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
|
||||||
};
|
};
|
||||||
/**
|
/**
|
||||||
* InvocationDenoiseProgressEvent
|
* InvocationDenoiseProgressEvent
|
||||||
@ -8675,7 +8701,7 @@ export type components = {
|
|||||||
float_range: components["schemas"]["FloatCollectionOutput"];
|
float_range: components["schemas"]["FloatCollectionOutput"];
|
||||||
float_to_int: components["schemas"]["IntegerOutput"];
|
float_to_int: components["schemas"]["IntegerOutput"];
|
||||||
flux_model_loader: components["schemas"]["FluxModelLoaderOutput"];
|
flux_model_loader: components["schemas"]["FluxModelLoaderOutput"];
|
||||||
flux_text_encoder: components["schemas"]["ConditioningOutput"];
|
flux_text_encoder: components["schemas"]["FluxConditioningOutput"];
|
||||||
flux_text_to_image: components["schemas"]["ImageOutput"];
|
flux_text_to_image: components["schemas"]["ImageOutput"];
|
||||||
freeu: components["schemas"]["UNetOutput"];
|
freeu: components["schemas"]["UNetOutput"];
|
||||||
grounding_dino: components["schemas"]["BoundingBoxCollectionOutput"];
|
grounding_dino: components["schemas"]["BoundingBoxCollectionOutput"];
|
||||||
|
Loading…
Reference in New Issue
Block a user