tidy(ui): add type for templates

This commit is contained in:
psychedelicious 2024-05-16 17:06:07 +10:00
parent 1d884fb794
commit 708c68413d
7 changed files with 18 additions and 18 deletions

View File

@ -11,8 +11,8 @@ import type { Layer } from 'features/controlLayers/store/types';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { selectSystemSlice } from 'features/system/store/systemSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice';
@ -29,7 +29,7 @@ const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
regional_guidance_layer: 'controlLayers.regionalGuidance', regional_guidance_layer: 'controlLayers.regionalGuidance',
}; };
const createSelector = (templates: Record<string, InvocationTemplate>) => const createSelector = (templates: Templates) =>
createMemoizedSelector( createMemoizedSelector(
[ [
selectControlAdaptersSlice, selectControlAdaptersSlice,

View File

@ -1,8 +1,8 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor'; import { getFieldColor } from './getEdgeColor';
@ -15,7 +15,7 @@ const defaultReturnValue = {
}; };
export const makeEdgeSelector = ( export const makeEdgeSelector = (
templates: Record<string, InvocationTemplate>, templates: Templates,
source: string, source: string,
sourceHandleId: string | null | undefined, sourceHandleId: string | null | undefined,
target: string, target: string,

View File

@ -43,12 +43,7 @@ import {
zT2IAdapterModelFieldValue, zT2IAdapterModelFieldValue,
zVAEModelFieldValue, zVAEModelFieldValue,
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import type { import type { AnyNode, InvocationNodeEdge, NodeExecutionState } from 'features/nodes/types/invocation';
AnyNode,
InvocationNodeEdge,
InvocationTemplate,
NodeExecutionState,
} from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
@ -74,7 +69,7 @@ import {
} from 'services/events/actions'; } from 'services/events/actions';
import type { z } from 'zod'; import type { z } from 'zod';
import type { NodesState } from './types'; import type { NodesState, Templates } from './types';
import { findConnectionToValidHandle } from './util/findConnectionToValidHandle'; import { findConnectionToValidHandle } from './util/findConnectionToValidHandle';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
@ -766,7 +761,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
); );
export const $cursorPos = atom<XYPosition | null>(null); export const $cursorPos = atom<XYPosition | null>(null);
export const $templates = atom<Record<string, InvocationTemplate>>({}); export const $templates = atom<Templates>({});
export const $copiedNodes = atom<AnyNode[]>([]); export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]); export const $copiedEdges = atom<InvocationNodeEdge[]>([]);

View File

@ -2,11 +2,14 @@ import type { FieldIdentifier, FieldType, StatefulFieldValue } from 'features/no
import type { import type {
AnyNode, AnyNode,
InvocationNodeEdge, InvocationNodeEdge,
InvocationTemplate,
NodeExecutionState, NodeExecutionState,
} from 'features/nodes/types/invocation'; } from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { OnConnectStartParams, Viewport, XYPosition } from 'reactflow'; import type { OnConnectStartParams, Viewport, XYPosition } from 'reactflow';
export type Templates = Record<string, InvocationTemplate>;
export type NodesState = { export type NodesState = {
_version: 1; _version: 1;
nodes: AnyNode[]; nodes: AnyNode[];

View File

@ -1,5 +1,6 @@
import type { Templates } from 'features/nodes/store/types';
import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field'; import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import type { Connection, Edge, HandleType, Node } from 'reactflow'; import type { Connection, Edge, HandleType, Node } from 'reactflow';
@ -43,7 +44,7 @@ export const findConnectionToValidHandle = (
node: AnyNode, node: AnyNode,
nodes: AnyNode[], nodes: AnyNode[],
edges: InvocationNodeEdge[], edges: InvocationNodeEdge[],
templates: Record<string, InvocationTemplate>, templates: Templates,
handleCurrentNodeId: string, handleCurrentNodeId: string,
handleCurrentName: string, handleCurrentName: string,
handleCurrentType: HandleType, handleCurrentType: HandleType,

View File

@ -1,5 +1,6 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import type { Templates } from 'features/nodes/store/types';
import { FieldParseError } from 'features/nodes/types/error'; import { FieldParseError } from 'features/nodes/types/error';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { InvocationTemplate } from 'features/nodes/types/invocation'; import type { InvocationTemplate } from 'features/nodes/types/invocation';
@ -58,14 +59,14 @@ export const parseSchema = (
openAPI: OpenAPIV3_1.Document, openAPI: OpenAPIV3_1.Document,
nodesAllowlistExtra: string[] | undefined = undefined, nodesAllowlistExtra: string[] | undefined = undefined,
nodesDenylistExtra: string[] | undefined = undefined nodesDenylistExtra: string[] | undefined = undefined
): Record<string, InvocationTemplate> => { ): Templates => {
const filteredSchemas = Object.values(openAPI.components?.schemas ?? {}) const filteredSchemas = Object.values(openAPI.components?.schemas ?? {})
.filter(isInvocationSchemaObject) .filter(isInvocationSchemaObject)
.filter(isNotInDenylist) .filter(isNotInDenylist)
.filter((schema) => (nodesAllowlistExtra ? nodesAllowlistExtra.includes(schema.properties.type.default) : true)) .filter((schema) => (nodesAllowlistExtra ? nodesAllowlistExtra.includes(schema.properties.type.default) : true))
.filter((schema) => (nodesDenylistExtra ? !nodesDenylistExtra.includes(schema.properties.type.default) : true)); .filter((schema) => (nodesDenylistExtra ? !nodesDenylistExtra.includes(schema.properties.type.default) : true));
const invocations = filteredSchemas.reduce<Record<string, InvocationTemplate>>((invocationsAccumulator, schema) => { const invocations = filteredSchemas.reduce<Templates>((invocationsAccumulator, schema) => {
const type = schema.properties.type.default; const type = schema.properties.type.default;
const title = schema.title.replace('Invocation', ''); const title = schema.title.replace('Invocation', '');
const tags = schema.tags ?? []; const tags = schema.tags ?? [];

View File

@ -1,6 +1,6 @@
import type { JSONObject } from 'common/types'; import type { JSONObject } from 'common/types';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import type { InvocationTemplate } from 'features/nodes/types/invocation'; import type { Templates } from 'features/nodes/store/types';
import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { isWorkflowInvocationNode } from 'features/nodes/types/workflow'; import { isWorkflowInvocationNode } from 'features/nodes/types/workflow';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
@ -33,7 +33,7 @@ type ValidateWorkflowResult = {
*/ */
export const validateWorkflow = ( export const validateWorkflow = (
workflow: unknown, workflow: unknown,
invocationTemplates: Record<string, InvocationTemplate> invocationTemplates: Templates
): ValidateWorkflowResult => { ): ValidateWorkflowResult => {
// Parse the raw workflow data & migrate it to the latest version // Parse the raw workflow data & migrate it to the latest version
const _workflow = parseAndMigrateWorkflow(workflow); const _workflow = parseAndMigrateWorkflow(workflow);