merge with main

This commit is contained in:
Lincoln Stein 2023-05-10 00:03:32 -04:00
commit fa6a580452
20 changed files with 591 additions and 90 deletions

View File

@ -2,8 +2,7 @@ name: mkdocs-material
on: on:
push: push:
branches: branches:
- 'main' - 'refs/heads/v2.3'
- 'development'
permissions: permissions:
contents: write contents: write
@ -12,6 +11,10 @@ jobs:
mkdocs-material: mkdocs-material:
if: github.event.pull_request.draft == false if: github.event.pull_request.draft == false
runs-on: ubuntu-latest runs-on: ubuntu-latest
env:
REPO_URL: '${{ github.server_url }}/${{ github.repository }}'
REPO_NAME: '${{ github.repository }}'
SITE_URL: 'https://${{ github.repository_owner }}.github.io/InvokeAI'
steps: steps:
- name: checkout sources - name: checkout sources
uses: actions/checkout@v3 uses: actions/checkout@v3
@ -22,11 +25,15 @@ jobs:
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: '3.10' python-version: '3.10'
cache: pip
cache-dependency-path: pyproject.toml
- name: install requirements - name: install requirements
env:
PIP_USE_PEP517: 1
run: | run: |
python -m \ python -m \
pip install -r docs/requirements-mkdocs.txt pip install ".[docs]"
- name: confirm buildability - name: confirm buildability
run: | run: |
@ -36,7 +43,7 @@ jobs:
--verbose --verbose
- name: deploy to gh-pages - name: deploy to gh-pages
if: ${{ github.ref == 'refs/heads/main' }} if: ${{ github.ref == 'refs/heads/v2.3' }}
run: | run: |
python -m \ python -m \
mkdocs gh-deploy \ mkdocs gh-deploy \

View File

@ -89,7 +89,7 @@ experimental versions later.
sudo apt update sudo apt update
sudo apt install -y software-properties-common sudo apt install -y software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt install python3.10 python3-pip python3.10-venv sudo apt install -y python3.10 python3-pip python3.10-venv
sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3 sudo update-alternatives --install /usr/local/bin/python python /usr/bin/python3.10 3
``` ```

View File

