feat: workflow saving and loading

This commit is contained in:
psychedelicious 2023-08-24 21:42:32 +10:00
parent 7f6fdf5d39
commit 7d1942e9f0
51 changed files with 1175 additions and 320 deletions

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
import json
from typing import (
TYPE_CHECKING,
AbstractSet,
@ -20,7 +21,7 @@ from typing import (
get_type_hints,
)
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
from pydantic.fields import Undefined
from pydantic.typing import NoArgAnyCallable
@ -141,9 +142,11 @@ class UIType(str, Enum):
# endregion
# region Misc
FilePath = "FilePath"
Enum = "enum"
Scheduler = "Scheduler"
WorkflowField = "WorkflowField"
IsIntermediate = "IsIntermediate"
MetadataField = "MetadataField"
# endregion
@ -507,8 +510,24 @@ class BaseInvocation(ABC, BaseModel):
id: str = Field(description="The id of this node. Must be unique among all nodes.")
is_intermediate: bool = InputField(
default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct
default=False, description="Whether or not this node is an intermediate node.", ui_type=UIType.IsIntermediate
)
workflow: Optional[str] = InputField(
default=None,
description="The workflow to save with the image",
ui_type=UIType.WorkflowField,
)
@validator("workflow", pre=True)
def validate_workflow_is_json(cls, v):
if v is None:
return None
try:
json.loads(v)
except json.decoder.JSONDecodeError:
raise ValueError("Workflow must be valid JSON")
return v
UIConfig: ClassVar[Type[UIConfigBase]]

View File

@ -151,11 +151,6 @@ class ImageProcessorInvocation(BaseInvocation):
# image type should be PIL.PngImagePlugin.PngImageFile ?
processed_image = self.run_processor(raw_image)
# FIXME: what happened to image metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self
# )
# currently can't see processed image in node UI without a showImage node,
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
image_dto = context.services.images.create(
@ -165,6 +160,7 @@ class ImageProcessorInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
"""Builds an ImageOutput and its ImageField"""

View File

@ -45,6 +45,7 @@ class CvInpaintInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -65,6 +65,7 @@ class BlankImageInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -102,6 +103,7 @@ class ImageCropInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -154,6 +156,7 @@ class ImagePasteInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -189,6 +192,7 @@ class MaskFromAlphaInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -223,6 +227,7 @@ class ImageMultiplyInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -259,6 +264,7 @@ class ImageChannelInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -295,6 +301,7 @@ class ImageConvertInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -333,6 +340,7 @@ class ImageBlurInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -393,6 +401,7 @@ class ImageResizeInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -438,6 +447,7 @@ class ImageScaleInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -475,6 +485,7 @@ class ImageLerpInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -512,6 +523,7 @@ class ImageInverseLerpInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -555,6 +567,7 @@ class ImageNSFWBlurInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
@ -596,6 +609,7 @@ class ImageWatermarkInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
@ -644,6 +658,7 @@ class MaskEdgeInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -677,6 +692,7 @@ class MaskCombineInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -785,6 +801,7 @@ class ColorCorrectInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -827,6 +844,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation):
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
)
return ImageOutput(
@ -877,6 +895,7 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation):
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
)
return ImageOutput(
@ -925,6 +944,7 @@ class ImageSaturationAdjustmentInvocation(BaseInvocation):
node_id=self.id,
is_intermediate=self.is_intermediate,
session_id=context.graph_execution_state_id,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -145,6 +145,7 @@ class InfillColorInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -184,6 +185,7 @@ class InfillTileInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(
@ -218,6 +220,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -545,6 +545,7 @@ class LatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -376,6 +376,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation):
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -7,7 +7,7 @@ from pydantic import validator
from invokeai.app.invocations.primitives import StringCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, UIType, tags, title
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, tags, title
@title("Dynamic Prompt")
@ -41,7 +41,7 @@ class PromptsFromFileInvocation(BaseInvocation):
type: Literal["prompt_from_file"] = "prompt_from_file"
# Inputs
file_path: str = InputField(description="Path to prompt text file", ui_type=UIType.FilePath)
file_path: str = InputField(description="Path to prompt text file")
pre_prompt: Optional[str] = InputField(
default=None, description="String to prepend to each prompt", ui_component=UIComponent.Textarea
)

View File

@ -110,6 +110,7 @@ class ESRGANInvocation(BaseInvocation):
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
workflow=self.workflow,
)
return ImageOutput(

View File

@ -60,7 +60,7 @@ class ImageFileStorageBase(ABC):
image: PILImageType,
image_name: str,
metadata: Optional[dict] = None,
graph: Optional[dict] = None,
workflow: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -110,7 +110,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image: PILImageType,
image_name: str,
metadata: Optional[dict] = None,
graph: Optional[dict] = None,
workflow: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
try:
@ -121,8 +121,8 @@ class DiskImageFileStorage(ImageFileStorageBase):
if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if graph is not None:
pnginfo.add_text("invokeai_graph", json.dumps(graph))
if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow)
image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name)

View File

@ -54,6 +54,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
workflow: Optional[str] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -177,6 +178,7 @@ class ImageService(ImageServiceABC):
board_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
workflow: Optional[str] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
@ -186,16 +188,16 @@ class ImageService(ImageServiceABC):
image_name = self._services.names.create_image_name()
graph = None
if session_id is not None:
session_raw = self._services.graph_execution_manager.get_raw(session_id)
if session_raw is not None:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
# TODO: Do we want to store the graph in the image at all? I don't think so...
# graph = None
# if session_id is not None:
# session_raw = self._services.graph_execution_manager.get_raw(session_id)
# if session_raw is not None:
# try:
# graph = get_metadata_graph_from_raw_session(session_raw)
# except Exception as e:
# self._services.logger.warn(f"Failed to parse session graph: {e}")
# graph = None
(width, height) = image.size
@ -217,7 +219,7 @@ class ImageService(ImageServiceABC):
)
if board_id is not None:
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, workflow=workflow)
image_dto = self.get_dto(image_name)
return image_dto

View File

@ -7,5 +7,4 @@ stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*
src/services/api/schema.d.ts

View File

