feat(ui): store graph in image metadata

The previous super-minimal implementation had a major issue - the saved workflow didn't take into account batched field values. When generating with multiple iterations or dynamic prompts, the same workflow with the first prompt, seed, etc was stored in each image.

As a result, when the batch results in multiple queue items, only one of the images has the correct workflow - the others are mismatched.

To work around this, we can store the _graph_ in the image metadata (alongside the workflow, if generated via workflow editor). When loading a workflow from an image, we can choose to load the workflow or the graph, preferring the workflow.

Internally, we need to update images router image-saving services. The changes are minimal.

To avoid pydantic errors deserializing the graph, when we extract it from the image, we will leave it as stringified JSON and let the frontend's more sophisticated and flexible parsing handle it. The worklow is also changed to just return stringified JSON, so the API is consistent.
This commit is contained in:
psychedelicious 2024-05-17 18:10:04 +10:00
parent 66fc110b64
commit 922716d2ab
18 changed files with 510 additions and 158 deletions

View File

@ -12,7 +12,7 @@ from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidato
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies
@ -185,14 +185,21 @@ async def get_image_metadata(
raise HTTPException(status_code=404)
class WorkflowAndGraphResponse(BaseModel):
workflow: Optional[str] = Field(description="The workflow used to generate the image, as stringified JSON")
graph: Optional[str] = Field(description="The graph used to generate the image, as stringified JSON")
@images_router.get(
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
)
async def get_image_workflow(
image_name: str = Path(description="The name of image whose workflow to get"),
) -> Optional[WorkflowWithoutID]:
) -> WorkflowAndGraphResponse:
try:
return ApiDependencies.invoker.services.images.get_workflow(image_name)
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
return WorkflowAndGraphResponse(workflow=workflow, graph=graph)
except Exception:
raise HTTPException(status_code=404)

View File

@ -5,6 +5,7 @@ from typing import Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
@ -35,6 +36,7 @@ class ImageFileStorageBase(ABC):
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
graph: Optional[Graph] = None,
thumbnail_size: int = 256,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
@ -46,6 +48,11 @@ class ImageFileStorageBase(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets the workflow of an image."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets the graph of an image."""
pass

View File

@ -9,6 +9,7 @@ from send2trash import send2trash
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -58,6 +59,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_name: str,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
graph: Optional[Graph] = None,
thumbnail_size: int = 256,
) -> None:
try:
@ -75,6 +77,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
workflow_json = workflow.model_dump_json()
info_dict["invokeai_workflow"] = workflow_json
pnginfo.add_text("invokeai_workflow", workflow_json)
if graph is not None:
graph_json = graph.model_dump_json()
info_dict["invokeai_graph"] = graph_json
pnginfo.add_text("invokeai_graph", graph_json)
# When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict
@ -129,11 +135,18 @@ class DiskImageFileStorage(ImageFileStorageBase):
path = path if isinstance(path, Path) else Path(path)
return path.exists()
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
def get_workflow(self, image_name: str) -> str | None:
image = self.get(image_name)
workflow = image.info.get("invokeai_workflow", None)
if workflow is not None:
return WorkflowWithoutID.model_validate_json(workflow)
if isinstance(workflow, str):
return workflow
return None
def get_graph(self, image_name: str) -> str | None:
image = self.get(image_name)
graph = image.info.get("invokeai_graph", None)
if isinstance(graph, str):
return graph
return None
def __validate_storage_folders(self) -> None:

View File

