feat(ui): handle new image origin/category setup

- Update all thunks & network related things
- Update gallery

What I have not done yet is rename the gallery tabs and the relevant slices, but I believe the functionality is all there.

Also I fixed several bugs along the way but couldn't really commit them separately bc I was refactoring. Can't remember what they were, but related to the gallery image switching.
This commit is contained in:
psychedelicious 2023-05-27 21:46:03 +10:00 committed by Kent Keirsey
parent d78e3572e3
commit 29fcc92da9
29 changed files with 181 additions and 345 deletions

View File

@ -67,6 +67,10 @@ import {
addReceivedUploadImagesPageFulfilledListener, addReceivedUploadImagesPageFulfilledListener,
addReceivedUploadImagesPageRejectedListener, addReceivedUploadImagesPageRejectedListener,
} from './listeners/receivedUploadImages'; } from './listeners/receivedUploadImages';
import {
addImageUpdatedFulfilledListener,
addImageUpdatedRejectedListener,
} from './listeners/imageUpdated';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -90,6 +94,11 @@ export type AppListenerEffect = ListenerEffect<
addImageUploadedFulfilledListener(); addImageUploadedFulfilledListener();
addImageUploadedRejectedListener(); addImageUploadedRejectedListener();
// Image updated
addImageUpdatedFulfilledListener();
addImageUpdatedRejectedListener();
// Image selected
addInitialImageSelectedListener(); addInitialImageSelectedListener();
// Image deleted // Image deleted

View File

@ -57,7 +57,6 @@ export const addCanvasMergedListener = () => {
}, },
imageCategory: 'general', imageCategory: 'general',
isIntermediate: true, isIntermediate: true,
showInGallery: false,
}) })
); );

View File

@ -37,8 +37,7 @@ export const addCanvasSavedToGalleryListener = () => {
file: new File([blob], filename, { type: 'image/png' }), file: new File([blob], filename, { type: 'image/png' }),
}, },
imageCategory: 'general', imageCategory: 'general',
isIntermediate: false, isIntermediate: true,
showInGallery: true,
}) })
); );

View File

