Merge branch 'main' into lstein/global-configuration

This commit is contained in:
Lincoln Stein 2023-05-06 21:20:25 -04:00 committed by GitHub
commit afd2e32092
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 390 additions and 24 deletions

View File

@ -89,7 +89,7 @@ experimental versions later.
sudo apt update sudo apt update
sudo apt install -y software-properties-common sudo apt install -y software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt install python3.10 python3-pip python3.10-venv sudo apt install -y python3.10 python3-pip python3.10-venv
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3 sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
``` ```

View File

@ -0,0 +1,245 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
from compel import Compel
from compel.prompt_parser import (
Blend,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment,
)
from invokeai.backend.globals import Globals
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
class CompelOutput(BaseInvocationOutput):
"""Compel parser output"""
#fmt: off
type: Literal["compel_output"] = "compel_output"
conditioning: ConditioningField = Field(default=None, description="Conditioning")
#fmt: on
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel"
prompt: str = Field(default="", description="Prompt")
model: str = Field(default="", description="Model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: load without model
model = choose_model(context.services.model_manager, self.model)
pipeline = model["model"]
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
# TODO: global? input?
#use_full_precision = precision == "float32" or precision == "autocast"
#use_full_precision = False
# TODO: redo TI when separate model loding implemented
#textual_inversion_manager = TextualInversionManager(
# tokenizer=tokenizer,
# text_encoder=text_encoder,
# full_precision=use_full_precision,
#)
def load_huggingface_concepts(concepts: list[str]):
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
# apply the concepts library to the prompt
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.prompt,
lambda concepts: load_huggingface_concepts(concepts),
pipeline.textual_inversion_manager.get_all_trigger_strings(),
)
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_str
)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
# TODO: support legacy blend?
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
if getattr(Globals, "log_tokenization", False):
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max(
[
get_max_token_count(tokenizer, c, truncate_if_too_long)
for c in blend.prompts
]
)
else:
return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
)
def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]:
if type(parsed_prompt) is Blend:
raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children"
)
text_fragments = [
x.text
if type(x) is Fragment
else (
" ".join([f.text for f in x.original])
if type(x) is CrossAttentionControlSubstitute
else str(x)
)
for x in parsed_prompt.children
]
text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
for i, c in enumerate(blend.prompts):
log_tokenization_for_prompt_object(
c,
tokenizer,
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
)
elif type(p) is FlattenedPrompt:
flattened_prompt: FlattenedPrompt = p
if flattened_prompt.wants_cross_attention_control:
original_fragments = []
edited_fragments = []
for f in flattened_prompt.children:
if type(f) is CrossAttentionControlSubstitute:
original_fragments += f.original
edited_fragments += f.edited
else:
original_fragments.append(f)
edited_fragments.append(f)
original_text = " ".join([x.text for x in original_fragments])
log_tokenization_for_text(
original_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap originals)",
)
edited_text = " ".join([x.text for x in edited_fragments])
log_tokenization_for_text(
edited_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap replacements)",
)
else:
text = " ".join([x.text for x in flattened_prompt.children])
log_tokenization_for_text(
text, tokenizer, display_label=display_label_prefix
)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace("</w>", " ")
# alternate color
s = (usedTokens % 6) + 1
if truncate_if_too_long and i >= tokenizer.model_max_length:
discarded = discarded + f"\x1b[0;3{s};40m{token}"
else:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f"{tokenized}\x1b[0m")
if discarded != "":
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
print(f"{discarded}\x1b[0m")

View File

