feat(ui): update for InfillInvocation

This commit is contained in:
psychedelicious 2023-05-05 15:16:54 +10:00
parent da4eacdffe
commit bcc21531fb
21 changed files with 320 additions and 25 deletions

View File

@ -11,6 +11,7 @@ import ModelInputFieldComponent from './fields/ModelInputFieldComponent';
import NumberInputFieldComponent from './fields/NumberInputFieldComponent';
import StringInputFieldComponent from './fields/StringInputFieldComponent';
import ItemInputFieldComponent from './fields/ItemInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
@ -126,6 +127,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
if (type === 'color' && template.type === 'color') {
return (
<ColorInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>;
};

View File

@ -0,0 +1,31 @@
import {
ColorInputFieldTemplate,
ColorInputFieldValue,
} from 'features/nodes/types/types';
import { memo } from 'react';
import { FieldComponentProps } from './types';
import { RgbaColor, RgbaColorPicker } from 'react-colorful';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { useAppDispatch } from 'app/store/storeHooks';
const ColorInputFieldComponent = (
props: FieldComponentProps<ColorInputFieldValue, ColorInputFieldTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (value: RgbaColor) => {
dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value }));
};
return (
<RgbaColorPicker
className="nodrag"
color={field.value}
onChange={handleValueChanged}
/>
);
};
export default memo(ColorInputFieldComponent);

View File

@ -11,13 +11,14 @@ import {
NodeChange,
OnConnectStartParams,
} from 'reactflow';
import { Graph, ImageField } from 'services/api';
import { ColorField, Graph, ImageField } from 'services/api';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { InvocationTemplate, InvocationValue } from '../types/types';
import { parseSchema } from '../util/parseSchema';
import { log } from 'app/logging/useLogger';
import { size } from 'lodash-es';
import { isAnyGraphBuilt } from './actions';
import { RgbaColor } from 'react-colorful';
export type NodesState = {
nodes: Node<InvocationValue>[];
@ -69,6 +70,7 @@ const nodesSlice = createSlice({
| number
| boolean
| Pick<ImageField, 'image_name' | 'image_type'>
| RgbaColor
| undefined;
}>
) => {

View File

@ -15,6 +15,7 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
model: 'model',
array: 'array',
item: 'item',
ColorField: 'color',
};
const COLOR_TOKEN_VALUE = 500;
@ -89,4 +90,10 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Collection Item',
description: 'TODO: Collection Item type description.',
},
color: {
color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'),
title: 'Color',
description: 'A RGBA color.',
},
};

View File