@ -4,8 +4,15 @@ import { imageDeleted } from 'services/thunks/image';
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import { imageSelected } from 'features/gallery/store/gallerySlice';
import { uploadsAdapter } from 'features/gallery/store/uploadsSlice'; import {
import { resultsAdapter } from 'features/gallery/store/resultsSlice'; uploadRemoved,
uploadsAdapter,
} from 'features/gallery/store/uploadsSlice';
import {
resultRemoved,
resultsAdapter,
} from 'features/gallery/store/resultsSlice';
import { isUploadsImageDTO } from 'services/types/guards';
const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' }); const moduleLog = log.child({ namespace: 'addRequestedImageDeletionListener' });
@ -22,13 +29,17 @@ export const addRequestedImageDeletionListener = () => {
return; return;
} }
const { image_name, image_type } = image; const { image_name, image_origin } = image;
const selectedImageName = getState().gallery.selectedImage?.image_name; const state = getState();
const selectedImage = state.gallery.selectedImage;
const isUserImage = isUploadsImageDTO(selectedImage);
if (selectedImage && selectedImage.image_name === image_name) {
const allIds = isUserImage ? state.uploads.ids : state.results.ids;
if (selectedImageName === image_name) { const allEntities = isUserImage
const allIds = getState()[image_type].ids; ? state.uploads.entities
const allEntities = getState()[image_type].entities; : state.results.entities;
const deletedImageIndex = allIds.findIndex( const deletedImageIndex = allIds.findIndex(
(result) => result.toString() === image_name (result) => result.toString() === image_name
@ -53,7 +64,15 @@ export const addRequestedImageDeletionListener = () => {
} }
} }
dispatch(imageDeleted({ imageName: image_name, imageType: image_type })); if (isUserImage) {
dispatch(uploadRemoved(image_name));
} else {
dispatch(resultRemoved(image_name));
}
dispatch(
imageDeleted({ imageName: image_name, imageOrigin: image_origin })
);
}, },
}); });
}; };
@ -65,12 +84,12 @@ export const addImageDeletedPendingListener = () => {
startAppListening({ startAppListening({
actionCreator: imageDeleted.pending, actionCreator: imageDeleted.pending,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const { imageName, imageType } = action.meta.arg; const { imageName, imageOrigin } = action.meta.arg;
// Preemptively remove the image from the gallery // Preemptively remove the image from the gallery
if (imageType === 'uploads') { if (imageOrigin === 'external') {
uploadsAdapter.removeOne(getState().uploads, imageName); uploadsAdapter.removeOne(getState().uploads, imageName);
} }
if (imageType === 'results') { if (imageOrigin === 'internal') {
resultsAdapter.removeOne(getState().results, imageName); resultsAdapter.removeOne(getState().results, imageName);
} }
}, },

View File

@ -1,14 +1,9 @@
import { log } from 'app/logging/useLogger'; import { log } from 'app/logging/useLogger';
import { startAppListening } from '..'; import { startAppListening } from '..';
import { imageMetadataReceived } from 'services/thunks/image'; import { imageMetadataReceived } from 'services/thunks/image';
import { import { resultUpserted } from 'features/gallery/store/resultsSlice';
ResultsImageDTO, import { uploadUpserted } from 'features/gallery/store/uploadsSlice';
resultUpserted, import { imageSelected } from 'features/gallery/store/gallerySlice';
} from 'features/gallery/store/resultsSlice';
import {
UploadsImageDTO,
uploadUpserted,
} from 'features/gallery/store/uploadsSlice';
const moduleLog = log.child({ namespace: 'image' }); const moduleLog = log.child({ namespace: 'image' });
@ -16,15 +11,15 @@ export const addImageMetadataReceivedFulfilledListener = () => {
startAppListening({ startAppListening({
actionCreator: imageMetadataReceived.fulfilled, actionCreator: imageMetadataReceived.fulfilled,
effect: (action, { getState, dispatch }) => { effect: (action, { getState, dispatch }) => {
const image = action.payload; const imageDTO = action.payload;
moduleLog.debug({ data: { image } }, 'Image metadata received'); moduleLog.debug({ data: { imageDTO } }, 'Image metadata received');
if (image.image_type === 'results') { if (imageDTO.image_origin === 'internal') {
dispatch(resultUpserted(action.payload as ResultsImageDTO)); dispatch(resultUpserted(imageDTO));
} }
if (image.image_type === 'uploads') { if (imageDTO.image_origin === 'external') {
dispatch(uploadUpserted(action.payload as UploadsImageDTO)); dispatch(uploadUpserted(imageDTO));
} }
}, },
}); });

View File

@ -0,0 +1,26 @@
import { startAppListening } from '..';
import { imageUpdated } from 'services/thunks/image';
import { log } from 'app/logging/useLogger';
const moduleLog = log.child({ namespace: 'image' });
export const addImageUpdatedFulfilledListener = () => {
startAppListening({
actionCreator: imageUpdated.fulfilled,
effect: (action, { dispatch, getState }) => {
moduleLog.debug(
{ oldImage: action.meta.arg, updatedImage: action.payload },
'Image updated'
);
},
});
};
export const addImageUpdatedRejectedListener = () => {
startAppListening({
actionCreator: imageUpdated.rejected,
effect: (action, { dispatch }) => {
moduleLog.debug({ oldImage: action.meta.arg }, 'Image update failed');
},
});
};

View File

@ -1,6 +1,9 @@
import { startAppListening } from '..'; import { startAppListening } from '..';
import { uploadUpserted } from 'features/gallery/store/uploadsSlice'; import { uploadUpserted } from 'features/gallery/store/uploadsSlice';
import { imageSelected } from 'features/gallery/store/gallerySlice'; import {
imageSelected,
setCurrentCategory,
} from 'features/gallery/store/gallerySlice';
import { imageUploaded } from 'services/thunks/image'; import { imageUploaded } from 'services/thunks/image';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { resultUpserted } from 'features/gallery/store/resultsSlice'; import { resultUpserted } from 'features/gallery/store/resultsSlice';
@ -10,31 +13,30 @@ const moduleLog = log.child({ namespace: 'image' });
export const addImageUploadedFulfilledListener = () => { export const addImageUploadedFulfilledListener = () => {
startAppListening({ startAppListening({
predicate: (action): action is ReturnType<typeof imageUploaded.fulfilled> => actionCreator: imageUploaded.fulfilled,
imageUploaded.fulfilled.match(action) &&
action.payload.is_intermediate === false,
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const image = action.payload; const image = action.payload;
moduleLog.debug({ arg: '<Blob>', image }, 'Image uploaded'); moduleLog.debug({ arg: '<Blob>', image }, 'Image uploaded');
if (action.payload.is_intermediate) {
// No further actions needed for intermediate images
return;
}
const state = getState(); const state = getState();
// Handle uploads // Handle uploads
if (!image.show_in_gallery && image.image_type === 'uploads') { if (image.image_category === 'user' && !image.is_intermediate) {
dispatch(uploadUpserted(image)); dispatch(uploadUpserted(image));
dispatch(addToast({ title: 'Image Uploaded', status: 'success' })); dispatch(addToast({ title: 'Image Uploaded', status: 'success' }));
if (state.gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(image));
}
} }
// Handle results // Handle results
// TODO: Can this ever happen? I don't think so... // TODO: Can this ever happen? I don't think so...
if (image.show_in_gallery) { if (image.image_category !== 'user' && !image.is_intermediate) {
dispatch(resultUpserted(image)); dispatch(resultUpserted(image));
dispatch(setCurrentCategory('results'));
} }
}, },
}); });
@ -44,6 +46,9 @@ export const addImageUploadedRejectedListener = () => {
startAppListening({ startAppListening({
actionCreator: imageUploaded.rejected, actionCreator: imageUploaded.rejected,
effect: (action, { dispatch }) => { effect: (action, { dispatch }) => {
const { formData, ...rest } = action.meta.arg;
const sanitizedData = { arg: { ...rest, formData: { file: '<Blob>' } } };
moduleLog.error({ data: sanitizedData }, 'Image upload failed');
dispatch( dispatch(
addToast({ addToast({
title: 'Image Upload Failed', title: 'Image Upload Failed',

View File

@ -13,9 +13,9 @@ export const addImageUrlsReceivedFulfilledListener = () => {
const image = action.payload; const image = action.payload;
moduleLog.debug({ data: { image } }, 'Image URLs received'); moduleLog.debug({ data: { image } }, 'Image URLs received');
const { image_type, image_name, image_url, thumbnail_url } = image; const { image_origin, image_name, image_url, thumbnail_url } = image;
if (image_type === 'results') { if (image_origin === 'results') {
resultsAdapter.updateOne(getState().results, { resultsAdapter.updateOne(getState().results, {
id: image_name, id: image_name,
changes: { changes: {
@ -25,7 +25,7 @@ export const addImageUrlsReceivedFulfilledListener = () => {
}); });
} }
if (image_type === 'uploads') { if (image_origin === 'uploads') {
uploadsAdapter.updateOne(getState().uploads, { uploadsAdapter.updateOne(getState().uploads, {
id: image_name, id: image_name,
changes: { changes: {

View File

@ -30,14 +30,14 @@ export const addInitialImageSelectedListener = () => {
return; return;
} }
const { image_name, image_type } = action.payload; const { image_name, image_origin } = action.payload;
let image: ImageDTO | undefined; let image: ImageDTO | undefined;
const state = getState(); const state = getState();
if (image_type === 'results') { if (image_origin === 'results') {
image = selectResultsById(state, image_name); image = selectResultsById(state, image_name);
} else if (image_type === 'uploads') { } else if (image_origin === 'uploads') {
image = selectUploadsById(state, image_name); image = selectUploadsById(state, image_name);
} }

View File

@ -34,13 +34,13 @@ export const addInvocationCompleteListener = () => {
// This complete event has an associated image output // This complete event has an associated image output
if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { if (isImageOutput(result) && !nodeDenylist.includes(node.type)) {
const { image_name, image_type } = result.image; const { image_name, image_origin } = result.image;
// Get its metadata // Get its metadata
dispatch( dispatch(
imageMetadataReceived({ imageMetadataReceived({
imageName: image_name, imageName: image_name,
imageType: image_type, imageOrigin: image_origin,
}) })
); );
@ -48,10 +48,6 @@ export const addInvocationCompleteListener = () => {
imageMetadataReceived.fulfilled.match imageMetadataReceived.fulfilled.match
); );
if (getState().gallery.shouldAutoSwitchToNewImages) {
dispatch(imageSelected(imageDTO));
}
// Handle canvas image // Handle canvas image
if ( if (
graph_execution_state_id === graph_execution_state_id ===

View File

@ -103,7 +103,6 @@ export const addUserInvokedCanvasListener = () => {
}, },
imageCategory: 'general', imageCategory: 'general',
isIntermediate: true, isIntermediate: true,
showInGallery: false,
}) })
); );
@ -117,7 +116,7 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type // Update the base node with the image name and type
baseNode.image = { baseNode.image = {
image_name: baseImageDTO.image_name, image_name: baseImageDTO.image_name,
image_type: baseImageDTO.image_type, image_origin: baseImageDTO.image_origin,
}; };
} }
@ -131,7 +130,6 @@ export const addUserInvokedCanvasListener = () => {
}, },
imageCategory: 'mask', imageCategory: 'mask',
isIntermediate: true, isIntermediate: true,
showInGallery: false,
}) })
); );
@ -145,7 +143,7 @@ export const addUserInvokedCanvasListener = () => {
// Update the base node with the image name and type // Update the base node with the image name and type
baseNode.mask = { baseNode.mask = {
image_name: maskImageDTO.image_name, image_name: maskImageDTO.image_name,
image_type: maskImageDTO.image_type, image_origin: maskImageDTO.image_origin,
}; };
} }
@ -162,7 +160,7 @@ export const addUserInvokedCanvasListener = () => {
dispatch( dispatch(
imageUpdated({ imageUpdated({
imageName: baseNode.image.image_name, imageName: baseNode.image.image_name,
imageType: baseNode.image.image_type, imageOrigin: baseNode.image.image_origin,
requestBody: { session_id: sessionId }, requestBody: { session_id: sessionId },
}) })
); );
@ -173,7 +171,7 @@ export const addUserInvokedCanvasListener = () => {
dispatch( dispatch(
imageUpdated({ imageUpdated({
imageName: baseNode.mask.image_name, imageName: baseNode.mask.image_name,
imageType: baseNode.mask.image_type, imageOrigin: baseNode.mask.image_origin,
requestBody: { session_id: sessionId }, requestBody: { session_id: sessionId },
}) })
); );

View File

@ -15,7 +15,7 @@
import { SelectedImage } from 'features/parameters/store/actions'; import { SelectedImage } from 'features/parameters/store/actions';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types'; import { IRect } from 'konva/lib/types';
import { ImageResponseMetadata, ImageType } from 'services/api'; import { ImageResponseMetadata, ResourceOrigin } from 'services/api';
import { O } from 'ts-toolbelt'; import { O } from 'ts-toolbelt';
/** /**
@ -124,7 +124,7 @@ export type PostProcessedImageMetadata = ESRGANMetadata | FacetoolMetadata;
*/ */
// export ty`pe Image = { // export ty`pe Image = {
// name: string; // name: string;
// type: ImageType; // type: image_origin;
// url: string; // url: string;
// thumbnail: string; // thumbnail: string;
// metadata: ImageResponseMetadata; // metadata: ImageResponseMetadata;

View File

@ -69,9 +69,8 @@ const ImageUploader = (props: ImageUploaderProps) => {
dispatch( dispatch(
imageUploaded({ imageUploaded({
formData: { file }, formData: { file },
imageCategory: 'general', imageCategory: 'user',
isIntermediate: false, isIntermediate: false,
showInGallery: false,
}) })
); );
}, },

View File

@ -1,239 +0,0 @@
import { forEach, size } from 'lodash-es';
import {
ImageField,
LatentsField,
ConditioningField,
ControlField,
} from 'services/api';
const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]';
const NUMBER_TYPESTRING = '[object Number]';
const BOOLEAN_TYPESTRING = '[object Boolean]';
const ARRAY_TYPESTRING = '[object Array]';
const isObject = (obj: unknown): obj is Record<string | number, any> =>
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
const isString = (obj: unknown): obj is string =>
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
const isNumber = (obj: unknown): obj is number =>
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
const isBoolean = (obj: unknown): obj is boolean =>
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
const isArray = (obj: unknown): obj is Array<any> =>
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
const parseImageField = (imageField: unknown): ImageField | undefined => {
// Must be an object
if (!isObject(imageField)) {
return;
}
// An ImageField must have both `image_name` and `image_type`
if (!('image_name' in imageField && 'image_type' in imageField)) {
return;
}
// An ImageField's `image_type` must be one of the allowed values
if (
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
) {
return;
}
// An ImageField's `image_name` must be a string
if (typeof imageField.image_name !== 'string') {
return;
}
// Build a valid ImageField
return {
image_type: imageField.image_type,
image_name: imageField.image_name,
};
};
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
// Must be an object
if (!isObject(latentsField)) {
return;
}
// A LatentsField must have a `latents_name`
if (!('latents_name' in latentsField)) {
return;
}
// A LatentsField's `latents_name` must be a string
if (typeof latentsField.latents_name !== 'string') {
return;
}
// Build a valid LatentsField
return {
latents_name: latentsField.latents_name,
};
};
const parseConditioningField = (
conditioningField: unknown
): ConditioningField | undefined => {
// Must be an object
if (!isObject(conditioningField)) {
return;
}
// A ConditioningField must have a `conditioning_name`
if (!('conditioning_name' in conditioningField)) {
return;
}
// A ConditioningField's `conditioning_name` must be a string
if (typeof conditioningField.conditioning_name !== 'string') {
return;
}
// Build a valid ConditioningField
return {
conditioning_name: conditioningField.conditioning_name,
};
};
const parseControlField = (controlField: unknown): ControlField | undefined => {
// Must be an object
if (!isObject(controlField)) {
return;
}
// A ControlField must have a `control`
if (!('control' in controlField)) {
return;
}
// console.log(typeof controlField.control);
// Build a valid ControlField
return {
control: controlField.control,
};
};
type NodeMetadata = {
[key: string]:
| string
| number
| boolean
| ImageField
| LatentsField
| ConditioningField
| ControlField;
};
type InvokeAIMetadata = {
session_id?: string;
node?: NodeMetadata;
};
export const parseNodeMetadata = (
nodeMetadata: Record<string | number, any>
): NodeMetadata | undefined => {
if (!isObject(nodeMetadata)) {
return;
}
const parsed: NodeMetadata = {};
forEach(nodeMetadata, (nodeItem, nodeKey) => {
// `id` and `type` must be strings if they are present
if (['id', 'type'].includes(nodeKey)) {
if (isString(nodeItem)) {
parsed[nodeKey] = nodeItem;
}
return;
}
// the only valid object types are ImageField, LatentsField, ConditioningField, ControlField
if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem);
if (imageField) {
parsed[nodeKey] = imageField;
}
return;
}
if ('latents_name' in nodeItem) {
const latentsField = parseLatentsField(nodeItem);
if (latentsField) {
parsed[nodeKey] = latentsField;
}
return;
}
if ('conditioning_name' in nodeItem) {
const conditioningField = parseConditioningField(nodeItem);
if (conditioningField) {
parsed[nodeKey] = conditioningField;
}
return;
}
if ('control' in nodeItem) {
const controlField = parseControlField(nodeItem);
if (controlField) {
parsed[nodeKey] = controlField;
}
return;
}
}
// otherwise we accept any string, number or boolean
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
parsed[nodeKey] = nodeItem;
return;
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};
export const parseInvokeAIMetadata = (
metadata: Record<string | number, any> | undefined
): InvokeAIMetadata | undefined => {
if (metadata === undefined) {
return;
}
if (!isObject(metadata)) {
return;
}
const parsed: InvokeAIMetadata = {};
forEach(metadata, (item, key) => {
if (key === 'session_id' && isString(item)) {
parsed['session_id'] = item;
}
if (key === 'node' && isObject(item)) {
const nodeMetadata = parseNodeMetadata(item);
if (nodeMetadata) {
parsed['node'] = nodeMetadata;
}
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};

View File

@ -62,7 +62,7 @@ const CurrentImagePreview = () => {
return; return;
} }
e.dataTransfer.setData('invokeai/imageName', image.image_name); e.dataTransfer.setData('invokeai/imageName', image.image_name);
e.dataTransfer.setData('invokeai/imageType', image.image_type); e.dataTransfer.setData('invokeai/imageOrigin', image.image_origin);
e.dataTransfer.effectAllowed = 'move'; e.dataTransfer.effectAllowed = 'move';
}, },
[image] [image]

View File

@ -147,7 +147,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleDragStart = useCallback( const handleDragStart = useCallback(
(e: DragEvent<HTMLDivElement>) => { (e: DragEvent<HTMLDivElement>) => {
e.dataTransfer.setData('invokeai/imageName', image.image_name); e.dataTransfer.setData('invokeai/imageName', image.image_name);
e.dataTransfer.setData('invokeai/imageType', image.image_type); e.dataTransfer.setData('invokeai/imageOrigin', image.image_origin);
e.dataTransfer.effectAllowed = 'move'; e.dataTransfer.effectAllowed = 'move';
}, },
[image] [image]

View File

@ -1,6 +1,6 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { ImageType } from 'services/api'; import { ResourceOrigin } from 'services/api';
import { selectResultsEntities } from '../store/resultsSlice'; import { selectResultsEntities } from '../store/resultsSlice';
import { selectUploadsEntities } from '../store/uploadsSlice'; import { selectUploadsEntities } from '../store/uploadsSlice';
@ -11,17 +11,17 @@ const useGetImageByNameSelector = createSelector(
} }
); );
const useGetImageByNameAndType = () => { const useGetImageByNameAndOrigin = () => {
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector); const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
return (name: string, type: ImageType) => { return (name: string, origin: ResourceOrigin) => {
if (type === 'results') { if (origin === 'internal') {
const resultImagesResult = allResults[name]; const resultImagesResult = allResults[name];
if (resultImagesResult) { if (resultImagesResult) {
return resultImagesResult; return resultImagesResult;
} }
} }
if (type === 'uploads') { if (origin === 'external') {
const userImagesResult = allUploads[name]; const userImagesResult = allUploads[name];
if (userImagesResult) { if (userImagesResult) {
return userImagesResult; return userImagesResult;
@ -30,4 +30,4 @@ const useGetImageByNameAndType = () => {
}; };
}; };
export default useGetImageByNameAndType; export default useGetImageByNameAndOrigin;

View File

@ -1,9 +1,9 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { ImageNameAndType } from 'features/parameters/store/actions'; import { ImageNameAndOrigin } from 'features/parameters/store/actions';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
export const requestedImageDeletion = createAction< export const requestedImageDeletion = createAction<
ImageDTO | ImageNameAndType | undefined ImageDTO | ImageNameAndOrigin | undefined
>('gallery/requestedImageDeletion'); >('gallery/requestedImageDeletion');
export const sentImageToCanvas = createAction('gallery/sentImageToCanvas'); export const sentImageToCanvas = createAction('gallery/sentImageToCanvas');

View File

@ -5,6 +5,8 @@ import {
receivedUploadImages, receivedUploadImages,
} from '../../../services/thunks/gallery'; } from '../../../services/thunks/gallery';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { resultUpserted } from './resultsSlice';
import { uploadUpserted } from './uploadsSlice';
type GalleryImageObjectFitType = 'contain' | 'cover'; type GalleryImageObjectFitType = 'contain' | 'cover';
@ -76,6 +78,7 @@ export const gallerySlice = createSlice({
} }
} }
}); });
builder.addCase(receivedUploadImages.fulfilled, (state, action) => { builder.addCase(receivedUploadImages.fulfilled, (state, action) => {
// rehydrate selectedImage URL when results list comes in // rehydrate selectedImage URL when results list comes in
// solves case when outdated URL is in local storage // solves case when outdated URL is in local storage
@ -92,6 +95,20 @@ export const gallerySlice = createSlice({
} }
} }
}); });
builder.addCase(resultUpserted, (state, action) => {
if (state.shouldAutoSwitchToNewImages) {
state.selectedImage = action.payload;
state.currentCategory = 'results';
}
});
builder.addCase(uploadUpserted, (state, action) => {
if (state.shouldAutoSwitchToNewImages) {
state.selectedImage = action.payload;
state.currentCategory = 'uploads';
}
});
}, },
}); });

