feat: save workflow to images db

- Add `workflow` column to `images` table
- Revise image saving and uploading logic to save workflow and metadata to db
- Update UI queries to fetch metadata and workflow from db instead of file
This commit is contained in:
psychedelicious 2023-09-22 23:09:05 +10:00
parent b152fbf72f
commit 78dda533e2
10 changed files with 146 additions and 127 deletions

View File

@ -45,13 +45,17 @@ async def upload_image(
if not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
contents = await file.read()
metadata: Optional[str] = None
workflow: Optional[str] = None
contents = await file.read()
try:
pil_image = Image.open(io.BytesIO(contents))
if crop_visible:
bbox = pil_image.getbbox()
pil_image = pil_image.crop(bbox)
metadata = pil_image.info.get("invokeai_metadata", None)
workflow = pil_image.info.get("invokeai_workflow", None)
except Exception:
# Error opening the image
raise HTTPException(status_code=415, detail="Failed to read image")
@ -63,6 +67,8 @@ async def upload_image(
image_category=image_category,
session_id=session_id,
board_id=board_id,
metadata=metadata,
workflow=workflow,
is_intermediate=is_intermediate,
)

View File

@ -85,11 +85,8 @@ class CoreMetadata(BaseModelExcludeNull):
class ImageMetadata(BaseModelExcludeNull):
"""An image's generation metadata"""
metadata: Optional[dict] = Field(
default=None,
description="The image's core metadata, if it was created in the Linear or Canvas UI",
)
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
metadata: Optional[dict] = Field(default=None, description="The metadata associated with the image")
workflow: Optional[dict] = Field(default=None, description="The workflow associated with the image")
@invocation_output("metadata_accumulator_output")

View File

@ -59,7 +59,7 @@ class ImageFileStorageBase(ABC):
self,
image: PILImageType,
image_name: str,
metadata: Optional[dict] = None,
metadata: Optional[Union[str, dict]] = None,
workflow: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
@ -109,7 +109,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
self,
image: PILImageType,
image_name: str,
metadata: Optional[dict] = None,
metadata: Optional[Union[str, dict]] = None,
workflow: Optional[str] = None,
thumbnail_size: int = 256,
) -> None:
@ -119,20 +119,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None or workflow is not None:
if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
pnginfo.add_text("invokeai_metadata", json.dumps(metadata) if type(metadata) is dict else metadata)
if workflow is not None:
pnginfo.add_text("invokeai_workflow", workflow)
else:
# For uploaded images, we want to retain metadata. PIL strips it on save; manually add it back
# TODO: retain non-invokeai metadata on save...
original_metadata = image.info.get("invokeai_metadata", None)
if original_metadata is not None:
pnginfo.add_text("invokeai_metadata", original_metadata)
original_workflow = image.info.get("invokeai_workflow", None)
if original_workflow is not None:
pnginfo.add_text("invokeai_workflow", original_workflow)
image.save(image_path, "PNG", pnginfo=pnginfo)

View File

