feat(ui): wip gallery migration

This commit is contained in:
psychedelicious 2023-04-04 22:58:46 +10:00
parent cfe86ec541
commit cc3401a159
15 changed files with 231 additions and 98 deletions

View File

@ -1,5 +1,10 @@
import { createAction } from '@reduxjs/toolkit';
import {
GeneratorProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationStartedEvent,
} from 'services/events/types';
/**
* We can't use redux-toolkit's createSlice() to make these actions,
* because they have no associated reducer. They only exist to dispatch
@ -9,3 +14,23 @@ import { createAction } from '@reduxjs/toolkit';
export const emitSubscribe = createAction<string>('socketio/subscribe');
export const emitUnsubscribe = createAction<string>('socketio/unsubscribe');
type Timestamp = {
timestamp: Date;
};
export const invocationStarted = createAction<
{ data: InvocationStartedEvent } & Timestamp
>('socketio/invocationStarted');
export const invocationComplete = createAction<
{ data: InvocationCompleteEvent } & Timestamp
>('socketio/invocationComplete');
export const invocationError = createAction<
{ data: InvocationErrorEvent } & Timestamp
>('socketio/invocationError');
export const generatorProgress = createAction<
{ data: GeneratorProgressEvent } & Timestamp
>('socketio/generatorProgress');

View File

@ -10,6 +10,8 @@ import {
setIsCancelable,
setIsConnected,
setIsProcessing,
socketioConnected,
socketioDisconnected,
} from 'features/system/store/systemSlice';
import {
@ -32,13 +34,13 @@ import {
setStatus,
STATUS,
} from 'services/apiSlice';
import { emitUnsubscribe } from './actions';
import { emitUnsubscribe, invocationComplete } from './actions';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import {
getNextResultsPage,
getNextUploadsPage,
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { processImageField } from 'services/util/processImageField';
import { deserializeImageField } from 'services/util/deserializeImageField';
/**
* Returns an object containing listener callbacks
@ -54,15 +56,15 @@ const makeSocketIOListeners = (
*/
onConnect: () => {
try {
dispatch(setIsConnected(true));
dispatch(setCurrentStatus(i18n.t('common.statusConnected')));
dispatch(socketioConnected());
// fetch more images, but only if we don't already have images
if (!getState().results.ids.length) {
dispatch(getNextResultsPage());
dispatch(receivedResultImagesPage());
}
if (!getState().uploads.ids.length) {
dispatch(getNextUploadsPage());
dispatch(receivedUploadImagesPage());
}
} catch (e) {
console.error(e);
@ -73,8 +75,8 @@ const makeSocketIOListeners = (
*/
onDisconnect: () => {
try {
dispatch(setIsConnected(false));
dispatch(setCurrentStatus(i18n.t('common.statusDisconnected')));
dispatch(socketioDisconnected());
dispatch(emitUnsubscribe(getState().api.sessionId));
dispatch(
addLogEntry({
@ -97,38 +99,40 @@ const makeSocketIOListeners = (
onInvocationComplete: (data: InvocationCompleteEvent) => {
console.log('invocation_complete', data);
try {
dispatch(invocationComplete({ data, timestamp: new Date() }));
const sessionId = data.graph_execution_state_id;
if (data.result.type === 'image') {
const resultImage = processImageField(data.result.image);
// const resultImage = deserializeImageField(data.result.image);
dispatch(resultAdded(resultImage));
// dispatch(resultAdded(resultImage));
// // need to update the type for this or figure out how to get these values
dispatch(
addImage({
category: 'result',
image: {
uuid: uuidv4(),
url: resultImage.url,
thumbnail: '',
width: 512,
height: 512,
category: 'result',
name: resultImage.name,
mtime: new Date().getTime(),
},
})
);
// dispatch(
// addImage({
// category: 'result',
// image: {
// uuid: uuidv4(),
// url: resultImage.url,
// thumbnail: '',
// width: 512,
// height: 512,
// category: 'result',
// name: resultImage.name,
// mtime: new Date().getTime(),
// },
// })
// );
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(false));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Generated: ${data.result.image.image_name}`,
})
);
dispatch(setIsProcessing(false));
dispatch(setIsCancelable(false));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Generated: ${data.result.image.image_name}`,
// })
// );
dispatch(emitUnsubscribe(sessionId));
dispatch(setSessionId(null));
// dispatch(setSessionId(''));
}
} catch (e) {
console.error(e);

View File

@ -10,6 +10,7 @@ import {
InvocationErrorEvent,
InvocationStartedEvent,
} from 'services/events/types';
import { invocationComplete } from './actions';
const socket_url = `ws://${window.location.host}`;
@ -40,6 +41,12 @@ export const socketioMiddleware = () => {
areListenersSet = true;
// use the action's match() function for type narrowing and safety
if (invocationComplete.match(action)) {
emitUnsubscribe(action.payload.data.graph_execution_state_id);
socketio.removeAllListeners();
}
/**
* Handle redux actions caught by middleware.
*/
@ -63,12 +70,12 @@ export const socketioMiddleware = () => {
break;
}
case 'socketio/unsubscribe': {
emitUnsubscribe(action.payload);
// case 'socketio/unsubscribe': {
// emitUnsubscribe(action.payload);
socketio.removeAllListeners();
break;
}
// socketio.removeAllListeners();
// break;
// }
}
next(action);

View File

@ -4,17 +4,20 @@ import { useAppSelector } from 'app/storeHooks';
import { isEqual } from 'lodash';
import { MdPhoto } from 'react-icons/md';
import { gallerySelector } from '../store/gallerySelectors';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
export const currentImageDisplaySelector = createSelector(
[gallerySelector],
(gallery) => {
[gallerySelector, selectedImageSelector],
(gallery, selectedImage) => {
const { currentImage, intermediateImage } = gallery;
return {
hasAnImageToDisplay: currentImage || intermediateImage,
hasAnImageToDisplay: selectedImage || intermediateImage,
};
},
{

View File

@ -6,19 +6,22 @@ import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { gallerySelector } from '../store/gallerySelectors';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
export const imagesSelector = createSelector(
[gallerySelector, uiSelector],
(gallery: GalleryState, ui) => {
[gallerySelector, uiSelector, selectedImageSelector],
(gallery: GalleryState, ui, selectedImage) => {
const { currentImage, intermediateImage } = gallery;
const { shouldShowImageDetails } = ui;
return {
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
imageToDisplay: intermediateImage ? intermediateImage : selectedImage,
isIntermediate: Boolean(intermediateImage),
shouldShowImageDetails,
};
@ -33,7 +36,7 @@ export const imagesSelector = createSelector(
export default function CurrentImagePreview() {
const { shouldShowImageDetails, imageToDisplay, isIntermediate } =
useAppSelector(imagesSelector);
console.log(imageToDisplay);
return (
<Flex
sx={{
@ -74,7 +77,7 @@ export default function CurrentImagePreview() {
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay} />
{/* <ImageMetadataViewer image={imageToDisplay} /> */}
</Box>
)}
</Flex>

View File

@ -9,7 +9,10 @@ import {
useToast,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
setCurrentImage,
} from 'features/gallery/store/gallerySlice';
import {
setAllImageToImageParameters,
setAllParameters,
@ -33,14 +36,14 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import IAIIconButton from 'common/components/IAIIconButton';
interface HoverableImageProps {
image: InvokeAI._Image;
image: InvokeAI.Image;
isSelected: boolean;
}
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
@ -55,7 +58,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn,
} = useAppSelector(hoverableImageSelector);
const { image, isSelected } = props;
const { url, thumbnail, uuid, metadata } = image;
const { url, thumbnail, name, metadata } = image;
const [isHovered, setIsHovered] = useState<boolean>(false);
@ -92,7 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleSendToImageToImage = () => {
dispatch(setInitialImage(image));
// dispatch(setInitialImage(image));
if (activeTabName !== 'img2img') {
dispatch(setActiveTab('img2img'));
}
@ -105,7 +108,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleSendToCanvas = () => {
dispatch(setInitialCanvasImage(image));
// dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
@ -155,16 +158,19 @@ const HoverableImage = memo((props: HoverableImageProps) => {
});
};
const handleSelectImage = () => dispatch(setCurrentImage(image));
const handleSelectImage = () => {
dispatch(imageSelected(image.name));
// dispatch(setCurrentImage(image));
};
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
e.dataTransfer.setData('invokeai/imageUuid', uuid);
e.dataTransfer.effectAllowed = 'move';
// e.dataTransfer.setData('invokeai/imageUuid', uuid);
// e.dataTransfer.effectAllowed = 'move';
};
const handleLightBox = () => {
dispatch(setCurrentImage(image));
dispatch(setIsLightboxOpen(true));
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
return (
@ -209,9 +215,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
<MenuItem data-warning>
<DeleteImageModal image={image}>
{/* <DeleteImageModal image={image}>
<p>{t('parameters.deleteImage')}</p>
</DeleteImageModal>
</DeleteImageModal> */}
</MenuItem>
</MenuList>
)}
@ -219,7 +225,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{(ref) => (
<Box
position="relative"
key={uuid}
key={name}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
userSelect="none"
@ -290,7 +296,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
insetInlineEnd: 1,
}}
>
<DeleteImageModal image={image}>
{/* <DeleteImageModal image={image}>
<IAIIconButton
aria-label={t('parameters.deleteImage')}
icon={<FaTrashAlt />}
@ -298,7 +304,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
fontSize={14}
isDisabled={!mayDeleteImage}
/>
</DeleteImageModal>
</DeleteImageModal> */}
</Box>
)}
</Box>

View File

@ -31,8 +31,8 @@ import {
selectResultsTotal,
} from '../store/resultsSlice';
import {
getNextResultsPage,
getNextUploadsPage,
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
@ -90,11 +90,11 @@ const ImageGalleryContent = () => {
// };
const handleClickLoadMore = () => {
if (currentCategory === 'result') {
dispatch(getNextResultsPage());
dispatch(receivedResultImagesPage());
}
if (currentCategory === 'user') {
dispatch(getNextUploadsPage());
dispatch(receivedUploadImagesPage());
}
};
@ -249,20 +249,17 @@ const ImageGalleryContent = () => {
gap={2}
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
>
{/* {images.map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
{images.map((image) => {
const { name } = image;
const isSelected = currentImageUuid === name;
return (
<HoverableImage
key={uuid}
key={name}
image={image}
isSelected={isSelected}
/>
);
})} */}
{images.map((image) => (
<Image key={image.name} src={image.thumbnail} />
))}
})}
</Grid>
<IAIButton
onClick={handleClickLoadMore}

View File

@ -7,6 +7,8 @@ import {
uiSelector,
} from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { selectResultsAll, selectResultsEntities } from './resultsSlice';
import { selectUploadsAll, selectUploadsEntities } from './uploadsSlice';
export const gallerySelector = (state: RootState) => state.gallery;
@ -75,3 +77,18 @@ export const hoverableImageSelector = createSelector(
},
}
);
export const selectedImageSelector = createSelector(
[gallerySelector, selectResultsEntities, selectUploadsEntities],
(gallery, allResults, allUploads) => {
const selectedImageName = gallery.selectedImageName;
if (selectedImageName in allResults) {
return allResults[selectedImageName];
}
if (selectedImageName in allUploads) {
return allUploads[selectedImageName];
}
}
);

View File

@ -1,9 +1,11 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { invocationComplete } from 'app/nodesSocketio/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { clamp } from 'lodash';
import { isImageOutput } from 'services/types/guards';
export type GalleryCategory = 'user' | 'result';
@ -23,6 +25,7 @@ export type Gallery = {
};
export interface GalleryState {
selectedImageName: string;
currentImage?: InvokeAI._Image;
currentImageUuid: string;
intermediateImage?: InvokeAI._Image & {
@ -42,6 +45,7 @@ export interface GalleryState {
}
const initialState: GalleryState = {
selectedImageName: '',
currentImageUuid: '',
galleryImageMinimumWidth: 64,
galleryImageObjectFit: 'cover',
@ -69,6 +73,9 @@ export const gallerySlice = createSlice({
name: 'gallery',
initialState,
reducers: {
imageSelected: (state, action: PayloadAction<string>) => {
state.selectedImageName = action.payload;
},
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
@ -255,9 +262,19 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload;
},
},
extraReducers(builder) {
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
if (isImageOutput(data.result)) {
state.selectedImageName = data.result.image.image_name;
state.intermediateImage = undefined;
}
});
},
});
export const {
imageSelected,
addImage,
clearIntermediateImage,
removeImage,

View File

@ -1,9 +1,16 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { createEntityAdapter, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { invocationComplete } from 'app/nodesSocketio/actions';
import { RootState } from 'app/store';
import { getNextResultsPage, IMAGES_PER_PAGE } from 'services/thunks/gallery';
import { processImageField } from 'services/util/processImageField';
import { socketioConnected } from 'features/system/store/systemSlice';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import { deserializeImageField } from 'services/util/deserializeImageField';
import { setCurrentCategory } from './gallerySlice';
// use `createEntityAdapter` to create a slice for results images
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
@ -47,13 +54,14 @@ const resultsSlice = createSlice({
extraReducers: (builder) => {
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
// because we pass in the fulfilled thunk action creator, everything is typed
builder.addCase(getNextResultsPage.pending, (state) => {
builder.addCase(receivedResultImagesPage.pending, (state) => {
state.isLoading = true;
});
builder.addCase(getNextResultsPage.fulfilled, (state, action) => {
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const resultImages = items.map((image) => processImageField(image));
const resultImages = items.map((image) => deserializeImageField(image));
// use the adapter reducer to append all the results to state
resultsAdapter.addMany(state, resultImages);
@ -63,6 +71,15 @@ const resultsSlice = createSlice({
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
if (isImageOutput(data.result)) {
const resultImage = deserializeImageField(data.result.image);
resultsAdapter.addOne(state, resultImage);
}
});
},
});

View File

@ -2,8 +2,11 @@ import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { RootState } from 'app/store';
import { getNextUploadsPage, IMAGES_PER_PAGE } from 'services/thunks/gallery';
import { processImageField } from 'services/util/processImageField';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { deserializeImageField } from 'services/util/deserializeImageField';
export const uploadsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
@ -29,13 +32,13 @@ const uploadsSlice = createSlice({
uploadAdded: uploadsAdapter.addOne,
},
extraReducers: (builder) => {
builder.addCase(getNextUploadsPage.pending, (state) => {
builder.addCase(receivedUploadImagesPage.pending, (state) => {
state.isLoading = true;
});
builder.addCase(getNextUploadsPage.fulfilled, (state, action) => {
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const images = items.map((image) => processImageField(image));
const images = items.map((image) => deserializeImageField(image));
uploadsAdapter.addMany(state, images);

View File

@ -2,7 +2,12 @@ import { ExpandedIndex, UseToastOptions } from '@chakra-ui/react';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { invocationComplete } from 'app/nodesSocketio/actions';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import dateFormat from 'dateformat';
import i18n from 'i18n';
import { isImageOutput } from 'services/types/guards';
export type LogLevel = 'info' | 'warning' | 'error';
@ -271,6 +276,29 @@ export const systemSlice = createSlice({
setCancelAfter: (state, action: PayloadAction<number | null>) => {
state.cancelOptions.cancelAfter = action.payload;
},
socketioConnected: (state) => {
state.isConnected = true;
state.currentStatus = i18n.t('common.statusConnected');
},
socketioDisconnected: (state) => {
state.isConnected = false;
state.currentStatus = i18n.t('common.statusDisconnected');
},
},
extraReducers(builder) {
builder.addCase(invocationComplete, (state, action) => {
const { data, timestamp } = action.payload;
state.isProcessing = false;
state.isCancelable = false;
if (isImageOutput(data.result)) {
state.log.push({
timestamp: dateFormat(timestamp, 'isoDateTime'),
message: `Generated: ${data.result.image.image_name}`,
level: 'info',
});
}
});
},
});
@ -306,6 +334,8 @@ export const {
setOpenModel,
setCancelType,
setCancelAfter,
socketioConnected,
socketioDisconnected,
} = systemSlice.actions;
export default systemSlice.reducer;

View File

@ -3,6 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import { ProgressImage } from './events/types';
import { createSession, invokeSession } from 'services/thunks/session';
import { getImage, uploadImage } from './thunks/image';
import { invocationComplete } from 'app/nodesSocketio/actions';
/**
* Just temp until we work out better statuses
@ -17,14 +18,14 @@ export enum STATUS {
* Type for the temp (?) API slice.
*/
export interface APIState {
sessionId: string | null;
sessionId: string;
progressImage: ProgressImage | null;
progress: number | null;
status: STATUS;
}
const initialSystemState: APIState = {
sessionId: null,
sessionId: '',
status: STATUS.idle,
progress: null,
progressImage: null,
@ -106,6 +107,9 @@ export const apiSlice = createSlice({
// !HTTP 200
// state.networkStatus = 'idle'
});
builder.addCase(invocationComplete, (state) => {
state.sessionId = '';
});
},
});

View File

@ -3,8 +3,8 @@ import { ImagesService } from 'services/api';
export const IMAGES_PER_PAGE = 20;
export const getNextResultsPage = createAppAsyncThunk(
'results/getInitialResultsPage',
export const receivedResultImagesPage = createAppAsyncThunk(
'results/receivedResultImagesPage',
async (_arg, { getState }) => {
const response = await ImagesService.listImages({
imageType: 'results',
@ -16,8 +16,8 @@ export const getNextResultsPage = createAppAsyncThunk(
}
);
export const getNextUploadsPage = createAppAsyncThunk(
'uploads/getNextUploadsPage',
export const receivedUploadImagesPage = createAppAsyncThunk(
'uploads/receivedUploadImagesPage',
async (_arg, { getState }) => {
const response = await ImagesService.listImages({
imageType: 'uploads',

View File

@ -27,7 +27,7 @@ export const extractTimestampFromImageName = (imageName: string) => {
return Number(timestamp);
};
export const processImageField = (image: ImageField): Image => {
export const deserializeImageField = (image: ImageField): Image => {
const name = image.image_name;
const type = image.image_type;