wip: add Transformer Field to Node UI

This commit is contained in:
blessedcoolant 2024-06-14 22:25:26 +05:30
parent 0c970bc880
commit ddbd2ebd9d
7 changed files with 160 additions and 146 deletions

View File

@ -8,13 +8,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
class ModelIdentifierField(BaseModel):
@ -54,6 +48,11 @@ class UNetField(BaseModel):
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
class CLIPField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")

View File

@ -1,17 +1,10 @@
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, VAEField
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
@invocation_output("sd3_model_loader_output")
class SD3ModelLoaderOutput(BaseInvocationOutput):
"""Stable Diffuion 3 base model loader output"""

View File

@ -36,6 +36,7 @@ export const MODEL_TYPES = [
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
'TransformerField',
'VAEField',
'CLIPField',
'T2IAdapterModelField',
@ -68,6 +69,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
UNetField: 'red.500',
TransformerField: 'red.500',
VAEField: 'blue.500',
VAEModelField: 'teal.500',
};

View File

@ -298,6 +298,11 @@ const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: {
isCollection: false,
isCollectionOrScalar: false,
},
TransformerField: {
name: 'TransformerField',
isCollection: false,
isCollectionOrScalar: false,
},
VaeField: {
name: 'VaeField',
isCollection: false,

View File

@ -99,6 +99,7 @@ const zFieldTypeV1 = z.enum([
'T2IAdapterModelField',
'T2IAdapterPolymorphic',
'UNetField',
'TransformerField',
'VaeField',
'VaeModelField',
]);
@ -367,6 +368,17 @@ const zUNetInputFieldValue = zInputFieldValueBase.extend({
value: zUNetField.optional(),
});
const zTransformerField = z.object({
unet: zModelInfo,
scheduler: zModelInfo,
loras: z.array(zLoraInfo),
});
const zTransformerInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('TransformerField'),
value: zTransformerField.optional(),
});
const zClipField = z.object({
tokenizer: zModelInfo,
text_encoder: zModelInfo,
@ -588,6 +600,7 @@ const zInputFieldValue = z.discriminatedUnion('type', [
zT2IAdapterCollectionInputFieldValue,
zT2IAdapterPolymorphicInputFieldValue,
zUNetInputFieldValue,
zTransformerInputFieldValue,
zVaeInputFieldValue,
zVaeModelInputFieldValue,
zMetadataItemInputFieldValue,

View File

@ -7276,145 +7276,145 @@ export type components = {
project_id: string | null;
};
InvocationOutputMap: {
float_to_int: components["schemas"]["IntegerOutput"];
range_of_size: components["schemas"]["IntegerCollectionOutput"];
img_hue_adjust: components["schemas"]["ImageOutput"];
hed_image_processor: components["schemas"]["ImageOutput"];
img_blur: components["schemas"]["ImageOutput"];
infill_tile: components["schemas"]["ImageOutput"];
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
div: components["schemas"]["IntegerOutput"];
color_correct: components["schemas"]["ImageOutput"];
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
image_collection: components["schemas"]["ImageCollectionOutput"];
img_mul: components["schemas"]["ImageOutput"];
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
img_lerp: components["schemas"]["ImageOutput"];
l2i: components["schemas"]["ImageOutput"];
string_collection: components["schemas"]["StringCollectionOutput"];
face_off: components["schemas"]["FaceOffOutput"];
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
normalbae_image_processor: components["schemas"]["ImageOutput"];
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
float_math: components["schemas"]["FloatOutput"];
ideal_size: components["schemas"]["IdealSizeOutput"];
sub: components["schemas"]["IntegerOutput"];
string: components["schemas"]["StringOutput"];
core_metadata: components["schemas"]["MetadataOutput"];
latents: components["schemas"]["LatentsOutput"];
crop_latents: components["schemas"]["LatentsOutput"];
denoise_latents: components["schemas"]["LatentsOutput"];
range: components["schemas"]["IntegerCollectionOutput"];
unsharp_mask: components["schemas"]["ImageOutput"];
pidi_image_processor: components["schemas"]["ImageOutput"];
float_collection: components["schemas"]["FloatCollectionOutput"];
i2l: components["schemas"]["LatentsOutput"];
face_identifier: components["schemas"]["ImageOutput"];
step_param_easing: components["schemas"]["FloatCollectionOutput"];
img_pad_crop: components["schemas"]["ImageOutput"];
lineart_image_processor: components["schemas"]["ImageOutput"];
infill_rgba: components["schemas"]["ImageOutput"];
lblend: components["schemas"]["LatentsOutput"];
mlsd_image_processor: components["schemas"]["ImageOutput"];
lresize: components["schemas"]["LatentsOutput"];
model_identifier: components["schemas"]["ModelIdentifierOutput"];
tile_image_processor: components["schemas"]["ImageOutput"];
mask_combine: components["schemas"]["ImageOutput"];
string_replace: components["schemas"]["StringOutput"];
conditioning: components["schemas"]["ConditioningOutput"];
scheduler: components["schemas"]["SchedulerOutput"];
add: components["schemas"]["IntegerOutput"];
metadata: components["schemas"]["MetadataOutput"];
random_range: components["schemas"]["IntegerCollectionOutput"];
img_ilerp: components["schemas"]["ImageOutput"];
canvas_paste_back: components["schemas"]["ImageOutput"];
mask_from_id: components["schemas"]["ImageOutput"];
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
img_resize: components["schemas"]["ImageOutput"];
mul: components["schemas"]["IntegerOutput"];
integer_collection: components["schemas"]["IntegerCollectionOutput"];
infill_patchmatch: components["schemas"]["ImageOutput"];
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
lora_loader: components["schemas"]["LoRALoaderOutput"];
iterate: components["schemas"]["IterateInvocationOutput"];
depth_anything_image_processor: components["schemas"]["ImageOutput"];
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
string_join: components["schemas"]["StringOutput"];
esrgan: components["schemas"]["ImageOutput"];
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
round_float: components["schemas"]["FloatOutput"];
noise: components["schemas"]["NoiseOutput"];
img_channel_offset: components["schemas"]["ImageOutput"];
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
cv_inpaint: components["schemas"]["ImageOutput"];
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
lora_selector: components["schemas"]["LoRASelectorOutput"];
invert_tensor_mask: components["schemas"]["MaskOutput"];
img_chan: components["schemas"]["ImageOutput"];
sub: components["schemas"]["IntegerOutput"];
mediapipe_face_processor: components["schemas"]["ImageOutput"];
compel: components["schemas"]["ConditioningOutput"];
sd3_model_loader: components["schemas"]["SD3ModelLoaderOutput"];
rand_float: components["schemas"]["FloatOutput"];
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
infill_rgba: components["schemas"]["ImageOutput"];
color_map_image_processor: components["schemas"]["ImageOutput"];
img_hue_adjust: components["schemas"]["ImageOutput"];
lineart_image_processor: components["schemas"]["ImageOutput"];
metadata_item: components["schemas"]["MetadataItemOutput"];
float: components["schemas"]["FloatOutput"];
merge_metadata: components["schemas"]["MetadataOutput"];
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
crop_latents: components["schemas"]["LatentsOutput"];
segment_anything_processor: components["schemas"]["ImageOutput"];
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
string_join: components["schemas"]["StringOutput"];
heuristic_resize: components["schemas"]["ImageOutput"];
lblend: components["schemas"]["LatentsOutput"];
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
string_split_neg: components["schemas"]["StringPosNegOutput"];
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
infill_lama: components["schemas"]["ImageOutput"];
float_collection: components["schemas"]["FloatCollectionOutput"];
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
lscale: components["schemas"]["LatentsOutput"];
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
float_to_int: components["schemas"]["IntegerOutput"];
float_math: components["schemas"]["FloatOutput"];
collect: components["schemas"]["CollectInvocationOutput"];
boolean: components["schemas"]["BooleanOutput"];
latents: components["schemas"]["LatentsOutput"];
blank_image: components["schemas"]["ImageOutput"];
vae_loader: components["schemas"]["VAEOutput"];
denoise_latents: components["schemas"]["LatentsOutput"];
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
range_of_size: components["schemas"]["IntegerCollectionOutput"];
face_mask_detection: components["schemas"]["FaceMaskOutput"];
tomask: components["schemas"]["ImageOutput"];
rectangle_mask: components["schemas"]["MaskOutput"];
controlnet: components["schemas"]["ControlOutput"];
seamless: components["schemas"]["SeamlessModeOutput"];
pair_tile_image: components["schemas"]["PairTileImageOutput"];
unsharp_mask: components["schemas"]["ImageOutput"];
hed_image_processor: components["schemas"]["ImageOutput"];
metadata: components["schemas"]["MetadataOutput"];
freeu: components["schemas"]["UNetOutput"];
image_collection: components["schemas"]["ImageCollectionOutput"];
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
face_off: components["schemas"]["FaceOffOutput"];
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
show_image: components["schemas"]["ImageOutput"];
img_nsfw: components["schemas"]["ImageOutput"];
round_float: components["schemas"]["FloatOutput"];
string: components["schemas"]["StringOutput"];
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
img_crop: components["schemas"]["ImageOutput"];
mask_edge: components["schemas"]["ImageOutput"];
normalbae_image_processor: components["schemas"]["ImageOutput"];
save_image: components["schemas"]["ImageOutput"];
add: components["schemas"]["IntegerOutput"];
main_model_loader: components["schemas"]["ModelLoaderOutput"];
color: components["schemas"]["ColorOutput"];
string_replace: components["schemas"]["StringOutput"];
img_lerp: components["schemas"]["ImageOutput"];
midas_depth_image_processor: components["schemas"]["ImageOutput"];
infill_patchmatch: components["schemas"]["ImageOutput"];
noise: components["schemas"]["NoiseOutput"];
img_watermark: components["schemas"]["ImageOutput"];
depth_anything_image_processor: components["schemas"]["ImageOutput"];
i2l: components["schemas"]["LatentsOutput"];
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
canvas_paste_back: components["schemas"]["ImageOutput"];
mul: components["schemas"]["IntegerOutput"];
pidi_image_processor: components["schemas"]["ImageOutput"];
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
img_conv: components["schemas"]["ImageOutput"];
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
mask_from_id: components["schemas"]["ImageOutput"];
lora_loader: components["schemas"]["LoRALoaderOutput"];
step_param_easing: components["schemas"]["FloatCollectionOutput"];
face_identifier: components["schemas"]["ImageOutput"];
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
esrgan: components["schemas"]["ImageOutput"];
color_correct: components["schemas"]["ImageOutput"];
lora_selector: components["schemas"]["LoRASelectorOutput"];
cv_inpaint: components["schemas"]["ImageOutput"];
img_pad_crop: components["schemas"]["ImageOutput"];
merge_tiles_to_image: components["schemas"]["ImageOutput"];
img_channel_offset: components["schemas"]["ImageOutput"];
string_collection: components["schemas"]["StringCollectionOutput"];
scheduler: components["schemas"]["SchedulerOutput"];
conditioning: components["schemas"]["ConditioningOutput"];
string_split: components["schemas"]["String2Output"];
string_join_three: components["schemas"]["StringOutput"];
img_ilerp: components["schemas"]["ImageOutput"];
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
core_metadata: components["schemas"]["MetadataOutput"];
float_range: components["schemas"]["FloatCollectionOutput"];
random_range: components["schemas"]["IntegerCollectionOutput"];
rand_int: components["schemas"]["IntegerOutput"];
canny_image_processor: components["schemas"]["ImageOutput"];
merge_metadata: components["schemas"]["MetadataOutput"];
latents_collection: components["schemas"]["LatentsCollectionOutput"];
range: components["schemas"]["IntegerCollectionOutput"];
iterate: components["schemas"]["IterateInvocationOutput"];
img_scale: components["schemas"]["ImageOutput"];
img_blur: components["schemas"]["ImageOutput"];
img_channel_multiply: components["schemas"]["ImageOutput"];
integer_math: components["schemas"]["IntegerOutput"];
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
img_mul: components["schemas"]["ImageOutput"];
mlsd_image_processor: components["schemas"]["ImageOutput"];
ip_adapter: components["schemas"]["IPAdapterOutput"];
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
infill_cv2: components["schemas"]["ImageOutput"];
prompt_from_file: components["schemas"]["StringCollectionOutput"];
image: components["schemas"]["ImageOutput"];
img_resize: components["schemas"]["ImageOutput"];
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
lresize: components["schemas"]["LatentsOutput"];
l2i: components["schemas"]["ImageOutput"];
integer_collection: components["schemas"]["IntegerCollectionOutput"];
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
div: components["schemas"]["IntegerOutput"];
leres_image_processor: components["schemas"]["ImageOutput"];
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
ideal_size: components["schemas"]["IdealSizeOutput"];
integer: components["schemas"]["IntegerOutput"];
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
img_paste: components["schemas"]["ImageOutput"];
save_image: components["schemas"]["ImageOutput"];
color_map_image_processor: components["schemas"]["ImageOutput"];
rand_float: components["schemas"]["FloatOutput"];
midas_depth_image_processor: components["schemas"]["ImageOutput"];
blank_image: components["schemas"]["ImageOutput"];
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
rectangle_mask: components["schemas"]["MaskOutput"];
collect: components["schemas"]["CollectInvocationOutput"];
tomask: components["schemas"]["ImageOutput"];
model_identifier: components["schemas"]["ModelIdentifierOutput"];
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
rand_int: components["schemas"]["IntegerOutput"];
sd3_model_loader: components["schemas"]["SD3ModelLoaderOutput"];
infill_lama: components["schemas"]["ImageOutput"];
heuristic_resize: components["schemas"]["ImageOutput"];
latents_collection: components["schemas"]["LatentsCollectionOutput"];
face_mask_detection: components["schemas"]["FaceMaskOutput"];
vae_loader: components["schemas"]["VAEOutput"];
invert_tensor_mask: components["schemas"]["MaskOutput"];
integer: components["schemas"]["IntegerOutput"];
img_channel_multiply: components["schemas"]["ImageOutput"];
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
tile_image_processor: components["schemas"]["ImageOutput"];
freeu: components["schemas"]["UNetOutput"];
boolean: components["schemas"]["BooleanOutput"];
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
mediapipe_face_processor: components["schemas"]["ImageOutput"];
prompt_from_file: components["schemas"]["StringCollectionOutput"];
img_nsfw: components["schemas"]["ImageOutput"];
string_split_neg: components["schemas"]["StringPosNegOutput"];
img_chan: components["schemas"]["ImageOutput"];
seamless: components["schemas"]["SeamlessModeOutput"];
img_scale: components["schemas"]["ImageOutput"];
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
mask_edge: components["schemas"]["ImageOutput"];
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
controlnet: components["schemas"]["ControlOutput"];
leres_image_processor: components["schemas"]["ImageOutput"];
main_model_loader: components["schemas"]["ModelLoaderOutput"];
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
string_split: components["schemas"]["String2Output"];
img_watermark: components["schemas"]["ImageOutput"];
merge_tiles_to_image: components["schemas"]["ImageOutput"];
img_conv: components["schemas"]["ImageOutput"];
segment_anything_processor: components["schemas"]["ImageOutput"];
image_mask_to_tensor: components["schemas"]["MaskOutput"];
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
show_image: components["schemas"]["ImageOutput"];
string_join_three: components["schemas"]["StringOutput"];
pair_tile_image: components["schemas"]["PairTileImageOutput"];
infill_cv2: components["schemas"]["ImageOutput"];
integer_math: components["schemas"]["IntegerOutput"];
color: components["schemas"]["ColorOutput"];
canny_image_processor: components["schemas"]["ImageOutput"];
img_crop: components["schemas"]["ImageOutput"];
lscale: components["schemas"]["LatentsOutput"];
metadata_item: components["schemas"]["MetadataItemOutput"];
image: components["schemas"]["ImageOutput"];
compel: components["schemas"]["ConditioningOutput"];
};
/**
* InvocationStartedEvent

View File

@ -39,6 +39,7 @@ from invokeai.app.invocations.model import (
ModelIdentifierField,
ModelLoaderOutput,
SDXLLoRALoaderOutput,
TransformerField,
UNetField,
UNetOutput,
VAEField,
@ -117,6 +118,7 @@ __all__ = [
# invokeai.app.invocations.model
"ModelIdentifierField",
"UNetField",
"TransformerField",
"CLIPField",
"VAEField",
"UNetOutput",