@ -250,8 +250,8 @@ class InpaintInvocation(ImageToImageInvocation):
outputs = Inpaint(model).generate( outputs = Inpaint(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_img=image, init_image=image,
init_mask=mask, mask_image=mask,
step_callback=partial(self.dispatch_progress, context, source_node_id), step_callback=partial(self.dispatch_progress, context, source_node_id),
**self.dict( **self.dict(
exclude={"prompt", "image", "mask"} exclude={"prompt", "image", "mask"}

View File

@ -13,13 +13,13 @@ from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers import diffusers
@ -138,14 +138,14 @@ class NoiseInvocation(BaseInvocation):
# Text to image # Text to image
class TextToLatentsInvocation(BaseInvocation): class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from a prompt.""" """Generates latents from conditionings."""
type: Literal["t2l"] = "t2l" type: Literal["t2l"] = "t2l"
# Inputs # Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off # fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from") positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
@ -203,8 +203,10 @@ class TextToLatentsInvocation(BaseInvocation):
return model return model
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData: def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model) c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
uc, uc,
c, c,
@ -231,7 +233,7 @@ class TextToLatentsInvocation(BaseInvocation):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
model = self.get_model(context.services.model_manager) model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model) conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
@ -16,24 +17,32 @@ def create_text_to_image() -> LibraryGraph:
nodes={ nodes={
'width': ParamIntInvocation(id='width', a=512), 'width': ParamIntInvocation(id='width', a=512),
'height': ParamIntInvocation(id='height', a=512), 'height': ParamIntInvocation(id='height', a=512),
'seed': ParamIntInvocation(id='seed', a=-1),
'3': NoiseInvocation(id='3'), '3': NoiseInvocation(id='3'),
'4': TextToLatentsInvocation(id='4'), '4': CompelInvocation(id='4'),
'5': LatentsToImageInvocation(id='5') '5': CompelInvocation(id='5'),
'6': TextToLatentsInvocation(id='6'),
'7': LatentsToImageInvocation(id='7'),
}, },
edges=[ edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')), Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')), Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')), Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
] ]
), ),
exposed_inputs=[ exposed_inputs=[
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'), ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
ExposedNodeInput(node_path='width', field='a', alias='width'), ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height') ExposedNodeInput(node_path='height', field='a', alias='height'),
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
], ],
exposed_outputs=[ exposed_outputs=[
ExposedNodeOutput(node_path='5', field='image', alias='image') ExposedNodeOutput(node_path='7', field='image', alias='image')
]) ])

View File

@ -78,7 +78,6 @@ class InvokeAIWebServer:
mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css") mimetypes.add_type("text/css", ".css")
# Socket IO # Socket IO
logger = True if args.web_verbose else False
engineio_logger = True if args.web_verbose else False engineio_logger = True if args.web_verbose else False
max_http_buffer_size = 10000000 max_http_buffer_size = 10000000

View File

@ -1,5 +1,5 @@
import { forEach, size } from 'lodash-es'; import { forEach, size } from 'lodash-es';
import { ImageField, LatentsField } from 'services/api'; import { ImageField, LatentsField, ConditioningField } from 'services/api';
const OBJECT_TYPESTRING = '[object Object]'; const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]'; const STRING_TYPESTRING = '[object String]';
@ -74,8 +74,38 @@ const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
}; };
}; };
const parseConditioningField = (
conditioningField: unknown
): ConditioningField | undefined => {
// Must be an object
if (!isObject(conditioningField)) {
return;
}
// A ConditioningField must have a `conditioning_name`
if (!('conditioning_name' in conditioningField)) {
return;
}
// A ConditioningField's `conditioning_name` must be a string
if (typeof conditioningField.conditioning_name !== 'string') {
return;
}
// Build a valid ConditioningField
return {
conditioning_name: conditioningField.conditioning_name,
};
};
type NodeMetadata = { type NodeMetadata = {
[key: string]: string | number | boolean | ImageField | LatentsField; [key: string]:
| string
| number
| boolean
| ImageField
| LatentsField
| ConditioningField;
}; };
type InvokeAIMetadata = { type InvokeAIMetadata = {
@ -101,7 +131,7 @@ export const parseNodeMetadata = (
return; return;
} }
// the only valid object types are ImageField and LatentsField // the only valid object types are ImageField, LatentsField and ConditioningField
if (isObject(nodeItem)) { if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) { if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem); const imageField = parseImageField(nodeItem);
@ -118,6 +148,14 @@ export const parseNodeMetadata = (
} }
return; return;
} }
if ('conditioning_name' in nodeItem) {
const conditioningField = parseConditioningField(nodeItem);
if (conditioningField) {
parsed[nodeKey] = conditioningField;
}
return;
}
} }
// otherwise we accept any string, number or boolean // otherwise we accept any string, number or boolean

View File

@ -6,6 +6,7 @@ import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent';
@ -84,6 +85,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'conditioning' && template.type === 'conditioning') {
return (
<ConditioningInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'model' && template.type === 'model') { if (type === 'model' && template.type === 'model') {
return ( return (
<ModelInputFieldComponent <ModelInputFieldComponent

View File

@ -0,0 +1,19 @@
import {
ConditioningInputFieldTemplate,
ConditioningInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const ConditioningInputFieldComponent = (
props: FieldComponentProps<
ConditioningInputFieldValue,
ConditioningInputFieldTemplate
>
) => {
const { nodeId, field } = props;
return null;
};
export default memo(ConditioningInputFieldComponent);

View File

@ -11,6 +11,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
enum: 'enum', enum: 'enum',
ImageField: 'image', ImageField: 'image',
LatentsField: 'latents', LatentsField: 'latents',
ConditioningField: 'conditioning',
model: 'model', model: 'model',
array: 'array', array: 'array',
}; };
@ -63,6 +64,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Latents', title: 'Latents',
description: 'Latents may be passed between nodes.', description: 'Latents may be passed between nodes.',
}, },
conditioning: {
color: 'cyan',
colorCssVar: getColorTokenCssVariable('cyan'),
title: 'Conditioning',
description: 'Conditioning may be passed between nodes.',
},
model: { model: {
color: 'teal', color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'), colorCssVar: getColorTokenCssVariable('teal'),

View File

@ -56,6 +56,7 @@ export type FieldType =
| 'enum' | 'enum'
| 'image' | 'image'
| 'latents' | 'latents'
| 'conditioning'
| 'model' | 'model'
| 'array'; | 'array';
@ -74,6 +75,7 @@ export type InputFieldValue =
| BooleanInputFieldValue | BooleanInputFieldValue
| ImageInputFieldValue | ImageInputFieldValue
| LatentsInputFieldValue | LatentsInputFieldValue
| ConditioningInputFieldValue
| EnumInputFieldValue | EnumInputFieldValue
| ModelInputFieldValue | ModelInputFieldValue
| ArrayInputFieldValue; | ArrayInputFieldValue;
@ -91,6 +93,7 @@ export type InputFieldTemplate =
| BooleanInputFieldTemplate | BooleanInputFieldTemplate
| ImageInputFieldTemplate | ImageInputFieldTemplate
| LatentsInputFieldTemplate | LatentsInputFieldTemplate
| ConditioningInputFieldTemplate
| EnumInputFieldTemplate | EnumInputFieldTemplate
| ModelInputFieldTemplate | ModelInputFieldTemplate
| ArrayInputFieldTemplate; | ArrayInputFieldTemplate;
@ -162,6 +165,11 @@ export type LatentsInputFieldValue = FieldValueBase & {
value?: undefined; value?: undefined;
}; };
export type ConditioningInputFieldValue = FieldValueBase & {
type: 'conditioning';
value?: undefined;
};
export type ImageInputFieldValue = FieldValueBase & { export type ImageInputFieldValue = FieldValueBase & {
type: 'image'; type: 'image';
value?: Pick<ImageField, 'image_name' | 'image_type'>; value?: Pick<ImageField, 'image_name' | 'image_type'>;
@ -229,6 +237,11 @@ export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
type: 'latents'; type: 'latents';
}; };
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'conditioning';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & { export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number; default: string | number;
type: 'enum'; type: 'enum';

View File

@ -9,6 +9,7 @@ import {
ImageInputFieldTemplate, ImageInputFieldTemplate,
IntegerInputFieldTemplate, IntegerInputFieldTemplate,
LatentsInputFieldTemplate, LatentsInputFieldTemplate,
ConditioningInputFieldTemplate,
StringInputFieldTemplate, StringInputFieldTemplate,
ModelInputFieldTemplate, ModelInputFieldTemplate,
InputFieldTemplateBase, InputFieldTemplateBase,
@ -196,6 +197,21 @@ const buildLatentsInputFieldTemplate = ({
return template; return template;
}; };
const buildConditioningInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ConditioningInputFieldTemplate => {
const template: ConditioningInputFieldTemplate = {
...baseField,
type: 'conditioning',
inputRequirement: 'always',
inputKind: 'connection',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({ const buildEnumInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -266,6 +282,9 @@ export const buildInputFieldTemplate = (
if (['latents'].includes(fieldType)) { if (['latents'].includes(fieldType)) {
return buildLatentsInputFieldTemplate({ schemaObject, baseField }); return buildLatentsInputFieldTemplate({ schemaObject, baseField });
} }
if (['conditioning'].includes(fieldType)) {
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
}
if (['model'].includes(fieldType)) { if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField }); return buildModelInputFieldTemplate({ schemaObject, baseField });
} }

View File

@ -48,6 +48,10 @@ export const buildInputFieldValue = (
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'conditioning') {
fieldValue.value = undefined;
}
if (template.type === 'model') { if (template.type === 'model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }

View File

@ -463,16 +463,16 @@ def test_graph_subgraph_t2i():
n4 = ShowImageInvocation(id = "4") n4 = ShowImageInvocation(id = "4")
g.add_node(n4) g.add_node(n4)
g.add_edge(create_edge("1.5","image","4","image")) g.add_edge(create_edge("1.7","image","4","image"))
# Validate # Validate
dg = g.nx_graph_flat() dg = g.nx_graph_flat()
assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4']) assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '2', '3', '4'])
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges] expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
expected_edges.extend([ expected_edges.extend([
('2','1.width'), ('2','1.width'),
('3','1.height'), ('3','1.height'),
('1.5','4') ('1.7','4')
]) ])
print(expected_edges) print(expected_edges)
print(list(dg.edges)) print(list(dg.edges))