@ -1,6 +1,9 @@
import { Image } from 'app/types/invokeai';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import { ImageField } from 'services/api';
import { AnyInvocationType } from 'services/events/types';
import { O } from 'ts-toolbelt';
export type InvocationValue = {
id: string;
@ -59,7 +62,8 @@ export type FieldType =
| 'conditioning'
| 'model'
| 'array'
| 'item';
| 'item'
| 'color';
/**
* An input field is persisted across reloads as part of the user's local state.
@ -80,7 +84,8 @@ export type InputFieldValue =
| EnumInputFieldValue
| ModelInputFieldValue
| ArrayInputFieldValue
| ItemInputFieldValue;
| ItemInputFieldValue
| ColorInputFieldValue;
/**
* An input field template is generated on each page load from the OpenAPI schema.
@ -99,7 +104,8 @@ export type InputFieldTemplate =
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| ArrayInputFieldTemplate
| ItemInputFieldTemplate;
| ItemInputFieldTemplate
| ColorInputFieldTemplate;
/**
* An output field is persisted across as part of the user's local state.
@ -193,6 +199,11 @@ export type ItemInputFieldValue = FieldValueBase & {
value?: undefined;
};
export type ColorInputFieldValue = FieldValueBase & {
type: 'color';
value?: RgbaColor;
};
export type InputFieldTemplateBase = {
name: string;
title: string;
@ -241,7 +252,7 @@ export type ImageInputFieldTemplate = InputFieldTemplateBase & {
};
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
default: string;
type: 'latents';
};
@ -272,6 +283,11 @@ export type ItemInputFieldTemplate = InputFieldTemplateBase & {
type: 'item';
};
export type ColorInputFieldTemplate = InputFieldTemplateBase & {
default: RgbaColor;
type: 'color';
};
/**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
*/

View File

@ -1,8 +1,30 @@
import { Graph } from 'services/api';
import { v4 as uuidv4 } from 'uuid';
import { reduce } from 'lodash-es';
import { cloneDeep, reduce } from 'lodash-es';
import { RootState } from 'app/store/store';
import { AnyInvocation } from 'services/events/types';
import { InputFieldValue } from '../types/types';
/**
* We need to do special handling for some fields
*/
export const parseFieldValue = (field: InputFieldValue) => {
if (field.type === 'color') {
if (field.value) {
const clonedValue = cloneDeep(field.value);
const { r, g, b, a } = field.value;
// scale alpha value to PIL's desired range 0-255
const scaledAlpha = Math.max(0, Math.min(a * 255, 255));
const transformedColor = { r, g, b, a: scaledAlpha };
Object.assign(clonedValue, transformedColor);
return clonedValue;
}
}
return field.value;
};
/**
* Builds a graph from the node editor state.
@ -20,7 +42,8 @@ export const buildNodesGraph = (state: RootState): Graph => {
const transformedInputs = reduce(
inputs,
(inputsAccumulator, input, name) => {
inputsAccumulator[name] = input.value;
const parsedValue = parseFieldValue(input);
inputsAccumulator[name] = parsedValue;
return inputsAccumulator;
},

View File

@ -12,12 +12,13 @@ import {
ConditioningInputFieldTemplate,
StringInputFieldTemplate,
ModelInputFieldTemplate,
ArrayInputFieldTemplate,
ItemInputFieldTemplate,
ColorInputFieldTemplate,
InputFieldTemplateBase,
OutputFieldTemplate,
TypeHints,
FieldType,
ArrayInputFieldTemplate,
ItemInputFieldTemplate,
} from '../types/types';
export type BaseFieldProperties = 'name' | 'title' | 'description';
@ -262,6 +263,21 @@ const buildItemInputFieldTemplate = ({
return template;
};
const buildColorInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ColorInputFieldTemplate => {
const template: ColorInputFieldTemplate = {
...baseField,
type: 'color',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 },
};
return template;
};
export const getFieldType = (
schemaObject: OpenAPIV3.SchemaObject,
name: string,
@ -341,6 +357,9 @@ export const buildInputFieldTemplate = (
if (['item'].includes(fieldType)) {
return buildItemInputFieldTemplate({ schemaObject, baseField });
}
if (['color'].includes(fieldType)) {
return buildColorInputFieldTemplate({ schemaObject, baseField });
}
return;
};

View File

@ -31,7 +31,9 @@ export type { ImageOutput } from './models/ImageOutput';
export type { ImageResponse } from './models/ImageResponse';
export type { ImageResponseMetadata } from './models/ImageResponseMetadata';
export type { ImageToImageInvocation } from './models/ImageToImageInvocation';
export type { ImageToLatentsInvocation } from './models/ImageToLatentsInvocation';
export type { ImageType } from './models/ImageType';
export type { InfillImageInvocation } from './models/InfillImageInvocation';
export type { InpaintInvocation } from './models/InpaintInvocation';
export type { IntCollectionOutput } from './models/IntCollectionOutput';
export type { IntOutput } from './models/IntOutput';
@ -47,6 +49,7 @@ export type { LerpInvocation } from './models/LerpInvocation';
export type { LoadImageInvocation } from './models/LoadImageInvocation';
export type { MaskFromAlphaInvocation } from './models/MaskFromAlphaInvocation';
export type { MaskOutput } from './models/MaskOutput';
export type { MetadataColorField } from './models/MetadataColorField';
export type { MetadataImageField } from './models/MetadataImageField';
export type { MetadataLatentsField } from './models/MetadataLatentsField';
export type { ModelsList } from './models/ModelsList';
@ -96,7 +99,9 @@ export { $ImageOutput } from './schemas/$ImageOutput';
export { $ImageResponse } from './schemas/$ImageResponse';
export { $ImageResponseMetadata } from './schemas/$ImageResponseMetadata';
export { $ImageToImageInvocation } from './schemas/$ImageToImageInvocation';
export { $ImageToLatentsInvocation } from './schemas/$ImageToLatentsInvocation';
export { $ImageType } from './schemas/$ImageType';
export { $InfillImageInvocation } from './schemas/$InfillImageInvocation';
export { $InpaintInvocation } from './schemas/$InpaintInvocation';
export { $IntCollectionOutput } from './schemas/$IntCollectionOutput';
export { $IntOutput } from './schemas/$IntOutput';
@ -112,6 +117,7 @@ export { $LerpInvocation } from './schemas/$LerpInvocation';
export { $LoadImageInvocation } from './schemas/$LoadImageInvocation';
export { $MaskFromAlphaInvocation } from './schemas/$MaskFromAlphaInvocation';
export { $MaskOutput } from './schemas/$MaskOutput';
export { $MetadataColorField } from './schemas/$MetadataColorField';
export { $MetadataImageField } from './schemas/$MetadataImageField';
export { $MetadataLatentsField } from './schemas/$MetadataLatentsField';
export { $ModelsList } from './schemas/$ModelsList';

View File

@ -7,17 +7,17 @@ export type ColorField = {
* The red component
*/
'r': number;
/**
* The blue component
*/
'b': number;
/**
* The green component
*/
'g': number;
/**
* The blue component
*/
'b': number;
/**
* The alpha component
*/
'a'?: number;
'a': number;
};

