mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into maryhipp/informational-popover
This commit is contained in:
@ -1,8 +1,9 @@
|
||||
import { CoreMetadata } from 'features/nodes/types/types';
|
||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||
import { memo, useCallback } from 'react';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { isValidLoRAModel } from '../../../parameters/types/parameterSchemas';
|
||||
import ImageMetadataItem from './ImageMetadataItem';
|
||||
|
||||
type Props = {
|
||||
metadata?: CoreMetadata;
|
||||
@ -24,6 +25,7 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallWidth,
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallLoRA,
|
||||
} = useRecallParameters();
|
||||
|
||||
const handleRecallPositivePrompt = useCallback(() => {
|
||||
@ -66,6 +68,13 @@ const ImageMetadataActions = (props: Props) => {
|
||||
recallStrength(metadata?.strength);
|
||||
}, [metadata?.strength, recallStrength]);
|
||||
|
||||
const handleRecallLoRA = useCallback(
|
||||
(lora: LoRAMetadataItem) => {
|
||||
recallLoRA(lora);
|
||||
},
|
||||
[recallLoRA]
|
||||
);
|
||||
|
||||
if (!metadata || Object.keys(metadata).length === 0) {
|
||||
return null;
|
||||
}
|
||||
@ -130,20 +139,6 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallHeight}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.threshold !== undefined && (
|
||||
<MetadataItem
|
||||
label={t('metadata.threshold')}
|
||||
value={metadata.threshold}
|
||||
onClick={() => dispatch(setThreshold(Number(metadata.threshold)))}
|
||||
/>
|
||||
)}
|
||||
{metadata.perlin !== undefined && (
|
||||
<MetadataItem
|
||||
label={t('metadata.perlin')}
|
||||
value={metadata.perlin}
|
||||
onClick={() => dispatch(setPerlin(Number(metadata.perlin)))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.scheduler && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.scheduler')}
|
||||
@ -165,40 +160,6 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallCfgScale}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.variations && metadata.variations.length > 0 && (
|
||||
<MetadataItem
|
||||
label="{t('metadata.variations')}
|
||||
value={seedWeightsToString(metadata.variations)}
|
||||
onClick={() =>
|
||||
dispatch(
|
||||
setSeedWeights(seedWeightsToString(metadata.variations))
|
||||
)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{metadata.seamless && (
|
||||
<MetadataItem
|
||||
label={t('metadata.seamless')}
|
||||
value={metadata.seamless}
|
||||
onClick={() => dispatch(setSeamless(metadata.seamless))}
|
||||
/>
|
||||
)}
|
||||
{metadata.hires_fix && (
|
||||
<MetadataItem
|
||||
label={t('metadata.hiresFix')}
|
||||
value={metadata.hires_fix}
|
||||
onClick={() => dispatch(setHiresFix(metadata.hires_fix))}
|
||||
/>
|
||||
)} */}
|
||||
|
||||
{/* {init_image_path && (
|
||||
<MetadataItem
|
||||
label={t('metadata.initImage')}
|
||||
value={init_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.strength && (
|
||||
<ImageMetadataItem
|
||||
label={t('metadata.strength')}
|
||||
@ -206,13 +167,19 @@ const ImageMetadataActions = (props: Props) => {
|
||||
onClick={handleRecallStrength}
|
||||
/>
|
||||
)}
|
||||
{/* {metadata.fit && (
|
||||
<MetadataItem
|
||||
label={t('metadata.fit')}
|
||||
value={metadata.fit}
|
||||
onClick={() => dispatch(setShouldFitToWidthHeight(metadata.fit))}
|
||||
/>
|
||||
)} */}
|
||||
{metadata.loras &&
|
||||
metadata.loras.map((lora, index) => {
|
||||
if (isValidLoRAModel(lora.lora)) {
|
||||
return (
|
||||
<ImageMetadataItem
|
||||
key={index}
|
||||
label="LoRA"
|
||||
value={`${lora.lora.model_name} - ${lora.weight}`}
|
||||
onClick={() => handleRecallLoRA(lora)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
@ -27,6 +27,13 @@ export const loraSlice = createSlice({
|
||||
const { model_name, id, base_model } = action.payload;
|
||||
state.loras[id] = { id, model_name, base_model, ...defaultLoRAConfig };
|
||||
},
|
||||
loraRecalled: (
|
||||
state,
|
||||
action: PayloadAction<LoRAModelConfigEntity & { weight: number }>
|
||||
) => {
|
||||
const { model_name, id, base_model, weight } = action.payload;
|
||||
state.loras[id] = { id, model_name, base_model, weight };
|
||||
},
|
||||
loraRemoved: (state, action: PayloadAction<string>) => {
|
||||
const id = action.payload;
|
||||
delete state.loras[id];
|
||||
@ -62,6 +69,7 @@ export const {
|
||||
loraWeightChanged,
|
||||
loraWeightReset,
|
||||
lorasCleared,
|
||||
loraRecalled,
|
||||
} = loraSlice.actions;
|
||||
|
||||
export default loraSlice.reducer;
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { FieldType, FieldUIConfig } from './types';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
|
||||
export const COLOR_TOKEN_VALUE = 500;
|
||||
@ -102,73 +103,73 @@ export const isPolymorphicItemType = (
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
boolean: {
|
||||
color: 'green.500',
|
||||
description: 'Booleans are true or false.',
|
||||
title: 'Boolean',
|
||||
description: t('nodes.booleanDescription'),
|
||||
title: t('nodes.boolean'),
|
||||
},
|
||||
BooleanCollection: {
|
||||
color: 'green.500',
|
||||
description: 'A collection of booleans.',
|
||||
title: 'Boolean Collection',
|
||||
description: t('nodes.booleanCollectionDescription'),
|
||||
title: t('nodes.booleanCollection'),
|
||||
},
|
||||
BooleanPolymorphic: {
|
||||
color: 'green.500',
|
||||
description: 'A collection of booleans.',
|
||||
title: 'Boolean Polymorphic',
|
||||
description: t('nodes.booleanPolymorphicDescription'),
|
||||
title: t('nodes.booleanPolymorphic'),
|
||||
},
|
||||
ClipField: {
|
||||
color: 'green.500',
|
||||
description: 'Tokenizer and text_encoder submodels.',
|
||||
title: 'Clip',
|
||||
description: t('nodes.clipFieldDescription'),
|
||||
title: t('nodes.clipField'),
|
||||
},
|
||||
Collection: {
|
||||
color: 'base.500',
|
||||
description: 'TODO',
|
||||
title: 'Collection',
|
||||
description: t('nodes.collectionDescription'),
|
||||
title: t('nodes.collection'),
|
||||
},
|
||||
CollectionItem: {
|
||||
color: 'base.500',
|
||||
description: 'TODO',
|
||||
title: 'Collection Item',
|
||||
description: t('nodes.collectionItemDescription'),
|
||||
title: t('nodes.collectionItem'),
|
||||
},
|
||||
ColorCollection: {
|
||||
color: 'pink.300',
|
||||
description: 'A collection of colors.',
|
||||
title: 'Color Collection',
|
||||
description: t('nodes.colorCollectionDescription'),
|
||||
title: t('nodes.colorCollection'),
|
||||
},
|
||||
ColorField: {
|
||||
color: 'pink.300',
|
||||
description: 'A RGBA color.',
|
||||
title: 'Color',
|
||||
description: t('nodes.colorFieldDescription'),
|
||||
title: t('nodes.colorField'),
|
||||
},
|
||||
ColorPolymorphic: {
|
||||
color: 'pink.300',
|
||||
description: 'A collection of colors.',
|
||||
title: 'Color Polymorphic',
|
||||
description: t('nodes.colorPolymorphicDescription'),
|
||||
title: t('nodes.colorPolymorphic'),
|
||||
},
|
||||
ConditioningCollection: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning Collection',
|
||||
description: t('nodes.conditioningCollectionDescription'),
|
||||
title: t('nodes.conditioningCollection'),
|
||||
},
|
||||
ConditioningField: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning',
|
||||
description: t('nodes.conditioningFieldDescription'),
|
||||
title: t('nodes.conditioningField'),
|
||||
},
|
||||
ConditioningPolymorphic: {
|
||||
color: 'cyan.500',
|
||||
description: 'Conditioning may be passed between nodes.',
|
||||
title: 'Conditioning Polymorphic',
|
||||
description: t('nodes.conditioningPolymorphicDescription'),
|
||||
title: t('nodes.conditioningPolymorphic'),
|
||||
},
|
||||
ControlCollection: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control Collection',
|
||||
description: t('nodes.controlCollectionDescription'),
|
||||
title: t('nodes.controlCollection'),
|
||||
},
|
||||
ControlField: {
|
||||
color: 'teal.500',
|
||||
description: 'Control info passed between nodes.',
|
||||
title: 'Control',
|
||||
description: t('nodes.controlFieldDescription'),
|
||||
title: t('nodes.controlField'),
|
||||
},
|
||||
ControlNetModelField: {
|
||||
color: 'teal.500',
|
||||
@ -182,132 +183,132 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
},
|
||||
DenoiseMaskField: {
|
||||
color: 'blue.300',
|
||||
description: 'Denoise Mask may be passed between nodes',
|
||||
title: 'Denoise Mask',
|
||||
description: t('nodes.denoiseMaskFieldDescription'),
|
||||
title: t('nodes.denoiseMaskField'),
|
||||
},
|
||||
enum: {
|
||||
color: 'blue.500',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
title: 'Enum',
|
||||
description: t('nodes.enumDescription'),
|
||||
title: t('nodes.enum'),
|
||||
},
|
||||
float: {
|
||||
color: 'orange.500',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
title: 'Float',
|
||||
description: t('nodes.floatDescription'),
|
||||
title: t('nodes.float'),
|
||||
},
|
||||
FloatCollection: {
|
||||
color: 'orange.500',
|
||||
description: 'A collection of floats.',
|
||||
title: 'Float Collection',
|
||||
description: t('nodes.floatCollectionDescription'),
|
||||
title: t('nodes.floatCollection'),
|
||||
},
|
||||
FloatPolymorphic: {
|
||||
color: 'orange.500',
|
||||
description: 'A collection of floats.',
|
||||
title: 'Float Polymorphic',
|
||||
description: t('nodes.floatPolymorphicDescription'),
|
||||
title: t('nodes.floatPolymorphic'),
|
||||
},
|
||||
ImageCollection: {
|
||||
color: 'purple.500',
|
||||
description: 'A collection of images.',
|
||||
title: 'Image Collection',
|
||||
description: t('nodes.imageCollectionDescription'),
|
||||
title: t('nodes.imageCollection'),
|
||||
},
|
||||
ImageField: {
|
||||
color: 'purple.500',
|
||||
description: 'Images may be passed between nodes.',
|
||||
title: 'Image',
|
||||
description: t('nodes.imageFieldDescription'),
|
||||
title: t('nodes.imageField'),
|
||||
},
|
||||
ImagePolymorphic: {
|
||||
color: 'purple.500',
|
||||
description: 'A collection of images.',
|
||||
title: 'Image Polymorphic',
|
||||
description: t('nodes.imagePolymorphicDescription'),
|
||||
title: t('nodes.imagePolymorphic'),
|
||||
},
|
||||
integer: {
|
||||
color: 'red.500',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
title: 'Integer',
|
||||
description: t('nodes.integerDescription'),
|
||||
title: t('nodes.integer'),
|
||||
},
|
||||
IntegerCollection: {
|
||||
color: 'red.500',
|
||||
description: 'A collection of integers.',
|
||||
title: 'Integer Collection',
|
||||
description: t('nodes.integerCollectionDescription'),
|
||||
title: t('nodes.integerCollection'),
|
||||
},
|
||||
IntegerPolymorphic: {
|
||||
color: 'red.500',
|
||||
description: 'A collection of integers.',
|
||||
title: 'Integer Polymorphic',
|
||||
description: t('nodes.integerPolymorphicDescription'),
|
||||
title: t('nodes.integerPolymorphic'),
|
||||
},
|
||||
LatentsCollection: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents Collection',
|
||||
description: t('nodes.latentsCollectionDescription'),
|
||||
title: t('nodes.latentsCollection'),
|
||||
},
|
||||
LatentsField: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents',
|
||||
description: t('nodes.latentsFieldDescription'),
|
||||
title: t('nodes.latentsField'),
|
||||
},
|
||||
LatentsPolymorphic: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
title: 'Latents Polymorphic',
|
||||
description: t('nodes.latentsPolymorphicDescription'),
|
||||
title: t('nodes.latentsPolymorphic'),
|
||||
},
|
||||
LoRAModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'LoRA',
|
||||
description: t('nodes.loRAModelFieldDescription'),
|
||||
title: t('nodes.loRAModelField'),
|
||||
},
|
||||
MainModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'Model',
|
||||
description: t('nodes.mainModelFieldDescription'),
|
||||
title: t('nodes.mainModelField'),
|
||||
},
|
||||
ONNXModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'ONNX model field.',
|
||||
title: 'ONNX Model',
|
||||
description: t('nodes.oNNXModelFieldDescription'),
|
||||
title: t('nodes.oNNXModelField'),
|
||||
},
|
||||
Scheduler: {
|
||||
color: 'base.500',
|
||||
description: 'TODO',
|
||||
title: 'Scheduler',
|
||||
description: t('nodes.schedulerDescription'),
|
||||
title: t('nodes.scheduler'),
|
||||
},
|
||||
SDXLMainModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'SDXL model field.',
|
||||
title: 'SDXL Model',
|
||||
description: t('nodes.sDXLMainModelFieldDescription'),
|
||||
title: t('nodes.sDXLMainModelField'),
|
||||
},
|
||||
SDXLRefinerModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'Refiner Model',
|
||||
description: t('nodes.sDXLRefinerModelFieldDescription'),
|
||||
title: t('nodes.sDXLRefinerModelField'),
|
||||
},
|
||||
string: {
|
||||
color: 'yellow.500',
|
||||
description: 'Strings are text.',
|
||||
title: 'String',
|
||||
description: t('nodes.stringDescription'),
|
||||
title: t('nodes.string'),
|
||||
},
|
||||
StringCollection: {
|
||||
color: 'yellow.500',
|
||||
description: 'A collection of strings.',
|
||||
title: 'String Collection',
|
||||
description: t('nodes.stringCollectionDescription'),
|
||||
title: t('nodes.stringCollection'),
|
||||
},
|
||||
StringPolymorphic: {
|
||||
color: 'yellow.500',
|
||||
description: 'A collection of strings.',
|
||||
title: 'String Polymorphic',
|
||||
description: t('nodes.stringPolymorphicDescription'),
|
||||
title: t('nodes.stringPolymorphic'),
|
||||
},
|
||||
UNetField: {
|
||||
color: 'red.500',
|
||||
description: 'UNet submodel.',
|
||||
title: 'UNet',
|
||||
description: t('nodes.uNetFieldDescription'),
|
||||
title: t('nodes.uNetField'),
|
||||
},
|
||||
VaeField: {
|
||||
color: 'blue.500',
|
||||
description: 'Vae submodel.',
|
||||
title: 'Vae',
|
||||
description: t('nodes.vaeFieldDescription'),
|
||||
title: t('nodes.vaeField'),
|
||||
},
|
||||
VaeModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'TODO',
|
||||
title: 'VAE',
|
||||
description: t('nodes.vaeModelFieldDescription'),
|
||||
title: t('nodes.vaeModelField'),
|
||||
},
|
||||
};
|
||||
|
@ -20,6 +20,7 @@ import {
|
||||
import { O } from 'ts-toolbelt';
|
||||
import { JsonObject } from 'type-fest';
|
||||
import { z } from 'zod';
|
||||
import i18n from 'i18next';
|
||||
|
||||
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
|
||||
|
||||
@ -1056,6 +1057,13 @@ export const isInvocationFieldSchema = (
|
||||
|
||||
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
|
||||
|
||||
const zLoRAMetadataItem = z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
});
|
||||
|
||||
export type LoRAMetadataItem = z.infer<typeof zLoRAMetadataItem>;
|
||||
|
||||
export const zCoreMetadata = z
|
||||
.object({
|
||||
app_version: z.string().nullish(),
|
||||
@ -1075,14 +1083,7 @@ export const zCoreMetadata = z
|
||||
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||
.nullish(),
|
||||
controlnets: z.array(zControlField.deepPartial()).nullish(),
|
||||
loras: z
|
||||
.array(
|
||||
z.object({
|
||||
lora: zLoRAModelField.deepPartial(),
|
||||
weight: z.number(),
|
||||
})
|
||||
)
|
||||
.nullish(),
|
||||
loras: z.array(zLoRAMetadataItem).nullish(),
|
||||
vae: zVaeModelField.nullish(),
|
||||
strength: z.number().nullish(),
|
||||
init_image: z.string().nullish(),
|
||||
@ -1258,23 +1259,35 @@ export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
|
||||
const targetNode = keyedNodes[edge.target];
|
||||
const issues: string[] = [];
|
||||
if (!sourceNode) {
|
||||
issues.push(`Output node ${edge.source} does not exist`);
|
||||
issues.push(
|
||||
`${i18n.t('nodes.outputNode')} ${edge.source} ${i18n.t(
|
||||
'nodes.doesNotExist'
|
||||
)}`
|
||||
);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.sourceHandle in sourceNode.data.outputs)
|
||||
) {
|
||||
issues.push(
|
||||
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
|
||||
`${i18n.t('nodes.outputField')}"${edge.source}.${
|
||||
edge.sourceHandle
|
||||
}" ${i18n.t('nodes.doesNotExist')}`
|
||||
);
|
||||
}
|
||||
if (!targetNode) {
|
||||
issues.push(`Input node ${edge.target} does not exist`);
|
||||
issues.push(
|
||||
`${i18n.t('nodes.inputNode')} ${edge.target} ${i18n.t(
|
||||
'nodes.doesNotExist'
|
||||
)}`
|
||||
);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.targetHandle in targetNode.data.inputs)
|
||||
) {
|
||||
issues.push(
|
||||
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
|
||||
`${i18n.t('nodes.inputField')} "${edge.target}.${
|
||||
edge.targetHandle
|
||||
}" ${i18n.t('nodes.doesNotExist')}`
|
||||
);
|
||||
}
|
||||
if (issues.length) {
|
||||
@ -1282,7 +1295,9 @@ export const zValidatedWorkflow = zWorkflow.transform((workflow) => {
|
||||
const src = edge.type === 'default' ? edge.sourceHandle : edge.source;
|
||||
const tgt = edge.type === 'default' ? edge.targetHandle : edge.target;
|
||||
warnings.push({
|
||||
message: `Edge "${src} -> ${tgt}" skipped`,
|
||||
message: `${i18n.t('nodes.edge')} "${src} -> ${tgt}" ${i18n.t(
|
||||
'nodes.skipped'
|
||||
)}`,
|
||||
issues,
|
||||
data: edge,
|
||||
});
|
||||
|
@ -3,6 +3,7 @@ import { NodesState } from '../store/types';
|
||||
import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types';
|
||||
import { fromZodError } from 'zod-validation-error';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import i18n from 'i18next';
|
||||
|
||||
export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
||||
const { workflow: workflowMeta, nodes, edges } = nodesState;
|
||||
@ -20,7 +21,7 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
||||
const result = zWorkflowNode.safeParse(node);
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
prefix: 'Unable to parse node',
|
||||
prefix: i18n.t('nodes.unableToParseNode'),
|
||||
});
|
||||
logger('nodes').warn({ node: parseify(node) }, message);
|
||||
return;
|
||||
@ -32,7 +33,7 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
|
||||
const result = zWorkflowEdge.safeParse(edge);
|
||||
if (!result.success) {
|
||||
const { message } = fromZodError(result.error, {
|
||||
prefix: 'Unable to parse edge',
|
||||
prefix: i18n.t('nodes.unableToParseEdge'),
|
||||
});
|
||||
logger('nodes').warn({ edge: parseify(edge) }, message);
|
||||
return;
|
||||
|
@ -79,8 +79,8 @@ export const buildCanvasInpaintGraph = (
|
||||
} = state.generation;
|
||||
|
||||
if (!model) {
|
||||
log.error('No model found in state');
|
||||
throw new Error('No model found in state');
|
||||
log.error('No Image found in state');
|
||||
throw new Error('No Image found in state');
|
||||
}
|
||||
|
||||
// The bounding box determines width and height, not the width and height params
|
||||
|
@ -7,6 +7,7 @@ import {
|
||||
isWorkflowInvocationNode,
|
||||
} from '../types/types';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import i18n from 'i18next';
|
||||
|
||||
export const validateWorkflow = (
|
||||
workflow: Workflow,
|
||||
@ -25,8 +26,14 @@ export const validateWorkflow = (
|
||||
const nodeTemplate = nodeTemplates[node.data.type];
|
||||
if (!nodeTemplate) {
|
||||
errors.push({
|
||||
message: `Node "${node.data.type}" skipped`,
|
||||
issues: [`Node type "${node.data.type}" does not exist`],
|
||||
message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t(
|
||||
'nodes.skipped'
|
||||
)}`,
|
||||
issues: [
|
||||
`${i18n.t('nodes.nodeType')}"${node.data.type}" ${i18n.t(
|
||||
'nodes.doesNotExist'
|
||||
)}`,
|
||||
],
|
||||
data: node,
|
||||
});
|
||||
return;
|
||||
@ -38,9 +45,13 @@ export const validateWorkflow = (
|
||||
compareVersions(nodeTemplate.version, node.data.version) !== 0
|
||||
) {
|
||||
errors.push({
|
||||
message: `Node "${node.data.type}" has mismatched version`,
|
||||
message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t(
|
||||
'nodes.mismatchedVersion'
|
||||
)}`,
|
||||
issues: [
|
||||
`Node "${node.data.type}" v${node.data.version} may be incompatible with installed v${nodeTemplate.version}`,
|
||||
`${i18n.t('nodes.node')} "${node.data.type}" v${
|
||||
node.data.version
|
||||
} ${i18n.t('nodes.maybeIncompatible')} v${nodeTemplate.version}`,
|
||||
],
|
||||
data: { node, nodeTemplate: parseify(nodeTemplate) },
|
||||
});
|
||||
@ -52,33 +63,49 @@ export const validateWorkflow = (
|
||||
const targetNode = keyedNodes[edge.target];
|
||||
const issues: string[] = [];
|
||||
if (!sourceNode) {
|
||||
issues.push(`Output node ${edge.source} does not exist`);
|
||||
issues.push(
|
||||
`${i18n.t('nodes.outputNode')} ${edge.source} ${i18n.t(
|
||||
'nodes.doesNotExist'
|
||||
)}`
|
||||
);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.sourceHandle in sourceNode.data.outputs)
|
||||
) {
|
||||
issues.push(
|
||||
`Output field "${edge.source}.${edge.sourceHandle}" does not exist`
|
||||
`${i18n.t('nodes.outputNodes')} "${edge.source}.${
|
||||
edge.sourceHandle
|
||||
}" ${i18n.t('nodes.doesNotExist')}`
|
||||
);
|
||||
}
|
||||
if (!targetNode) {
|
||||
issues.push(`Input node ${edge.target} does not exist`);
|
||||
issues.push(
|
||||
`${i18n.t('nodes.inputNode')} ${edge.target} ${i18n.t(
|
||||
'nodes.doesNotExist'
|
||||
)}`
|
||||
);
|
||||
} else if (
|
||||
edge.type === 'default' &&
|
||||
!(edge.targetHandle in targetNode.data.inputs)
|
||||
) {
|
||||
issues.push(
|
||||
`Input field "${edge.target}.${edge.targetHandle}" does not exist`
|
||||
`${i18n.t('nodes.inputFeilds')} "${edge.target}.${
|
||||
edge.targetHandle
|
||||
}" ${i18n.t('nodes.doesNotExist')}`
|
||||
);
|
||||
}
|
||||
if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||
issues.push(
|
||||
`Source node "${edge.source}" missing template "${sourceNode?.data.type}"`
|
||||
`${i18n.t('nodes.sourceNode')} "${edge.source}" ${i18n.t(
|
||||
'nodes.missingTemplate'
|
||||
)} "${sourceNode?.data.type}"`
|
||||
);
|
||||
}
|
||||
if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) {
|
||||
issues.push(
|
||||
`Source node "${edge.target}" missing template "${targetNode?.data.type}"`
|
||||
`${i18n.t('nodes.sourceNode')}"${edge.target}" ${i18n.t(
|
||||
'nodes.missingTemplate'
|
||||
)} "${targetNode?.data.type}"`
|
||||
);
|
||||
}
|
||||
if (issues.length) {
|
||||
|
@ -1,6 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { CoreMetadata } from 'features/nodes/types/types';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { CoreMetadata, LoRAMetadataItem } from 'features/nodes/types/types';
|
||||
import {
|
||||
refinerModelChanged,
|
||||
setNegativeStylePromptSDXL,
|
||||
@ -15,6 +17,11 @@ import {
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import {
|
||||
loraModelsAdapter,
|
||||
useGetLoRAModelsQuery,
|
||||
} from '../../../services/api/endpoints/models';
|
||||
import { loraRecalled } from '../../lora/store/loraSlice';
|
||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||
import {
|
||||
setCfgScale,
|
||||
@ -30,6 +37,7 @@ import {
|
||||
import {
|
||||
isValidCfgScale,
|
||||
isValidHeight,
|
||||
isValidLoRAModel,
|
||||
isValidMainModel,
|
||||
isValidNegativePrompt,
|
||||
isValidPositivePrompt,
|
||||
@ -46,10 +54,16 @@ import {
|
||||
isValidWidth,
|
||||
} from '../types/parameterSchemas';
|
||||
|
||||
const selector = createSelector(stateSelector, ({ generation }) => {
|
||||
const { model } = generation;
|
||||
return { model };
|
||||
});
|
||||
|
||||
export const useRecallParameters = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const toaster = useAppToaster();
|
||||
const { t } = useTranslation();
|
||||
const { model } = useAppSelector(selector);
|
||||
|
||||
const parameterSetToast = useCallback(() => {
|
||||
toaster({
|
||||
@ -60,14 +74,18 @@ export const useRecallParameters = () => {
|
||||
});
|
||||
}, [t, toaster]);
|
||||
|
||||
const parameterNotSetToast = useCallback(() => {
|
||||
toaster({
|
||||
title: t('toast.parameterNotSet'),
|
||||
status: 'warning',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
}, [t, toaster]);
|
||||
const parameterNotSetToast = useCallback(
|
||||
(description?: string) => {
|
||||
toaster({
|
||||
title: t('toast.parameterNotSet'),
|
||||
description,
|
||||
status: 'warning',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
},
|
||||
[t, toaster]
|
||||
);
|
||||
|
||||
const allParameterSetToast = useCallback(() => {
|
||||
toaster({
|
||||
@ -78,14 +96,18 @@ export const useRecallParameters = () => {
|
||||
});
|
||||
}, [t, toaster]);
|
||||
|
||||
const allParameterNotSetToast = useCallback(() => {
|
||||
toaster({
|
||||
title: t('toast.parametersNotSet'),
|
||||
status: 'warning',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
}, [t, toaster]);
|
||||
const allParameterNotSetToast = useCallback(
|
||||
(description?: string) => {
|
||||
toaster({
|
||||
title: t('toast.parametersNotSet'),
|
||||
status: 'warning',
|
||||
description,
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
},
|
||||
[t, toaster]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall both prompts with toast
|
||||
@ -307,6 +329,67 @@ export const useRecallParameters = () => {
|
||||
[dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/**
|
||||
* Recall LoRA with toast
|
||||
*/
|
||||
|
||||
const { loras } = useGetLoRAModelsQuery(undefined, {
|
||||
selectFromResult: (result) => ({
|
||||
loras: result.data
|
||||
? loraModelsAdapter.getSelectors().selectAll(result.data)
|
||||
: [],
|
||||
}),
|
||||
});
|
||||
|
||||
const prepareLoRAMetadataItem = useCallback(
|
||||
(loraMetadataItem: LoRAMetadataItem) => {
|
||||
if (!isValidLoRAModel(loraMetadataItem.lora)) {
|
||||
return { lora: null, error: 'Invalid LoRA model' };
|
||||
}
|
||||
|
||||
const { base_model, model_name } = loraMetadataItem.lora;
|
||||
|
||||
const matchingLoRA = loras.find(
|
||||
(l) => l.base_model === base_model && l.model_name === model_name
|
||||
);
|
||||
|
||||
if (!matchingLoRA) {
|
||||
return { lora: null, error: 'LoRA model is not installed' };
|
||||
}
|
||||
|
||||
const isCompatibleBaseModel =
|
||||
matchingLoRA?.base_model === model?.base_model;
|
||||
|
||||
if (!isCompatibleBaseModel) {
|
||||
return {
|
||||
lora: null,
|
||||
error: 'LoRA incompatible with currently-selected model',
|
||||
};
|
||||
}
|
||||
|
||||
return { lora: matchingLoRA, error: null };
|
||||
},
|
||||
[loras, model?.base_model]
|
||||
);
|
||||
|
||||
const recallLoRA = useCallback(
|
||||
(loraMetadataItem: LoRAMetadataItem) => {
|
||||
const result = prepareLoRAMetadataItem(loraMetadataItem);
|
||||
|
||||
if (!result.lora) {
|
||||
parameterNotSetToast(result.error);
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
loraRecalled({ ...result.lora, weight: loraMetadataItem.weight })
|
||||
);
|
||||
|
||||
parameterSetToast();
|
||||
},
|
||||
[prepareLoRAMetadataItem, dispatch, parameterSetToast, parameterNotSetToast]
|
||||
);
|
||||
|
||||
/*
|
||||
* Sets image as initial image with toast
|
||||
*/
|
||||
@ -344,6 +427,7 @@ export const useRecallParameters = () => {
|
||||
refiner_positive_aesthetic_score,
|
||||
refiner_negative_aesthetic_score,
|
||||
refiner_start,
|
||||
loras,
|
||||
} = metadata;
|
||||
|
||||
if (isValidCfgScale(cfg_scale)) {
|
||||
@ -425,9 +509,21 @@ export const useRecallParameters = () => {
|
||||
dispatch(setRefinerStart(refiner_start));
|
||||
}
|
||||
|
||||
loras?.forEach((lora) => {
|
||||
const result = prepareLoRAMetadataItem(lora);
|
||||
if (result.lora) {
|
||||
dispatch(loraRecalled({ ...result.lora, weight: lora.weight }));
|
||||
}
|
||||
});
|
||||
|
||||
allParameterSetToast();
|
||||
},
|
||||
[allParameterNotSetToast, allParameterSetToast, dispatch]
|
||||
[
|
||||
allParameterNotSetToast,
|
||||
allParameterSetToast,
|
||||
dispatch,
|
||||
prepareLoRAMetadataItem,
|
||||
]
|
||||
);
|
||||
|
||||
return {
|
||||
@ -444,6 +540,7 @@ export const useRecallParameters = () => {
|
||||
recallWidth,
|
||||
recallHeight,
|
||||
recallStrength,
|
||||
recallLoRA,
|
||||
recallAllParameters,
|
||||
sendToImageToImage,
|
||||
};
|
||||
|
Reference in New Issue
Block a user