feat(ui): wip refactor socket events

This commit is contained in:
psychedelicious 2023-04-05 18:01:32 +10:00
parent 4e2358cb09
commit 760b4b938c
25 changed files with 395 additions and 398 deletions

View File

@ -14,6 +14,7 @@
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageMetadata, ImageType } from 'services/api';
/**
* TODO:
@ -132,12 +133,10 @@ export declare type _Image = {
*/
export declare type Image = {
name: string;
type: ImageType;
url: string;
thumbnail: string;
width: number;
height: number;
timestamp: number;
metadata?: Metadata;
metadata: ImageMetadata;
};
// GalleryImages is an array of Image.

View File

@ -5,32 +5,39 @@ import {
InvocationErrorEvent,
InvocationStartedEvent,
} from 'services/events/types';
/**
* We can't use redux-toolkit's createSlice() to make these actions,
* because they have no associated reducer. They only exist to dispatch
* requests to the server via socketio. These actions will be handled
* by the middleware.
*/
export const emitSubscribe = createAction<string>('socketio/subscribe');
export const emitUnsubscribe = createAction<string>('socketio/unsubscribe');
type Timestamp = {
type SocketioPayload = {
timestamp: Date;
};
export const socketioConnected = createAction<SocketioPayload>(
'socketio/socketioConnected'
);
export const socketioDisconnected = createAction<SocketioPayload>(
'socketio/socketioDisconnected'
);
export const socketioSubscribed = createAction<
SocketioPayload & { sessionId: string }
>('socketio/socketioSubscribed');
export const socketioUnsubscribed = createAction<
SocketioPayload & { sessionId: string }
>('socketio/socketioUnsubscribed');
export const invocationStarted = createAction<
{ data: InvocationStartedEvent } & Timestamp
SocketioPayload & { data: InvocationStartedEvent }
>('socketio/invocationStarted');
export const invocationComplete = createAction<
{ data: InvocationCompleteEvent } & Timestamp
SocketioPayload & { data: InvocationCompleteEvent }
>('socketio/invocationComplete');
export const invocationError = createAction<
{ data: InvocationErrorEvent } & Timestamp
SocketioPayload & { data: InvocationErrorEvent }
>('socketio/invocationError');
export const generatorProgress = createAction<
{ data: GeneratorProgressEvent } & Timestamp
SocketioPayload & { data: GeneratorProgressEvent }
>('socketio/generatorProgress');

View File

@ -1,15 +0,0 @@
import { Socket } from 'socket.io-client';
const makeSocketIOEmitters = (socketio: Socket) => {
return {
emitSubscribe: (sessionId: string) => {
socketio.emit('subscribe', { session: sessionId });
},
emitUnsubscribe: (sessionId: string) => {
socketio.emit('unsubscribe', { session: sessionId });
},
};
};
export default makeSocketIOEmitters;

View File

@ -1,192 +0,0 @@
import { MiddlewareAPI } from '@reduxjs/toolkit';
import dateFormat from 'dateformat';
import i18n from 'i18n';
import { v4 as uuidv4 } from 'uuid';
import {
addLogEntry,
errorOccurred,
setCurrentStatus,
setIsCancelable,
setIsConnected,
setIsProcessing,
socketioConnected,
socketioDisconnected,
} from 'features/system/store/systemSlice';
import {
addImage,
clearIntermediateImage,
setIntermediateImage,
} from 'features/gallery/store/gallerySlice';
import type { AppDispatch, RootState } from 'app/store';
import {
GeneratorProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationStartedEvent,
} from 'services/events/types';
import {
setProgress,
setProgressImage,
setSessionId,
setStatus,
STATUS,
} from 'services/apiSlice';
import { emitUnsubscribe, invocationComplete } from './actions';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { deserializeImageField } from 'services/util/deserializeImageField';
/**
* Returns an object containing listener callbacks
*/
const makeSocketIOListeners = (
store: MiddlewareAPI<AppDispatch, RootState>
) => {
const { dispatch, getState } = store;
return {
/**
* Callback to run when we receive a 'connect' event.
*/
onConnect: () => {
try {
dispatch(socketioConnected());
// fetch more images, but only if we don't already have images
if (!getState().results.ids.length) {
dispatch(receivedResultImagesPage());
}
if (!getState().uploads.ids.length) {
dispatch(receivedUploadImagesPage());
}
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'disconnect' event.
*/
onDisconnect: () => {
try {
dispatch(socketioDisconnected());
dispatch(emitUnsubscribe(getState().api.sessionId));
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Disconnected from server`,
level: 'warning',
})
);
} catch (e) {
console.error(e);
}
},
onInvocationStarted: (data: InvocationStartedEvent) => {
console.log('invocation_started', data);
dispatch(setStatus(STATUS.busy));
},
/**
* Callback to run when we receive a 'generationResult' event.
*/
onInvocationComplete: (data: InvocationCompleteEvent) => {
console.log('invocation_complete', data);
try {
dispatch(invocationComplete({ data, timestamp: new Date() }));
const sessionId = data.graph_execution_state_id;
if (data.result.type === 'image') {
// const resultImage = deserializeImageField(data.result.image);
// dispatch(resultAdded(resultImage));
// // need to update the type for this or figure out how to get these values
// dispatch(
// addImage({
// category: 'result',
// image: {
// uuid: uuidv4(),
// url: resultImage.url,
// thumbnail: '',
// width: 512,
// height: 512,
// category: 'result',
// name: resultImage.name,
// mtime: new Date().getTime(),
// },
// })
// );
// dispatch(setIsProcessing(false));
// dispatch(setIsCancelable(false));
// dispatch(
// addLogEntry({
// timestamp: dateFormat(new Date(), 'isoDateTime'),
// message: `Generated: ${data.result.image.image_name}`,
// })
// );
dispatch(emitUnsubscribe(sessionId));
// dispatch(setSessionId(''));
}
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
* TODO: Add additional progress phases
*/
onGeneratorProgress: (data: GeneratorProgressEvent) => {
try {
console.log('generator_progress', data);
dispatch(setProgress(data.step / data.total_steps));
if (data.progress_image) {
dispatch(
setIntermediateImage({
// need to update the type for this or figure out how to get these values
category: 'result',
uuid: uuidv4(),
mtime: new Date().getTime(),
url: data.progress_image.dataURL,
thumbnail: '',
...data.progress_image,
})
);
}
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'progressUpdate' event.
*/
onInvocationError: (data: InvocationErrorEvent) => {
const { error } = data;
try {
dispatch(
addLogEntry({
timestamp: dateFormat(new Date(), 'isoDateTime'),
message: `Server error: ${error}`,
level: 'error',
})
);
dispatch(errorOccurred());
dispatch(clearIntermediateImage());
} catch (e) {
console.error(e);
}
},
/**
* Callback to run when we receive a 'galleryImages' event.
*/
};
};
export default makeSocketIOListeners;

View File

@ -1,16 +1,26 @@
import { Middleware } from '@reduxjs/toolkit';
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { io } from 'socket.io-client';
import makeSocketIOEmitters from './emitters';
import makeSocketIOListeners from './listeners';
import {
GeneratorProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
InvocationStartedEvent,
} from 'services/events/types';
import { invocationComplete } from './actions';
import {
generatorProgress,
invocationComplete,
invocationError,
invocationStarted,
socketioConnected,
socketioDisconnected,
socketioSubscribed,
} from './actions';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { AppDispatch, RootState } from 'app/store';
const socket_url = `ws://${window.location.host}`;
@ -22,64 +32,62 @@ const socketio = io(socket_url, {
export const socketioMiddleware = () => {
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {
const { emitSubscribe, emitUnsubscribe } = makeSocketIOEmitters(socketio);
const middleware: Middleware =
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
const { dispatch, getState } = store;
const timestamp = new Date();
const {
onConnect,
onDisconnect,
onInvocationStarted,
onGeneratorProgress,
onInvocationError,
onInvocationComplete,
} = makeSocketIOListeners(store);
if (!areListenersSet) {
socketio.on('connect', () => {
dispatch(socketioConnected({ timestamp }));
if (!areListenersSet) {
socketio.on('connect', () => onConnect());
socketio.on('disconnect', () => onDisconnect());
}
if (!getState().results.ids.length) {
dispatch(receivedResultImagesPage());
}
areListenersSet = true;
if (!getState().uploads.ids.length) {
dispatch(receivedUploadImagesPage());
}
});
// use the action's match() function for type narrowing and safety
if (invocationComplete.match(action)) {
emitUnsubscribe(action.payload.data.graph_execution_state_id);
socketio.removeAllListeners();
}
/**
* Handle redux actions caught by middleware.
*/
switch (action.type) {
case 'socketio/subscribe': {
emitSubscribe(action.payload);
socketio.on('invocation_started', (data: InvocationStartedEvent) =>
onInvocationStarted(data)
);
socketio.on('generator_progress', (data: GeneratorProgressEvent) =>
onGeneratorProgress(data)
);
socketio.on('invocation_error', (data: InvocationErrorEvent) =>
onInvocationError(data)
);
socketio.on('invocation_complete', (data: InvocationCompleteEvent) =>
onInvocationComplete(data)
);
break;
socketio.on('disconnect', () => {
dispatch(socketioDisconnected({ timestamp }));
socketio.removeAllListeners();
});
}
// case 'socketio/unsubscribe': {
// emitUnsubscribe(action.payload);
areListenersSet = true;
// socketio.removeAllListeners();
// break;
// }
}
if (invocationComplete.match(action)) {
socketio.emit('unsubscribe', {
session: action.payload.data.graph_execution_state_id,
});
next(action);
};
socketio.removeAllListeners();
}
if (socketioSubscribed.match(action)) {
socketio.emit('subscribe', { session: action.payload.sessionId });
socketio.on('invocation_started', (data: InvocationStartedEvent) => {
dispatch(invocationStarted({ data, timestamp }));
});
socketio.on('generator_progress', (data: GeneratorProgressEvent) => {
dispatch(generatorProgress({ data, timestamp }));
});
socketio.on('invocation_error', (data: InvocationErrorEvent) => {
dispatch(invocationError({ data, timestamp }));
});
socketio.on('invocation_complete', (data: InvocationCompleteEvent) => {
dispatch(invocationComplete({ data, timestamp }));
});
}
next(action);
};
return middleware;
};

View File

@ -34,8 +34,9 @@ import type { RootState } from 'app/store';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
clearInitialImage,
initialImageSelected,
setInfillMethod,
setInitialImage,
// setInitialImage,
setMaskPath,
} from 'features/parameters/store/generationSlice';
import { tabMap } from 'features/ui/store/tabMap';
@ -146,7 +147,8 @@ const makeSocketIOListeners = (
const activeTabName = tabMap[activeTab];
switch (activeTabName) {
case 'img2img': {
dispatch(setInitialImage(newImage));
dispatch(initialImageSelected(newImage.uuid));
// dispatch(setInitialImage(newImage));
break;
}
}

View File

@ -7,6 +7,7 @@ import {
} from 'services/api';
import { _Image } from 'app/invokeai';
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
// fe todo fix model type (frontend uses null, backend uses undefined)
// fe todo update front end to store to have whole image field (vs just name)
@ -66,10 +67,16 @@ export function buildImg2ImgNode(
seamless,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
initialImage,
shouldRandomizeSeed,
} = generation;
const initialImage = initialImageSelector(state);
if (!initialImage) {
// TODO: handle this
throw 'no initial image';
}
return {
type: 'img2img',
prompt,
@ -83,8 +90,8 @@ export function buildImg2ImgNode(
model,
progress_images: shouldDisplayInProgressType === 'full-res',
image: {
image_name: (initialImage as _Image).name!,
image_type: 'result',
image_name: initialImage.name,
image_type: 'results',
},
strength,
fit,
@ -107,7 +114,7 @@ export function buildFacetoolNode(
image_name:
(typeof initialImage === 'string' ? initialImage : initialImage?.url) ||
'',
image_type: 'result',
image_type: 'results',
},
strength,
};
@ -130,7 +137,7 @@ export function buildUpscaleNode(
image_name:
(typeof initialImage === 'string' ? initialImage : initialImage?.url) ||
'',
image_type: 'result',
image_type: 'results',
},
strength,
level,

View File

@ -14,8 +14,9 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
import {
initialImageSelected,
setAllParameters,
setInitialImage,
// setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
@ -129,8 +130,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
const handleClickUseAsInitialImage = () => {
if (!currentImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(setInitialImage(currentImage));
dispatch(setActiveTab('img2img'));
dispatch(initialImageSelected(currentImage.uuid));
// dispatch(setInitialImage(currentImage));
// dispatch(setActiveTab('img2img'));
};
const handleCopyImage = async () => {

View File

@ -1,29 +1,44 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { systemSelector } from 'features/system/store/systemSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import { selectedImageSelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
export const imagesSelector = createSelector(
[gallerySelector, uiSelector, selectedImageSelector],
(gallery: GalleryState, ui, selectedImage) => {
const { currentImage, intermediateImage } = gallery;
[uiSelector, selectedImageSelector, systemSelector],
(ui, selectedImage, system) => {
const { shouldShowImageDetails } = ui;
const { progressImage } = system;
// TODO: Clean this up, this is really gross
const imageToDisplay = progressImage
? {
url: progressImage.dataURL,
width: progressImage.width,
height: progressImage.height,
isProgressImage: true,
image: progressImage,
}
: selectedImage
? {
url: selectedImage.url,
width: selectedImage.metadata.width,
height: selectedImage.metadata.height,
isProgressImage: false,
image: selectedImage,
}
: null;
return {
imageToDisplay: intermediateImage ? intermediateImage : selectedImage,
isIntermediate: Boolean(intermediateImage),
shouldShowImageDetails,
imageToDisplay,
};
},
{
@ -34,7 +49,7 @@ export const imagesSelector = createSelector(
);
export default function CurrentImagePreview() {
const { shouldShowImageDetails, imageToDisplay, isIntermediate } =
const { shouldShowImageDetails, imageToDisplay } =
useAppSelector(imagesSelector);
console.log(imageToDisplay);
return (
@ -52,34 +67,42 @@ export default function CurrentImagePreview() {
src={imageToDisplay.url}
width={imageToDisplay.width}
height={imageToDisplay.height}
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
fallback={
!imageToDisplay.isProgressImage ? (
<CurrentImageFallback />
) : undefined
}
sx={{
objectFit: 'contain',
maxWidth: '100%',
maxHeight: '100%',
height: 'auto',
position: 'absolute',
imageRendering: isIntermediate ? 'pixelated' : 'initial',
imageRendering: imageToDisplay.isProgressImage
? 'pixelated'
: 'initial',
borderRadius: 'base',
}}
/>
)}
{!shouldShowImageDetails && <NextPrevImageButtons />}
{shouldShowImageDetails && imageToDisplay && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
{/* <ImageMetadataViewer image={imageToDisplay} /> */}
</Box>
)}
{shouldShowImageDetails &&
imageToDisplay &&
'metadata' in imageToDisplay.image && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay.image} />
</Box>
)}
</Flex>
);
}

View File

@ -14,9 +14,9 @@ import {
setCurrentImage,
} from 'features/gallery/store/gallerySlice';
import {
initialImageSelected,
setAllImageToImageParameters,
setAllParameters,
setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { DragEvent, memo, useState } from 'react';
@ -72,10 +72,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleMouseOut = () => setIsHovered(false);
const handleUsePrompt = () => {
if (image.metadata?.image?.prompt) {
setBothPrompts(image.metadata?.image?.prompt);
if (image.metadata?.sd_metadata?.prompt) {
setBothPrompts(image.metadata?.sd_metadata?.prompt);
}
toast({
title: t('toast.promptSet'),
status: 'success',
@ -85,7 +84,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseSeed = () => {
image.metadata && dispatch(setSeed(image.metadata.image.seed));
image.metadata.sd_metadata &&
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
toast({
title: t('toast.seedSet'),
status: 'success',
@ -95,16 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleSendToImageToImage = () => {
// dispatch(setInitialImage(image));
if (activeTabName !== 'img2img') {
dispatch(setActiveTab('img2img'));
}
toast({
title: t('toast.sentToImageToImage'),
status: 'success',
duration: 2500,
isClosable: true,
});
dispatch(initialImageSelected(image.name));
};
const handleSendToCanvas = () => {
@ -125,7 +116,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseAllParameters = () => {
metadata && dispatch(setAllParameters(metadata));
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
toast({
title: t('toast.parametersSet'),
status: 'success',
@ -135,11 +126,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseInitialImage = async () => {
if (metadata?.image?.init_image_path) {
const response = await fetch(metadata.image.init_image_path);
if (metadata.sd_metadata?.image?.init_image_path) {
const response = await fetch(
metadata.sd_metadata?.image?.init_image_path
);
if (response.ok) {
dispatch(setActiveTab('img2img'));
dispatch(setAllImageToImageParameters(metadata));
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
toast({
title: t('toast.initialImageSet'),
status: 'success',
@ -160,7 +153,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleSelectImage = () => {
dispatch(imageSelected(image.name));
// dispatch(setCurrentImage(image));
};
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
@ -183,28 +175,30 @@ const HoverableImage = memo((props: HoverableImageProps) => {
</MenuItem>
<MenuItem
onClickCapture={handleUsePrompt}
isDisabled={image?.metadata?.image?.prompt === undefined}
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
onClickCapture={handleUseSeed}
isDisabled={image?.metadata?.image?.seed === undefined}
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
onClickCapture={handleUseAllParameters}
isDisabled={
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
!['txt2img', 'img2img'].includes(
image?.metadata?.sd_metadata?.type
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
onClickCapture={handleUseInitialImage}
isDisabled={image?.metadata?.image?.type !== 'img2img'}
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem>

View File

@ -18,7 +18,7 @@ import {
setCfgScale,
setHeight,
setImg2imgStrength,
setInitialImage,
// setInitialImage,
setMaskPath,
setPerlin,
setSampler,
@ -113,14 +113,14 @@ const MetadataItem = ({
};
type ImageMetadataViewerProps = {
image: InvokeAI._Image;
image: InvokeAI.Image;
};
// TODO: I don't know if this is needed.
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.uuid === next.image.uuid;
) => prev.image.name === next.image.name;
// TODO: Show more interesting information in this component.
@ -137,8 +137,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
dispatch(setShouldShowImageDetails(false));
});
const metadata = image?.metadata?.image || {};
const dreamPrompt = image?.dreamPrompt;
const metadata = image?.metadata.sd_metadata || {};
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
const {
cfg_scale,
@ -160,6 +160,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
type,
variations,
width,
model_weights,
} = metadata;
const { t } = useTranslation();
@ -193,8 +194,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
{Object.keys(metadata).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
{image.metadata?.model_weights && (
<MetadataItem label="Model" value={image.metadata.model_weights} />
{model_weights && (
<MetadataItem label="Model" value={model_weights} />
)}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} />
@ -288,14 +289,14 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
onClick={() => dispatch(setHeight(height))}
/>
)}
{init_image_path && (
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)}
)} */}
{mask_image_path && (
<MetadataItem
label="Mask image"

View File

@ -7,8 +7,16 @@ import {
uiSelector,
} from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { selectResultsAll, selectResultsEntities } from './resultsSlice';
import { selectUploadsAll, selectUploadsEntities } from './uploadsSlice';
import {
selectResultsAll,
selectResultsById,
selectResultsEntities,
} from './resultsSlice';
import {
selectUploadsAll,
selectUploadsById,
selectUploadsEntities,
} from './uploadsSlice';
export const gallerySelector = (state: RootState) => state.gallery;

View File

@ -3,13 +3,14 @@ import { Image } from 'app/invokeai';
import { invocationComplete } from 'app/nodesSocketio/actions';
import { RootState } from 'app/store';
import { socketioConnected } from 'features/system/store/systemSlice';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import { deserializeImageField } from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
// import { deserializeImageField } from 'services/util/deserializeImageField';
import { setCurrentCategory } from './gallerySlice';
// use `createEntityAdapter` to create a slice for results images
@ -21,7 +22,7 @@ export const resultsAdapter = createEntityAdapter<Image>({
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
selectId: (image) => image.name,
// Order all images by their time (in descending order)
sortComparer: (a, b) => b.timestamp - a.timestamp,
sortComparer: (a, b) => b.metadata.timestamp - a.metadata.timestamp,
});
// This type is intersected with the Entity type to create the shape of the state
@ -61,7 +62,9 @@ const resultsSlice = createSlice({
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const resultImages = items.map((image) => deserializeImageField(image));
const resultImages = items.map((image) =>
deserializeImageResponse(image)
);
// use the adapter reducer to append all the results to state
resultsAdapter.addMany(state, resultImages);

View File

@ -7,10 +7,11 @@ import {
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { deserializeImageField } from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
export const uploadsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
sortComparer: (a, b) => b.timestamp - a.timestamp,
sortComparer: (a, b) => b.metadata.timestamp - a.metadata.timestamp,
});
type AdditionalUploadsState = {
@ -38,7 +39,7 @@ const uploadsSlice = createSlice({
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const images = items.map((image) => deserializeImageField(image));
const images = items.map((image) => deserializeImageResponse(image));
uploadsAdapter.addMany(state, images);

View File

@ -1,5 +1,11 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
import {
selectResultsById,
selectResultsEntities,
} from 'features/gallery/store/resultsSlice';
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
import { isEqual } from 'lodash';
export const generationSelector = (state: RootState) => state.generation;
@ -15,3 +21,15 @@ export const mayGenerateMultipleImagesSelector = createSelector(
},
}
);
export const initialImageSelector = createSelector(
[(state: RootState) => state, generationSelector],
(state, generation) => {
const { initialImage: initialImageName } = generation;
return (
selectResultsById(state, initialImageName as string) ??
selectUploadsById(state, initialImageName as string)
);
}
);

View File

@ -317,12 +317,12 @@ export const generationSlice = createSlice({
setShouldRandomizeSeed: (state, action: PayloadAction<boolean>) => {
state.shouldRandomizeSeed = action.payload;
},
setInitialImage: (
state,
action: PayloadAction<InvokeAI._Image | string>
) => {
state.initialImage = action.payload;
},
// setInitialImage: (
// state,
// action: PayloadAction<InvokeAI._Image | string>
// ) => {
// state.initialImage = action.payload;
// },
clearInitialImage: (state) => {
state.initialImage = undefined;
},
@ -353,6 +353,9 @@ export const generationSlice = createSlice({
setVerticalSymmetrySteps: (state, action: PayloadAction<number>) => {
state.verticalSymmetrySteps = action.payload;
},
initialImageSelected: (state, action: PayloadAction<string>) => {
state.initialImage = action.payload;
},
},
});
@ -368,7 +371,7 @@ export const {
setHeight,
setImg2imgStrength,
setInfillMethod,
setInitialImage,
// setInitialImage,
setIterations,
setMaskPath,
setParameter,
@ -394,6 +397,7 @@ export const {
setShouldUseSymmetry,
setHorizontalSymmetrySteps,
setVerticalSymmetrySteps,
initialImageSelected,
} = generationSlice.actions;
export default generationSlice.reducer;

View File

@ -1,13 +1,22 @@
import { ExpandedIndex, UseToastOptions } from '@chakra-ui/react';
import { ExpandedIndex, StatHelpText, UseToastOptions } from '@chakra-ui/react';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { invocationComplete } from 'app/nodesSocketio/actions';
import {
generatorProgress,
invocationComplete,
invocationError,
invocationStarted,
socketioConnected,
socketioDisconnected,
} from 'app/nodesSocketio/actions';
import { resultAdded } from 'features/gallery/store/resultsSlice';
import dateFormat from 'dateformat';
import i18n from 'i18n';
import { isImageOutput } from 'services/types/guards';
import { ProgressImage } from 'services/events/types';
import { initialImageSelected } from 'features/parameters/store/generationSlice';
export type LogLevel = 'info' | 'warning' | 'error';
@ -61,6 +70,7 @@ export interface SystemState
cancelType: CancelType;
cancelAfter: number | null;
};
progressImage: ProgressImage | null;
}
const initialSystemState: SystemState = {
@ -103,6 +113,7 @@ const initialSystemState: SystemState = {
cancelType: 'immediate',
cancelAfter: null,
},
progressImage: null,
};
export const systemSlice = createSlice({
@ -276,21 +287,63 @@ export const systemSlice = createSlice({
setCancelAfter: (state, action: PayloadAction<number | null>) => {
state.cancelOptions.cancelAfter = action.payload;
},
socketioConnected: (state) => {
state.isConnected = true;
state.currentStatus = i18n.t('common.statusConnected');
},
socketioDisconnected: (state) => {
state.isConnected = false;
state.currentStatus = i18n.t('common.statusDisconnected');
},
// socketioConnected: (state) => {
// state.isConnected = true;
// state.currentStatus = i18n.t('common.statusConnected');
// },
// socketioDisconnected: (state) => {
// state.isConnected = false;
// state.currentStatus = i18n.t('common.statusDisconnected');
// },
},
extraReducers(builder) {
builder.addCase(socketioConnected, (state, action) => {
const { timestamp } = action.payload;
state.isConnected = true;
state.currentStatus = i18n.t('common.statusConnected');
state.log.push({
timestamp: dateFormat(timestamp, 'isoDateTime'),
message: `Connected to server`,
level: 'info',
});
});
builder.addCase(socketioDisconnected, (state, action) => {
const { timestamp } = action.payload;
state.isConnected = false;
state.currentStatus = i18n.t('common.statusDisconnected');
state.log.push({
timestamp: dateFormat(timestamp, 'isoDateTime'),
message: `Disconnected from server`,
level: 'warning',
});
});
builder.addCase(invocationStarted, (state, action) => {
state.isProcessing = true;
state.currentStatusHasSteps = false;
});
builder.addCase(generatorProgress, (state, action) => {
const { step, total_steps, progress_image } = action.payload.data;
state.currentStatusHasSteps = true;
state.currentStep = step + 1; // TODO: step starts at -1, think this is a bug
state.totalSteps = total_steps;
state.progressImage = progress_image ?? null;
});
builder.addCase(invocationComplete, (state, action) => {
const { data, timestamp } = action.payload;
state.isProcessing = false;
state.isCancelable = false;
state.isProcessing = false;
state.currentStep = 0;
state.totalSteps = 0;
state.progressImage = null;
// TODO: handle logging for other invocation types
if (isImageOutput(data.result)) {
state.log.push({
timestamp: dateFormat(timestamp, 'isoDateTime'),
@ -299,6 +352,29 @@ export const systemSlice = createSlice({
});
}
});
builder.addCase(invocationError, (state, action) => {
const { data, timestamp } = action.payload;
state.log.push({
timestamp: dateFormat(timestamp, 'isoDateTime'),
message: `Server error: ${data.error}`,
level: 'error',
});
state.wasErrorSeen = true;
state.progressImage = null;
state.isProcessing = false;
});
builder.addCase(initialImageSelected, (state) => {
state.toastQueue.push({
title: i18n.t('toast.sentToImageToImage'),
status: 'success',
duration: 2500,
isClosable: true,
});
});
},
});
@ -334,8 +410,8 @@ export const {
setOpenModel,
setCancelType,
setCancelAfter,
socketioConnected,
socketioDisconnected,
// socketioConnected,
// socketioDisconnected,
} = systemSlice.actions;
export default systemSlice.reducer;

View File

@ -1,7 +1,7 @@
import { Box, BoxProps, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setInitialImage } from 'features/parameters/store/generationSlice';
import { initialImageSelected } from 'features/parameters/store/generationSlice';
import {
activeTabNameSelector,
uiSelector,
@ -47,7 +47,7 @@ const InvokeWorkarea = (props: InvokeWorkareaProps) => {
const image = getImageByUuid(uuid);
if (!image) return;
if (activeTabName === 'img2img') {
dispatch(setInitialImage(image));
dispatch(initialImageSelected(image.uuid));
} else if (activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(image));
}

View File

@ -1,14 +1,12 @@
import { Flex, Image, Text, useToast } from '@chakra-ui/react';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import ImageUploaderIconButton from 'common/components/ImageUploaderIconButton';
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next';
export default function InitImagePreview() {
const initialImage = useAppSelector(
(state: RootState) => state.generation.initialImage
);
const initialImage = useAppSelector(initialImageSelector);
const { t } = useTranslation();

View File

@ -0,0 +1,13 @@
import { InvokeTabName, tabMap } from './tabMap';
import { UIState } from './uiTypes';
export const setActiveTabReducer = (
state: UIState,
newActiveTab: number | InvokeTabName
) => {
if (typeof newActiveTab === 'number') {
state.activeTab = newActiveTab;
} else {
state.activeTab = tabMap.indexOf(newActiveTab);
}
};

View File

@ -1,5 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { initialImageSelected } from 'features/parameters/store/generationSlice';
import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName, tabMap } from './tabMap';
import { AddNewModelType, UIState } from './uiTypes';
@ -25,11 +27,7 @@ export const uiSlice = createSlice({
initialState,
reducers: {
setActiveTab: (state, action: PayloadAction<number | InvokeTabName>) => {
if (typeof action.payload === 'number') {
state.activeTab = action.payload;
} else {
state.activeTab = tabMap.indexOf(action.payload);
}
setActiveTabReducer(state, action.payload);
},
setCurrentTheme: (state, action: PayloadAction<string>) => {
state.currentTheme = action.payload;
@ -93,6 +91,13 @@ export const uiSlice = createSlice({
}
},
},
extraReducers(builder) {
builder.addCase(initialImageSelected, (state) => {
if (tabMap[state.activeTab] !== 'img2img') {
setActiveTabReducer(state, 'img2img');
}
});
},
});
export const {

View File

@ -1,7 +1,7 @@
import { isFulfilled, Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { v4 as uuidv4 } from 'uuid';
import { emitSubscribe } from 'app/nodesSocketio/actions';
import { socketioSubscribed } from 'app/nodesSocketio/actions';
import { AppDispatch, RootState } from 'app/store';
import { setSessionId } from './apiSlice';
import { uploadImage } from './thunks/image';
@ -10,7 +10,7 @@ import * as InvokeAI from 'app/invokeai';
import { addImage } from 'features/gallery/store/gallerySlice';
import { tabMap } from 'features/ui/store/tabMap';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { setInitialImage } from 'features/parameters/store/generationSlice';
import { initialImageSelected as initialImageSet } from 'features/parameters/store/generationSlice';
/**
* `redux-toolkit` provides nice matching utilities, which can be used as type guards
@ -24,12 +24,14 @@ export const invokeMiddleware: Middleware =
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
const { dispatch, getState } = store;
const timestamp = new Date();
if (isFulfilledCreateSession(action)) {
const sessionId = action.payload.id;
console.log('createSession.fulfilled');
dispatch(setSessionId(sessionId));
dispatch(emitSubscribe(sessionId));
dispatch(socketioSubscribed({ sessionId, timestamp }));
dispatch(invokeSession({ sessionId }));
} else if (isFulfilledUploadImage(action)) {
const uploadLocation = action.payload;
@ -54,7 +56,8 @@ export const invokeMiddleware: Middleware =
if (activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(newImage));
} else if (activeTabName === 'img2img') {
dispatch(setInitialImage(newImage));
// dispatch(setInitialImage(newImage));
dispatch(initialImageSet(newImage.uuid));
}
} else {
next(action);

View File

@ -27,6 +27,10 @@ export const extractTimestampFromImageName = (imageName: string) => {
return Number(timestamp);
};
/**
* Process ImageField objects. These come from `invocation_complete` events and do not contain all the data we need.
* This is a WIP on the server side.
*/
export const deserializeImageField = (image: ImageField): Image => {
const name = image.image_name;
const type = image.image_type;
@ -37,10 +41,13 @@ export const deserializeImageField = (image: ImageField): Image => {
return {
name,
type,
url,
thumbnail,
timestamp,
height: 512,
width: 512,
metadata: {
timestamp,
height: 512, // TODO: need the server to give this to us
width: 512,
},
};
};

View File

@ -0,0 +1,22 @@
import { Image } from 'app/invokeai';
import { ImageResponse } from 'services/api';
/**
* Process ImageReponse objects, which we get from the `list_images` endpoint.
*/
export const deserializeImageResponse = (
imageResponse: ImageResponse
): Image => {
const { image_name, image_type, image_url, metadata, thumbnail_url } =
imageResponse;
// TODO: parse metadata - just leaving it as-is for now
return {
name: image_name,
type: image_type,
url: image_url,
thumbnail: thumbnail_url,
metadata,
};
};

View File

@ -9,7 +9,9 @@ const { defineMultiStyleConfig, definePartsStyle } =
const invokeAIFilledTrack = defineStyle((_props) => ({
bg: 'accent.600',
transition: 'width 0.2s ease-in-out',
// TODO: the animation is nice but looks weird bc it is substantially longer than each step
// so we get to 100% long before it finishes
// transition: 'width 0.2s ease-in-out',
_indeterminate: {
bgGradient:
'linear(to-r, transparent 0%, accent.600 50%, transparent 100%);',