feat(ui): store graph in image metadata

The previous super-minimal implementation had a major issue - the saved workflow didn't take into account batched field values. When generating with multiple iterations or dynamic prompts, the same workflow with the first prompt, seed, etc was stored in each image.

As a result, when the batch results in multiple queue items, only one of the images has the correct workflow - the others are mismatched.

To work around this, we can store the _graph_ in the image metadata (alongside the workflow, if generated via workflow editor). When loading a workflow from an image, we can choose to load the workflow or the graph, preferring the workflow.

Internally, we need to update images router image-saving services. The changes are minimal.

To avoid pydantic errors deserializing the graph, when we extract it from the image, we will leave it as stringified JSON and let the frontend's more sophisticated and flexible parsing handle it. The worklow is also changed to just return stringified JSON, so the API is consistent.
This commit is contained in:
psychedelicious 2024-05-17 18:10:04 +10:00
parent 66fc110b64
commit 922716d2ab
18 changed files with 510 additions and 158 deletions
invokeai
app
frontend/web
public/locales
src
app/store/middleware/listenerMiddleware/listeners
features
services/api

@ -12,7 +12,7 @@ from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidato
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies
@ -185,14 +185,21 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
class WorkflowAndGraphResponse(BaseModel):
workflow: Optional[str] = Field(description="The workflow used to generate the image, as stringified JSON")
graph: Optional[str] = Field(description="The graph used to generate the image, as stringified JSON")
@images_router.get(
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
)
async def get_image_workflow(
image_name: str = Path(description="The name of image whose workflow to get"),
) -> Optional[WorkflowWithoutID]:
) -> WorkflowAndGraphResponse:
try:
return ApiDependencies.invoker.services.images.get_workflow(image_name)
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
return WorkflowAndGraphResponse(workflow=workflow, graph=graph)
except Exception:
raise HTTPException(status_code=404)

@ -5,6 +5,7 @@ from typing import Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
@ -35,6 +36,7 @@ class ImageFileStorageBase(ABC):
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
graph: Optional[Graph] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -46,6 +48,11 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets the workflow of an image."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets the graph of an image."""
pass

@ -9,6 +9,7 @@ from send2trash import send2trash
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -58,6 +59,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
graph: Optional[Graph] = None,
thumbnail_size: int = 256,
) -> None:
try:
@ -75,6 +77,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
workflow_json = workflow.model_dump_json()
info_dict["invokeai_workflow"] = workflow_json
pnginfo.add_text("invokeai_workflow", workflow_json)
if graph is not None:
graph_json = graph.model_dump_json()
info_dict["invokeai_graph"] = graph_json
pnginfo.add_text("invokeai_graph", graph_json)
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
@ -129,11 +135,18 @@ class DiskImageFileStorage(ImageFileStorageBase):
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
def get_workflow(self, image_name: str) -> str | None:
image = self.get(image_name)
workflow = image.info.get("invokeai_workflow", None)
if workflow is not None:
return WorkflowWithoutID.model_validate_json(workflow)
if isinstance(workflow, str):
return workflow
return None
def get_graph(self, image_name: str) -> str | None:
image = self.get(image_name)
graph = image.info.get("invokeai_graph", None)
if isinstance(graph, str):
return graph
return None
def __validate_storage_folders(self) -> None:

@ -11,6 +11,7 @@ from invokeai.app.services.image_records.image_records_common import (
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
@ -53,6 +54,7 @@ class ImageServiceABC(ABC):
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
graph: Optional[Graph] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -87,7 +89,12 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow."""
pass

@ -4,6 +4,7 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
@ -44,6 +45,7 @@ class ImageService(ImageServiceABC):
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
graph: Optional[Graph] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
@ -64,7 +66,7 @@ class ImageService(ImageServiceABC):
image_category=image_category,
width=width,
height=height,
has_workflow=workflow is not None,
has_workflow=workflow is not None or graph is not None,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
@ -75,7 +77,7 @@ class ImageService(ImageServiceABC):
if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
image_dto = self.get_dto(image_name)
@ -157,7 +159,7 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image metadata")
raise e
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
def get_workflow(self, image_name: str) -> Optional[str]:
try:
return self.__invoker.services.image_files.get_workflow(image_name)
except ImageFileNotFoundException:
@ -167,6 +169,16 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image workflow")
raise
def get_graph(self, image_name: str) -> Optional[str]:
try:
return self.__invoker.services.image_files.get_graph(image_name)
except ImageFileNotFoundException:
self.__invoker.services.logger.error("Image file not found")
raise
except Exception:
self.__invoker.services.logger.error("Problem getting image graph")
raise
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))

@ -199,6 +199,7 @@ class ImagesInterface(InvocationContextInterface):
metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL,
workflow=self._data.queue_item.workflow,
graph=self._data.queue_item.session.graph,
session_id=self._data.queue_item.session_id,
node_id=self._data.invocation.id,
)

@ -880,6 +880,7 @@
"versionUnknown": " Version Unknown",
"workflow": "Workflow",
"graph": "Graph",
"noGraph": "No Graph",
"workflowAuthor": "Author",
"workflowContact": "Contact",
"workflowDescription": "Short Description",

@ -4,31 +4,49 @@ import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
if (data.workflow) {
// Prefer to load the workflow if it's available - it has more information
const parsed = JSON.parse(data.workflow);
return validateWorkflow(parsed, templates);
} else if (data.graph) {
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
const parsed = JSON.parse(data.graph);
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
return validateWorkflow(workflow, templates);
} else {
throw new Error('No workflow or graph provided');
}
};
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch }) => {
const log = logger('nodes');
const { workflow, asCopy } = action.payload;
const { data, asCopy } = action.payload;
const nodeTemplates = $templates.get();
try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
const { workflow, warnings } = getWorkflow(data, nodeTemplates);
if (asCopy) {
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
delete validatedWorkflow.id;
delete workflow.id;
}
dispatch(workflowLoaded(validatedWorkflow));
dispatch(workflowLoaded(workflow));
if (!warnings.length) {
dispatch(
addToast(

@ -0,0 +1,34 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
import type { ImageDTO } from 'services/api/types';
import DataViewer from './DataViewer';
type Props = {
image: ImageDTO;
};
const ImageMetadataGraphTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { currentData } = useDebouncedImageWorkflow(image);
const graph = useMemo(() => {
if (currentData?.graph) {
try {
return JSON.parse(currentData.graph);
} catch {
return null;
}
}
return null;
}, [currentData]);
if (!graph) {
return <IAINoContentFallback label={t('nodes.noGraph')} />;
}
return <DataViewer data={graph} label={t('nodes.graph')} />;
};
export default memo(ImageMetadataGraphTabContent);

@ -1,6 +1,7 @@
import { ExternalLink, Flex, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import ImageMetadataGraphTabContent from 'features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent';
import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem';
import { handlers } from 'features/metadata/util/handlers';
import { memo } from 'react';
@ -52,6 +53,7 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<Tab>{t('metadata.metadata')}</Tab>
<Tab>{t('metadata.imageDetails')}</Tab>
<Tab>{t('metadata.workflow')}</Tab>
<Tab>{t('nodes.graph')}</Tab>
</TabList>
<TabPanels>
@ -81,6 +83,9 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<TabPanel>
<ImageMetadataWorkflowTabContent image={image} />
</TabPanel>
<TabPanel>
<ImageMetadataGraphTabContent image={image} />
</TabPanel>
</TabPanels>
</Tabs>
</Flex>

@ -1,5 +1,5 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
import type { ImageDTO } from 'services/api/types';
@ -12,7 +12,17 @@ type Props = {
const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { workflow } = useDebouncedImageWorkflow(image);
const { currentData } = useDebouncedImageWorkflow(image);
const workflow = useMemo(() => {
if (currentData?.workflow) {
try {
return JSON.parse(currentData.workflow);
} catch {
return null;
}
}
return null;
}, [currentData]);
if (!workflow) {
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;

@ -1,6 +1,6 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { Graph } from 'services/api/types';
import type { Graph, GraphAndWorkflowResponse } from 'services/api/types';
const textToImageGraphBuilt = createAction<Graph>('nodes/textToImageGraphBuilt');
const imageToImageGraphBuilt = createAction<Graph>('nodes/imageToImageGraphBuilt');
@ -15,7 +15,7 @@ export const isAnyGraphBuilt = isAnyOf(
);
export const workflowLoadRequested = createAction<{
workflow: unknown;
data: GraphAndWorkflowResponse;
asCopy: boolean;
}>('nodes/workflowLoadRequested');

@ -58,8 +58,7 @@ export const LoadWorkflowFromGraphModal = () => {
setWorkflowRaw(JSON.stringify(workflow, null, 2));
}, [graphRaw, shouldAutoLayout]);
const loadWorkflow = useCallback(() => {
const workflow = JSON.parse(workflowRaw);
dispatch(workflowLoadRequested({ workflow, asCopy: true }));
dispatch(workflowLoadRequested({ data: { workflow: workflowRaw, graph: null }, asCopy: true }));
onClose();
}, [dispatch, onClose, workflowRaw]);
return (

@ -27,10 +27,17 @@ export const useGetAndLoadEmbeddedWorkflow: UseGetAndLoadEmbeddedWorkflow = ({ o
const getAndLoadEmbeddedWorkflow = useCallback(
async (imageName: string) => {
try {
const workflow = await _getAndLoadEmbeddedWorkflow(imageName);
dispatch(workflowLoadRequested({ workflow: workflow.data, asCopy: true }));
// No toast - the listener for this action does that after the workflow is loaded
onSuccess && onSuccess();
const { data } = await _getAndLoadEmbeddedWorkflow(imageName);
if (data) {
dispatch(workflowLoadRequested({ data, asCopy: true }));
// No toast - the listener for this action does that after the workflow is loaded
onSuccess && onSuccess();
} else {
toaster({
title: t('toast.problemRetrievingWorkflow'),
status: 'error',
});
}
} catch {
toaster({
title: t('toast.problemRetrievingWorkflow'),

@ -10,6 +10,7 @@ import { keyBy } from 'lodash-es';
import type { components, paths } from 'services/api/schema';
import type {
DeleteBoardResult,
GraphAndWorkflowResponse,
ImageCategory,
ImageDTO,
ListImagesArgs,
@ -122,10 +123,7 @@ export const imagesApi = api.injectEndpoints({
providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }],
keepUnusedDataFor: 86400, // 24 hours
}),
getImageWorkflow: build.query<
paths['/api/v1/images/i/{image_name}/workflow']['get']['responses']['200']['content']['application/json'],
string
>({
getImageWorkflow: build.query<GraphAndWorkflowResponse, string>({
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/workflow`) }),
providesTags: (result, error, image_name) => [{ type: 'ImageWorkflow', id: image_name }],
keepUnusedDataFor: 86400, // 24 hours

@ -9,7 +9,7 @@ export const useDebouncedImageWorkflow = (imageDTO?: ImageDTO | null) => {
const [debouncedImageName] = useDebounce(imageDTO?.has_workflow ? imageDTO.image_name : null, workflowFetchDebounce);
const { data: workflow, isLoading } = useGetImageWorkflowQuery(debouncedImageName ?? skipToken);
const result = useGetImageWorkflowQuery(debouncedImageName ?? skipToken);
return { workflow, isLoading };
return result;
};

File diff suppressed because one or more lines are too long

@ -16,6 +16,9 @@ export type UpdateBoardArg = paths['/api/v1/boards/{board_id}']['patch']['parame
changes: paths['/api/v1/boards/{board_id}']['patch']['requestBody']['content']['application/json'];
};
export type GraphAndWorkflowResponse =
paths['/api/v1/images/i/{image_name}/workflow']['get']['responses']['200']['content']['application/json'];
export type BatchConfig =
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json'];