@ -7,8 +7,7 @@ index.html
.yarn/
.yalc/
*.scss
src/services/api/
src/services/fixtures/*
src/services/api/schema.d.ts
docs/
static/
src/theme/css/overlayscrollbars.css

View File

@ -74,6 +74,7 @@
"@nanostores/react": "^0.7.1",
"@reduxjs/toolkit": "^1.9.5",
"@roarr/browser-log-writer": "^1.1.5",
"@stevebel/png": "^1.5.1",
"dateformat": "^5.0.3",
"formik": "^2.4.3",
"framer-motion": "^10.16.1",

View File

@ -716,7 +716,7 @@
},
"nodes": {
"reloadNodeTemplates": "Reload Node Templates",
"saveWorkflow": "Save Workflow",
"downloadWorkflow": "Download Workflow JSON",
"loadWorkflow": "Load Workflow",
"resetWorkflow": "Reset Workflow",
"resetWorkflowDesc": "Are you sure you want to reset this workflow?",

View File

@ -1,10 +1,12 @@
import { Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppToaster } from 'app/components/Toaster';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectIsBusy } from 'features/system/store/systemSelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { AnimatePresence, motion } from 'framer-motion';
import {
KeyboardEvent,
ReactNode,
@ -18,8 +20,6 @@ import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import ImageUploadOverlay from './ImageUploadOverlay';
import { AnimatePresence, motion } from 'framer-motion';
import { stateSelector } from 'app/store/store';
const selector = createSelector(
[stateSelector, activeTabNameSelector],

View File

@ -9,20 +9,24 @@ import {
MenuButton,
MenuList,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import ParamUpscalePopover from 'features/parameters/components/Parameters/Upscale/ParamUpscaleSettings';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
setActiveTab,
setShouldShowImageDetails,
setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice';
@ -37,12 +41,12 @@ import {
FaSeedling,
FaShareAlt,
} from 'react-icons/fa';
import { MdDeviceHub } from 'react-icons/md';
import {
useGetImageDTOQuery,
useGetImageMetadataQuery,
useGetImageMetadataFromFileQuery,
} from 'services/api/endpoints/images';
import { menuListMotionProps } from 'theme/components/menu';
import { useDebounce } from 'use-debounce';
import { sentImageToImg2Img } from '../../store/actions';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
@ -101,22 +105,36 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const { recallBothPrompts, recallSeed, recallAllParameters } =
useRecallParameters();
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
lastSelectedImage,
500
);
const { currentData: imageDTO } = useGetImageDTOQuery(
lastSelectedImage?.image_name ?? skipToken
);
const { currentData: metadataData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg?.image_name ?? skipToken
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
lastSelectedImage?.image_name ?? skipToken,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const metadata = metadataData?.metadata;
const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
return;
}
dispatch(workflowLoaded(workflow));
dispatch(setActiveTab('nodes'));
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
}, [dispatch, workflow]);
const handleClickUseAllParameters = useCallback(() => {
recallAllParameters(metadata);
@ -153,6 +171,8 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys('p', handleUsePrompt, [imageDTO]);
useHotkeys('w', handleLoadWorkflow, [workflow]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO));
@ -259,22 +279,31 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<ButtonGroup isAttached={true} isDisabled={shouldDisableToolbarButtons}>
<IAIIconButton
isLoading={isLoading}
icon={<MdDeviceHub />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={!workflow}
onClick={handleLoadWorkflow}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!metadata?.positive_prompt}
onClick={handleUsePrompt}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!metadata?.seed}
onClick={handleUseSeed}
/>
<IAIIconButton
isLoading={isLoading}
icon={<FaAsterisk />}
tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`}

View File

@ -1,5 +1,4 @@
import { MenuItem } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
@ -8,9 +7,12 @@ import {
isModalOpenChanged,
} from 'features/changeBoardModal/store/slice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { workflowLoaded } from 'features/nodes/store/nodesSlice';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { initialImageSelected } from 'features/parameters/store/actions';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCopyImageToClipboard } from 'features/ui/hooks/useCopyImageToClipboard';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
@ -26,14 +28,13 @@ import {
FaShare,
FaTrash,
} from 'react-icons/fa';
import { MdStar, MdStarBorder } from 'react-icons/md';
import { MdDeviceHub, MdStar, MdStarBorder } from 'react-icons/md';
import {
useGetImageMetadataQuery,
useGetImageMetadataFromFileQuery,
useStarImagesMutation,
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
type SingleSelectionMenuItemsProps = {
@ -50,15 +51,15 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
imageDTO.image_name,
500
);
const { currentData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
{
selectFromResult: (res) => ({
isLoading: res.isFetching,
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const [starImages] = useStarImagesMutation();
@ -67,8 +68,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { isClipboardAPIAvailable, copyImageToClipboard } =
useCopyImageToClipboard();
const metadata = currentData?.metadata;
const handleDelete = useCallback(() => {
if (!imageDTO) {
return;
@ -99,6 +98,22 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
recallSeed(metadata?.seed);
}, [metadata?.seed, recallSeed]);
const handleLoadWorkflow = useCallback(() => {
if (!workflow) {
return;
}
dispatch(workflowLoaded(workflow));
dispatch(setActiveTab('nodes'));
dispatch(
addToast(
makeToast({
title: 'Workflow Loaded',
status: 'success',
})
)
);
}, [dispatch, workflow]);
const handleSendToImageToImage = useCallback(() => {
dispatch(sentImageToImg2Img());
dispatch(initialImageSelected(imageDTO));
@ -118,7 +133,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
}, [dispatch, imageDTO, t, toaster]);
const handleUseAllParameters = useCallback(() => {
console.log(metadata);
recallAllParameters(metadata);
}, [metadata, recallAllParameters]);
@ -169,27 +183,34 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
{t('parameters.downloadImage')}
</MenuItem>
<MenuItem
icon={<FaQuoteRight />}
icon={isLoading ? <SpinnerIcon /> : <MdDeviceHub />}
onClickCapture={handleLoadWorkflow}
isDisabled={isLoading || !workflow}
>
{t('nodes.loadWorkflow')}
</MenuItem>
<MenuItem
icon={isLoading ? <SpinnerIcon /> : <FaQuoteRight />}
onClickCapture={handleRecallPrompt}
isDisabled={
metadata?.positive_prompt === undefined &&
metadata?.negative_prompt === undefined
isLoading ||
(metadata?.positive_prompt === undefined &&
metadata?.negative_prompt === undefined)
}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
icon={<FaSeedling />}
icon={isLoading ? <SpinnerIcon /> : <FaSeedling />}
onClickCapture={handleRecallSeed}
isDisabled={metadata?.seed === undefined}
isDisabled={isLoading || metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
icon={<FaAsterisk />}
icon={isLoading ? <SpinnerIcon /> : <FaAsterisk />}
onClickCapture={handleUseAllParameters}
isDisabled={!metadata}
isDisabled={isLoading || !metadata}
>
{t('parameters.useAll')}
</MenuItem>
@ -233,3 +254,9 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
};
export default memo(SingleSelectionMenuItems);
const SpinnerIcon = () => (
<Flex w="14px" alignItems="center" justifyContent="center">
<Spinner size="xs" />
</Flex>
);

View File

@ -2,7 +2,7 @@ import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
import { FaCopy, FaSave } from 'react-icons/fa';
import { FaCopy, FaDownload } from 'react-icons/fa';
type Props = {
label: string;
@ -23,7 +23,7 @@ const DataViewer = (props: Props) => {
navigator.clipboard.writeText(dataString);
}, [dataString]);
const handleSave = useCallback(() => {
const handleDownload = useCallback(() => {
const blob = new Blob([dataString]);
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
@ -73,13 +73,13 @@ const DataViewer = (props: Props) => {
</Box>
<Flex sx={{ position: 'absolute', top: 0, insetInlineEnd: 0, p: 2 }}>
{withDownload && (
<Tooltip label={`Save ${label} JSON`}>
<Tooltip label={`Download ${label} JSON`}>
<IconButton
aria-label={`Save ${label} JSON`}
icon={<FaSave />}
aria-label={`Download ${label} JSON`}
icon={<FaDownload />}
variant="ghost"
opacity={0.7}
onClick={handleSave}
onClick={handleDownload}
/>
</Tooltip>
)}

View File

@ -1,10 +1,10 @@
import { CoreMetadata } from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useCallback } from 'react';
import { UnsafeImageMetadata } from 'services/api/types';
import ImageMetadataItem from './ImageMetadataItem';
type Props = {
metadata?: UnsafeImageMetadata['metadata'];
metadata?: CoreMetadata;
};
const ImageMetadataActions = (props: Props) => {
@ -91,14 +91,14 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallNegativePrompt}
/>
)}
{metadata.seed !== undefined && (
{metadata.seed !== undefined && metadata.seed !== null && (
<ImageMetadataItem
label="Seed"
value={metadata.seed}
onClick={handleRecallSeed}
/>
)}
{metadata.model !== undefined && (
{metadata.model !== undefined && metadata.model !== null && (
<ImageMetadataItem
label="Model"
value={metadata.model.model_name}
@ -147,7 +147,7 @@ const ImageMetadataActions = (props: Props) => {
onClick={handleRecallSteps}
/>
)}
{metadata.cfg_scale !== undefined && (
{metadata.cfg_scale !== undefined && metadata.cfg_scale !== null && (
<ImageMetadataItem
label="CFG scale"
value={metadata.cfg_scale}

View File

@ -9,14 +9,12 @@ import {
Tabs,
Text,
} from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import ImageMetadataActions from './ImageMetadataActions';
import DataViewer from './DataViewer';
import ImageMetadataActions from './ImageMetadataActions';
type ImageMetadataViewerProps = {
image: ImageDTO;
@ -29,19 +27,16 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// dispatch(setShouldShowImageDetails(false));
// });
const [debouncedMetadataQueryArg, debounceState] = useDebounce(
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
image.image_name,
500
{
selectFromResult: (res) => ({
metadata: res?.currentData?.metadata,
workflow: res?.currentData?.workflow,
}),
}
);
const { currentData } = useGetImageMetadataQuery(
debounceState.isPending()
? skipToken
: debouncedMetadataQueryArg ?? skipToken
);
const metadata = currentData?.metadata;
const graph = currentData?.graph;
return (
<Flex
layerStyle="first"
@ -71,17 +66,17 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }}
>
<TabList>
<Tab>Core Metadata</Tab>
<Tab>Metadata</Tab>
<Tab>Image Details</Tab>
<Tab>Graph</Tab>
<Tab>Workflow</Tab>
</TabList>
<TabPanels>
<TabPanel>
{metadata ? (
<DataViewer data={metadata} label="Core Metadata" />
<DataViewer data={metadata} label="Metadata" />
) : (
<IAINoContentFallback label="No core metadata found" />
<IAINoContentFallback label="No metadata found" />
)}
</TabPanel>
<TabPanel>
@ -92,10 +87,10 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
)}
</TabPanel>
<TabPanel>
{graph ? (
<DataViewer data={graph} label="Graph" />
{workflow ? (
<DataViewer data={workflow} label="Workflow" />
) : (
<IAINoContentFallback label="No graph found" />
<IAINoContentFallback label="No workflow found" />
)}
</TabPanel>
</TabPanels>

View File

@ -0,0 +1,41 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useEmbedWorkflow } from 'features/nodes/hooks/useEmbedWorkflow';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { nodeEmbedWorkflowChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const embedWorkflow = useEmbedWorkflow(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeEmbedWorkflowChanged({
nodeId,
embedWorkflow: e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Embed Workflow</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={embedWorkflow}
/>
</FormControl>
);
};
export default memo(EmbedWorkflowCheckbox);

View File

@ -1,16 +1,8 @@
import {
Checkbox,
Flex,
FormControl,
FormLabel,
Spacer,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice';
import { Flex } from '@chakra-ui/react';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { ChangeEvent, memo, useCallback } from 'react';
import { memo } from 'react';
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
type Props = {
nodeId: string;
@ -27,48 +19,13 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
px: 2,
py: 0,
h: 6,
justifyContent: 'space-between',
}}
>
<Spacer />
<SaveImageCheckbox nodeId={nodeId} />
<EmbedWorkflowCheckbox nodeId={nodeId} />
<SaveToGalleryCheckbox nodeId={nodeId} />
</Flex>
);
};
export default memo(InvocationNodeFooter);
const SaveImageCheckbox = memo(({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const is_intermediate = useIsIntermediate(nodeId);
const handleChangeIsIntermediate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldBooleanValueChanged({
nodeId,
fieldName: 'is_intermediate',
value: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save Output</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChangeIsIntermediate}
isChecked={!is_intermediate}
/>
</FormControl>
);
});
SaveImageCheckbox.displayName = 'SaveImageCheckbox';

View File

@ -0,0 +1,41 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
import { useIsIntermediate } from 'features/nodes/hooks/useIsIntermediate';
import { nodeIsIntermediateChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const SaveToGalleryCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const hasImageOutput = useHasImageOutput(nodeId);
const isIntermediate = useIsIntermediate(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeIsIntermediateChanged({
nodeId,
isIntermediate: !e.target.checked,
})
);
},
[dispatch, nodeId]
);
if (!hasImageOutput) {
return null;
}
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Save to Gallery</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={!isIntermediate}
/>
</FormControl>
);
};
export default memo(SaveToGalleryCheckbox);

View File

@ -2,12 +2,12 @@ import IAIIconButton from 'common/components/IAIIconButton';
import { useWorkflow } from 'features/nodes/hooks/useWorkflow';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaSave } from 'react-icons/fa';
import { FaDownload } from 'react-icons/fa';
const SaveWorkflowButton = () => {
const DownloadWorkflowButton = () => {
const { t } = useTranslation();
const workflow = useWorkflow();
const handleSave = useCallback(() => {
const handleDownload = useCallback(() => {
const blob = new Blob([JSON.stringify(workflow, null, 2)]);
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
@ -18,12 +18,12 @@ const SaveWorkflowButton = () => {
}, [workflow]);
return (
<IAIIconButton
icon={<FaSave />}
tooltip={t('nodes.saveWorkflow')}
aria-label={t('nodes.saveWorkflow')}
onClick={handleSave}
icon={<FaDownload />}
tooltip={t('nodes.downloadWorkflow')}
aria-label={t('nodes.downloadWorkflow')}
onClick={handleDownload}
/>
);
};
export default memo(SaveWorkflowButton);
export default memo(DownloadWorkflowButton);

View File

@ -2,7 +2,7 @@ import { Flex } from '@chakra-ui/layout';
import { memo } from 'react';
import LoadWorkflowButton from './LoadWorkflowButton';
import ResetWorkflowButton from './ResetWorkflowButton';
import SaveWorkflowButton from './SaveWorkflowButton';
import DownloadWorkflowButton from './DownloadWorkflowButton';
const TopCenterPanel = () => {
return (
@ -15,7 +15,7 @@ const TopCenterPanel = () => {
transform: 'translate(-50%)',
}}
>
<SaveWorkflowButton />
<DownloadWorkflowButton />
<LoadWorkflowButton />
<ResetWorkflowButton />
</Flex>

View File

@ -22,6 +22,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
}
return map(nodeTemplate.inputs)
.filter((field) => ['any', 'direct'].includes(field.input))
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name)
.filter((fieldName) => fieldName !== 'is_intermediate');

View File

@ -143,6 +143,8 @@ export const useBuildNodeData = () => {
isOpen: true,
label: '',
notes: '',
embedWorkflow: false,
isIntermediate: true,
},
};

View File

@ -22,6 +22,7 @@ export const useConnectionInputFieldNames = (nodeId: string) => {
}
return map(nodeTemplate.inputs)
.filter((field) => field.input === 'connection')
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name)
.filter((fieldName) => fieldName !== 'is_intermediate');

View File

@ -0,0 +1,27 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
export const useEmbedWorkflow = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
return node.data.embedWorkflow;
},
defaultSelectorOptions
),
[nodeId]
);
const embedWorkflow = useAppSelector(selector);
return embedWorkflow;
};

View File

@ -15,7 +15,7 @@ export const useIsIntermediate = (nodeId: string) => {
if (!isInvocationNode(node)) {
return false;
}
return Boolean(node.data.inputs.is_intermediate?.value);
return node.data.isIntermediate;
},
defaultSelectorOptions
),

View File

@ -21,6 +21,7 @@ export const useOutputFieldNames = (nodeId: string) => {
return [];
}
return map(nodeTemplate.outputs)
.filter((field) => !field.ui_hidden)
.sort((a, b) => (a.ui_order ?? 0) - (b.ui_order ?? 0))
.map((field) => field.name)
.filter((fieldName) => fieldName !== 'is_intermediate');

View File

@ -245,6 +245,34 @@ const nodesSlice = createSlice({
}
field.label = label;
},
nodeEmbedWorkflowChanged: (
state,
action: PayloadAction<{ nodeId: string; embedWorkflow: boolean }>
) => {
const { nodeId, embedWorkflow } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
node.data.embedWorkflow = embedWorkflow;
},
nodeIsIntermediateChanged: (
state,
action: PayloadAction<{ nodeId: string; isIntermediate: boolean }>
) => {
const { nodeId, isIntermediate } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
node.data.isIntermediate = isIntermediate;
},
nodeIsOpenChanged: (
state,
action: PayloadAction<{ nodeId: string; isOpen: boolean }>
@ -850,6 +878,8 @@ export const {
addNodePopoverClosed,
addNodePopoverToggled,
selectionModeChanged,
nodeEmbedWorkflowChanged,
nodeIsIntermediateChanged,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@ -169,11 +169,6 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: 'Color Collection',
description: 'A collection of colors.',
},
FilePath: {
color: 'base.500',
title: 'File Path',
description: 'A path to a file.',
},
ONNXModelField: {
color: 'base.500',
title: 'ONNX Model',

View File

@ -2,6 +2,7 @@ import {
SchedulerParam,
zBaseModel,
zMainOrOnnxModel,
zSDXLRefinerModel,
zScheduler,
} from 'features/parameters/types/parameterSchemas';
import { OpenAPIV3 } from 'openapi-types';
@ -97,7 +98,6 @@ export const zFieldType = z.enum([
// endregion
// region Misc
'FilePath',
'enum',
'Scheduler',
// endregion
@ -105,8 +105,17 @@ export const zFieldType = z.enum([
export type FieldType = z.infer<typeof zFieldType>;
export const zReservedFieldType = z.enum([
'WorkflowField',
'IsIntermediate',
'MetadataField',
]);
export type ReservedFieldType = z.infer<typeof zReservedFieldType>;
export const isFieldType = (value: unknown): value is FieldType =>
zFieldType.safeParse(value).success;
zFieldType.safeParse(value).success ||
zReservedFieldType.safeParse(value).success;
/**
* An input field template is generated on each page load from the OpenAPI schema.
@ -619,6 +628,11 @@ export type SchedulerInputFieldTemplate = InputFieldTemplateBase & {
type: 'Scheduler';
};
export type WorkflowInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'WorkflowField';
};
export const isInputFieldValue = (
field?: InputFieldValue | OutputFieldValue
): field is InputFieldValue => Boolean(field && field.fieldKind === 'input');
@ -715,6 +729,47 @@ export const isInvocationFieldSchema = (
export type InvocationEdgeExtra = { type: 'default' | 'collapsed' };
export const zCoreMetadata = z
.object({
app_version: z.string().nullish(),
generation_mode: z.string().nullish(),
positive_prompt: z.string().nullish(),
negative_prompt: z.string().nullish(),
width: z.number().int().nullish(),
height: z.number().int().nullish(),
seed: z.number().int().nullish(),
rand_device: z.string().nullish(),
cfg_scale: z.number().nullish(),
steps: z.number().int().nullish(),
scheduler: z.string().nullish(),
clip_skip: z.number().int().nullish(),
model: zMainOrOnnxModel.nullish(),
controlnets: z.array(zControlField).nullish(),
loras: z
.array(
z.object({
lora: zLoRAModelField,
weight: z.number(),
})
)
.nullish(),
vae: zVaeModelField.nullish(),
strength: z.number().nullish(),
init_image: z.string().nullish(),
positive_style_prompt: z.string().nullish(),
negative_style_prompt: z.string().nullish(),
refiner_model: zSDXLRefinerModel.nullish(),
refiner_cfg_scale: z.number().nullish(),
refiner_steps: z.number().int().nullish(),
refiner_scheduler: z.string().nullish(),
refiner_positive_aesthetic_store: z.number().nullish(),
refiner_negative_aesthetic_store: z.number().nullish(),
refiner_start: z.number().nullish(),
})
.catchall(z.record(z.any()));
export type CoreMetadata = z.infer<typeof zCoreMetadata>;
export const zInvocationNodeData = z.object({
id: z.string().trim().min(1),
// no easy way to build this dynamically, and we don't want to anyways, because this will be used
@ -725,6 +780,8 @@ export const zInvocationNodeData = z.object({
label: z.string(),
isOpen: z.boolean(),
notes: z.string(),
embedWorkflow: z.boolean(),
isIntermediate: z.boolean(),
});
// Massage this to get better type safety while developing
@ -817,10 +874,18 @@ export const zWorkflow = z.object({
nodes: z.array(zWorkflowNode),
edges: z.array(zWorkflowEdge),
exposedFields: z.array(zFieldIdentifier),
meta: z.object({
version: zSemVer,
}),
});
export type Workflow = z.infer<typeof zWorkflow>;
export type ImageMetadataAndWorkflow = {
metadata?: CoreMetadata;
workflow?: Workflow;
};
export type CurrentImageNodeData = {
id: string;
type: 'current_image';

View File

@ -1,7 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { logger } from 'app/logging/logger';
import { NodesState } from '../store/types';
import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types';
import { fromZodError } from 'zod-validation-error';
import { parseify } from 'common/util/serialize';
export const buildWorkflow = (nodesState: NodesState): Workflow => {
const { workflow: workflowMeta, nodes, edges } = nodesState;
@ -14,6 +15,10 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
nodes.forEach((node) => {
const result = zWorkflowNode.safeParse(node);
if (!result.success) {
const { message } = fromZodError(result.error, {
prefix: 'Unable to parse node',
});
logger('nodes').warn({ node: parseify(node) }, message);
return;
}
workflow.nodes.push(result.data);
@ -22,6 +27,10 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
edges.forEach((edge) => {
const result = zWorkflowEdge.safeParse(edge);
if (!result.success) {
const { message } = fromZodError(result.error, {
prefix: 'Unable to parse edge',
});
logger('nodes').warn({ edge: parseify(edge) }, message);
return;
}
workflow.edges.push(result.data);
@ -29,7 +38,3 @@ export const buildWorkflow = (nodesState: NodesState): Workflow => {
return workflow;
};
export const workflowSelector = createSelector(stateSelector, ({ nodes }) =>
buildWorkflow(nodes)
);

View File

@ -27,7 +27,6 @@ import {
UNetInputFieldTemplate,
VaeInputFieldTemplate,
VaeModelInputFieldTemplate,
isFieldType,
} from '../types/types';
export type BaseFieldProperties = 'name' | 'title' | 'description';
@ -408,9 +407,7 @@ const buildSchedulerInputFieldTemplate = ({
return template;
};
export const getFieldType = (
schemaObject: InvocationFieldSchema
): FieldType => {
export const getFieldType = (schemaObject: InvocationFieldSchema): string => {
let fieldType = '';
const { ui_type } = schemaObject;
@ -446,10 +443,6 @@ export const getFieldType = (
}
}
if (!isFieldType(fieldType)) {
throw `Field type "${fieldType}" is unknown!`;
}
return fieldType;
};
@ -461,12 +454,9 @@ export const getFieldType = (
export const buildInputFieldTemplate = (
nodeSchema: InvocationSchemaObject,
fieldSchema: InvocationFieldSchema,
name: string
name: string,
fieldType: FieldType
) => {
// console.log('input', schemaObject);
const fieldType = getFieldType(fieldSchema);
// console.log('input fieldType', fieldType);
const { input, ui_hidden, ui_component, ui_type, ui_order } = fieldSchema;
const extra = {

View File

@ -0,0 +1,37 @@
import * as png from '@stevebel/png';
import { logger } from 'app/logging/logger';
import {
ImageMetadataAndWorkflow,
zCoreMetadata,
zWorkflow,
} from 'features/nodes/types/types';
import { get } from 'lodash-es';
export const getMetadataAndWorkflowFromImageBlob = async (
image: Blob
): Promise<ImageMetadataAndWorkflow> => {
const data: ImageMetadataAndWorkflow = {};
try {
const buffer = await image.arrayBuffer();
const text = png.decode(buffer).text;
const rawMetadata = get(text, 'invokeai_metadata');
const rawWorkflow = get(text, 'invokeai_workflow');
if (rawMetadata) {
try {
data.metadata = zCoreMetadata.parse(JSON.parse(rawMetadata));
} catch {
// no-op
}
}
if (rawWorkflow) {
try {
data.workflow = zWorkflow.parse(JSON.parse(rawWorkflow));
} catch {
// no-op
}
}
} catch {
logger('nodes').warn('Unable to parse image');
}
return data;
};

View File

@ -4,6 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es';
import { Graph } from 'services/api/types';
import { AnyInvocation } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
import { buildWorkflow } from '../buildWorkflow';
/**
* We need to do special handling for some fields
@ -34,12 +35,13 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const { nodes, edges } = nodesState;
const filteredNodes = nodes.filter(isInvocationNode);
const workflowJSON = JSON.stringify(buildWorkflow(nodesState));
// Reduce the node editor nodes into invocation graph nodes
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>(
(nodesAccumulator, node) => {
const { id, data } = node;
const { type, inputs } = data;
const { type, inputs, isIntermediate, embedWorkflow } = data;
// Transform each node's inputs to simple key-value pairs
const transformedInputs = reduce(
@ -58,8 +60,14 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
type,
id,
...transformedInputs,
is_intermediate: isIntermediate,
};
if (embedWorkflow) {
// add the workflow to the node
Object.assign(graphNode, { workflow: workflowJSON });
}
// Add it to the nodes object
Object.assign(nodesAccumulator, {
[id]: graphNode,

View File

@ -4,10 +4,12 @@ import { reduce } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import { AnyInvocationType } from 'services/events/types';
import {
FieldType,
InputFieldTemplate,
InvocationSchemaObject,
InvocationTemplate,
OutputFieldTemplate,
isFieldType,
isInvocationFieldSchema,
isInvocationOutputSchemaObject,
isInvocationSchemaObject,
@ -16,23 +18,35 @@ import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders';
const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'metadata'];
const RESERVED_OUTPUT_FIELD_NAMES = ['type'];
const RESERVED_FIELD_TYPES = [
'WorkflowField',
'MetadataField',
'IsIntermediate',
];
const invocationDenylist: AnyInvocationType[] = [
'graph',
'metadata_accumulator',
];
const isAllowedInputField = (nodeType: string, fieldName: string) => {
const isReservedInputField = (nodeType: string, fieldName: string) => {
if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) {
return false;
return true;
}
if (nodeType === 'collect' && fieldName === 'collection') {
return false;
return true;
}
if (nodeType === 'iterate' && fieldName === 'index') {
return false;
}
return true;
}
return false;
};
const isReservedFieldType = (fieldType: FieldType) => {
if (RESERVED_FIELD_TYPES.includes(fieldType)) {
return true;
}
return false;
};
const isAllowedOutputField = (nodeType: string, fieldName: string) => {
@ -62,10 +76,14 @@ export const parseSchema = (
const inputs = reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (!isAllowedInputField(type, propertyName)) {
(
inputsAccumulator: Record<string, InputFieldTemplate>,
property,
propertyName
) => {
if (isReservedInputField(type, propertyName)) {
logger('nodes').trace(
{ type, propertyName, property: parseify(property) },
{ node: type, fieldName: propertyName, field: parseify(property) },
'Skipped reserved input field'
);
return inputsAccumulator;
@ -73,21 +91,64 @@ export const parseSchema = (
if (!isInvocationFieldSchema(property)) {
logger('nodes').warn(
{ type, propertyName, property: parseify(property) },
{ node: type, propertyName, property: parseify(property) },
'Unhandled input property'
);
return inputsAccumulator;
}
const field = buildInputFieldTemplate(schema, property, propertyName);
const fieldType = getFieldType(property);
if (field) {
inputsAccumulator[propertyName] = field;
if (!isFieldType(fieldType)) {
logger('nodes').warn(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
'Skipping unknown input field type'
);
return inputsAccumulator;
}
if (isReservedFieldType(fieldType)) {
logger('nodes').trace(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
'Skipping reserved field type'
);
return inputsAccumulator;
}
const field = buildInputFieldTemplate(
schema,
property,
propertyName,
fieldType
);
if (!field) {
logger('nodes').warn(
{
node: type,
fieldName: propertyName,
fieldType,
field: parseify(property),
},
'Skipping input field with no template'
);
return inputsAccumulator;
}
inputsAccumulator[propertyName] = field;
return inputsAccumulator;
},
{} as Record<string, InputFieldTemplate>
{}
);
const outputSchemaName = schema.output.$ref.split('/').pop();
@ -136,6 +197,13 @@ export const parseSchema = (
}
const fieldType = getFieldType(property);
if (!isFieldType(fieldType)) {
logger('nodes').warn(
{ fieldName: propertyName, fieldType, field: parseify(property) },
'Skipping unknown output field type'
);
} else {
outputsAccumulator[propertyName] = {
fieldKind: 'output',
name: propertyName,
@ -146,6 +214,7 @@ export const parseSchema = (
ui_type: property.ui_type,
ui_order: property.ui_order,
};
}
return outputsAccumulator;
},

View File

@ -1,5 +1,6 @@
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch } from 'app/store/storeHooks';
import { CoreMetadata } from 'features/nodes/types/types';
import {
refinerModelChanged,
setNegativeStylePromptSDXL,
@ -13,7 +14,7 @@ import {
} from 'features/sdxl/store/sdxlSlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { ImageDTO, UnsafeImageMetadata } from 'services/api/types';
import { ImageDTO } from 'services/api/types';
import { initialImageSelected, modelSelected } from '../store/actions';
import {
setCfgScale,
@ -317,7 +318,7 @@ export const useRecallParameters = () => {
);
const recallAllParameters = useCallback(
(metadata: UnsafeImageMetadata['metadata'] | undefined) => {
(metadata: CoreMetadata | undefined) => {
if (!metadata) {
allParameterNotSetToast();
return;

View File

@ -29,11 +29,13 @@ export const $projectId = atom<string | undefined>();
* @example
* const { get, post, del } = $client.get();
*/
export const $client = computed([$authToken, $baseUrl, $projectId], (authToken, baseUrl, projectId) =>
export const $client = computed(
[$authToken, $baseUrl, $projectId],
(authToken, baseUrl, projectId) =>
createClient<paths>({
headers: {
...(authToken ? { Authorization: `Bearer ${authToken}` } : {}),
...(projectId ? { "project-id": projectId } : {})
...(projectId ? { 'project-id': projectId } : {}),
},
// do not include `api/v1` in the base url for this client
baseUrl: `${baseUrl ?? ''}`,

View File

@ -19,7 +19,7 @@ export const boardsApi = api.injectEndpoints({
*/
listBoards: build.query<OffsetPaginatedResults_BoardDTO_, ListBoardsArg>({
query: (arg) => ({ url: 'boards/', params: arg }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];
@ -42,7 +42,7 @@ export const boardsApi = api.injectEndpoints({
url: 'boards/',
params: { all: true },
}),
providesTags: (result, error, arg) => {
providesTags: (result) => {
// any list of boards
const tags: ApiFullTagDescription[] = [{ type: 'Board', id: LIST_TAG }];

View File

@ -6,7 +6,8 @@ import {
IMAGE_CATEGORIES,
IMAGE_LIMIT,
} from 'features/gallery/store/types';
import { keyBy } from 'lodash';
import { getMetadataAndWorkflowFromImageBlob } from 'features/nodes/util/getMetadataAndWorkflowFromImageBlob';
import { keyBy } from 'lodash-es';
import { ApiFullTagDescription, LIST_TAG, api } from '..';
import { components, paths } from '../schema';
import {
@ -26,6 +27,7 @@ import {
imagesSelectors,
} from '../util';
import { boardsApi } from './boards';
import { ImageMetadataAndWorkflow } from 'features/nodes/types/types';
export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({
@ -113,6 +115,19 @@ export const imagesApi = api.injectEndpoints({
],
keepUnusedDataFor: 86400, // 24 hours
}),
getImageMetadataFromFile: build.query<ImageMetadataAndWorkflow, string>({
query: (image_name) => ({
url: `images/i/${image_name}/full`,
responseHandler: async (res) => {
return await res.blob();
},
}),
providesTags: (result, error, image_name) => [
{ type: 'ImageMetadataFromFile', id: image_name },
],
transformResponse: (response: Blob) =>
getMetadataAndWorkflowFromImageBlob(response),
}),
clearIntermediates: build.mutation<number, void>({
query: () => ({ url: `images/clear-intermediates`, method: 'POST' }),
invalidatesTags: ['IntermediatesCount'],
@ -357,7 +372,7 @@ export const imagesApi = api.injectEndpoints({
],
async onQueryStarted(
{ imageDTO, session_id },
{ dispatch, queryFulfilled, getState }
{ dispatch, queryFulfilled }
) {
/**
* Cache changes for `changeImageSessionId`:
@ -432,7 +447,9 @@ export const imagesApi = api.injectEndpoints({
data.updated_image_names.includes(i.image_name)
);
if (!updatedImages[0]) return;
if (!updatedImages[0]) {
return;
}
// assume all images are on the same board/category
const categories = getCategories(updatedImages[0]);
@ -544,7 +561,9 @@ export const imagesApi = api.injectEndpoints({
data.updated_image_names.includes(i.image_name)
);
if (!updatedImages[0]) return;
if (!updatedImages[0]) {
return;
}
// assume all images are on the same board/category
const categories = getCategories(updatedImages[0]);
const boardId = updatedImages[0].board_id;
@ -645,17 +664,7 @@ export const imagesApi = api.injectEndpoints({
},
};
},
async onQueryStarted(
{
file,
image_category,
is_intermediate,
postUploadAction,
session_id,
board_id,
},
{ dispatch, queryFulfilled }
) {
async onQueryStarted(_, { dispatch, queryFulfilled }) {
try {
/**
* NOTE: PESSIMISTIC UPDATE
@ -712,7 +721,7 @@ export const imagesApi = api.injectEndpoints({
deleteBoard: build.mutation<DeleteBoardResult, string>({
query: (board_id) => ({ url: `boards/${board_id}`, method: 'DELETE' }),
invalidatesTags: (result, error, board_id) => [
invalidatesTags: () => [
{ type: 'Board', id: LIST_TAG },
// invalidate the 'No Board' cache
{
@ -732,7 +741,7 @@ export const imagesApi = api.injectEndpoints({
{ type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' },
],
async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) {
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
/**
* Cache changes for deleteBoard:
* - Update every image in the 'getImageDTO' cache that has the board_id
@ -802,7 +811,7 @@ export const imagesApi = api.injectEndpoints({
method: 'DELETE',
params: { include_images: true },
}),
invalidatesTags: (result, error, board_id) => [
invalidatesTags: () => [
{ type: 'Board', id: LIST_TAG },
{
type: 'ImageList',
@ -821,7 +830,7 @@ export const imagesApi = api.injectEndpoints({
{ type: 'BoardImagesTotal', id: 'none' },
{ type: 'BoardAssetsTotal', id: 'none' },
],
async onQueryStarted(board_id, { dispatch, queryFulfilled, getState }) {
async onQueryStarted(board_id, { dispatch, queryFulfilled }) {
/**
* Cache changes for deleteBoardAndImages:
* - ~~Remove every image in the 'getImageDTO' cache that has the board_id~~
@ -1253,9 +1262,8 @@ export const imagesApi = api.injectEndpoints({
];
result?.removed_image_names.forEach((image_name) => {
const board_id = imageDTOs.find(
(i) => i.image_name === image_name
)?.board_id;
const board_id = imageDTOs.find((i) => i.image_name === image_name)
?.board_id;
if (!board_id || touchedBoardIds.includes(board_id)) {
return;
@ -1385,4 +1393,5 @@ export const {
useDeleteBoardMutation,
useStarImagesMutation,
useUnstarImagesMutation,
useGetImageMetadataFromFileQuery,
} = imagesApi;

View File

@ -178,7 +178,7 @@ export const modelsApi = api.injectEndpoints({
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'OnnxModel', id: LIST_TAG },
];
@ -194,11 +194,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: OnnxModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: OnnxModelConfig[] }) => {
const entities = createModelEntities<OnnxModelConfigEntity>(
response.models
);
@ -221,7 +217,7 @@ export const modelsApi = api.injectEndpoints({
const query = queryString.stringify(params, { arrayFormat: 'none' });
return `models/?${query}`;
},
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'MainModel', id: LIST_TAG },
];
@ -237,11 +233,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: MainModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: MainModelConfig[] }) => {
const entities = createModelEntities<MainModelConfigEntity>(
response.models
);
@ -361,7 +353,7 @@ export const modelsApi = api.injectEndpoints({
}),
getLoRAModels: build.query<EntityState<LoRAModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'lora' } }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'LoRAModel', id: LIST_TAG },
];
@ -377,11 +369,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: LoRAModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: LoRAModelConfig[] }) => {
const entities = createModelEntities<LoRAModelConfigEntity>(
response.models
);
@ -421,7 +409,7 @@ export const modelsApi = api.injectEndpoints({
void
>({
query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'ControlNetModel', id: LIST_TAG },
];
@ -437,11 +425,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: ControlNetModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: ControlNetModelConfig[] }) => {
const entities = createModelEntities<ControlNetModelConfigEntity>(
response.models
);
@ -453,7 +437,7 @@ export const modelsApi = api.injectEndpoints({
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'VaeModel', id: LIST_TAG },
];
@ -469,11 +453,7 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: VaeModelConfig[] },
meta,
arg
) => {
transformResponse: (response: { models: VaeModelConfig[] }) => {
const entities = createModelEntities<VaeModelConfigEntity>(
response.models
);
@ -488,7 +468,7 @@ export const modelsApi = api.injectEndpoints({
void
>({
query: () => ({ url: 'models/', params: { model_type: 'embedding' } }),
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'TextualInversionModel', id: LIST_TAG },
];
@ -504,11 +484,9 @@ export const modelsApi = api.injectEndpoints({
return tags;
},
transformResponse: (
response: { models: TextualInversionModelConfig[] },
meta,
arg
) => {
transformResponse: (response: {
models: TextualInversionModelConfig[];
}) => {
const entities = createModelEntities<TextualInversionModelConfigEntity>(
response.models
);
@ -525,7 +503,7 @@ export const modelsApi = api.injectEndpoints({
url: `/models/search?${folderQueryStr}`,
};
},
providesTags: (result, error, arg) => {
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'ScannedModels', id: LIST_TAG },
];

View File

@ -16,6 +16,7 @@ export const tagTypes = [
'ImageNameList',
'ImageList',
'ImageMetadata',
'ImageMetadataFromFile',
'Model',
];
export type ApiFullTagDescription = FullTagDescription<
@ -39,7 +40,7 @@ const dynamicBaseQuery: BaseQueryFn<
headers.set('Authorization', `Bearer ${authToken}`);
}
if (projectId) {
headers.set("project-id", projectId)
headers.set('project-id', projectId);
}
return headers;

File diff suppressed because it is too large Load Diff

View File

@ -1,14 +1,16 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
function getCircularReplacer() {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const ancestors: Record<string, any>[] = [];
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return function (key: string, value: any) {
if (typeof value !== 'object' || value === null) {
return value;
}
// `this` is the object that value is contained in,
// i.e., its direct parent.
// @ts-ignore
// `this` is the object that value is contained in, i.e., its direct parent.
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore don't think it's possible to not have TS complain about this...
while (ancestors.length > 0 && ancestors.at(-1) !== this) {
ancestors.pop();
}

View File

@ -73,7 +73,7 @@ export const sessionInvoked = createAsyncThunk<
>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
const { session_id } = arg;
const { PUT } = $client.get();
const { data, error, response } = await PUT(
const { error, response } = await PUT(
'/api/v1/sessions/{session_id}/invoke',
{
params: { query: { all: true }, path: { session_id } },
@ -85,6 +85,7 @@ export const sessionInvoked = createAsyncThunk<
return rejectWithValue({
arg,
status: response.status,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error: (error as any).body.detail,
});
}
@ -124,14 +125,11 @@ export const sessionCanceled = createAsyncThunk<
>('api/sessionCanceled', async (arg, { rejectWithValue }) => {
const { session_id } = arg;
const { DELETE } = $client.get();
const { data, error, response } = await DELETE(
'/api/v1/sessions/{session_id}/invoke',
{
const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', {
params: {
path: { session_id },
},
}
);
});
if (error) {
return rejectWithValue({ arg, error });
@ -164,7 +162,7 @@ export const listedSessions = createAsyncThunk<
>('api/listSessions', async (arg, { rejectWithValue }) => {
const { params } = arg;
const { GET } = $client.get();
const { data, error, response } = await GET('/api/v1/sessions/', {
const { data, error } = await GET('/api/v1/sessions/', {
params,
});

View File

@ -26,15 +26,21 @@ export const getIsImageInDateRange = (
for (let index = 0; index < totalCachedImageDtos.length; index++) {
const image = totalCachedImageDtos[index];
if (image?.starred) cachedStarredImages.push(image);
if (!image?.starred) cachedUnstarredImages.push(image);
if (image?.starred) {
cachedStarredImages.push(image);
}
if (!image?.starred) {
cachedUnstarredImages.push(image);
}
}
if (imageDTO.starred) {
const lastStarredImage =
cachedStarredImages[cachedStarredImages.length - 1];
// if starring or already starred, want to look in list of starred images
if (!lastStarredImage) return true; // no starred images showing, so always show this one
if (!lastStarredImage) {
return true;
} // no starred images showing, so always show this one
const createdDate = new Date(imageDTO.created_at);
const oldestDate = new Date(lastStarredImage.created_at);
return createdDate >= oldestDate;
@ -42,7 +48,9 @@ export const getIsImageInDateRange = (
const lastUnstarredImage =
cachedUnstarredImages[cachedUnstarredImages.length - 1];
// if unstarring or already unstarred, want to look in list of unstarred images
if (!lastUnstarredImage) return false; // no unstarred images showing, so don't show this one
if (!lastUnstarredImage) {
return false;
} // no unstarred images showing, so don't show this one
const createdDate = new Date(imageDTO.created_at);
const oldestDate = new Date(lastUnstarredImage.created_at);
return createdDate >= oldestDate;

View File

@ -1727,6 +1727,13 @@
resolved "https://registry.yarnpkg.com/@socket.io/component-emitter/-/component-emitter-3.1.0.tgz#96116f2a912e0c02817345b3c10751069920d553"
integrity sha512-+9jVqKhRSpsc591z5vX+X5Yyw+he/HCB4iQ/RYxw35CEPaY1gnsNE43nf9n9AaYjAQrTiI/mOwKUKdUs9vf7Xg==
"@stevebel/png@^1.5.1":
version "1.5.1"
resolved "https://registry.yarnpkg.com/@stevebel/png/-/png-1.5.1.tgz#c1179a2787c7440fc20082d6eff85362450f24f6"
integrity sha512-cUVgrRCgOQLqLpXvV4HffvkITWF1BBgslXkINKfMD2b+GkAbV+PeO6IeMF6k7c6FLvGox6mMLwwqcXKoDha9rw==
dependencies:
pako "^2.1.0"
"@swc/core-darwin-arm64@1.3.70":
version "1.3.70"
resolved "https://registry.yarnpkg.com/@swc/core-darwin-arm64/-/core-darwin-arm64-1.3.70.tgz#056ac6899e22cb7f7be21388d4d938ca5123a72b"
@ -5372,6 +5379,11 @@ p-locate@^5.0.0:
dependencies:
p-limit "^3.0.2"
pako@^2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/pako/-/pako-2.1.0.tgz#266cc37f98c7d883545d11335c00fbd4062c9a86"
integrity sha512-w+eufiZ1WuJYgPXbV/PO3NCMEc3xqylkKHzp8bxp1uW4qaSNQUkwmLLEc3kKsfz8lpV1F8Ht3U1Cm+9Srog2ug==
parent-module@^1.0.0:
version "1.0.1"
resolved "https://registry.yarnpkg.com/parent-module/-/parent-module-1.0.1.tgz#691d2709e78c79fae3a156622452d00762caaaa2"