@ -0,0 +1,245 @@
from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.util.choose_model import choose_model
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager
from compel import Compel
from compel.prompt_parser import (
Blend,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment,
)
from invokeai.backend.globals import Globals
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
class CompelOutput(BaseInvocationOutput):
"""Compel parser output"""
#fmt: off
type: Literal["compel_output"] = "compel_output"
conditioning: ConditioningField = Field(default=None, description="Conditioning")
#fmt: on
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
type: Literal["compel"] = "compel"
prompt: str = Field(default="", description="Prompt")
model: str = Field(default="", description="Model to use")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
def invoke(self, context: InvocationContext) -> CompelOutput:
# TODO: load without model
model = choose_model(context.services.model_manager, self.model)
pipeline = model["model"]
tokenizer = pipeline.tokenizer
text_encoder = pipeline.text_encoder
# TODO: global? input?
#use_full_precision = precision == "float32" or precision == "autocast"
#use_full_precision = False
# TODO: redo TI when separate model loding implemented
#textual_inversion_manager = TextualInversionManager(
# tokenizer=tokenizer,
# text_encoder=text_encoder,
# full_precision=use_full_precision,
#)
def load_huggingface_concepts(concepts: list[str]):
pipeline.textual_inversion_manager.load_huggingface_concepts(concepts)
# apply the concepts library to the prompt
prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers(
self.prompt,
lambda concepts: load_huggingface_concepts(concepts),
pipeline.textual_inversion_manager.get_all_trigger_strings(),
)
# lazy-load any deferred textual inversions.
# this might take a couple of seconds the first time a textual inversion is used.
pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(
prompt_str
)
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=pipeline.textual_inversion_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
)
# TODO: support legacy blend?
prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str)
if getattr(Globals, "log_tokenization", False):
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
# TODO: long prompt support
#if not self.truncate_long_prompts:
# [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt),
cross_attention_control_args=options.get("cross_attention_control", None),
)
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.set(conditioning_name, (c, ec))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max(
[
get_max_token_count(tokenizer, c, truncate_if_too_long)
for c in blend.prompts
]
)
else:
return len(
get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long)
)
def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> [str]:
if type(parsed_prompt) is Blend:
raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children"
)
text_fragments = [
x.text
if type(x) is Fragment
else (
" ".join([f.text for f in x.original])
if type(x) is CrossAttentionControlSubstitute
else str(x)
)
for x in parsed_prompt.children
]
text = " ".join(text_fragments)
tokens = tokenizer.tokenize(text)
if truncate_if_too_long:
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
tokens = tokens[0:max_tokens_length]
return tokens
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
for i, c in enumerate(blend.prompts):
log_tokenization_for_prompt_object(
c,
tokenizer,
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})",
)
elif type(p) is FlattenedPrompt:
flattened_prompt: FlattenedPrompt = p
if flattened_prompt.wants_cross_attention_control:
original_fragments = []
edited_fragments = []
for f in flattened_prompt.children:
if type(f) is CrossAttentionControlSubstitute:
original_fragments += f.original
edited_fragments += f.edited
else:
original_fragments.append(f)
edited_fragments.append(f)
original_text = " ".join([x.text for x in original_fragments])
log_tokenization_for_text(
original_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap originals)",
)
edited_text = " ".join([x.text for x in edited_fragments])
log_tokenization_for_text(
edited_text,
tokenizer,
display_label=f"{display_label_prefix}(.swap replacements)",
)
else:
text = " ".join([x.text for x in flattened_prompt.children])
log_tokenization_for_text(
text, tokenizer, display_label=display_label_prefix
)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
"""
tokens = tokenizer.tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace("</w>", " ")
# alternate color
s = (usedTokens % 6) + 1
if truncate_if_too_long and i >= tokenizer.model_max_length:
discarded = discarded + f"\x1b[0;3{s};40m{token}"
else:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')
print(f"{tokenized}\x1b[0m")
if discarded != "":
print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):")
print(f"{discarded}\x1b[0m")

View File

@ -13,13 +13,13 @@ from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec
from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_storage import ImageType from ..services.image_storage import ImageType
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput, build_image_output from .image import ImageField, ImageOutput, build_image_output
from .compel import ConditioningField
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers import diffusers
@ -138,14 +138,14 @@ class NoiseInvocation(BaseInvocation):
# Text to image # Text to image
class TextToLatentsInvocation(BaseInvocation): class TextToLatentsInvocation(BaseInvocation):
"""Generates latents from a prompt.""" """Generates latents from conditionings."""
type: Literal["t2l"] = "t2l" type: Literal["t2l"] = "t2l"
# Inputs # Inputs
# TODO: consider making prompt optional to enable providing prompt through a link
# fmt: off # fmt: off
prompt: Optional[str] = Field(description="The prompt to generate an image from") positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use") noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
@ -204,8 +204,10 @@ class TextToLatentsInvocation(BaseInvocation):
return model_ctx return model_ctx
def get_conditioning_data(self, model: StableDiffusionGeneratorPipeline) -> ConditioningData: def get_conditioning_data(self, context: InvocationContext, model: StableDiffusionGeneratorPipeline) -> ConditioningData:
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(self.prompt, model=model) c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
uc, uc,
c, c,
@ -231,8 +233,8 @@ class TextToLatentsInvocation(BaseInvocation):
def step_callback(state: PipelineIntermediateState): def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state) self.dispatch_progress(context, source_node_id, state)
with self.get_model(context.services.model_manager) as model: model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(model) conditioning_data = self.get_conditioning_data(context, model)
# TODO: Verify the noise is the right size # TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings( result_latents, result_attention_map_saver = model.latents_from_embeddings(

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation from ..invocations.latent import LatentsToImageInvocation, NoiseInvocation, TextToLatentsInvocation
from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation from ..invocations.params import ParamIntInvocation
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
from .item_storage import ItemStorageABC from .item_storage import ItemStorageABC
@ -16,24 +17,32 @@ def create_text_to_image() -> LibraryGraph:
nodes={ nodes={
'width': ParamIntInvocation(id='width', a=512), 'width': ParamIntInvocation(id='width', a=512),
'height': ParamIntInvocation(id='height', a=512), 'height': ParamIntInvocation(id='height', a=512),
'seed': ParamIntInvocation(id='seed', a=-1),
'3': NoiseInvocation(id='3'), '3': NoiseInvocation(id='3'),
'4': TextToLatentsInvocation(id='4'), '4': CompelInvocation(id='4'),
'5': LatentsToImageInvocation(id='5') '5': CompelInvocation(id='5'),
'6': TextToLatentsInvocation(id='6'),
'7': LatentsToImageInvocation(id='7'),
}, },
edges=[ edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')), Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='4', field='noise')), Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
Edge(source=EdgeConnection(node_id='4', field='latents'), destination=EdgeConnection(node_id='5', field='latents')), Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
] ]
), ),
exposed_inputs=[ exposed_inputs=[
ExposedNodeInput(node_path='4', field='prompt', alias='prompt'), ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
ExposedNodeInput(node_path='width', field='a', alias='width'), ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height') ExposedNodeInput(node_path='height', field='a', alias='height'),
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
], ],
exposed_outputs=[ exposed_outputs=[
ExposedNodeOutput(node_path='5', field='image', alias='image') ExposedNodeOutput(node_path='7', field='image', alias='image')
]) ])

View File

@ -1,5 +1,5 @@
import { forEach, size } from 'lodash-es'; import { forEach, size } from 'lodash-es';
import { ImageField, LatentsField } from 'services/api'; import { ImageField, LatentsField, ConditioningField } from 'services/api';
const OBJECT_TYPESTRING = '[object Object]'; const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]'; const STRING_TYPESTRING = '[object String]';
@ -74,8 +74,38 @@ const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
}; };
}; };
const parseConditioningField = (
conditioningField: unknown
): ConditioningField | undefined => {
// Must be an object
if (!isObject(conditioningField)) {
return;
}
// A ConditioningField must have a `conditioning_name`
if (!('conditioning_name' in conditioningField)) {
return;
}
// A ConditioningField's `conditioning_name` must be a string
if (typeof conditioningField.conditioning_name !== 'string') {
return;
}
// Build a valid ConditioningField
return {
conditioning_name: conditioningField.conditioning_name,
};
};
type NodeMetadata = { type NodeMetadata = {
[key: string]: string | number | boolean | ImageField | LatentsField; [key: string]:
| string
| number
| boolean
| ImageField
| LatentsField
| ConditioningField;
}; };
type InvokeAIMetadata = { type InvokeAIMetadata = {
@ -101,7 +131,7 @@ export const parseNodeMetadata = (
return; return;
} }
// the only valid object types are ImageField and LatentsField // the only valid object types are ImageField, LatentsField and ConditioningField
if (isObject(nodeItem)) { if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) { if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem); const imageField = parseImageField(nodeItem);
@ -118,6 +148,14 @@ export const parseNodeMetadata = (
} }
return; return;
} }
if ('conditioning_name' in nodeItem) {
const conditioningField = parseConditioningField(nodeItem);
if (conditioningField) {
parsed[nodeKey] = conditioningField;
}
return;
}
} }
// otherwise we accept any string, number or boolean // otherwise we accept any string, number or boolean

View File

@ -6,9 +6,11 @@ import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
type InputFieldComponentProps = { type InputFieldComponentProps = {
nodeId: string; nodeId: string;
@ -84,6 +86,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'conditioning' && template.type === 'conditioning') {
return (
<ConditioningInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'model' && template.type === 'model') { if (type === 'model' && template.type === 'model') {
return ( return (
<ModelInputFieldComponent <ModelInputFieldComponent
@ -104,6 +116,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
); );
} }
if (type === 'item' && template.type === 'item') {
return (
<ItemInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>; return <Box p={2}>Unknown field type: {type}</Box>;
}; };

View File

@ -0,0 +1,19 @@
import {
ConditioningInputFieldTemplate,
ConditioningInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
const ConditioningInputFieldComponent = (
props: FieldComponentProps<
ConditioningInputFieldValue,
ConditioningInputFieldTemplate
>
) => {
const { nodeId, field } = props;
return null;
};
export default memo(ConditioningInputFieldComponent);

View File

@ -0,0 +1,17 @@
import {
ItemInputFieldTemplate,
ItemInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FaAddressCard, FaList } from 'react-icons/fa';
import { FieldComponentProps } from './types';
const ItemInputFieldComponent = (
props: FieldComponentProps<ItemInputFieldValue, ItemInputFieldTemplate>
) => {
const { nodeId, field } = props;
return <FaAddressCard />;
};
export default memo(ItemInputFieldComponent);

View File

@ -11,8 +11,10 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
enum: 'enum', enum: 'enum',
ImageField: 'image', ImageField: 'image',
LatentsField: 'latents', LatentsField: 'latents',
ConditioningField: 'conditioning',
model: 'model', model: 'model',
array: 'array', array: 'array',
item: 'item',
}; };
const COLOR_TOKEN_VALUE = 500; const COLOR_TOKEN_VALUE = 500;
@ -63,6 +65,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Latents', title: 'Latents',
description: 'Latents may be passed between nodes.', description: 'Latents may be passed between nodes.',
}, },
conditioning: {
color: 'cyan',
colorCssVar: getColorTokenCssVariable('cyan'),
title: 'Conditioning',
description: 'Conditioning may be passed between nodes.',
},
model: { model: {
color: 'teal', color: 'teal',
colorCssVar: getColorTokenCssVariable('teal'), colorCssVar: getColorTokenCssVariable('teal'),
@ -75,4 +83,10 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Array', title: 'Array',
description: 'TODO: Array type description.', description: 'TODO: Array type description.',
}, },
item: {
color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'),
title: 'Collection Item',
description: 'TODO: Collection Item type description.',
},
}; };

View File

@ -56,8 +56,10 @@ export type FieldType =
| 'enum' | 'enum'
| 'image' | 'image'
| 'latents' | 'latents'
| 'conditioning'
| 'model' | 'model'
| 'array'; | 'array'
| 'item';
/** /**
* An input field is persisted across reloads as part of the user's local state. * An input field is persisted across reloads as part of the user's local state.
@ -74,9 +76,11 @@ export type InputFieldValue =
| BooleanInputFieldValue | BooleanInputFieldValue
| ImageInputFieldValue | ImageInputFieldValue
| LatentsInputFieldValue | LatentsInputFieldValue
| ConditioningInputFieldValue
| EnumInputFieldValue | EnumInputFieldValue
| ModelInputFieldValue | ModelInputFieldValue
| ArrayInputFieldValue; | ArrayInputFieldValue
| ItemInputFieldValue;
/** /**
* An input field template is generated on each page load from the OpenAPI schema. * An input field template is generated on each page load from the OpenAPI schema.
@ -91,9 +95,11 @@ export type InputFieldTemplate =
| BooleanInputFieldTemplate | BooleanInputFieldTemplate
| ImageInputFieldTemplate | ImageInputFieldTemplate
| LatentsInputFieldTemplate | LatentsInputFieldTemplate
| ConditioningInputFieldTemplate
| EnumInputFieldTemplate | EnumInputFieldTemplate
| ModelInputFieldTemplate | ModelInputFieldTemplate
| ArrayInputFieldTemplate; | ArrayInputFieldTemplate
| ItemInputFieldTemplate;
/** /**
* An output field is persisted across as part of the user's local state. * An output field is persisted across as part of the user's local state.
@ -162,6 +168,11 @@ export type LatentsInputFieldValue = FieldValueBase & {
value?: undefined; value?: undefined;
}; };
export type ConditioningInputFieldValue = FieldValueBase & {
type: 'conditioning';
value?: undefined;
};
export type ImageInputFieldValue = FieldValueBase & { export type ImageInputFieldValue = FieldValueBase & {
type: 'image'; type: 'image';
value?: Pick<ImageField, 'image_name' | 'image_type'>; value?: Pick<ImageField, 'image_name' | 'image_type'>;
@ -177,6 +188,11 @@ export type ArrayInputFieldValue = FieldValueBase & {
value?: (string | number)[]; value?: (string | number)[];
}; };
export type ItemInputFieldValue = FieldValueBase & {
type: 'item';
value?: undefined;
};
export type InputFieldTemplateBase = { export type InputFieldTemplateBase = {
name: string; name: string;
title: string; title: string;
@ -229,6 +245,11 @@ export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
type: 'latents'; type: 'latents';
}; };
export type ConditioningInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'conditioning';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & { export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number; default: string | number;
type: 'enum'; type: 'enum';
@ -242,10 +263,15 @@ export type ModelInputFieldTemplate = InputFieldTemplateBase & {
}; };
export type ArrayInputFieldTemplate = InputFieldTemplateBase & { export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: (string | number)[]; default: [];
type: 'array'; type: 'array';
}; };
export type ItemInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'item';
};
/** /**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES * JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
*/ */

View File

@ -9,12 +9,15 @@ import {
ImageInputFieldTemplate, ImageInputFieldTemplate,
IntegerInputFieldTemplate, IntegerInputFieldTemplate,
LatentsInputFieldTemplate, LatentsInputFieldTemplate,
ConditioningInputFieldTemplate,
StringInputFieldTemplate, StringInputFieldTemplate,
ModelInputFieldTemplate, ModelInputFieldTemplate,
InputFieldTemplateBase, InputFieldTemplateBase,
OutputFieldTemplate, OutputFieldTemplate,
TypeHints, TypeHints,
FieldType, FieldType,
ArrayInputFieldTemplate,
ItemInputFieldTemplate,
} from '../types/types'; } from '../types/types';
export type BaseFieldProperties = 'name' | 'title' | 'description'; export type BaseFieldProperties = 'name' | 'title' | 'description';
@ -196,6 +199,21 @@ const buildLatentsInputFieldTemplate = ({
return template; return template;
}; };
const buildConditioningInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ConditioningInputFieldTemplate => {
const template: ConditioningInputFieldTemplate = {
...baseField,
type: 'conditioning',
inputRequirement: 'always',
inputKind: 'connection',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({ const buildEnumInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -214,6 +232,36 @@ const buildEnumInputFieldTemplate = ({
return template; return template;
}; };
const buildArrayInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ArrayInputFieldTemplate => {
const template: ArrayInputFieldTemplate = {
...baseField,
type: 'array',
inputRequirement: 'always',
inputKind: 'direct',
default: [],
};
return template;
};
const buildItemInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ItemInputFieldTemplate => {
const template: ItemInputFieldTemplate = {
...baseField,
type: 'item',
inputRequirement: 'always',
inputKind: 'direct',
default: undefined,
};
return template;
};
export const getFieldType = ( export const getFieldType = (
schemaObject: OpenAPIV3.SchemaObject, schemaObject: OpenAPIV3.SchemaObject,
name: string, name: string,
@ -266,6 +314,9 @@ export const buildInputFieldTemplate = (
if (['latents'].includes(fieldType)) { if (['latents'].includes(fieldType)) {
return buildLatentsInputFieldTemplate({ schemaObject, baseField }); return buildLatentsInputFieldTemplate({ schemaObject, baseField });
} }
if (['conditioning'].includes(fieldType)) {
return buildConditioningInputFieldTemplate({ schemaObject, baseField });
}
if (['model'].includes(fieldType)) { if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField }); return buildModelInputFieldTemplate({ schemaObject, baseField });
} }
@ -284,6 +335,12 @@ export const buildInputFieldTemplate = (
if (['boolean'].includes(fieldType)) { if (['boolean'].includes(fieldType)) {
return buildBooleanInputFieldTemplate({ schemaObject, baseField }); return buildBooleanInputFieldTemplate({ schemaObject, baseField });
} }
if (['array'].includes(fieldType)) {
return buildArrayInputFieldTemplate({ schemaObject, baseField });
}
if (['item'].includes(fieldType)) {
return buildItemInputFieldTemplate({ schemaObject, baseField });
}
return; return;
}; };

View File

@ -48,6 +48,10 @@ export const buildInputFieldValue = (
fieldValue.value = undefined; fieldValue.value = undefined;
} }
if (template.type === 'conditioning') {
fieldValue.value = undefined;
}
if (template.type === 'model') { if (template.type === 'model') {
fieldValue.value = undefined; fieldValue.value = undefined;
} }

View File

@ -7,7 +7,7 @@ export const buildIterateNode = (): IterateInvocation => {
return { return {
id: nodeId, id: nodeId,
type: 'iterate', type: 'iterate',
collection: [], // collection: [],
index: 0, // index: 0,
}; };
}; };

View File

@ -13,7 +13,7 @@ import {
buildOutputFieldTemplates, buildOutputFieldTemplates,
} from './fieldTemplateBuilders'; } from './fieldTemplateBuilders';
const invocationDenylist = ['Graph', 'Collect', 'LoadImage']; const invocationDenylist = ['Graph', 'LoadImage'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => { export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now // filter out non-invocation schemas, plus some tricky invocations for now
@ -32,16 +32,43 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
if (isInvocationSchemaObject(schema)) { if (isInvocationSchemaObject(schema)) {
const type = schema.properties.type.default; const type = schema.properties.type.default;
const title = const title = schema.ui?.title ?? schema.title.replace('Invocation', '');
schema.ui?.title ??
schema.title
.replace('Invocation', '')
.split(/(?=[A-Z])/) // split PascalCase into array
.join(' ');
const typeHints = schema.ui?.type_hints; const typeHints = schema.ui?.type_hints;
const inputs = reduce( const inputs: Record<string, InputFieldTemplate> = {};
if (type === 'collect') {
const itemProperty = schema.properties[
'item'
] as InvocationSchemaObject;
// Handle the special Collect node
inputs.item = {
type: 'item',
name: 'item',
description: itemProperty.description ?? '',
title: 'Collection Item',
inputKind: 'connection',
inputRequirement: 'always',
default: undefined,
};
} else if (type === 'iterate') {
const itemProperty = schema.properties[
'collection'
] as InvocationSchemaObject;
inputs.collection = {
type: 'array',
name: 'collection',
title: itemProperty.title ?? '',
default: [],
description: itemProperty.description ?? '',
inputRequirement: 'always',
inputKind: 'connection',
};
} else {
// All other nodes
reduce(
schema.properties, schema.properties,
(inputsAccumulator, property, propertyName) => { (inputsAccumulator, property, propertyName) => {
if ( if (
@ -49,32 +76,18 @@ export const parseSchema = (openAPI: OpenAPIV3.Document) => {
!['type', 'id'].includes(propertyName) && !['type', 'id'].includes(propertyName) &&
isSchemaObject(property) isSchemaObject(property)
) { ) {
let field: InputFieldTemplate | undefined; const field: InputFieldTemplate | undefined =
if (propertyName === 'collection') { buildInputFieldTemplate(property, propertyName, typeHints);
field = {
default: property.default ?? [],
name: 'collection',
title: property.title ?? '',
description: property.description ?? '',
type: 'array',
inputRequirement: 'always',
inputKind: 'connection',
};
} else {
field = buildInputFieldTemplate(
property,
propertyName,
typeHints
);
}
if (field) { if (field) {
inputsAccumulator[propertyName] = field; inputsAccumulator[propertyName] = field;
} }
} }
return inputsAccumulator; return inputsAccumulator;
}, },
{} as Record<string, InputFieldTemplate> inputs
); );
}
const rawOutput = (schema as InvocationSchemaObject).output; const rawOutput = (schema as InvocationSchemaObject).output;

View File

@ -107,7 +107,7 @@ const initialSystemState: SystemState = {
subscribedNodeIds: [], subscribedNodeIds: [],
wereModelsReceived: false, wereModelsReceived: false,
wasSchemaParsed: false, wasSchemaParsed: false,
consoleLogLevel: 'error', consoleLogLevel: 'debug',
shouldLogToConsole: true, shouldLogToConsole: true,
statusTranslationKey: 'common.statusDisconnected', statusTranslationKey: 'common.statusDisconnected',
canceledSession: '', canceledSession: '',
@ -384,6 +384,13 @@ export const systemSlice = createSlice({
state.statusTranslationKey = 'common.statusPreparing'; state.statusTranslationKey = 'common.statusPreparing';
}); });
builder.addCase(sessionInvoked.rejected, (state, action) => {
const error = action.payload as string | undefined;
state.toastQueue.push(
makeToast({ title: error || t('toast.serverError'), status: 'error' })
);
});
/** /**
* Session Canceled * Session Canceled
*/ */

View File

@ -46,6 +46,8 @@ export const socketMiddleware = () => {
// TODO: handle providing jwt to socket.io // TODO: handle providing jwt to socket.io
socketOptions.auth = { token: OpenAPI.TOKEN }; socketOptions.auth = { token: OpenAPI.TOKEN };
} }
socketOptions.transports = ['websocket', 'polling'];
} }
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io( const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(

View File

@ -22,6 +22,8 @@ import {
} from 'services/thunks/gallery'; } from 'services/thunks/gallery';
import { receivedModels } from 'services/thunks/model'; import { receivedModels } from 'services/thunks/model';
import { receivedOpenAPISchema } from 'services/thunks/schema'; import { receivedOpenAPISchema } from 'services/thunks/schema';
import { makeToast } from '../../../features/system/hooks/useToastWatcher';
import { addToast } from '../../../features/system/store/systemSlice';
type SetEventListenersArg = { type SetEventListenersArg = {
socket: Socket<ServerToClientEvents, ClientToServerEvents>; socket: Socket<ServerToClientEvents, ClientToServerEvents>;
@ -78,6 +80,16 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
} }
}); });
socket.on('connect_error', (error) => {
if (error && error.message) {
dispatch(
addToast(
makeToast({ title: error.message, status: 'error', duration: 10000 })
)
);
}
});
/** /**
* Disconnect * Disconnect
*/ */

View File

@ -101,17 +101,24 @@ export const nodeAdded = createAppAsyncThunk(
*/ */
export const sessionInvoked = createAppAsyncThunk( export const sessionInvoked = createAppAsyncThunk(
'api/sessionInvoked', 'api/sessionInvoked',
async (arg: { sessionId: string }, _thunkApi) => { async (arg: { sessionId: string }, { rejectWithValue }) => {
const { sessionId } = arg; const { sessionId } = arg;
try {
const response = await SessionsService.invokeSession({ const response = await SessionsService.invokeSession({
sessionId, sessionId,
all: true, all: true,
}); });
sessionLog.info({ arg, response }, `Session invoked (${sessionId})`); sessionLog.info({ arg, response }, `Session invoked (${sessionId})`);
return response; return response;
} catch (error) {
const err = error as any;
if (err.status === 403) {
return rejectWithValue(err.body.detail);
}
throw error;
}
} }
); );

View File

@ -463,16 +463,16 @@ def test_graph_subgraph_t2i():
n4 = ShowImageInvocation(id = "4") n4 = ShowImageInvocation(id = "4")
g.add_node(n4) g.add_node(n4)
g.add_edge(create_edge("1.5","image","4","image")) g.add_edge(create_edge("1.7","image","4","image"))
# Validate # Validate
dg = g.nx_graph_flat() dg = g.nx_graph_flat()
assert set(dg.nodes) == set(['1.width', '1.height', '1.3', '1.4', '1.5', '2', '3', '4']) assert set(dg.nodes) == set(['1.width', '1.height', '1.seed', '1.3', '1.4', '1.5', '1.6', '1.7', '2', '3', '4'])
expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges] expected_edges = [(f'1.{e.source.node_id}',f'1.{e.destination.node_id}') for e in lg.graph.edges]
expected_edges.extend([ expected_edges.extend([
('2','1.width'), ('2','1.width'),
('3','1.height'), ('3','1.height'),
('1.5','4') ('1.7','4')
]) ])
print(expected_edges) print(expected_edges)
print(list(dg.edges)) print(list(dg.edges))