@ -3,11 +3,12 @@ import sqlite3
import threading
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Generic, Optional, TypeVar, cast
from typing import Generic, Optional, TypeVar, Union, cast
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from invokeai.app.invocations.metadata import ImageMetadata
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.services.models.image_record import ImageRecord, ImageRecordChanges, deserialize_image_record
@ -81,7 +82,7 @@ class ImageRecordStorageBase(ABC):
pass
@abstractmethod
def get_metadata(self, image_name: str) -> Optional[dict]:
def get_metadata(self, image_name: str) -> ImageMetadata:
"""Gets an image's metadata'."""
pass
@ -134,7 +135,8 @@ class ImageRecordStorageBase(ABC):
height: int,
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[dict],
metadata: Optional[Union[str, dict]],
workflow: Optional[str],
is_intermediate: bool = False,
starred: bool = False,
) -> datetime:
@ -204,6 +206,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
)
if "workflow" not in columns:
self._cursor.execute(
"""--sql
ALTER TABLE images ADD COLUMN workflow TEXT;
"""
)
# Create the `images` table indices.
self._cursor.execute(
"""--sql
@ -269,22 +278,31 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[dict]:
def get_metadata(self, image_name: str) -> ImageMetadata:
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT images.metadata FROM images
SELECT metadata, workflow FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], self._cursor.fetchone())
if not result or not result[0]:
return None
return json.loads(result[0])
if not result:
return ImageMetadata()
as_dict = dict(result)
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
workflow_raw = cast(Optional[str], as_dict.get("workflow", None))
return ImageMetadata(
metadata=json.loads(metadata_raw) if metadata_raw is not None else None,
workflow=json.loads(workflow_raw) if workflow_raw is not None else None,
)
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordNotFoundException from e
@ -519,12 +537,15 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
width: int,
height: int,
node_id: Optional[str],
metadata: Optional[dict],
metadata: Optional[Union[str, dict]],
workflow: Optional[str],
is_intermediate: bool = False,
starred: bool = False,
) -> datetime:
try:
metadata_json = None if metadata is None else json.dumps(metadata)
metadata_json: Optional[str] = None
if metadata is not None:
metadata_json = metadata if type(metadata) is str else json.dumps(metadata)
self._lock.acquire()
self._cursor.execute(
"""--sql
@ -537,10 +558,11 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id,
session_id,
metadata,
workflow,
is_intermediate,
starred
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
@ -551,6 +573,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id,
session_id,
metadata_json,
workflow,
is_intermediate,
starred,
),

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from logging import Logger
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, Union
from PIL.Image import Image as PILImageType
@ -29,7 +29,6 @@ from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.models.image_record import ImageDTO, ImageRecord, ImageRecordChanges, image_record_to_dto
from invokeai.app.services.resource_name import NameServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.metadata import get_metadata_graph_from_raw_session
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
@ -71,7 +70,7 @@ class ImageServiceABC(ABC):
session_id: Optional[str] = None,
board_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
metadata: Optional[Union[str, dict]] = None,
workflow: Optional[str] = None,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
@ -196,7 +195,7 @@ class ImageService(ImageServiceABC):
session_id: Optional[str] = None,
board_id: Optional[str] = None,
is_intermediate: bool = False,
metadata: Optional[dict] = None,
metadata: Optional[Union[str, dict]] = None,
workflow: Optional[str] = None,
) -> ImageDTO:
if image_origin not in ResourceOrigin:
@ -234,6 +233,7 @@ class ImageService(ImageServiceABC):
# Nullable fields
node_id=node_id,
metadata=metadata,
workflow=workflow,
session_id=session_id,
)
if board_id is not None:
@ -311,23 +311,7 @@ class ImageService(ImageServiceABC):
def get_metadata(self, image_name: str) -> Optional[ImageMetadata]:
try:
image_record = self._services.image_records.get(image_name)
metadata = self._services.image_records.get_metadata(image_name)
if not image_record.session_id:
return ImageMetadata(metadata=metadata)
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
graph = None
if session_raw:
try:
graph = get_metadata_graph_from_raw_session(session_raw)
except Exception as e:
self._services.logger.warn(f"Failed to parse session graph: {e}")
graph = None
return ImageMetadata(graph=graph, metadata=metadata)
return self._services.image_records.get_metadata(image_name)
except ImageRecordNotFoundException:
self._services.logger.error("Image record not found")
raise

View File