View File

@ -12,6 +12,8 @@ import type { DivideInvocation } from './DivideInvocation';
import type { Edge } from './Edge';
import type { GraphInvocation } from './GraphInvocation';
import type { ImageToImageInvocation } from './ImageToImageInvocation';
import type { ImageToLatentsInvocation } from './ImageToLatentsInvocation';
import type { InfillImageInvocation } from './InfillImageInvocation';
import type { InpaintInvocation } from './InpaintInvocation';
import type { InverseLerpInvocation } from './InverseLerpInvocation';
import type { IterateInvocation } from './IterateInvocation';
@ -43,7 +45,7 @@ export type Graph = {
/**
* The nodes in this graph
*/
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | DataURLToImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | ParamIntInvocation | CvInpaintInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
nodes?: Record<string, (LoadImageInvocation | ShowImageInvocation | DataURLToImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | ParamIntInvocation | CvInpaintInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation)>;
/**
* The connections between nodes and their fields in this graph
*/

View File

@ -0,0 +1,25 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ImageField } from './ImageField';
/**
* Encodes an image into latents.
*/
export type ImageToLatentsInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'i2l';
/**
* The image to encode
*/
image?: ImageField;
/**
* The model to use
*/
model?: string;
};

View File

@ -0,0 +1,38 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
import type { ColorField } from './ColorField';
import type { ImageField } from './ImageField';
/**
* Infills transparent areas of an image
*/
export type InfillImageInvocation = {
/**
* The id of this node. Must be unique among all nodes.
*/
id: string;
type?: 'infill';
/**
* The image to infill
*/
image?: ImageField;
/**
* The method used to infill empty regions (px)
*/
infill_method?: 'patchmatch' | 'tile' | 'solid';
/**
* The solid infill method color
*/
inpaint_fill?: ColorField;
/**
* The tile infill method size (px)
*/
tile_size?: number;
/**
* The seed to use (-1 for a random seed)
*/
seed?: number;
};

View File

@ -2,11 +2,12 @@
/* tslint:disable */
/* eslint-disable */
import type { MetadataColorField } from './MetadataColorField';
import type { MetadataImageField } from './MetadataImageField';
import type { MetadataLatentsField } from './MetadataLatentsField';
export type InvokeAIMetadata = {
session_id?: string;
node?: Record<string, (string | number | boolean | MetadataImageField | MetadataLatentsField)>;
node?: Record<string, (string | number | boolean | MetadataImageField | MetadataLatentsField | MetadataColorField)>;
};

View File

@ -0,0 +1,11 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type MetadataColorField = {
'r': number;
'g': number;
'b': number;
'a': number;
};

View File

@ -9,21 +9,22 @@ export const $ColorField = {
isRequired: true,
maximum: 255,
},
'b': {
type: 'number',
description: `The blue component`,
isRequired: true,
maximum: 255,
},
'g': {
type: 'number',
description: `The green component`,
isRequired: true,
maximum: 255,
},
'b': {
type: 'number',
description: `The blue component`,
isRequired: true,
maximum: 255,
},
'a': {
type: 'number',
description: `The alpha component`,
isRequired: true,
maximum: 255,
},
},

View File