@ -11,6 +11,7 @@ from invokeai.app.services.image_records.image_records_common import (
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
@ -53,6 +54,7 @@ class ImageServiceABC(ABC):
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
graph: Optional[Graph] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@ -87,7 +89,12 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
def get_workflow(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow."""
pass
@abstractmethod
def get_graph(self, image_name: str) -> Optional[str]:
"""Gets an image's workflow."""
pass

View File

@ -4,6 +4,7 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
@ -44,6 +45,7 @@ class ImageService(ImageServiceABC):
is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None,
workflow: Optional[WorkflowWithoutID] = None,
graph: Optional[Graph] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
raise InvalidOriginException
@ -64,7 +66,7 @@ class ImageService(ImageServiceABC):
image_category=image_category,
width=width,
height=height,
has_workflow=workflow is not None,
has_workflow=workflow is not None or graph is not None,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
@ -75,7 +77,7 @@ class ImageService(ImageServiceABC):
if board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
image_dto = self.get_dto(image_name)
@ -157,7 +159,7 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image metadata")
raise e
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
def get_workflow(self, image_name: str) -> Optional[str]:
try:
return self.__invoker.services.image_files.get_workflow(image_name)
except ImageFileNotFoundException:
@ -167,6 +169,16 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting image workflow")
raise
def get_graph(self, image_name: str) -> Optional[str]:
try:
return self.__invoker.services.image_files.get_graph(image_name)
except ImageFileNotFoundException:
self.__invoker.services.logger.error("Image file not found")
raise
except Exception:
self.__invoker.services.logger.error("Problem getting image graph")
raise
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
try:
return str(self.__invoker.services.image_files.get_path(image_name, thumbnail))

View File

@ -199,6 +199,7 @@ class ImagesInterface(InvocationContextInterface):
metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL,
workflow=self._data.queue_item.workflow,
graph=self._data.queue_item.session.graph,
session_id=self._data.queue_item.session_id,
node_id=self._data.invocation.id,
)

View File

@ -880,6 +880,7 @@
"versionUnknown": " Version Unknown",
"workflow": "Workflow",
"graph": "Graph",
"noGraph": "No Graph",
"workflowAuthor": "Author",
"workflowContact": "Contact",
"workflowDescription": "Short Description",

View File

@ -4,31 +4,49 @@ import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import type { GraphAndWorkflowResponse, NonNullableGraph } from 'services/api/types';
import { z } from 'zod';
import { fromZodError } from 'zod-validation-error';
const getWorkflow = (data: GraphAndWorkflowResponse, templates: Templates) => {
if (data.workflow) {
// Prefer to load the workflow if it's available - it has more information
const parsed = JSON.parse(data.workflow);
return validateWorkflow(parsed, templates);
} else if (data.graph) {
// Else we fall back on the graph, using the graphToWorkflow function to convert and do layout
const parsed = JSON.parse(data.graph);
const workflow = graphToWorkflow(parsed as NonNullableGraph, true);
return validateWorkflow(workflow, templates);
} else {
throw new Error('No workflow or graph provided');
}
};
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch }) => {
const log = logger('nodes');
const { workflow, asCopy } = action.payload;
const { data, asCopy } = action.payload;
const nodeTemplates = $templates.get();
try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);
const { workflow, warnings } = getWorkflow(data, nodeTemplates);
if (asCopy) {
// If we're loading a copy, we need to remove the ID so that the backend will create a new workflow
delete validatedWorkflow.id;
delete workflow.id;
}
dispatch(workflowLoaded(validatedWorkflow));
dispatch(workflowLoaded(workflow));
if (!warnings.length) {
dispatch(
addToast(

View File

@ -0,0 +1,34 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
import type { ImageDTO } from 'services/api/types';
import DataViewer from './DataViewer';
type Props = {
image: ImageDTO;
};
const ImageMetadataGraphTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { currentData } = useDebouncedImageWorkflow(image);
const graph = useMemo(() => {
if (currentData?.graph) {
try {
return JSON.parse(currentData.graph);
} catch {
return null;
}
}
return null;
}, [currentData]);
if (!graph) {
return <IAINoContentFallback label={t('nodes.noGraph')} />;
}
return <DataViewer data={graph} label={t('nodes.graph')} />;
};
export default memo(ImageMetadataGraphTabContent);

View File

@ -1,6 +1,7 @@
import { ExternalLink, Flex, Tab, TabList, TabPanel, TabPanels, Tabs, Text } from '@invoke-ai/ui-library';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import ImageMetadataGraphTabContent from 'features/gallery/components/ImageMetadataViewer/ImageMetadataGraphTabContent';
import { useMetadataItem } from 'features/metadata/hooks/useMetadataItem';
import { handlers } from 'features/metadata/util/handlers';
import { memo } from 'react';
@ -52,6 +53,7 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<Tab>{t('metadata.metadata')}</Tab>
<Tab>{t('metadata.imageDetails')}</Tab>
<Tab>{t('metadata.workflow')}</Tab>
<Tab>{t('nodes.graph')}</Tab>
</TabList>
<TabPanels>
@ -81,6 +83,9 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
<TabPanel>
<ImageMetadataWorkflowTabContent image={image} />
</TabPanel>
<TabPanel>
<ImageMetadataGraphTabContent image={image} />
</TabPanel>
</TabPanels>
</Tabs>
</Flex>

View File

@ -1,5 +1,5 @@
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useDebouncedImageWorkflow } from 'services/api/hooks/useDebouncedImageWorkflow';
import type { ImageDTO } from 'services/api/types';
@ -12,7 +12,17 @@ type Props = {
const ImageMetadataWorkflowTabContent = ({ image }: Props) => {
const { t } = useTranslation();
const { workflow } = useDebouncedImageWorkflow(image);
const { currentData } = useDebouncedImageWorkflow(image);
const workflow = useMemo(() => {
if (currentData?.workflow) {
try {
return JSON.parse(currentData.workflow);
} catch {
return null;
}
}
return null;
}, [currentData]);
if (!workflow) {
return <IAINoContentFallback label={t('nodes.noWorkflow')} />;

View File

@ -1,6 +1,6 @@
import { createAction, isAnyOf } from '@reduxjs/toolkit';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import type { Graph } from 'services/api/types';
import type { Graph, GraphAndWorkflowResponse } from 'services/api/types';
const textToImageGraphBuilt = createAction<Graph>('nodes/textToImageGraphBuilt');
const imageToImageGraphBuilt = createAction<Graph>('nodes/imageToImageGraphBuilt');
@ -15,7 +15,7 @@ export const isAnyGraphBuilt = isAnyOf(
);
export const workflowLoadRequested = createAction<{
workflow: unknown;
data: GraphAndWorkflowResponse;
asCopy: boolean;
}>('nodes/workflowLoadRequested');

View File

@ -58,8 +58,7 @@ export const LoadWorkflowFromGraphModal = () => {
setWorkflowRaw(JSON.stringify(workflow, null, 2));
}, [graphRaw, shouldAutoLayout]);
const loadWorkflow = useCallback(() => {
const workflow = JSON.parse(workflowRaw);
dispatch(workflowLoadRequested({ workflow, asCopy: true }));
dispatch(workflowLoadRequested({ data: { workflow: workflowRaw, graph: null }, asCopy: true }));
onClose();
}, [dispatch, onClose, workflowRaw]);
return (

View File

@ -27,10 +27,17 @@ export const useGetAndLoadEmbeddedWorkflow: UseGetAndLoadEmbeddedWorkflow = ({ o
const getAndLoadEmbeddedWorkflow = useCallback(
async (imageName: string) => {
try {
const workflow = await _getAndLoadEmbeddedWorkflow(imageName);
dispatch(workflowLoadRequested({ workflow: workflow.data, asCopy: true }));
// No toast - the listener for this action does that after the workflow is loaded
onSuccess && onSuccess();
const { data } = await _getAndLoadEmbeddedWorkflow(imageName);
if (data) {
dispatch(workflowLoadRequested({ data, asCopy: true }));
// No toast - the listener for this action does that after the workflow is loaded
onSuccess && onSuccess();
} else {
toaster({
title: t('toast.problemRetrievingWorkflow'),
status: 'error',
});
}
} catch {
toaster({
title: t('toast.problemRetrievingWorkflow'),

View File

@ -10,6 +10,7 @@ import { keyBy } from 'lodash-es';
import type { components, paths } from 'services/api/schema';
import type {
DeleteBoardResult,
GraphAndWorkflowResponse,
ImageCategory,
ImageDTO,
ListImagesArgs,
@ -122,10 +123,7 @@ export const imagesApi = api.injectEndpoints({
providesTags: (result, error, image_name) => [{ type: 'ImageMetadata', id: image_name }],
keepUnusedDataFor: 86400, // 24 hours
}),
getImageWorkflow: build.query<
paths['/api/v1/images/i/{image_name}/workflow']['get']['responses']['200']['content']['application/json'],
string
>({
getImageWorkflow: build.query<GraphAndWorkflowResponse, string>({
query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}/workflow`) }),
providesTags: (result, error, image_name) => [{ type: 'ImageWorkflow', id: image_name }],
keepUnusedDataFor: 86400, // 24 hours

View File

@ -9,7 +9,7 @@ export const useDebouncedImageWorkflow = (imageDTO?: ImageDTO | null) => {
const [debouncedImageName] = useDebounce(imageDTO?.has_workflow ? imageDTO.image_name : null, workflowFetchDebounce);
const { data: workflow, isLoading } = useGetImageWorkflowQuery(debouncedImageName ?? skipToken);
const result = useGetImageWorkflowQuery(debouncedImageName ?? skipToken);
return { workflow, isLoading };
return result;
};

File diff suppressed because one or more lines are too long

View File

@ -16,6 +16,9 @@ export type UpdateBoardArg = paths['/api/v1/boards/{board_id}']['patch']['parame
changes: paths['/api/v1/boards/{board_id}']['patch']['requestBody']['content']['application/json'];
};
export type GraphAndWorkflowResponse =
paths['/api/v1/images/i/{image_name}/workflow']['get']['responses']['200']['content']['application/json'];
export type BatchConfig =
paths['/api/v1/queue/{queue_id}/enqueue_batch']['post']['requestBody']['content']['application/json'];