@ -28,7 +28,7 @@ import {
setShouldShowImageDetails,
setShouldShowProgressInViewer,
} from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import {
@ -41,9 +41,10 @@ import {
import { FaCircleNodes, FaEllipsis } from 'react-icons/fa6';
import {
useGetImageDTOQuery,
useGetImageMetadataFromFileQuery,
useGetImageMetadataQuery,
} from 'services/api/endpoints/images';
import { menuListMotionProps } from 'theme/components/menu';
import { useDebounce } from 'use-debounce';
import { sentImageToImg2Img } from '../../store/actions';
import SingleSelectionMenuItems from '../ImageContextMenu/SingleSelectionMenuItems';
@ -92,7 +93,6 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
shouldShowImageDetails,
lastSelectedImage,
shouldShowProgressInViewer,
shouldFetchMetadataFromApi,
} = useAppSelector(currentImageButtonsSelector);
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
@ -107,16 +107,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
lastSelectedImage?.image_name ?? skipToken
);
const getMetadataArg = useMemo(() => {
if (lastSelectedImage) {
return { image: lastSelectedImage, shouldFetchMetadataFromApi };
} else {
return skipToken;
}
}, [lastSelectedImage, shouldFetchMetadataFromApi]);
const [debouncedImageName] = useDebounce(lastSelectedImage?.image_name, 300);
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
getMetadataArg,
const { metadata, workflow, isLoading } = useGetImageMetadataQuery(
debouncedImageName ?? skipToken,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,

View File

@ -1,8 +1,9 @@
import { Flex, MenuItem, Spinner } from '@chakra-ui/react';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { useAppToaster } from 'app/components/Toaster';
import { $customStarUI } from 'app/store/nanostores/customStarUI';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch } from 'app/store/storeHooks';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
imagesToChangeSelected,
@ -32,12 +33,12 @@ import {
import { FaCircleNodes } from 'react-icons/fa6';
import { MdStar, MdStarBorder } from 'react-icons/md';
import {
useGetImageMetadataFromFileQuery,
useGetImageMetadataQuery,
useStarImagesMutation,
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { configSelector } from '../../../system/store/configSelectors';
import { useDebounce } from 'use-debounce';
import { sentImageToCanvas, sentImageToImg2Img } from '../../store/actions';
type SingleSelectionMenuItemsProps = {
@ -53,11 +54,12 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const toaster = useAppToaster();
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
const customStarUi = useStore($customStarUI);
const { metadata, workflow, isLoading } = useGetImageMetadataFromFileQuery(
{ image: imageDTO, shouldFetchMetadataFromApi },
const [debouncedImageName] = useDebounce(imageDTO.image_name, 300);
const { metadata, workflow, isLoading } = useGetImageMetadataQuery(
debouncedImageName ?? skipToken,
{
selectFromResult: (res) => ({
isLoading: res.isFetching,

View File

@ -9,15 +9,15 @@ import {
Tabs,
Text,
} from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { memo } from 'react';
import { useGetImageMetadataFromFileQuery } from 'services/api/endpoints/images';
import { useTranslation } from 'react-i18next';
import { useGetImageMetadataQuery } from 'services/api/endpoints/images';
import { ImageDTO } from 'services/api/types';
import { useDebounce } from 'use-debounce';
import DataViewer from './DataViewer';
import ImageMetadataActions from './ImageMetadataActions';
import { useAppSelector } from '../../../../app/store/storeHooks';
import { configSelector } from '../../../system/store/configSelectors';
import { useTranslation } from 'react-i18next';
type ImageMetadataViewerProps = {
image: ImageDTO;
@ -31,10 +31,10 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => {
// });
const { t } = useTranslation();
const { shouldFetchMetadataFromApi } = useAppSelector(configSelector);
const [debouncedImageName] = useDebounce(image.image_name, 300);
const { metadata, workflow } = useGetImageMetadataFromFileQuery(
{ image, shouldFetchMetadataFromApi },
const { metadata, workflow } = useGetImageMetadataQuery(
debouncedImageName ?? skipToken,
{
selectFromResult: (res) => ({
metadata: res?.currentData?.metadata,

View File

@ -10,6 +10,7 @@ import {
import {
ImageMetadataAndWorkflow,
zCoreMetadata,
zWorkflow,
} from 'features/nodes/types/types';
import { getMetadataAndWorkflowFromImageBlob } from 'features/nodes/util/getMetadataAndWorkflowFromImageBlob';
import { keyBy } from 'lodash-es';
@ -23,7 +24,6 @@ import {
ListImagesArgs,
OffsetPaginatedResults_ImageDTO_,
PostUploadAction,
UnsafeImageMetadata,
} from '../types';
import {
getCategories,
@ -33,6 +33,7 @@ import {
imagesSelectors,
} from '../util';
import { boardsApi } from './boards';
import { logger } from 'app/logging/logger';
export const imagesApi = api.injectEndpoints({
endpoints: (build) => ({
@ -113,11 +114,33 @@ export const imagesApi = api.injectEndpoints({
],
keepUnusedDataFor: 86400, // 24 hours
}),
getImageMetadata: build.query<UnsafeImageMetadata, string>({
getImageMetadata: build.query<ImageMetadataAndWorkflow, string>({
query: (image_name) => ({ url: `images/i/${image_name}/metadata` }),
providesTags: (result, error, image_name) => [
{ type: 'ImageMetadata', id: image_name },
],
transformResponse: (
response: paths['/api/v1/images/i/{image_name}/metadata']['get']['responses']['200']['content']['application/json']
) => {
const imageMetadataAndWorkflow: ImageMetadataAndWorkflow = {};
if (response?.metadata) {
const metadataResult = zCoreMetadata.safeParse(response.metadata);
if (metadataResult.success) {
imageMetadataAndWorkflow.metadata = metadataResult.data;
} else {
logger('images').warn('Problem parsing metadata');
}
}
if (response?.workflow) {
const workflowResult = zWorkflow.safeParse(response.workflow);
if (workflowResult.success) {
imageMetadataAndWorkflow.workflow = workflowResult.data;
} else {
logger('images').warn('Problem parsing workflow');
}
}
return imageMetadataAndWorkflow;
},
keepUnusedDataFor: 86400, // 24 hours
}),
getImageMetadataFromFile: build.query<

View File

@ -1287,6 +1287,11 @@ export type components = {
* @default true
*/
use_cache?: boolean;
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
/**
* Skipped Layers
* @description Number of layers to skip in text encoder
@ -1299,11 +1304,6 @@ export type components = {
* @enum {string}
*/
type: "clip_skip";
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
};
/**
* ClipSkipInvocationOutput
@ -3916,14 +3916,14 @@ export type components = {
ImageMetadata: {
/**
* Metadata
* @description The image's core metadata, if it was created in the Linear or Canvas UI
* @description The metadata associated with the image
*/
metadata?: Record<string, never>;
/**
* Graph
* @description The graph that created the image
* Workflow
* @description The workflow associated with the image
*/
graph?: Record<string, never>;
workflow?: Record<string, never>;
};
/**
* Multiply Images
@ -7550,6 +7550,11 @@ export type components = {
* @default false
*/
use_cache?: boolean;
/**
* Image
* @description The image to load
*/
image?: components["schemas"]["ImageField"];
/**
* Metadata
* @description Optional core metadata to be written to image
@ -7561,11 +7566,6 @@ export type components = {
* @enum {string}
*/
type: "save_image";
/**
* Image
* @description The image to load
*/
image?: components["schemas"]["ImageField"];
};
/**
* Scale Latents
@ -7862,16 +7862,6 @@ export type components = {
* @description The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed.
*/
session_id: string;
/**
* Field Values
* @description The field values that were used for this queue item
*/
field_values?: components["schemas"]["NodeFieldValue"][];
/**
* Queue Id
* @description The id of the queue with which this item is associated
*/
queue_id: string;
/**
* Error
* @description The error message if this queue item errored
@ -7897,6 +7887,16 @@ export type components = {
* @description When this queue item was completed
*/
completed_at?: string;
/**
* Queue Id
* @description The id of the queue with which this item is associated
*/
queue_id: string;
/**
* Field Values
* @description The field values that were used for this queue item
*/
field_values?: components["schemas"]["NodeFieldValue"][];
/**
* Session
* @description The fully-populated session to be executed
@ -7936,16 +7936,6 @@ export type components = {
* @description The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed.
*/
session_id: string;
/**
* Field Values
* @description The field values that were used for this queue item
*/
field_values?: components["schemas"]["NodeFieldValue"][];
/**
* Queue Id
* @description The id of the queue with which this item is associated
*/
queue_id: string;
/**
* Error
* @description The error message if this queue item errored
@ -7971,6 +7961,16 @@ export type components = {
* @description When this queue item was completed
*/
completed_at?: string;
/**
* Queue Id
* @description The id of the queue with which this item is associated
*/
queue_id: string;
/**
* Field Values
* @description The field values that were used for this queue item
*/
field_values?: components["schemas"]["NodeFieldValue"][];
};
/** SessionQueueStatus */
SessionQueueStatus: {
@ -9095,6 +9095,12 @@ export type components = {
/** Ui Order */
ui_order?: number;
};
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
@ -9107,18 +9113,18 @@ export type components = {
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* CLIPVisionModelFormat
* @description An enumeration.
* @enum {string}
*/
CLIPVisionModelFormat: "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* CLIPVisionModelFormat
* @description An enumeration.
* @enum {string}
*/
CLIPVisionModelFormat: "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
@ -9131,12 +9137,6 @@ export type components = {
* @enum {string}
*/
IPAdapterModelFormat: "invokeai";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;