mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): wip refactor socket events
This commit is contained in:
parent
4e2358cb09
commit
760b4b938c
7
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
7
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
@ -14,6 +14,7 @@
|
||||
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { ImageMetadata, ImageType } from 'services/api';
|
||||
|
||||
/**
|
||||
* TODO:
|
||||
@ -132,12 +133,10 @@ export declare type _Image = {
|
||||
*/
|
||||
export declare type Image = {
|
||||
name: string;
|
||||
type: ImageType;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
width: number;
|
||||
height: number;
|
||||
timestamp: number;
|
||||
metadata?: Metadata;
|
||||
metadata: ImageMetadata;
|
||||
};
|
||||
|
||||
// GalleryImages is an array of Image.
|
||||
|
@ -5,32 +5,39 @@ import {
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
} from 'services/events/types';
|
||||
/**
|
||||
* We can't use redux-toolkit's createSlice() to make these actions,
|
||||
* because they have no associated reducer. They only exist to dispatch
|
||||
* requests to the server via socketio. These actions will be handled
|
||||
* by the middleware.
|
||||
*/
|
||||
|
||||
export const emitSubscribe = createAction<string>('socketio/subscribe');
|
||||
export const emitUnsubscribe = createAction<string>('socketio/unsubscribe');
|
||||
|
||||
type Timestamp = {
|
||||
type SocketioPayload = {
|
||||
timestamp: Date;
|
||||
};
|
||||
|
||||
export const socketioConnected = createAction<SocketioPayload>(
|
||||
'socketio/socketioConnected'
|
||||
);
|
||||
|
||||
export const socketioDisconnected = createAction<SocketioPayload>(
|
||||
'socketio/socketioDisconnected'
|
||||
);
|
||||
|
||||
export const socketioSubscribed = createAction<
|
||||
SocketioPayload & { sessionId: string }
|
||||
>('socketio/socketioSubscribed');
|
||||
|
||||
export const socketioUnsubscribed = createAction<
|
||||
SocketioPayload & { sessionId: string }
|
||||
>('socketio/socketioUnsubscribed');
|
||||
|
||||
export const invocationStarted = createAction<
|
||||
{ data: InvocationStartedEvent } & Timestamp
|
||||
SocketioPayload & { data: InvocationStartedEvent }
|
||||
>('socketio/invocationStarted');
|
||||
|
||||
export const invocationComplete = createAction<
|
||||
{ data: InvocationCompleteEvent } & Timestamp
|
||||
SocketioPayload & { data: InvocationCompleteEvent }
|
||||
>('socketio/invocationComplete');
|
||||
|
||||
export const invocationError = createAction<
|
||||
{ data: InvocationErrorEvent } & Timestamp
|
||||
SocketioPayload & { data: InvocationErrorEvent }
|
||||
>('socketio/invocationError');
|
||||
|
||||
export const generatorProgress = createAction<
|
||||
{ data: GeneratorProgressEvent } & Timestamp
|
||||
SocketioPayload & { data: GeneratorProgressEvent }
|
||||
>('socketio/generatorProgress');
|
||||
|
@ -1,15 +0,0 @@
|
||||
import { Socket } from 'socket.io-client';
|
||||
|
||||
const makeSocketIOEmitters = (socketio: Socket) => {
|
||||
return {
|
||||
emitSubscribe: (sessionId: string) => {
|
||||
socketio.emit('subscribe', { session: sessionId });
|
||||
},
|
||||
|
||||
emitUnsubscribe: (sessionId: string) => {
|
||||
socketio.emit('unsubscribe', { session: sessionId });
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export default makeSocketIOEmitters;
|
@ -1,192 +0,0 @@
|
||||
import { MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import dateFormat from 'dateformat';
|
||||
import i18n from 'i18n';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import {
|
||||
addLogEntry,
|
||||
errorOccurred,
|
||||
setCurrentStatus,
|
||||
setIsCancelable,
|
||||
setIsConnected,
|
||||
setIsProcessing,
|
||||
socketioConnected,
|
||||
socketioDisconnected,
|
||||
} from 'features/system/store/systemSlice';
|
||||
|
||||
import {
|
||||
addImage,
|
||||
clearIntermediateImage,
|
||||
setIntermediateImage,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
|
||||
import type { AppDispatch, RootState } from 'app/store';
|
||||
import {
|
||||
GeneratorProgressEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
} from 'services/events/types';
|
||||
import {
|
||||
setProgress,
|
||||
setProgressImage,
|
||||
setSessionId,
|
||||
setStatus,
|
||||
STATUS,
|
||||
} from 'services/apiSlice';
|
||||
import { emitUnsubscribe, invocationComplete } from './actions';
|
||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from 'services/thunks/gallery';
|
||||
import { deserializeImageField } from 'services/util/deserializeImageField';
|
||||
|
||||
/**
|
||||
* Returns an object containing listener callbacks
|
||||
*/
|
||||
const makeSocketIOListeners = (
|
||||
store: MiddlewareAPI<AppDispatch, RootState>
|
||||
) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
return {
|
||||
/**
|
||||
* Callback to run when we receive a 'connect' event.
|
||||
*/
|
||||
onConnect: () => {
|
||||
try {
|
||||
dispatch(socketioConnected());
|
||||
|
||||
// fetch more images, but only if we don't already have images
|
||||
if (!getState().results.ids.length) {
|
||||
dispatch(receivedResultImagesPage());
|
||||
}
|
||||
|
||||
if (!getState().uploads.ids.length) {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'disconnect' event.
|
||||
*/
|
||||
onDisconnect: () => {
|
||||
try {
|
||||
dispatch(socketioDisconnected());
|
||||
dispatch(emitUnsubscribe(getState().api.sessionId));
|
||||
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Disconnected from server`,
|
||||
level: 'warning',
|
||||
})
|
||||
);
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
onInvocationStarted: (data: InvocationStartedEvent) => {
|
||||
console.log('invocation_started', data);
|
||||
dispatch(setStatus(STATUS.busy));
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'generationResult' event.
|
||||
*/
|
||||
onInvocationComplete: (data: InvocationCompleteEvent) => {
|
||||
console.log('invocation_complete', data);
|
||||
try {
|
||||
dispatch(invocationComplete({ data, timestamp: new Date() }));
|
||||
|
||||
const sessionId = data.graph_execution_state_id;
|
||||
if (data.result.type === 'image') {
|
||||
// const resultImage = deserializeImageField(data.result.image);
|
||||
|
||||
// dispatch(resultAdded(resultImage));
|
||||
// // need to update the type for this or figure out how to get these values
|
||||
// dispatch(
|
||||
// addImage({
|
||||
// category: 'result',
|
||||
// image: {
|
||||
// uuid: uuidv4(),
|
||||
// url: resultImage.url,
|
||||
// thumbnail: '',
|
||||
// width: 512,
|
||||
// height: 512,
|
||||
// category: 'result',
|
||||
// name: resultImage.name,
|
||||
// mtime: new Date().getTime(),
|
||||
// },
|
||||
// })
|
||||
// );
|
||||
// dispatch(setIsProcessing(false));
|
||||
// dispatch(setIsCancelable(false));
|
||||
|
||||
// dispatch(
|
||||
// addLogEntry({
|
||||
// timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
// message: `Generated: ${data.result.image.image_name}`,
|
||||
// })
|
||||
// );
|
||||
dispatch(emitUnsubscribe(sessionId));
|
||||
// dispatch(setSessionId(''));
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'progressUpdate' event.
|
||||
* TODO: Add additional progress phases
|
||||
*/
|
||||
onGeneratorProgress: (data: GeneratorProgressEvent) => {
|
||||
try {
|
||||
console.log('generator_progress', data);
|
||||
dispatch(setProgress(data.step / data.total_steps));
|
||||
if (data.progress_image) {
|
||||
dispatch(
|
||||
setIntermediateImage({
|
||||
// need to update the type for this or figure out how to get these values
|
||||
category: 'result',
|
||||
uuid: uuidv4(),
|
||||
mtime: new Date().getTime(),
|
||||
url: data.progress_image.dataURL,
|
||||
thumbnail: '',
|
||||
...data.progress_image,
|
||||
})
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'progressUpdate' event.
|
||||
*/
|
||||
onInvocationError: (data: InvocationErrorEvent) => {
|
||||
const { error } = data;
|
||||
|
||||
try {
|
||||
dispatch(
|
||||
addLogEntry({
|
||||
timestamp: dateFormat(new Date(), 'isoDateTime'),
|
||||
message: `Server error: ${error}`,
|
||||
level: 'error',
|
||||
})
|
||||
);
|
||||
dispatch(errorOccurred());
|
||||
dispatch(clearIntermediateImage());
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
},
|
||||
/**
|
||||
* Callback to run when we receive a 'galleryImages' event.
|
||||
*/
|
||||
};
|
||||
};
|
||||
|
||||
export default makeSocketIOListeners;
|
@ -1,16 +1,26 @@
|
||||
import { Middleware } from '@reduxjs/toolkit';
|
||||
import { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import { io } from 'socket.io-client';
|
||||
|
||||
import makeSocketIOEmitters from './emitters';
|
||||
import makeSocketIOListeners from './listeners';
|
||||
|
||||
import {
|
||||
GeneratorProgressEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
} from 'services/events/types';
|
||||
import { invocationComplete } from './actions';
|
||||
import {
|
||||
generatorProgress,
|
||||
invocationComplete,
|
||||
invocationError,
|
||||
invocationStarted,
|
||||
socketioConnected,
|
||||
socketioDisconnected,
|
||||
socketioSubscribed,
|
||||
} from './actions';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from 'services/thunks/gallery';
|
||||
import { AppDispatch, RootState } from 'app/store';
|
||||
|
||||
const socket_url = `ws://${window.location.host}`;
|
||||
|
||||
@ -22,64 +32,62 @@ const socketio = io(socket_url, {
|
||||
export const socketioMiddleware = () => {
|
||||
let areListenersSet = false;
|
||||
|
||||
const middleware: Middleware = (store) => (next) => (action) => {
|
||||
const { emitSubscribe, emitUnsubscribe } = makeSocketIOEmitters(socketio);
|
||||
const middleware: Middleware =
|
||||
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
|
||||
const { dispatch, getState } = store;
|
||||
const timestamp = new Date();
|
||||
|
||||
const {
|
||||
onConnect,
|
||||
onDisconnect,
|
||||
onInvocationStarted,
|
||||
onGeneratorProgress,
|
||||
onInvocationError,
|
||||
onInvocationComplete,
|
||||
} = makeSocketIOListeners(store);
|
||||
if (!areListenersSet) {
|
||||
socketio.on('connect', () => {
|
||||
dispatch(socketioConnected({ timestamp }));
|
||||
|
||||
if (!areListenersSet) {
|
||||
socketio.on('connect', () => onConnect());
|
||||
socketio.on('disconnect', () => onDisconnect());
|
||||
}
|
||||
if (!getState().results.ids.length) {
|
||||
dispatch(receivedResultImagesPage());
|
||||
}
|
||||
|
||||
areListenersSet = true;
|
||||
if (!getState().uploads.ids.length) {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
}
|
||||
});
|
||||
|
||||
// use the action's match() function for type narrowing and safety
|
||||
if (invocationComplete.match(action)) {
|
||||
emitUnsubscribe(action.payload.data.graph_execution_state_id);
|
||||
socketio.removeAllListeners();
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle redux actions caught by middleware.
|
||||
*/
|
||||
switch (action.type) {
|
||||
case 'socketio/subscribe': {
|
||||
emitSubscribe(action.payload);
|
||||
|
||||
socketio.on('invocation_started', (data: InvocationStartedEvent) =>
|
||||
onInvocationStarted(data)
|
||||
);
|
||||
socketio.on('generator_progress', (data: GeneratorProgressEvent) =>
|
||||
onGeneratorProgress(data)
|
||||
);
|
||||
socketio.on('invocation_error', (data: InvocationErrorEvent) =>
|
||||
onInvocationError(data)
|
||||
);
|
||||
socketio.on('invocation_complete', (data: InvocationCompleteEvent) =>
|
||||
onInvocationComplete(data)
|
||||
);
|
||||
|
||||
break;
|
||||
socketio.on('disconnect', () => {
|
||||
dispatch(socketioDisconnected({ timestamp }));
|
||||
socketio.removeAllListeners();
|
||||
});
|
||||
}
|
||||
|
||||
// case 'socketio/unsubscribe': {
|
||||
// emitUnsubscribe(action.payload);
|
||||
areListenersSet = true;
|
||||
|
||||
// socketio.removeAllListeners();
|
||||
// break;
|
||||
// }
|
||||
}
|
||||
if (invocationComplete.match(action)) {
|
||||
socketio.emit('unsubscribe', {
|
||||
session: action.payload.data.graph_execution_state_id,
|
||||
});
|
||||
|
||||
next(action);
|
||||
};
|
||||
socketio.removeAllListeners();
|
||||
}
|
||||
|
||||
if (socketioSubscribed.match(action)) {
|
||||
socketio.emit('subscribe', { session: action.payload.sessionId });
|
||||
|
||||
socketio.on('invocation_started', (data: InvocationStartedEvent) => {
|
||||
dispatch(invocationStarted({ data, timestamp }));
|
||||
});
|
||||
|
||||
socketio.on('generator_progress', (data: GeneratorProgressEvent) => {
|
||||
dispatch(generatorProgress({ data, timestamp }));
|
||||
});
|
||||
|
||||
socketio.on('invocation_error', (data: InvocationErrorEvent) => {
|
||||
dispatch(invocationError({ data, timestamp }));
|
||||
});
|
||||
|
||||
socketio.on('invocation_complete', (data: InvocationCompleteEvent) => {
|
||||
dispatch(invocationComplete({ data, timestamp }));
|
||||
});
|
||||
}
|
||||
|
||||
next(action);
|
||||
};
|
||||
|
||||
return middleware;
|
||||
};
|
||||
|
@ -34,8 +34,9 @@ import type { RootState } from 'app/store';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
clearInitialImage,
|
||||
initialImageSelected,
|
||||
setInfillMethod,
|
||||
setInitialImage,
|
||||
// setInitialImage,
|
||||
setMaskPath,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { tabMap } from 'features/ui/store/tabMap';
|
||||
@ -146,7 +147,8 @@ const makeSocketIOListeners = (
|
||||
const activeTabName = tabMap[activeTab];
|
||||
switch (activeTabName) {
|
||||
case 'img2img': {
|
||||
dispatch(setInitialImage(newImage));
|
||||
dispatch(initialImageSelected(newImage.uuid));
|
||||
// dispatch(setInitialImage(newImage));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import {
|
||||
} from 'services/api';
|
||||
|
||||
import { _Image } from 'app/invokeai';
|
||||
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
|
||||
|
||||
// fe todo fix model type (frontend uses null, backend uses undefined)
|
||||
// fe todo update front end to store to have whole image field (vs just name)
|
||||
@ -66,10 +67,16 @@ export function buildImg2ImgNode(
|
||||
seamless,
|
||||
img2imgStrength: strength,
|
||||
shouldFitToWidthHeight: fit,
|
||||
initialImage,
|
||||
shouldRandomizeSeed,
|
||||
} = generation;
|
||||
|
||||
const initialImage = initialImageSelector(state);
|
||||
|
||||
if (!initialImage) {
|
||||
// TODO: handle this
|
||||
throw 'no initial image';
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'img2img',
|
||||
prompt,
|
||||
@ -83,8 +90,8 @@ export function buildImg2ImgNode(
|
||||
model,
|
||||
progress_images: shouldDisplayInProgressType === 'full-res',
|
||||
image: {
|
||||
image_name: (initialImage as _Image).name!,
|
||||
image_type: 'result',
|
||||
image_name: initialImage.name,
|
||||
image_type: 'results',
|
||||
},
|
||||
strength,
|
||||
fit,
|
||||
@ -107,7 +114,7 @@ export function buildFacetoolNode(
|
||||
image_name:
|
||||
(typeof initialImage === 'string' ? initialImage : initialImage?.url) ||
|
||||
'',
|
||||
image_type: 'result',
|
||||
image_type: 'results',
|
||||
},
|
||||
strength,
|
||||
};
|
||||
@ -130,7 +137,7 @@ export function buildUpscaleNode(
|
||||
image_name:
|
||||
(typeof initialImage === 'string' ? initialImage : initialImage?.url) ||
|
||||
'',
|
||||
image_type: 'result',
|
||||
image_type: 'results',
|
||||
},
|
||||
strength,
|
||||
level,
|
||||
|
@ -14,8 +14,9 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
||||
import {
|
||||
initialImageSelected,
|
||||
setAllParameters,
|
||||
setInitialImage,
|
||||
// setInitialImage,
|
||||
setSeed,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
||||
@ -129,8 +130,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
const handleClickUseAsInitialImage = () => {
|
||||
if (!currentImage) return;
|
||||
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
||||
dispatch(setInitialImage(currentImage));
|
||||
dispatch(setActiveTab('img2img'));
|
||||
dispatch(initialImageSelected(currentImage.uuid));
|
||||
// dispatch(setInitialImage(currentImage));
|
||||
|
||||
// dispatch(setActiveTab('img2img'));
|
||||
};
|
||||
|
||||
const handleCopyImage = async () => {
|
||||
|
@ -1,29 +1,44 @@
|
||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash';
|
||||
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
|
||||
|
||||
import {
|
||||
gallerySelector,
|
||||
selectedImageSelector,
|
||||
} from '../store/gallerySelectors';
|
||||
import { selectedImageSelector } from '../store/gallerySelectors';
|
||||
import CurrentImageFallback from './CurrentImageFallback';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
|
||||
export const imagesSelector = createSelector(
|
||||
[gallerySelector, uiSelector, selectedImageSelector],
|
||||
(gallery: GalleryState, ui, selectedImage) => {
|
||||
const { currentImage, intermediateImage } = gallery;
|
||||
[uiSelector, selectedImageSelector, systemSelector],
|
||||
(ui, selectedImage, system) => {
|
||||
const { shouldShowImageDetails } = ui;
|
||||
const { progressImage } = system;
|
||||
|
||||
// TODO: Clean this up, this is really gross
|
||||
const imageToDisplay = progressImage
|
||||
? {
|
||||
url: progressImage.dataURL,
|
||||
width: progressImage.width,
|
||||
height: progressImage.height,
|
||||
isProgressImage: true,
|
||||
image: progressImage,
|
||||
}
|
||||
: selectedImage
|
||||
? {
|
||||
url: selectedImage.url,
|
||||
width: selectedImage.metadata.width,
|
||||
height: selectedImage.metadata.height,
|
||||
isProgressImage: false,
|
||||
image: selectedImage,
|
||||
}
|
||||
: null;
|
||||
|
||||
return {
|
||||
imageToDisplay: intermediateImage ? intermediateImage : selectedImage,
|
||||
isIntermediate: Boolean(intermediateImage),
|
||||
shouldShowImageDetails,
|
||||
imageToDisplay,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -34,7 +49,7 @@ export const imagesSelector = createSelector(
|
||||
);
|
||||
|
||||
export default function CurrentImagePreview() {
|
||||
const { shouldShowImageDetails, imageToDisplay, isIntermediate } =
|
||||
const { shouldShowImageDetails, imageToDisplay } =
|
||||
useAppSelector(imagesSelector);
|
||||
console.log(imageToDisplay);
|
||||
return (
|
||||
@ -52,34 +67,42 @@ export default function CurrentImagePreview() {
|
||||
src={imageToDisplay.url}
|
||||
width={imageToDisplay.width}
|
||||
height={imageToDisplay.height}
|
||||
fallback={!isIntermediate ? <CurrentImageFallback /> : undefined}
|
||||
fallback={
|
||||
!imageToDisplay.isProgressImage ? (
|
||||
<CurrentImageFallback />
|
||||
) : undefined
|
||||
}
|
||||
sx={{
|
||||
objectFit: 'contain',
|
||||
maxWidth: '100%',
|
||||
maxHeight: '100%',
|
||||
height: 'auto',
|
||||
position: 'absolute',
|
||||
imageRendering: isIntermediate ? 'pixelated' : 'initial',
|
||||
imageRendering: imageToDisplay.isProgressImage
|
||||
? 'pixelated'
|
||||
: 'initial',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{!shouldShowImageDetails && <NextPrevImageButtons />}
|
||||
{shouldShowImageDetails && imageToDisplay && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
borderRadius: 'base',
|
||||
overflow: 'scroll',
|
||||
maxHeight: APP_METADATA_HEIGHT,
|
||||
}}
|
||||
>
|
||||
{/* <ImageMetadataViewer image={imageToDisplay} /> */}
|
||||
</Box>
|
||||
)}
|
||||
{shouldShowImageDetails &&
|
||||
imageToDisplay &&
|
||||
'metadata' in imageToDisplay.image && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
borderRadius: 'base',
|
||||
overflow: 'scroll',
|
||||
maxHeight: APP_METADATA_HEIGHT,
|
||||
}}
|
||||
>
|
||||
<ImageMetadataViewer image={imageToDisplay.image} />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
@ -14,9 +14,9 @@ import {
|
||||
setCurrentImage,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
initialImageSelected,
|
||||
setAllImageToImageParameters,
|
||||
setAllParameters,
|
||||
setInitialImage,
|
||||
setSeed,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { DragEvent, memo, useState } from 'react';
|
||||
@ -72,10 +72,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const handleMouseOut = () => setIsHovered(false);
|
||||
|
||||
const handleUsePrompt = () => {
|
||||
if (image.metadata?.image?.prompt) {
|
||||
setBothPrompts(image.metadata?.image?.prompt);
|
||||
if (image.metadata?.sd_metadata?.prompt) {
|
||||
setBothPrompts(image.metadata?.sd_metadata?.prompt);
|
||||
}
|
||||
|
||||
toast({
|
||||
title: t('toast.promptSet'),
|
||||
status: 'success',
|
||||
@ -85,7 +84,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseSeed = () => {
|
||||
image.metadata && dispatch(setSeed(image.metadata.image.seed));
|
||||
image.metadata.sd_metadata &&
|
||||
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
|
||||
toast({
|
||||
title: t('toast.seedSet'),
|
||||
status: 'success',
|
||||
@ -95,16 +95,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleSendToImageToImage = () => {
|
||||
// dispatch(setInitialImage(image));
|
||||
if (activeTabName !== 'img2img') {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
}
|
||||
toast({
|
||||
title: t('toast.sentToImageToImage'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
dispatch(initialImageSelected(image.name));
|
||||
};
|
||||
|
||||
const handleSendToCanvas = () => {
|
||||
@ -125,7 +116,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseAllParameters = () => {
|
||||
metadata && dispatch(setAllParameters(metadata));
|
||||
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
|
||||
toast({
|
||||
title: t('toast.parametersSet'),
|
||||
status: 'success',
|
||||
@ -135,11 +126,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseInitialImage = async () => {
|
||||
if (metadata?.image?.init_image_path) {
|
||||
const response = await fetch(metadata.image.init_image_path);
|
||||
if (metadata.sd_metadata?.image?.init_image_path) {
|
||||
const response = await fetch(
|
||||
metadata.sd_metadata?.image?.init_image_path
|
||||
);
|
||||
if (response.ok) {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
dispatch(setAllImageToImageParameters(metadata));
|
||||
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
|
||||
toast({
|
||||
title: t('toast.initialImageSet'),
|
||||
status: 'success',
|
||||
@ -160,7 +153,6 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
|
||||
const handleSelectImage = () => {
|
||||
dispatch(imageSelected(image.name));
|
||||
// dispatch(setCurrentImage(image));
|
||||
};
|
||||
|
||||
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
|
||||
@ -183,28 +175,30 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUsePrompt}
|
||||
isDisabled={image?.metadata?.image?.prompt === undefined}
|
||||
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
|
||||
>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
onClickCapture={handleUseSeed}
|
||||
isDisabled={image?.metadata?.image?.seed === undefined}
|
||||
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={
|
||||
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
|
||||
!['txt2img', 'img2img'].includes(
|
||||
image?.metadata?.sd_metadata?.type
|
||||
)
|
||||
}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUseInitialImage}
|
||||
isDisabled={image?.metadata?.image?.type !== 'img2img'}
|
||||
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
|
||||
>
|
||||
{t('parameters.useInitImg')}
|
||||
</MenuItem>
|
||||
|
@ -18,7 +18,7 @@ import {
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
setImg2imgStrength,
|
||||
setInitialImage,
|
||||
// setInitialImage,
|
||||
setMaskPath,
|
||||
setPerlin,
|
||||
setSampler,
|
||||
@ -113,14 +113,14 @@ const MetadataItem = ({
|
||||
};
|
||||
|
||||
type ImageMetadataViewerProps = {
|
||||
image: InvokeAI._Image;
|
||||
image: InvokeAI.Image;
|
||||
};
|
||||
|
||||
// TODO: I don't know if this is needed.
|
||||
const memoEqualityCheck = (
|
||||
prev: ImageMetadataViewerProps,
|
||||
next: ImageMetadataViewerProps
|
||||
) => prev.image.uuid === next.image.uuid;
|
||||
) => prev.image.name === next.image.name;
|
||||
|
||||
// TODO: Show more interesting information in this component.
|
||||
|
||||
@ -137,8 +137,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
dispatch(setShouldShowImageDetails(false));
|
||||
});
|
||||
|
||||
const metadata = image?.metadata?.image || {};
|
||||
const dreamPrompt = image?.dreamPrompt;
|
||||
const metadata = image?.metadata.sd_metadata || {};
|
||||
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
|
||||
|
||||
const {
|
||||
cfg_scale,
|
||||
@ -160,6 +160,7 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
type,
|
||||
variations,
|
||||
width,
|
||||
model_weights,
|
||||
} = metadata;
|
||||
|
||||
const { t } = useTranslation();
|
||||
@ -193,8 +194,8 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
{Object.keys(metadata).length > 0 ? (
|
||||
<>
|
||||
{type && <MetadataItem label="Generation type" value={type} />}
|
||||
{image.metadata?.model_weights && (
|
||||
<MetadataItem label="Model" value={image.metadata.model_weights} />
|
||||
{model_weights && (
|
||||
<MetadataItem label="Model" value={model_weights} />
|
||||
)}
|
||||
{['esrgan', 'gfpgan'].includes(type) && (
|
||||
<MetadataItem label="Original image" value={orig_path} />
|
||||
@ -288,14 +289,14 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
onClick={() => dispatch(setHeight(height))}
|
||||
/>
|
||||
)}
|
||||
{init_image_path && (
|
||||
{/* {init_image_path && (
|
||||
<MetadataItem
|
||||
label="Initial image"
|
||||
value={init_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)}
|
||||
)} */}
|
||||
{mask_image_path && (
|
||||
<MetadataItem
|
||||
label="Mask image"
|
||||
|
@ -7,8 +7,16 @@ import {
|
||||
uiSelector,
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash';
|
||||
import { selectResultsAll, selectResultsEntities } from './resultsSlice';
|
||||
import { selectUploadsAll, selectUploadsEntities } from './uploadsSlice';
|
||||
import {
|
||||
selectResultsAll,
|
||||
selectResultsById,
|
||||
selectResultsEntities,
|
||||
} from './resultsSlice';
|
||||
import {
|
||||
selectUploadsAll,
|
||||
selectUploadsById,
|
||||
selectUploadsEntities,
|
||||
} from './uploadsSlice';
|
||||
|
||||
export const gallerySelector = (state: RootState) => state.gallery;
|
||||
|
||||
|
@ -3,13 +3,14 @@ import { Image } from 'app/invokeai';
|
||||
import { invocationComplete } from 'app/nodesSocketio/actions';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import { socketioConnected } from 'features/system/store/systemSlice';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { deserializeImageField } from 'services/util/deserializeImageField';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
// import { deserializeImageField } from 'services/util/deserializeImageField';
|
||||
import { setCurrentCategory } from './gallerySlice';
|
||||
|
||||
// use `createEntityAdapter` to create a slice for results images
|
||||
@ -21,7 +22,7 @@ export const resultsAdapter = createEntityAdapter<Image>({
|
||||
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
|
||||
selectId: (image) => image.name,
|
||||
// Order all images by their time (in descending order)
|
||||
sortComparer: (a, b) => b.timestamp - a.timestamp,
|
||||
sortComparer: (a, b) => b.metadata.timestamp - a.metadata.timestamp,
|
||||
});
|
||||
|
||||
// This type is intersected with the Entity type to create the shape of the state
|
||||
@ -61,7 +62,9 @@ const resultsSlice = createSlice({
|
||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||
const { items, page, pages } = action.payload;
|
||||
|
||||
const resultImages = items.map((image) => deserializeImageField(image));
|
||||
const resultImages = items.map((image) =>
|
||||
deserializeImageResponse(image)
|
||||
);
|
||||
|
||||
// use the adapter reducer to append all the results to state
|
||||
resultsAdapter.addMany(state, resultImages);
|
||||
|
@ -7,10 +7,11 @@ import {
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { deserializeImageField } from 'services/util/deserializeImageField';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
|
||||
export const uploadsAdapter = createEntityAdapter<Image>({
|
||||
selectId: (image) => image.name,
|
||||
sortComparer: (a, b) => b.timestamp - a.timestamp,
|
||||
sortComparer: (a, b) => b.metadata.timestamp - a.metadata.timestamp,
|
||||
});
|
||||
|
||||
type AdditionalUploadsState = {
|
||||
@ -38,7 +39,7 @@ const uploadsSlice = createSlice({
|
||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||
const { items, page, pages } = action.payload;
|
||||
|
||||
const images = items.map((image) => deserializeImageField(image));
|
||||
const images = items.map((image) => deserializeImageResponse(image));
|
||||
|
||||
uploadsAdapter.addMany(state, images);
|
||||
|
||||
|
@ -1,5 +1,11 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { gallerySelector } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
selectResultsById,
|
||||
selectResultsEntities,
|
||||
} from 'features/gallery/store/resultsSlice';
|
||||
import { selectUploadsById } from 'features/gallery/store/uploadsSlice';
|
||||
import { isEqual } from 'lodash';
|
||||
|
||||
export const generationSelector = (state: RootState) => state.generation;
|
||||
@ -15,3 +21,15 @@ export const mayGenerateMultipleImagesSelector = createSelector(
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const initialImageSelector = createSelector(
|
||||
[(state: RootState) => state, generationSelector],
|
||||
(state, generation) => {
|
||||
const { initialImage: initialImageName } = generation;
|
||||
|
||||
return (
|
||||
selectResultsById(state, initialImageName as string) ??
|
||||
selectUploadsById(state, initialImageName as string)
|
||||
);
|
||||
}
|
||||
);
|
||||
|
@ -317,12 +317,12 @@ export const generationSlice = createSlice({
|
||||
setShouldRandomizeSeed: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldRandomizeSeed = action.payload;
|
||||
},
|
||||
setInitialImage: (
|
||||
state,
|
||||
action: PayloadAction<InvokeAI._Image | string>
|
||||
) => {
|
||||
state.initialImage = action.payload;
|
||||
},
|
||||
// setInitialImage: (
|
||||
// state,
|
||||
// action: PayloadAction<InvokeAI._Image | string>
|
||||
// ) => {
|
||||
// state.initialImage = action.payload;
|
||||
// },
|
||||
clearInitialImage: (state) => {
|
||||
state.initialImage = undefined;
|
||||
},
|
||||
@ -353,6 +353,9 @@ export const generationSlice = createSlice({
|
||||
setVerticalSymmetrySteps: (state, action: PayloadAction<number>) => {
|
||||
state.verticalSymmetrySteps = action.payload;
|
||||
},
|
||||
initialImageSelected: (state, action: PayloadAction<string>) => {
|
||||
state.initialImage = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@ -368,7 +371,7 @@ export const {
|
||||
setHeight,
|
||||
setImg2imgStrength,
|
||||
setInfillMethod,
|
||||
setInitialImage,
|
||||
// setInitialImage,
|
||||
setIterations,
|
||||
setMaskPath,
|
||||
setParameter,
|
||||
@ -394,6 +397,7 @@ export const {
|
||||
setShouldUseSymmetry,
|
||||
setHorizontalSymmetrySteps,
|
||||
setVerticalSymmetrySteps,
|
||||
initialImageSelected,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export default generationSlice.reducer;
|
||||
|
@ -1,13 +1,22 @@
|
||||
import { ExpandedIndex, UseToastOptions } from '@chakra-ui/react';
|
||||
import { ExpandedIndex, StatHelpText, UseToastOptions } from '@chakra-ui/react';
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { invocationComplete } from 'app/nodesSocketio/actions';
|
||||
import {
|
||||
generatorProgress,
|
||||
invocationComplete,
|
||||
invocationError,
|
||||
invocationStarted,
|
||||
socketioConnected,
|
||||
socketioDisconnected,
|
||||
} from 'app/nodesSocketio/actions';
|
||||
import { resultAdded } from 'features/gallery/store/resultsSlice';
|
||||
import dateFormat from 'dateformat';
|
||||
|
||||
import i18n from 'i18n';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { initialImageSelected } from 'features/parameters/store/generationSlice';
|
||||
|
||||
export type LogLevel = 'info' | 'warning' | 'error';
|
||||
|
||||
@ -61,6 +70,7 @@ export interface SystemState
|
||||
cancelType: CancelType;
|
||||
cancelAfter: number | null;
|
||||
};
|
||||
progressImage: ProgressImage | null;
|
||||
}
|
||||
|
||||
const initialSystemState: SystemState = {
|
||||
@ -103,6 +113,7 @@ const initialSystemState: SystemState = {
|
||||
cancelType: 'immediate',
|
||||
cancelAfter: null,
|
||||
},
|
||||
progressImage: null,
|
||||
};
|
||||
|
||||
export const systemSlice = createSlice({
|
||||
@ -276,21 +287,63 @@ export const systemSlice = createSlice({
|
||||
setCancelAfter: (state, action: PayloadAction<number | null>) => {
|
||||
state.cancelOptions.cancelAfter = action.payload;
|
||||
},
|
||||
socketioConnected: (state) => {
|
||||
state.isConnected = true;
|
||||
state.currentStatus = i18n.t('common.statusConnected');
|
||||
},
|
||||
socketioDisconnected: (state) => {
|
||||
state.isConnected = false;
|
||||
state.currentStatus = i18n.t('common.statusDisconnected');
|
||||
},
|
||||
// socketioConnected: (state) => {
|
||||
// state.isConnected = true;
|
||||
// state.currentStatus = i18n.t('common.statusConnected');
|
||||
// },
|
||||
// socketioDisconnected: (state) => {
|
||||
// state.isConnected = false;
|
||||
// state.currentStatus = i18n.t('common.statusDisconnected');
|
||||
// },
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(socketioConnected, (state, action) => {
|
||||
const { timestamp } = action.payload;
|
||||
|
||||
state.isConnected = true;
|
||||
state.currentStatus = i18n.t('common.statusConnected');
|
||||
state.log.push({
|
||||
timestamp: dateFormat(timestamp, 'isoDateTime'),
|
||||
message: `Connected to server`,
|
||||
level: 'info',
|
||||
});
|
||||
});
|
||||
|
||||
builder.addCase(socketioDisconnected, (state, action) => {
|
||||
const { timestamp } = action.payload;
|
||||
|
||||
state.isConnected = false;
|
||||
state.currentStatus = i18n.t('common.statusDisconnected');
|
||||
state.log.push({
|
||||
timestamp: dateFormat(timestamp, 'isoDateTime'),
|
||||
message: `Disconnected from server`,
|
||||
level: 'warning',
|
||||
});
|
||||
});
|
||||
|
||||
builder.addCase(invocationStarted, (state, action) => {
|
||||
state.isProcessing = true;
|
||||
state.currentStatusHasSteps = false;
|
||||
});
|
||||
|
||||
builder.addCase(generatorProgress, (state, action) => {
|
||||
const { step, total_steps, progress_image } = action.payload.data;
|
||||
|
||||
state.currentStatusHasSteps = true;
|
||||
state.currentStep = step + 1; // TODO: step starts at -1, think this is a bug
|
||||
state.totalSteps = total_steps;
|
||||
state.progressImage = progress_image ?? null;
|
||||
});
|
||||
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data, timestamp } = action.payload;
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = false;
|
||||
|
||||
state.isProcessing = false;
|
||||
state.currentStep = 0;
|
||||
state.totalSteps = 0;
|
||||
state.progressImage = null;
|
||||
|
||||
// TODO: handle logging for other invocation types
|
||||
if (isImageOutput(data.result)) {
|
||||
state.log.push({
|
||||
timestamp: dateFormat(timestamp, 'isoDateTime'),
|
||||
@ -299,6 +352,29 @@ export const systemSlice = createSlice({
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
builder.addCase(invocationError, (state, action) => {
|
||||
const { data, timestamp } = action.payload;
|
||||
|
||||
state.log.push({
|
||||
timestamp: dateFormat(timestamp, 'isoDateTime'),
|
||||
message: `Server error: ${data.error}`,
|
||||
level: 'error',
|
||||
});
|
||||
|
||||
state.wasErrorSeen = true;
|
||||
state.progressImage = null;
|
||||
state.isProcessing = false;
|
||||
});
|
||||
|
||||
builder.addCase(initialImageSelected, (state) => {
|
||||
state.toastQueue.push({
|
||||
title: i18n.t('toast.sentToImageToImage'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
@ -334,8 +410,8 @@ export const {
|
||||
setOpenModel,
|
||||
setCancelType,
|
||||
setCancelAfter,
|
||||
socketioConnected,
|
||||
socketioDisconnected,
|
||||
// socketioConnected,
|
||||
// socketioDisconnected,
|
||||
} = systemSlice.actions;
|
||||
|
||||
export default systemSlice.reducer;
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { Box, BoxProps, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { setInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { initialImageSelected } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
activeTabNameSelector,
|
||||
uiSelector,
|
||||
@ -47,7 +47,7 @@ const InvokeWorkarea = (props: InvokeWorkareaProps) => {
|
||||
const image = getImageByUuid(uuid);
|
||||
if (!image) return;
|
||||
if (activeTabName === 'img2img') {
|
||||
dispatch(setInitialImage(image));
|
||||
dispatch(initialImageSelected(image.uuid));
|
||||
} else if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
}
|
||||
|
@ -1,14 +1,12 @@
|
||||
import { Flex, Image, Text, useToast } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import ImageUploaderIconButton from 'common/components/ImageUploaderIconButton';
|
||||
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function InitImagePreview() {
|
||||
const initialImage = useAppSelector(
|
||||
(state: RootState) => state.generation.initialImage
|
||||
);
|
||||
const initialImage = useAppSelector(initialImageSelector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
|
13
invokeai/frontend/web/src/features/ui/store/extraReducers.ts
Normal file
13
invokeai/frontend/web/src/features/ui/store/extraReducers.ts
Normal file
@ -0,0 +1,13 @@
|
||||
import { InvokeTabName, tabMap } from './tabMap';
|
||||
import { UIState } from './uiTypes';
|
||||
|
||||
export const setActiveTabReducer = (
|
||||
state: UIState,
|
||||
newActiveTab: number | InvokeTabName
|
||||
) => {
|
||||
if (typeof newActiveTab === 'number') {
|
||||
state.activeTab = newActiveTab;
|
||||
} else {
|
||||
state.activeTab = tabMap.indexOf(newActiveTab);
|
||||
}
|
||||
};
|
@ -1,5 +1,7 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { initialImageSelected } from 'features/parameters/store/generationSlice';
|
||||
import { setActiveTabReducer } from './extraReducers';
|
||||
import { InvokeTabName, tabMap } from './tabMap';
|
||||
import { AddNewModelType, UIState } from './uiTypes';
|
||||
|
||||
@ -25,11 +27,7 @@ export const uiSlice = createSlice({
|
||||
initialState,
|
||||
reducers: {
|
||||
setActiveTab: (state, action: PayloadAction<number | InvokeTabName>) => {
|
||||
if (typeof action.payload === 'number') {
|
||||
state.activeTab = action.payload;
|
||||
} else {
|
||||
state.activeTab = tabMap.indexOf(action.payload);
|
||||
}
|
||||
setActiveTabReducer(state, action.payload);
|
||||
},
|
||||
setCurrentTheme: (state, action: PayloadAction<string>) => {
|
||||
state.currentTheme = action.payload;
|
||||
@ -93,6 +91,13 @@ export const uiSlice = createSlice({
|
||||
}
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(initialImageSelected, (state) => {
|
||||
if (tabMap[state.activeTab] !== 'img2img') {
|
||||
setActiveTabReducer(state, 'img2img');
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { isFulfilled, Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import { emitSubscribe } from 'app/nodesSocketio/actions';
|
||||
import { socketioSubscribed } from 'app/nodesSocketio/actions';
|
||||
import { AppDispatch, RootState } from 'app/store';
|
||||
import { setSessionId } from './apiSlice';
|
||||
import { uploadImage } from './thunks/image';
|
||||
@ -10,7 +10,7 @@ import * as InvokeAI from 'app/invokeai';
|
||||
import { addImage } from 'features/gallery/store/gallerySlice';
|
||||
import { tabMap } from 'features/ui/store/tabMap';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { setInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { initialImageSelected as initialImageSet } from 'features/parameters/store/generationSlice';
|
||||
|
||||
/**
|
||||
* `redux-toolkit` provides nice matching utilities, which can be used as type guards
|
||||
@ -24,12 +24,14 @@ export const invokeMiddleware: Middleware =
|
||||
(store: MiddlewareAPI<AppDispatch, RootState>) => (next) => (action) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
const timestamp = new Date();
|
||||
|
||||
if (isFulfilledCreateSession(action)) {
|
||||
const sessionId = action.payload.id;
|
||||
console.log('createSession.fulfilled');
|
||||
|
||||
dispatch(setSessionId(sessionId));
|
||||
dispatch(emitSubscribe(sessionId));
|
||||
dispatch(socketioSubscribed({ sessionId, timestamp }));
|
||||
dispatch(invokeSession({ sessionId }));
|
||||
} else if (isFulfilledUploadImage(action)) {
|
||||
const uploadLocation = action.payload;
|
||||
@ -54,7 +56,8 @@ export const invokeMiddleware: Middleware =
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(newImage));
|
||||
} else if (activeTabName === 'img2img') {
|
||||
dispatch(setInitialImage(newImage));
|
||||
// dispatch(setInitialImage(newImage));
|
||||
dispatch(initialImageSet(newImage.uuid));
|
||||
}
|
||||
} else {
|
||||
next(action);
|
||||
|
@ -27,6 +27,10 @@ export const extractTimestampFromImageName = (imageName: string) => {
|
||||
return Number(timestamp);
|
||||
};
|
||||
|
||||
/**
|
||||
* Process ImageField objects. These come from `invocation_complete` events and do not contain all the data we need.
|
||||
* This is a WIP on the server side.
|
||||
*/
|
||||
export const deserializeImageField = (image: ImageField): Image => {
|
||||
const name = image.image_name;
|
||||
const type = image.image_type;
|
||||
@ -37,10 +41,13 @@ export const deserializeImageField = (image: ImageField): Image => {
|
||||
|
||||
return {
|
||||
name,
|
||||
type,
|
||||
url,
|
||||
thumbnail,
|
||||
timestamp,
|
||||
height: 512,
|
||||
width: 512,
|
||||
metadata: {
|
||||
timestamp,
|
||||
height: 512, // TODO: need the server to give this to us
|
||||
width: 512,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
@ -0,0 +1,22 @@
|
||||
import { Image } from 'app/invokeai';
|
||||
import { ImageResponse } from 'services/api';
|
||||
|
||||
/**
|
||||
* Process ImageReponse objects, which we get from the `list_images` endpoint.
|
||||
*/
|
||||
export const deserializeImageResponse = (
|
||||
imageResponse: ImageResponse
|
||||
): Image => {
|
||||
const { image_name, image_type, image_url, metadata, thumbnail_url } =
|
||||
imageResponse;
|
||||
|
||||
// TODO: parse metadata - just leaving it as-is for now
|
||||
|
||||
return {
|
||||
name: image_name,
|
||||
type: image_type,
|
||||
url: image_url,
|
||||
thumbnail: thumbnail_url,
|
||||
metadata,
|
||||
};
|
||||
};
|
@ -9,7 +9,9 @@ const { defineMultiStyleConfig, definePartsStyle } =
|
||||
|
||||
const invokeAIFilledTrack = defineStyle((_props) => ({
|
||||
bg: 'accent.600',
|
||||
transition: 'width 0.2s ease-in-out',
|
||||
// TODO: the animation is nice but looks weird bc it is substantially longer than each step
|
||||
// so we get to 100% long before it finishes
|
||||
// transition: 'width 0.2s ease-in-out',
|
||||
_indeterminate: {
|
||||
bgGradient:
|
||||
'linear(to-r, transparent 0%, accent.600 50%, transparent 100%);',
|
||||
|
Loading…
Reference in New Issue
Block a user