@ -39,6 +39,8 @@ export const $Graph = {
type: 'ResizeLatentsInvocation',
}, {
type: 'ScaleLatentsInvocation',
}, {
type: 'ImageToLatentsInvocation',
}, {
type: 'AddInvocation',
}, {
@ -61,6 +63,8 @@ export const $Graph = {
type: 'RestoreFaceInvocation',
}, {
type: 'TextToImageInvocation',
}, {
type: 'InfillImageInvocation',
}, {
type: 'GraphInvocation',
}, {

View File

@ -0,0 +1,27 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $ImageToLatentsInvocation = {
description: `Encodes an image into latents.`,
properties: {
id: {
type: 'string',
description: `The id of this node. Must be unique among all nodes.`,
isRequired: true,
},
type: {
type: 'Enum',
},
image: {
type: 'all-of',
description: `The image to encode`,
contains: [{
type: 'ImageField',
}],
},
model: {
type: 'string',
description: `The model to use`,
},
},
} as const;

View File

@ -0,0 +1,44 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $InfillImageInvocation = {
description: `Infills transparent areas of an image`,
properties: {
id: {
type: 'string',
description: `The id of this node. Must be unique among all nodes.`,
isRequired: true,
},
type: {
type: 'Enum',
},
image: {
type: 'all-of',
description: `The image to infill`,
contains: [{
type: 'ImageField',
}],
},
infill_method: {
type: 'Enum',
},
inpaint_fill: {
type: 'all-of',
description: `The solid infill method color`,
contains: [{
type: 'ColorField',
}],
},
tile_size: {
type: 'number',
description: `The tile infill method size (px)`,
minimum: 1,
},
seed: {
type: 'number',
description: `The seed to use (-1 for a random seed)`,
maximum: 4294967295,
minimum: -1,
},
},
} as const;

View File

@ -22,6 +22,8 @@ export const $InvokeAIMetadata = {
type: 'MetadataImageField',
}, {
type: 'MetadataLatentsField',
}, {
type: 'MetadataColorField',
}],
},
},

View File

@ -0,0 +1,23 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export const $MetadataColorField = {
properties: {
'r': {
type: 'number',
isRequired: true,
},
'g': {
type: 'number',
isRequired: true,
},
'b': {
type: 'number',
isRequired: true,
},
'a': {
type: 'number',
isRequired: true,
},
},
} as const;

View File

@ -13,6 +13,8 @@ import type { Graph } from '../models/Graph';
import type { GraphExecutionState } from '../models/GraphExecutionState';
import type { GraphInvocation } from '../models/GraphInvocation';
import type { ImageToImageInvocation } from '../models/ImageToImageInvocation';
import type { ImageToLatentsInvocation } from '../models/ImageToLatentsInvocation';
import type { InfillImageInvocation } from '../models/InfillImageInvocation';
import type { InpaintInvocation } from '../models/InpaintInvocation';
import type { InverseLerpInvocation } from '../models/InverseLerpInvocation';
import type { IterateInvocation } from '../models/IterateInvocation';
@ -145,7 +147,7 @@ export class SessionsService {
* The id of the session
*/
sessionId: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | DataURLToImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | ParamIntInvocation | CvInpaintInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
requestBody: (LoadImageInvocation | ShowImageInvocation | DataURLToImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | ParamIntInvocation | CvInpaintInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<string> {
return __request(OpenAPI, {
method: 'POST',
@ -182,7 +184,7 @@ export class SessionsService {
* The path to the node in the graph
*/
nodePath: string,
requestBody: (LoadImageInvocation | ShowImageInvocation | DataURLToImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | ParamIntInvocation | CvInpaintInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
requestBody: (LoadImageInvocation | ShowImageInvocation | DataURLToImageInvocation | CropImageInvocation | PasteImageInvocation | MaskFromAlphaInvocation | BlurInvocation | LerpInvocation | InverseLerpInvocation | NoiseInvocation | TextToLatentsInvocation | LatentsToImageInvocation | ResizeLatentsInvocation | ScaleLatentsInvocation | ImageToLatentsInvocation | AddInvocation | SubtractInvocation | MultiplyInvocation | DivideInvocation | ParamIntInvocation | CvInpaintInvocation | RangeInvocation | RandomRangeInvocation | UpscaleInvocation | RestoreFaceInvocation | TextToImageInvocation | InfillImageInvocation | GraphInvocation | IterateInvocation | CollectInvocation | LatentsToLatentsInvocation | ImageToImageInvocation | InpaintInvocation),
}): CancelablePromise<GraphExecutionState> {
return __request(OpenAPI, {
method: 'PUT',