View File

@ -11,8 +11,8 @@ import {
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator'; import { dateComparator } from 'common/util/dateComparator';
export type ResultsImageDTO = Omit<ImageDTO, 'image_type'> & { export type ResultsImageDTO = Omit<ImageDTO, 'image_origin'> & {
image_type: 'results'; image_origin: 'results';
}; };
export const resultsAdapter = createEntityAdapter<ImageDTO>({ export const resultsAdapter = createEntityAdapter<ImageDTO>({
@ -47,6 +47,9 @@ const resultsSlice = createSlice({
resultsAdapter.upsertOne(state, action.payload); resultsAdapter.upsertOne(state, action.payload);
state.upsertedImageCount += 1; state.upsertedImageCount += 1;
}, },
resultRemoved: (state, action: PayloadAction<string>) => {
resultsAdapter.removeOne(state, action.payload);
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
/** /**
@ -83,6 +86,6 @@ export const {
selectTotal: selectResultsTotal, selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results); } = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultUpserted } = resultsSlice.actions; export const { resultUpserted, resultRemoved } = resultsSlice.actions;
export default resultsSlice.reducer; export default resultsSlice.reducer;

View File

@ -9,8 +9,12 @@ import { receivedUploadImages, IMAGES_PER_PAGE } from 'services/thunks/gallery';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { dateComparator } from 'common/util/dateComparator'; import { dateComparator } from 'common/util/dateComparator';
export type UploadsImageDTO = Omit<ImageDTO, 'image_type'> & { export type UploadsImageDTO = Omit<
image_type: 'uploads'; ImageDTO,
'image_origin' | 'image_category'
> & {
image_origin: 'external';
image_category: 'user';
}; };
export const uploadsAdapter = createEntityAdapter<ImageDTO>({ export const uploadsAdapter = createEntityAdapter<ImageDTO>({
@ -45,6 +49,9 @@ const uploadsSlice = createSlice({
uploadsAdapter.upsertOne(state, action.payload); uploadsAdapter.upsertOne(state, action.payload);
state.upsertedImageCount += 1; state.upsertedImageCount += 1;
}, },
uploadRemoved: (state, action: PayloadAction<string>) => {
uploadsAdapter.removeOne(state, action.payload);
},
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
/** /**
@ -81,6 +88,6 @@ export const {
selectTotal: selectUploadsTotal, selectTotal: selectUploadsTotal,
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads); } = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
export const { uploadUpserted } = uploadsSlice.actions; export const { uploadUpserted, uploadRemoved } = uploadsSlice.actions;
export default uploadsSlice.reducer; export default uploadsSlice.reducer;

View File

@ -2,7 +2,7 @@ import { Box, Image } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder'; import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder';
import { useGetUrl } from 'common/util/getUrl'; import { useGetUrl } from 'common/util/getUrl';
import useGetImageByNameAndType from 'features/gallery/hooks/useGetImageByName'; import useGetImageByNameAndOrigin from 'features/gallery/hooks/useGetImageByName';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import { import {
@ -11,7 +11,7 @@ import {
} from 'features/nodes/types/types'; } from 'features/nodes/types/types';
import { DragEvent, memo, useCallback, useState } from 'react'; import { DragEvent, memo, useCallback, useState } from 'react';
import { ImageType } from 'services/api'; import { ResourceOrigin } from 'services/api';
import { FieldComponentProps } from './types'; import { FieldComponentProps } from './types';
const ImageInputFieldComponent = ( const ImageInputFieldComponent = (
@ -19,7 +19,7 @@ const ImageInputFieldComponent = (
) => { ) => {
const { nodeId, field } = props; const { nodeId, field } = props;
const getImageByNameAndType = useGetImageByNameAndType(); const getImageByNameAndType = useGetImageByNameAndOrigin();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const [url, setUrl] = useState<string | undefined>(field.value?.image_url); const [url, setUrl] = useState<string | undefined>(field.value?.image_url);
const { getUrl } = useGetUrl(); const { getUrl } = useGetUrl();
@ -27,7 +27,9 @@ const ImageInputFieldComponent = (
const handleDrop = useCallback( const handleDrop = useCallback(
(e: DragEvent<HTMLDivElement>) => { (e: DragEvent<HTMLDivElement>) => {
const name = e.dataTransfer.getData('invokeai/imageName'); const name = e.dataTransfer.getData('invokeai/imageName');
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType; const type = e.dataTransfer.getData(
'invokeai/imageOrigin'
) as ResourceOrigin;
if (!name || !type) { if (!name || !type) {
return; return;

View File

@ -64,7 +64,7 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
model, model,
image: { image: {
image_name: initialImage?.image_name, image_name: initialImage?.image_name,
image_type: initialImage?.image_type, image_origin: initialImage?.image_origin,
}, },
}; };

View File

@ -58,7 +58,7 @@ export const buildImg2ImgNode = (
imageToImageNode.image = { imageToImageNode.image = {
image_name: initialImage.name, image_name: initialImage.name,
image_type: initialImage.type, image_origin: initialImage.type,
}; };
} }

View File

@ -51,7 +51,7 @@ export const buildInpaintNode = (
inpaintNode.image = { inpaintNode.image = {
image_name: initialImage.name, image_name: initialImage.name,
image_type: initialImage.type, image_origin: initialImage.type,
}; };
} }

View File

@ -5,7 +5,7 @@ import { useGetUrl } from 'common/util/getUrl';
import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { DragEvent, useCallback } from 'react'; import { DragEvent, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { ImageType } from 'services/api'; import { ResourceOrigin } from 'services/api';
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import { initialImageSelected } from 'features/parameters/store/actions'; import { initialImageSelected } from 'features/parameters/store/actions';
@ -55,9 +55,11 @@ const InitialImagePreview = () => {
const handleDrop = useCallback( const handleDrop = useCallback(
(e: DragEvent<HTMLDivElement>) => { (e: DragEvent<HTMLDivElement>) => {
const name = e.dataTransfer.getData('invokeai/imageName'); const name = e.dataTransfer.getData('invokeai/imageName');
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType; const type = e.dataTransfer.getData(
'invokeai/imageOrigin'
) as ResourceOrigin;
dispatch(initialImageSelected({ image_name: name, image_type: type })); dispatch(initialImageSelected({ image_name: name, image_origin: type }));
}, },
[dispatch] [dispatch]
); );

View File

@ -1,10 +1,10 @@
import { createAction } from '@reduxjs/toolkit'; import { createAction } from '@reduxjs/toolkit';
import { isObject } from 'lodash-es'; import { isObject } from 'lodash-es';
import { ImageDTO, ImageType } from 'services/api'; import { ImageDTO, ResourceOrigin } from 'services/api';
export type ImageNameAndType = { export type ImageNameAndOrigin = {
image_name: string; image_name: string;
image_type: ImageType; image_origin: ResourceOrigin;
}; };
export const isImageDTO = (image: any): image is ImageDTO => { export const isImageDTO = (image: any): image is ImageDTO => {
@ -13,8 +13,8 @@ export const isImageDTO = (image: any): image is ImageDTO => {
isObject(image) && isObject(image) &&
'image_name' in image && 'image_name' in image &&
image?.image_name !== undefined && image?.image_name !== undefined &&
'image_type' in image && 'image_origin' in image &&
image?.image_type !== undefined && image?.image_origin !== undefined &&
'image_url' in image && 'image_url' in image &&
image?.image_url !== undefined && image?.image_url !== undefined &&
'thumbnail_url' in image && 'thumbnail_url' in image &&
@ -27,5 +27,5 @@ export const isImageDTO = (image: any): image is ImageDTO => {
}; };
export const initialImageSelected = createAction< export const initialImageSelected = createAction<
ImageDTO | ImageNameAndType | undefined ImageDTO | ImageNameAndOrigin | undefined
>('generation/initialImageSelected'); >('generation/initialImageSelected');

View File

@ -23,8 +23,8 @@ export const receivedGalleryImages = createAppAsyncThunk<
const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE); const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE);
const response = await ImagesService.listImagesWithMetadata({ const response = await ImagesService.listImagesWithMetadata({
excludeCategories: ['user'],
isIntermediate: false, isIntermediate: false,
showInGallery: true,
page: nextPage + pageOffset, page: nextPage + pageOffset,
perPage: IMAGES_PER_PAGE, perPage: IMAGES_PER_PAGE,
}); });
@ -53,9 +53,8 @@ export const receivedUploadImages = createAppAsyncThunk<
const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE); const pageOffset = Math.floor(upsertedImageCount / IMAGES_PER_PAGE);
const response = await ImagesService.listImagesWithMetadata({ const response = await ImagesService.listImagesWithMetadata({
imageType: 'uploads', includeCategories: ['user'],
isIntermediate: false, isIntermediate: false,
showInGallery: false,
page: nextPage + pageOffset, page: nextPage + pageOffset,
perPage: IMAGES_PER_PAGE, perPage: IMAGES_PER_PAGE,
}); });

View File

@ -1,4 +1,3 @@
import { ResultsImageDTO } from 'features/gallery/store/resultsSlice';
import { UploadsImageDTO } from 'features/gallery/store/uploadsSlice'; import { UploadsImageDTO } from 'features/gallery/store/uploadsSlice';
import { get, isObject, isString } from 'lodash-es'; import { get, isObject, isString } from 'lodash-es';
import { import {
@ -9,17 +8,18 @@ import {
PromptOutput, PromptOutput,
IterateInvocationOutput, IterateInvocationOutput,
CollectInvocationOutput, CollectInvocationOutput,
ImageType,
ImageField, ImageField,
LatentsOutput, LatentsOutput,
ImageDTO, ImageDTO,
ResourceOrigin,
} from 'services/api'; } from 'services/api';
export const isUploadsImageDTO = (image: ImageDTO): image is UploadsImageDTO => export const isUploadsImageDTO = (
image.image_type === 'uploads'; image: ImageDTO | undefined
): image is UploadsImageDTO =>
export const isResultsImageDTO = (image: ImageDTO): image is ResultsImageDTO => image !== undefined &&
image.image_type === 'results'; image.image_origin === 'external' &&
image.image_category === 'user';
export const isImageOutput = ( export const isImageOutput = (
output: GraphExecutionState['results'][string] output: GraphExecutionState['results'][string]
@ -49,10 +49,10 @@ export const isCollectOutput = (
output: GraphExecutionState['results'][string] output: GraphExecutionState['results'][string]
): output is CollectInvocationOutput => output.type === 'collect_output'; ): output is CollectInvocationOutput => output.type === 'collect_output';
export const isImageType = (t: unknown): t is ImageType => export const isResourceOrigin = (t: unknown): t is ResourceOrigin =>
isString(t) && ['results', 'uploads', 'intermediates'].includes(t); isString(t) && ['internal', 'external'].includes(t);
export const isImageField = (imageField: unknown): imageField is ImageField => export const isImageField = (imageField: unknown): imageField is ImageField =>
isObject(imageField) && isObject(imageField) &&
isString(get(imageField, 'image_name')) && isString(get(imageField, 'image_name')) &&
isImageType(get(imageField, 'image_type')); isResourceOrigin(get(imageField, 'image_origin'));