mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Switch flux to using its own conditioning field
This commit is contained in:
parent
9ae190fc3e
commit
29595ce059
@ -236,6 +236,12 @@ class ColorField(BaseModel):
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class FluxConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
|
@ -6,7 +6,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokeniz
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.invocations.primitives import FluxConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
@ -38,17 +38,15 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
)
|
||||
prompt: str = InputField(description="Text prompt to encode.")
|
||||
|
||||
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
|
||||
# compatible with other ConditioningOutputs.
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
|
||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||
conditioning_data = ConditioningFieldData(
|
||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
return FluxConditioningOutput.build(conditioning_name)
|
||||
|
||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Load CLIP.
|
||||
|
@ -4,8 +4,8 @@ from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
@ -41,7 +41,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
positive_text_conditioning: ConditioningField = InputField(
|
||||
positive_text_conditioning: FluxConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||
|
@ -12,6 +12,7 @@ from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
FluxConditioningField,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
@ -414,6 +415,17 @@ class MaskOutput(BaseInvocationOutput):
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
||||
@invocation_output("flux_conditioning_output")
|
||||
class FluxConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
conditioning: FluxConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "FluxConditioningOutput":
|
||||
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
@ -52,6 +52,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
CLIPField: 'green.500',
|
||||
ColorField: 'pink.300',
|
||||
ConditioningField: 'cyan.500',
|
||||
FluxConditioningField: 'cyan.500',
|
||||
ControlField: 'teal.500',
|
||||
ControlNetModelField: 'teal.500',
|
||||
EnumField: 'blue.500',
|
||||
|
@ -5720,6 +5720,32 @@ export type components = {
|
||||
*/
|
||||
type: "float_to_int";
|
||||
};
|
||||
/**
|
||||
* FluxConditioningField
|
||||
* @description A conditioning tensor primitive value
|
||||
*/
|
||||
FluxConditioningField: {
|
||||
/**
|
||||
* Conditioning Name
|
||||
* @description The name of conditioning tensor
|
||||
*/
|
||||
conditioning_name: string;
|
||||
};
|
||||
/**
|
||||
* FluxConditioningOutput
|
||||
* @description Base class for nodes that output a single conditioning tensor
|
||||
*/
|
||||
FluxConditioningOutput: {
|
||||
/** @description Conditioning tensor */
|
||||
conditioning: components["schemas"]["FluxConditioningField"];
|
||||
/**
|
||||
* type
|
||||
* @default flux_conditioning_output
|
||||
* @constant
|
||||
* @enum {string}
|
||||
*/
|
||||
type: "flux_conditioning_output";
|
||||
};
|
||||
/**
|
||||
* Flux Main Model
|
||||
* @description Loads a flux base model, outputting its submodels.
|
||||
@ -5781,7 +5807,7 @@ export type components = {
|
||||
vae: components["schemas"]["VAEField"];
|
||||
/**
|
||||
* Max Seq Length
|
||||
* @description VAE
|
||||
* @description The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)
|
||||
* @enum {integer}
|
||||
*/
|
||||
max_seq_len: 256 | 512;
|
||||
@ -5835,11 +5861,11 @@ export type components = {
|
||||
*/
|
||||
t5_max_seq_len?: 256 | 512;
|
||||
/**
|
||||
* Positive Prompt
|
||||
* @description Positive prompt for text-to-image generation.
|
||||
* Prompt
|
||||
* @description Text prompt to encode.
|
||||
* @default null
|
||||
*/
|
||||
positive_prompt?: string;
|
||||
prompt?: string;
|
||||
/**
|
||||
* type
|
||||
* @default flux_text_encoder
|
||||
@ -5895,7 +5921,7 @@ export type components = {
|
||||
* @description Positive conditioning tensor
|
||||
* @default null
|
||||
*/
|
||||
positive_text_conditioning?: components["schemas"]["ConditioningField"];
|
||||
positive_text_conditioning?: components["schemas"]["FluxConditioningField"];
|
||||
/**
|
||||
* Width
|
||||
* @description Width of the generated image.
|
||||
@ -6105,7 +6131,7 @@ export type components = {
|
||||
* @description The results of node executions
|
||||
*/
|
||||
results?: {
|
||||
[key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
|
||||
[key: string]: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
|
||||
};
|
||||
/**
|
||||
* Errors
|
||||
@ -8500,7 +8526,7 @@ export type components = {
|
||||
* Result
|
||||
* @description The result of the invocation
|
||||
*/
|
||||
result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
|
||||
result: components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["BoundingBoxCollectionOutput"] | components["schemas"]["BoundingBoxOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FluxConditioningOutput"] | components["schemas"]["FluxModelLoaderOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["String2Output"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["VAEOutput"];
|
||||
};
|
||||
/**
|
||||
* InvocationDenoiseProgressEvent
|
||||
@ -8675,7 +8701,7 @@ export type components = {
|
||||
float_range: components["schemas"]["FloatCollectionOutput"];
|
||||
float_to_int: components["schemas"]["IntegerOutput"];
|
||||
flux_model_loader: components["schemas"]["FluxModelLoaderOutput"];
|
||||
flux_text_encoder: components["schemas"]["ConditioningOutput"];
|
||||
flux_text_encoder: components["schemas"]["FluxConditioningOutput"];
|
||||
flux_text_to_image: components["schemas"]["ImageOutput"];
|
||||
freeu: components["schemas"]["UNetOutput"];
|
||||
grounding_dino: components["schemas"]["BoundingBoxCollectionOutput"];
|
||||
|
Loading…
Reference in New Issue
Block a user