ui: redesign followups 8 (#5445)

* feat(ui): get rid of convoluted socket vs appSocket redux actions

There's no need to have `socket...` and `appSocket...` actions.

I did this initially due to a misunderstanding about the sequence of handling from middleware to reducers.

* feat(ui): bump deps

Mainly bumping to get latest `redux-remember`.

A change to socket.io required a change to the types in `useSocketIO`.

* chore(ui): format

* feat(ui): add error handling to redux persistence layer

- Add an error handler to `redux-remember` config using our logger
- Add custom errors representing storage set and get failures
- Update storage driver to raise these accordingly
- wrap method to clear idbkeyval storage and tidy its logic up

* feat(ui): add debuggingLoggerMiddleware

This simply logs every action and a diff of the state change.

Due to the noise this creates, it's not added by default at all. Add it to the middlewares if you want to use it.

* feat(ui): add $socket to window if in dev mode

* fix(ui): do not enable cancel hotkeys on inputs

* fix(ui): use JSON.stringify for ROARR logger serializer

A recent change to ROARR introduced limits to the size of data that will logged. This ends up making our logs far less useful. Change the serializer back to what it was previously.

* feat(ui): change diff util, update debuggerLoggerMiddleware

The previous diff library would present deleted things as `undefined`. Unfortunately, a JSON.stringify cycle will strip those values out. The ROARR logger does this and so the diffs end up being a lot less useful, not showing removed keys.

The new diff library uses a different format for the delta that serializes nicely.

* feat(ui): add migrations to redux persistence layer

- All persisted slices must now have a slice config, consisting of their initial state and a migrate callback. The migrate callback is very simple for now, with no type safety. It adds missing properties to the state. A future enhancement might be to model the each slice's state with e.g. zod and have proper validation and types.
- Persisted slices now have a `_version` property
- The migrate callback is called inside `redux-remember`'s `unserialize` handler. I couldn't figure out a good way to put this into the reducer and do logging (reducers should have no side effects). Also I ran into a weird race condition that I couldn't figure out. And finally, the typings are tricky. This works for now.
- `generationSlice` and `canvasSlice` both need migrations for the new aspect ratio setup, this has been added
- Stuff related to persistence has been moved in to `store.ts` for simplicity

* feat(ui): clean up StorageError class

* fix(ui): scale method default is now 'auto'

* feat(ui): when changing controlnet model, enable autoconfig

* fix(ui): make embedding popover immediately accessible

Prevents hotkeys from being captured when embeddings are still loading.
This commit is contained in:
psychedelicious 2024-01-09 01:11:45 +11:00 committed by GitHub
parent 5779542084
commit 0fc08bb384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 1722 additions and 1567 deletions

View File

@ -73,10 +73,11 @@
"chakra-react-select": "^4.7.6",
"compare-versions": "^6.1.0",
"dateformat": "^5.0.3",
"framer-motion": "^10.16.16",
"i18next": "^23.7.13",
"framer-motion": "^10.17.9",
"i18next": "^23.7.16",
"i18next-http-backend": "^2.4.2",
"idb-keyval": "^6.2.1",
"jsondiffpatch": "^0.6.0",
"konva": "^9.3.0",
"lodash-es": "^4.17.21",
"nanostores": "^0.9.5",
@ -90,7 +91,7 @@
"react-dropzone": "^14.2.3",
"react-error-boundary": "^4.0.12",
"react-hook-form": "^7.49.2",
"react-hotkeys-hook": "4.4.1",
"react-hotkeys-hook": "4.4.3",
"react-i18next": "^14.0.0",
"react-icons": "^4.12.0",
"react-konva": "^18.2.10",
@ -102,10 +103,10 @@
"react-virtuoso": "^4.6.2",
"reactflow": "^11.10.1",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.0.1",
"redux-remember": "^5.1.0",
"roarr": "^7.21.0",
"serialize-error": "^11.0.3",
"socket.io-client": "^4.7.2",
"socket.io-client": "^4.7.3",
"type-fest": "^4.9.0",
"use-debounce": "^10.0.0",
"use-image": "^1.1.1",
@ -121,27 +122,27 @@
"ts-toolbelt": "^9.6.0"
},
"devDependencies": {
"@arthurgeron/eslint-plugin-react-usememo": "^2.2.2",
"@arthurgeron/eslint-plugin-react-usememo": "^2.2.3",
"@chakra-ui/cli": "^2.4.1",
"@storybook/addon-docs": "^7.6.6",
"@storybook/addon-essentials": "^7.6.6",
"@storybook/addon-interactions": "^7.6.6",
"@storybook/addon-links": "^7.6.6",
"@storybook/addon-storysource": "^7.6.6",
"@storybook/blocks": "^7.6.6",
"@storybook/manager-api": "^7.6.6",
"@storybook/react": "^7.6.6",
"@storybook/react-vite": "^7.6.6",
"@storybook/test": "^7.6.6",
"@storybook/theming": "^7.6.6",
"@storybook/addon-docs": "^7.6.7",
"@storybook/addon-essentials": "^7.6.7",
"@storybook/addon-interactions": "^7.6.7",
"@storybook/addon-links": "^7.6.7",
"@storybook/addon-storysource": "^7.6.7",
"@storybook/blocks": "^7.6.7",
"@storybook/manager-api": "^7.6.7",
"@storybook/react": "^7.6.7",
"@storybook/react-vite": "^7.6.7",
"@storybook/test": "^7.6.7",
"@storybook/theming": "^7.6.7",
"@types/dateformat": "^5.0.2",
"@types/lodash-es": "^4.17.12",
"@types/node": "^20.10.6",
"@types/react": "^18.2.46",
"@types/node": "^20.10.7",
"@types/react": "^18.2.47",
"@types/react-dom": "^18.2.18",
"@types/uuid": "^9.0.7",
"@typescript-eslint/eslint-plugin": "^6.16.0",
"@typescript-eslint/parser": "^6.16.0",
"@typescript-eslint/eslint-plugin": "^6.18.0",
"@typescript-eslint/parser": "^6.18.0",
"@vitejs/plugin-react-swc": "^3.5.0",
"concurrently": "^8.2.2",
"eslint": "^8.56.0",
@ -159,10 +160,10 @@
"openapi-typescript": "^6.7.3",
"prettier": "^3.1.1",
"rollup-plugin-visualizer": "^5.12.0",
"storybook": "^7.6.6",
"storybook": "^7.6.7",
"ts-toolbelt": "^9.6.0",
"typescript": "^5.3.3",
"vite": "^5.0.10",
"vite": "^5.0.11",
"vite-plugin-css-injected-by-js": "^3.3.1",
"vite-plugin-dts": "^3.7.0",
"vite-plugin-eslint": "^1.8.1",

File diff suppressed because it is too large Load Diff

View File

@ -43,7 +43,7 @@ export const useSocketIO = () => {
}, [baseUrl]);
const socketOptions = useMemo(() => {
const options: Parameters<typeof io>[0] = {
const options: Partial<ManagerOptions & SocketOptions> = {
timeout: 60000,
path: '/ws/socket.io',
autoConnect: false, // achtung! removing this breaks the dynamic middleware
@ -71,7 +71,7 @@ export const useSocketIO = () => {
setEventListeners({ dispatch, socket });
socket.connect();
if ($isDebugging.get()) {
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
window.$socketOptions = $socketOptions;
console.log('Socket initialized', socket);
}
@ -79,7 +79,7 @@ export const useSocketIO = () => {
$isSocketInitialized.set(true);
return () => {
if ($isDebugging.get()) {
if ($isDebugging.get() || import.meta.env.MODE === 'development') {
window.$socketOptions = undefined;
console.log('Socket teardown', socket);
}

View File

@ -1,9 +1,14 @@
import { createLogWriter } from '@roarr/browser-log-writer';
import { atom } from 'nanostores';
import type { Logger } from 'roarr';
import type { Logger, MessageSerializer } from 'roarr';
import { ROARR, Roarr } from 'roarr';
import { z } from 'zod';
const serializeMessage: MessageSerializer = (message) => {
return JSON.stringify(message);
};
ROARR.serializeMessage = serializeMessage;
ROARR.write = createLogWriter();
export const BASE_CONTEXT = {};

View File

@ -0,0 +1,37 @@
import { StorageError } from 'app/store/enhancers/reduxRemember/errors';
import type { UseStore } from 'idb-keyval';
import {
clear,
createStore as createIDBKeyValStore,
get,
set,
} from 'idb-keyval';
import { action, atom } from 'nanostores';
import type { Driver } from 'redux-remember';
// Create a custom idb-keyval store (just needed to customize the name)
export const $idbKeyValStore = atom<UseStore>(
createIDBKeyValStore('invoke', 'invoke-store')
);
export const clearIdbKeyValStore = action($idbKeyValStore, 'clear', (store) => {
clear(store.get());
});
// Create redux-remember driver, wrapping idb-keyval
export const idbKeyValDriver: Driver = {
getItem: (key) => {
try {
return get(key, $idbKeyValStore.get());
} catch (originalError) {
throw new StorageError({ key, originalError });
}
},
setItem: (key, value) => {
try {
return set(key, value, $idbKeyValStore.get());
} catch (originalError) {
throw new StorageError({ key, value, originalError });
}
},
};

View File

@ -0,0 +1,41 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { PersistError, RehydrateError } from 'redux-remember';
import { serializeError } from 'serialize-error';
export type StorageErrorArgs = {
key: string;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ // any is correct
value?: any;
originalError?: unknown;
};
export class StorageError extends Error {
key: string;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */ // any is correct
value?: any;
originalError?: Error;
constructor({ key, value, originalError }: StorageErrorArgs) {
super(`Error setting ${key}`);
this.name = 'StorageSetError';
this.key = key;
if (value !== undefined) {
this.value = value;
}
if (originalError instanceof Error) {
this.originalError = originalError;
}
}
}
export const errorHandler = (err: PersistError | RehydrateError) => {
const log = logger('system');
if (err instanceof PersistError) {
log.error({ error: serializeError(err) }, 'Problem persisting state');
} else if (err instanceof RehydrateError) {
log.error({ error: serializeError(err) }, 'Problem rehydrating state');
} else {
log.error({ error: parseify(err) }, 'Problem in persistence layer');
}
};

View File

@ -1,30 +0,0 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { controlAdaptersPersistDenylist } from 'features/controlAdapters/store/controlAdaptersPersistDenylist';
import { dynamicPromptsPersistDenylist } from 'features/dynamicPrompts/store/dynamicPromptsPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import { omit } from 'lodash-es';
import type { SerializeFunction } from 'redux-remember';
const serializationDenylist: {
[key: string]: string[];
} = {
canvas: canvasPersistDenylist,
gallery: galleryPersistDenylist,
generation: generationPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,
ui: uiPersistDenylist,
controlNet: controlAdaptersPersistDenylist,
dynamicPrompts: dynamicPromptsPersistDenylist,
};
export const serialize: SerializeFunction = (data, key) => {
const result = omit(data, serializationDenylist[key] ?? []);
return JSON.stringify(result);
};

View File

@ -1,34 +0,0 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialControlAdapterState } from 'features/controlAdapters/store/controlAdaptersSlice';
import { initialDynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { initialPostprocessingState } from 'features/parameters/store/postprocessingSlice';
import { initialSDXLState } from 'features/sdxl/store/sdxlSlice';
import { initialConfigState } from 'features/system/store/configSlice';
import { initialSystemState } from 'features/system/store/systemSlice';
import { initialUIState } from 'features/ui/store/uiSlice';
import { defaultsDeep } from 'lodash-es';
import type { UnserializeFunction } from 'redux-remember';
const initialStates: {
[key: string]: object; // TODO: type this properly
} = {
canvas: initialCanvasState,
gallery: initialGalleryState,
generation: initialGenerationState,
nodes: initialNodesState,
postprocessing: initialPostprocessingState,
system: initialSystemState,
config: initialConfigState,
ui: initialUIState,
controlAdapters: initialControlAdapterState,
dynamicPrompts: initialDynamicPromptsState,
sdxl: initialSDXLState,
};
export const unserialize: UnserializeFunction = (data, key) => {
const result = defaultsDeep(JSON.parse(data), initialStates[key]);
return result;
};

View File

@ -0,0 +1,16 @@
import type { Middleware, MiddlewareAPI } from '@reduxjs/toolkit';
import { diff } from 'jsondiffpatch';
/**
* Super simple logger middleware. Useful for debugging when the redux devtools are awkward.
*/
export const debugLoggerMiddleware: Middleware =
(api: MiddlewareAPI) => (next) => (action) => {
const originalState = api.getState();
console.log('REDUX: dispatching', action);
const result = next(action);
const nextState = api.getState();
console.log('REDUX: next state', nextState);
console.log('REDUX: diff', diff(originalState, nextState));
return result;
};

View File

@ -1,8 +1,10 @@
import type { UnknownAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
import { cloneDeep } from 'lodash-es';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
if (isAnyGraphBuilt(action)) {
@ -30,5 +32,14 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
};
}
if (socketGeneratorProgress.match(action)) {
const sanitized = cloneDeep(action);
if (sanitized.payload.data.progress_image) {
sanitized.payload.data.progress_image.dataURL =
'<Progress image omitted>';
}
return sanitized;
}
return action;
};

View File

@ -1,16 +1,16 @@
/**
* This is a list of actions that should be excluded in the Redux DevTools.
*/
export const actionsDenylist = [
export const actionsDenylist: string[] = [
// very spammy canvas actions
'canvas/setStageCoordinates',
'canvas/setStageScale',
'canvas/setBoundingBoxCoordinates',
'canvas/setBoundingBoxDimensions',
'canvas/addPointToCurrentLine',
// 'canvas/setStageCoordinates',
// 'canvas/setStageScale',
// 'canvas/setBoundingBoxCoordinates',
// 'canvas/setBoundingBoxDimensions',
// 'canvas/addPointToCurrentLine',
// bazillions during generation
'socket/socketGeneratorProgress',
'socket/appSocketGeneratorProgress',
// 'socket/socketGeneratorProgress',
// 'socket/appSocketGeneratorProgress',
// this happens after every state change
'@@REMEMBER_PERSISTED',
// '@@REMEMBER_PERSISTED',
];

View File

@ -11,7 +11,7 @@ import {
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { setPositivePrompt } from 'features/parameters/store/generationSlice';
import { utilitiesApi } from 'services/api/endpoints/utilities';
import { appSocketConnected } from 'services/events/actions';
import { socketConnected } from 'services/events/actions';
import { startAppListening } from '..';
@ -20,7 +20,7 @@ const matcher = isAnyOf(
combinatorialToggled,
maxPromptsChanged,
maxPromptsReset,
appSocketConnected
socketConnected
);
export const addDynamicPromptsListener = () => {

View File

@ -3,16 +3,16 @@ import { isInitializedChanged } from 'features/system/store/systemSlice';
import { size } from 'lodash-es';
import { api } from 'services/api';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import { appSocketConnected, socketConnected } from 'services/events/actions';
import { socketConnected } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addSocketConnectedEventListener = () => {
startAppListening({
actionCreator: socketConnected,
effect: (action, { dispatch, getState }) => {
const log = logger('socketio');
log.debug('Connected');
const { nodeTemplates, config, system } = getState();
@ -29,9 +29,6 @@ export const addSocketConnectedEventListener = () => {
} else {
dispatch(isInitializedChanged(true));
}
// pass along the socket event as an application action
dispatch(appSocketConnected(action.payload));
},
});
};

View File

@ -1,20 +1,15 @@
import { logger } from 'app/logging/logger';
import {
appSocketDisconnected,
socketDisconnected,
} from 'services/events/actions';
import { socketDisconnected } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addSocketDisconnectedEventListener = () => {
startAppListening({
actionCreator: socketDisconnected,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: () => {
log.debug('Disconnected');
// pass along the socket event as an application action
dispatch(appSocketDisconnected(action.payload));
},
});
};

View File

@ -1,20 +1,15 @@
import { logger } from 'app/logging/logger';
import {
appSocketGeneratorProgress,
socketGeneratorProgress,
} from 'services/events/actions';
import { socketGeneratorProgress } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addGeneratorProgressEventListener = () => {
startAppListening({
actionCreator: socketGeneratorProgress,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
log.trace(action.payload, `Generator progress`);
dispatch(appSocketGeneratorProgress(action.payload));
},
});
};

View File

@ -1,19 +1,15 @@
import { logger } from 'app/logging/logger';
import {
appSocketGraphExecutionStateComplete,
socketGraphExecutionStateComplete,
} from 'services/events/actions';
import { socketGraphExecutionStateComplete } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addGraphExecutionStateCompleteEventListener = () => {
startAppListening({
actionCreator: socketGraphExecutionStateComplete,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
log.debug(action.payload, 'Session complete');
// pass along the socket event as an application action
dispatch(appSocketGraphExecutionStateComplete(action.payload));
},
});
};

View File

@ -15,21 +15,19 @@ import {
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import {
appSocketInvocationComplete,
socketInvocationComplete,
} from 'services/events/actions';
import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '../..';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
const nodeTypeDenylist = ['load_image', 'image'];
const log = logger('socketio');
export const addInvocationCompleteEventListener = () => {
startAppListening({
actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState }) => {
const log = logger('socketio');
const { data } = action.payload;
log.debug(
{ data: parseify(data) },
@ -136,8 +134,6 @@ export const addInvocationCompleteEventListener = () => {
}
}
}
// pass along the socket event as an application action
dispatch(appSocketInvocationComplete(action.payload));
},
});
};

View File

@ -1,21 +1,18 @@
import { logger } from 'app/logging/logger';
import {
appSocketInvocationError,
socketInvocationError,
} from 'services/events/actions';
import { socketInvocationError } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addInvocationErrorEventListener = () => {
startAppListening({
actionCreator: socketInvocationError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
log.error(
action.payload,
`Invocation error (${action.payload.data.node.type})`
);
dispatch(appSocketInvocationError(action.payload));
},
});
};

View File

@ -1,21 +1,18 @@
import { logger } from 'app/logging/logger';
import {
appSocketInvocationRetrievalError,
socketInvocationRetrievalError,
} from 'services/events/actions';
import { socketInvocationRetrievalError } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addInvocationRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketInvocationRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
log.error(
action.payload,
`Invocation retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketInvocationRetrievalError(action.payload));
},
});
};

View File

@ -1,23 +1,18 @@
import { logger } from 'app/logging/logger';
import {
appSocketInvocationStarted,
socketInvocationStarted,
} from 'services/events/actions';
import { socketInvocationStarted } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addInvocationStartedEventListener = () => {
startAppListening({
actionCreator: socketInvocationStarted,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
log.debug(
action.payload,
`Invocation started (${action.payload.data.node.type})`
);
dispatch(appSocketInvocationStarted(action.payload));
},
});
};

View File

@ -1,18 +1,17 @@
import { logger } from 'app/logging/logger';
import {
appSocketModelLoadCompleted,
appSocketModelLoadStarted,
socketModelLoadCompleted,
socketModelLoadStarted,
} from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addModelLoadEventListener = () => {
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
const { base_model, model_name, model_type, submodel } =
action.payload.data;
@ -23,16 +22,12 @@ export const addModelLoadEventListener = () => {
}
log.debug(action.payload, message);
// pass along the socket event as an application action
dispatch(appSocketModelLoadStarted(action.payload));
},
});
startAppListening({
actionCreator: socketModelLoadCompleted,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
const { base_model, model_name, model_type, submodel } =
action.payload.data;
@ -43,8 +38,6 @@ export const addModelLoadEventListener = () => {
}
log.debug(action.payload, message);
// pass along the socket event as an application action
dispatch(appSocketModelLoadCompleted(action.payload));
},
});
};

View File

@ -1,18 +1,15 @@
import { logger } from 'app/logging/logger';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import {
appSocketQueueItemStatusChanged,
socketQueueItemStatusChanged,
} from 'services/events/actions';
import { socketQueueItemStatusChanged } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addSocketQueueItemStatusChangedEventListener = () => {
startAppListening({
actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch }) => {
const log = logger('socketio');
// we've got new status for the queue item, batch and queue
const { queue_item, batch_status, queue_status } = action.payload.data;
@ -73,9 +70,6 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
'InvocationCacheStatus',
])
);
// Pass the event along
dispatch(appSocketQueueItemStatusChanged(action.payload));
},
});
};

View File

@ -1,21 +1,18 @@
import { logger } from 'app/logging/logger';
import {
appSocketSessionRetrievalError,
socketSessionRetrievalError,
} from 'services/events/actions';
import { socketSessionRetrievalError } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addSessionRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketSessionRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
log.error(
action.payload,
`Session retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketSessionRetrievalError(action.payload));
},
});
};

View File

@ -1,18 +1,15 @@
import { logger } from 'app/logging/logger';
import {
appSocketSubscribedSession,
socketSubscribedSession,
} from 'services/events/actions';
import { socketSubscribedSession } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addSocketSubscribedEventListener = () => {
startAppListening({
actionCreator: socketSubscribedSession,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
log.debug(action.payload, 'Subscribed');
dispatch(appSocketSubscribedSession(action.payload));
},
});
};

View File

@ -1,18 +1,14 @@
import { logger } from 'app/logging/logger';
import {
appSocketUnsubscribedSession,
socketUnsubscribedSession,
} from 'services/events/actions';
import { socketUnsubscribedSession } from 'services/events/actions';
import { startAppListening } from '../..';
const log = logger('socketio');
export const addSocketUnsubscribedEventListener = () => {
startAppListening({
actionCreator: socketUnsubscribedSession,
effect: (action, { dispatch }) => {
const log = logger('socketio');
effect: (action) => {
log.debug(action.payload, 'Unsubscribed');
dispatch(appSocketUnsubscribedSession(action.payload));
},
});
};

View File

@ -4,40 +4,94 @@ import {
combineReducers,
configureStore,
} from '@reduxjs/toolkit';
import canvasReducer from 'features/canvas/store/canvasSlice';
import { logger } from 'app/logging/logger';
import { idbKeyValDriver } from 'app/store/enhancers/reduxRemember/driver';
import { errorHandler } from 'app/store/enhancers/reduxRemember/errors';
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import canvasReducer, {
initialCanvasState,
migrateCanvasState,
} from 'features/canvas/store/canvasSlice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import controlAdaptersReducer from 'features/controlAdapters/store/controlAdaptersSlice';
import { controlAdaptersPersistDenylist } from 'features/controlAdapters/store/controlAdaptersPersistDenylist';
import controlAdaptersReducer, {
initialControlAdaptersState,
migrateControlAdaptersState,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import hrfReducer from 'features/hrf/store/hrfSlice';
import loraReducer from 'features/lora/store/loraSlice';
import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import { dynamicPromptsPersistDenylist } from 'features/dynamicPrompts/store/dynamicPromptsPersistDenylist';
import dynamicPromptsReducer, {
initialDynamicPromptsState,
migrateDynamicPromptsState,
} from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import galleryReducer, {
initialGalleryState,
migrateGalleryState,
} from 'features/gallery/store/gallerySlice';
import hrfReducer, {
initialHRFState,
migrateHRFState,
} from 'features/hrf/store/hrfSlice';
import loraReducer, {
initialLoraState,
migrateLoRAState,
} from 'features/lora/store/loraSlice';
import modelmanagerReducer, {
initialModelManagerState,
migrateModelManagerState,
} from 'features/modelManager/store/modelManagerSlice';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
import nodesReducer, {
initialNodesState,
migrateNodesState,
} from 'features/nodes/store/nodesSlice';
import nodeTemplatesReducer from 'features/nodes/store/nodeTemplatesSlice';
import workflowReducer from 'features/nodes/store/workflowSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import workflowReducer, {
initialWorkflowState,
migrateWorkflowState,
} from 'features/nodes/store/workflowSlice';
import { generationPersistDenylist } from 'features/parameters/store/generationPersistDenylist';
import generationReducer, {
initialGenerationState,
migrateGenerationState,
} from 'features/parameters/store/generationSlice';
import { postprocessingPersistDenylist } from 'features/parameters/store/postprocessingPersistDenylist';
import postprocessingReducer, {
initialPostprocessingState,
migratePostprocessingState,
} from 'features/parameters/store/postprocessingSlice';
import queueReducer from 'features/queue/store/queueSlice';
import sdxlReducer from 'features/sdxl/store/sdxlSlice';
import sdxlReducer, {
initialSDXLState,
migrateSDXLState,
} from 'features/sdxl/store/sdxlSlice';
import configReducer from 'features/system/store/configSlice';
import systemReducer from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice';
import { createStore as createIDBKeyValStore, get, set } from 'idb-keyval';
import { systemPersistDenylist } from 'features/system/store/systemPersistDenylist';
import systemReducer, {
initialSystemState,
migrateSystemState,
} from 'features/system/store/systemSlice';
import { uiPersistDenylist } from 'features/ui/store/uiPersistDenylist';
import uiReducer, {
initialUIState,
migrateUIState,
} from 'features/ui/store/uiSlice';
import { diff } from 'jsondiffpatch';
import { defaultsDeep, keys, omit, pick } from 'lodash-es';
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import type { Driver } from 'redux-remember';
import type { SerializeFunction, UnserializeFunction } from 'redux-remember';
import { rememberEnhancer, rememberReducer } from 'redux-remember';
import { serializeError } from 'serialize-error';
import { api } from 'services/api';
import { authToastMiddleware } from 'services/api/authToastMiddleware';
import type { JsonObject } from 'type-fest';
import { STORAGE_PREFIX } from './constants';
import { serialize } from './enhancers/reduxRemember/serialize';
import { unserialize } from './enhancers/reduxRemember/unserialize';
import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware';
const allReducers = {
canvas: canvasReducer,
gallery: galleryReducer,
@ -65,7 +119,7 @@ const rootReducer = combineReducers(allReducers);
const rememberedRootReducer = rememberReducer(rootReducer);
const rememberedKeys: (keyof typeof allReducers)[] = [
const rememberedKeys = [
'canvas',
'gallery',
'generation',
@ -80,15 +134,106 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'lora',
'modelmanager',
'hrf',
];
] satisfies (keyof typeof allReducers)[];
// Create a custom idb-keyval store (just needed to customize the name)
export const idbKeyValStore = createIDBKeyValStore('invoke', 'invoke-store');
type SliceConfig = {
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
initialState: any;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
migrate: (state: any) => any;
};
// Create redux-remember driver, wrapping idb-keyval
const idbKeyValDriver: Driver = {
getItem: (key) => get(key, idbKeyValStore),
setItem: (key, value) => set(key, value, idbKeyValStore),
const sliceConfigs: {
[key in (typeof rememberedKeys)[number]]: SliceConfig;
} = {
canvas: { initialState: initialCanvasState, migrate: migrateCanvasState },
gallery: { initialState: initialGalleryState, migrate: migrateGalleryState },
generation: {
initialState: initialGenerationState,
migrate: migrateGenerationState,
},
nodes: { initialState: initialNodesState, migrate: migrateNodesState },
postprocessing: {
initialState: initialPostprocessingState,
migrate: migratePostprocessingState,
},
system: { initialState: initialSystemState, migrate: migrateSystemState },
workflow: {
initialState: initialWorkflowState,
migrate: migrateWorkflowState,
},
ui: { initialState: initialUIState, migrate: migrateUIState },
controlAdapters: {
initialState: initialControlAdaptersState,
migrate: migrateControlAdaptersState,
},
dynamicPrompts: {
initialState: initialDynamicPromptsState,
migrate: migrateDynamicPromptsState,
},
sdxl: { initialState: initialSDXLState, migrate: migrateSDXLState },
lora: { initialState: initialLoraState, migrate: migrateLoRAState },
modelmanager: {
initialState: initialModelManagerState,
migrate: migrateModelManagerState,
},
hrf: { initialState: initialHRFState, migrate: migrateHRFState },
};
const unserialize: UnserializeFunction = (data, key) => {
const log = logger('system');
const config = sliceConfigs[key as keyof typeof sliceConfigs];
if (!config) {
throw new Error(`No unserialize config for slice "${key}"`);
}
try {
const { initialState, migrate } = config;
const parsed = JSON.parse(data);
// strip out old keys
const stripped = pick(parsed, keys(initialState));
// run (additive) migrations
const migrated = migrate(stripped);
// merge in initial state as default values, covering any missing keys
const transformed = defaultsDeep(migrated, initialState);
log.debug(
{
persistedData: parsed,
rehydratedData: transformed,
diff: diff(parsed, transformed) as JsonObject, // this is always serializable
},
`Rehydrated slice "${key}"`
);
return transformed;
} catch (err) {
log.warn(
{ error: serializeError(err) },
`Error rehydrating slice "${key}", falling back to default initial state`
);
return config.initialState;
}
};
const serializationDenylist: {
[key in (typeof rememberedKeys)[number]]?: string[];
} = {
canvas: canvasPersistDenylist,
gallery: galleryPersistDenylist,
generation: generationPersistDenylist,
nodes: nodesPersistDenylist,
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,
ui: uiPersistDenylist,
controlAdapters: controlAdaptersPersistDenylist,
dynamicPrompts: dynamicPromptsPersistDenylist,
};
export const serialize: SerializeFunction = (data, key) => {
const result = omit(
data,
serializationDenylist[key as keyof typeof serializationDenylist] ?? []
);
return JSON.stringify(result);
};
export const createStore = (uniqueStoreKey?: string, persist = true) =>
@ -114,6 +259,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
prefix: uniqueStoreKey
? `${STORAGE_PREFIX}${uniqueStoreKey}-`
: STORAGE_PREFIX,
errorHandler,
})
);
}
@ -124,21 +270,9 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
stateSanitizer,
trace: true,
predicate: (state, action) => {
// TODO: hook up to the log level param in system slice
// manually type state, cannot type the arg
// const typedState = state as ReturnType<typeof rootReducer>;
// TODO: doing this breaks the rtk query devtools, commenting out for now
// if (action.type.startsWith('api/')) {
// // don't log api actions, with manual cache updates they are extremely noisy
// return false;
// }
if (actionsDenylist.includes(action.type)) {
// don't log other noisy actions
return false;
}
return true;
},
},

View File

@ -1,10 +1,9 @@
import { idbKeyValStore } from 'app/store/store';
import { clear } from 'idb-keyval';
import { clearIdbKeyValStore } from 'app/store/enhancers/reduxRemember/driver';
import { useCallback } from 'react';
export const useClearStorage = () => {
const clearStorage = useCallback(() => {
clear(idbKeyValStore);
clearIdbKeyValStore();
localStorage.clear();
}, []);

View File

@ -57,7 +57,6 @@ export const useGlobalHotkeys = () => {
{
enabled: () => !isDisabledCancelQueueItem && !isLoadingCancelQueueItem,
preventDefault: true,
enableOnFormTags: ['input', 'textarea', 'select'],
},
[cancelQueueItem, isDisabledCancelQueueItem, isLoadingCancelQueueItem]
);
@ -74,7 +73,6 @@ export const useGlobalHotkeys = () => {
{
enabled: () => !isDisabledClearQueue && !isLoadingClearQueue,
preventDefault: true,
enableOnFormTags: ['input', 'textarea', 'select'],
},
[clearQueue, isDisabledClearQueue, isLoadingClearQueue]
);

View File

@ -11,6 +11,7 @@ import { STAGE_PADDING_PERCENTAGE } from 'features/canvas/util/constants';
import floorCoordinates from 'features/canvas/util/floorCoordinates';
import getScaledBoundingBoxDimensions from 'features/canvas/util/getScaledBoundingBoxDimensions';
import roundDimensionsToMultiple from 'features/canvas/util/roundDimensionsToMultiple';
import { initialAspectRatioState } from 'features/parameters/components/ImageSize/constants';
import type { AspectRatioState } from 'features/parameters/components/ImageSize/types';
import { modelChanged } from 'features/parameters/store/generationSlice';
import type { PayloadActionWithOptimalDimension } from 'features/parameters/store/types';
@ -23,7 +24,7 @@ import { clamp, cloneDeep } from 'lodash-es';
import type { RgbaColor } from 'react-colorful';
import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO } from 'services/api/types';
import { appSocketQueueItemStatusChanged } from 'services/events/actions';
import { socketQueueItemStatusChanged } from 'services/events/actions';
import type {
BoundingBoxScaleMethod,
@ -53,10 +54,11 @@ export const initialLayerState: CanvasLayerState = {
};
export const initialCanvasState: CanvasState = {
_version: 1,
boundingBoxCoordinates: { x: 0, y: 0 },
boundingBoxDimensions: { width: 512, height: 512 },
boundingBoxPreviewFill: { r: 0, g: 0, b: 0, a: 0.5 },
boundingBoxScaleMethod: 'none',
boundingBoxScaleMethod: 'auto',
brushColor: { r: 90, g: 90, b: 255, a: 1 },
brushSize: 50,
colorPickerColor: { r: 90, g: 90, b: 255, a: 1 },
@ -695,7 +697,7 @@ export const canvasSlice = createSlice({
);
});
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
const batch_status = action.payload.data.batch_status;
if (!state.batchIds.includes(batch_status.batch_id)) {
return;
@ -784,3 +786,12 @@ export const {
export default canvasSlice.reducer;
export const selectCanvasSlice = (state: RootState) => state.canvas;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateCanvasState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
state.aspectRatio = initialAspectRatioState;
}
return state;
};

View File

@ -117,6 +117,7 @@ export const isCanvasAnyLine = (
): obj is CanvasMaskLine | CanvasBaseLine => obj.kind === 'line';
export interface CanvasState {
_version: 1;
boundingBoxCoordinates: Vector2d;
boundingBoxDimensions: Dimensions;
boundingBoxPreviewFill: RgbaColor;

View File

@ -9,7 +9,7 @@ import type {
ParameterT2IAdapterModel,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, merge, uniq } from 'lodash-es';
import { appSocketInvocationError } from 'services/events/actions';
import { socketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import { controlAdapterImageProcessed } from './actions';
@ -51,10 +51,12 @@ export const {
selectTotal: selectControlAdapterTotal,
} = caAdapterSelectors;
export const initialControlAdapterState: ControlAdaptersState =
export const initialControlAdaptersState: ControlAdaptersState =
caAdapter.getInitialState<{
_version: 1;
pendingControlImages: string[];
}>({
_version: 1,
pendingControlImages: [],
});
@ -96,7 +98,7 @@ export const selectValidT2IAdapters = (controlAdapters: ControlAdaptersState) =>
export const controlAdaptersSlice = createSlice({
name: 'controlAdapters',
initialState: initialControlAdapterState,
initialState: initialControlAdaptersState,
reducers: {
controlAdapterAdded: {
reducer: (
@ -267,31 +269,29 @@ export const controlAdaptersSlice = createSlice({
const update: Update<ControlNetConfig | T2IAdapterConfig, string> = {
id,
changes: { model },
changes: { model, shouldAutoConfig: true },
};
update.changes.processedControlImage = null;
if (cn.shouldAutoConfig) {
let processorType: ControlAdapterProcessorType | undefined = undefined;
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType =
CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType =
CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlAdapterProcessorNode;
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlAdapterProcessorNode;
}
caAdapter.updateOne(state, update);
@ -435,7 +435,7 @@ export const controlAdaptersSlice = createSlice({
caAdapter.updateOne(state, update);
},
controlAdaptersReset: () => {
return cloneDeep(initialControlAdapterState);
return cloneDeep(initialControlAdaptersState);
},
pendingControlImagesCleared: (state) => {
state.pendingControlImages = [];
@ -454,7 +454,7 @@ export const controlAdaptersSlice = createSlice({
}
});
builder.addCase(appSocketInvocationError, (state) => {
builder.addCase(socketInvocationError, (state) => {
state.pendingControlImages = [];
});
},
@ -493,3 +493,11 @@ export const isAnyControlAdapterAdded = isAnyOf(
export const selectControlAdaptersSlice = (state: RootState) =>
state.controlAdapters;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateControlAdaptersState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -7,7 +7,9 @@ export const zSeedBehaviour = z.enum(['PER_ITERATION', 'PER_PROMPT']);
export type SeedBehaviour = z.infer<typeof zSeedBehaviour>;
export const isSeedBehaviour = (v: unknown): v is SeedBehaviour =>
zSeedBehaviour.safeParse(v).success;
export interface DynamicPromptsState {
_version: 1;
maxPrompts: number;
combinatorial: boolean;
prompts: string[];
@ -18,6 +20,7 @@ export interface DynamicPromptsState {
}
export const initialDynamicPromptsState: DynamicPromptsState = {
_version: 1,
maxPrompts: 100,
combinatorial: true,
prompts: [],
@ -78,3 +81,11 @@ export default dynamicPromptsSlice.reducer;
export const selectDynamicPromptsSlice = (state: RootState) =>
state.dynamicPrompts;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateDynamicPromptsState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -2,7 +2,6 @@ import type { ChakraProps } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { InvControl } from 'common/components/InvControl/InvControl';
import { InvSelect } from 'common/components/InvSelect/InvSelect';
import { InvSelectFallback } from 'common/components/InvSelect/InvSelectFallback';
import { useGroupedModelInvSelect } from 'common/components/InvSelect/useGroupedModelInvSelect';
import type { EmbeddingSelectProps } from 'features/embedding/types';
import { t } from 'i18next';
@ -47,23 +46,16 @@ export const EmbeddingSelect = memo(
onChange: _onChange,
});
if (isLoading) {
return <InvSelectFallback label={t('common.loading')} />;
}
if (options.length === 0) {
return <InvSelectFallback label={t('embedding.noEmbeddingsLoaded')} />;
}
return (
<InvControl isDisabled={!options.length}>
<InvControl>
<InvSelect
placeholder={t('embedding.addEmbedding')}
placeholder={
isLoading ? t('common.loading') : t('embedding.addEmbedding')
}
defaultMenuIsOpen
autoFocus
value={null}
options={options}
isDisabled={!options.length}
noOptionsMessage={noOptionsMessage}
onChange={onChange}
onMenuClose={onClose}

View File

@ -106,3 +106,11 @@ const isAnyBoardDeleted = isAnyOf(
);
export const selectGallerySlice = (state: RootState) => state.gallery;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateGalleryState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -7,12 +7,14 @@ import type {
} from 'features/parameters/types/parameterSchemas';
export interface HRFState {
_version: 1;
hrfEnabled: boolean;
hrfStrength: ParameterStrength;
hrfMethod: ParameterHRFMethod;
}
export const initialHRFState: HRFState = {
_version: 1,
hrfStrength: 0.45,
hrfEnabled: false,
hrfMethod: 'ESRGAN',
@ -41,3 +43,11 @@ export const { setHrfEnabled, setHrfStrength, setHrfMethod } = hrfSlice.actions;
export default hrfSlice.reducer;
export const selectHrfSlice = (state: RootState) => state.hrf;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateHRFState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -14,16 +14,18 @@ export const defaultLoRAConfig = {
};
export type LoraState = {
_version: 1;
loras: Record<string, LoRA>;
};
export const intialLoraState: LoraState = {
export const initialLoraState: LoraState = {
_version: 1,
loras: {},
};
export const loraSlice = createSlice({
name: 'lora',
initialState: intialLoraState,
initialState: initialLoraState,
reducers: {
loraAdded: (state, action: PayloadAction<LoRAModelConfigEntity>) => {
const { model_name, id, base_model } = action.payload;
@ -77,3 +79,11 @@ export const {
export default loraSlice.reducer;
export const selectLoraSlice = (state: RootState) => state.lora;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateLoRAState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -3,11 +3,13 @@ import { createSlice } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
type ModelManagerState = {
_version: 1;
searchFolder: string | null;
advancedAddScanModel: string | null;
};
const initialModelManagerState: ModelManagerState = {
export const initialModelManagerState: ModelManagerState = {
_version: 1,
searchFolder: null,
advancedAddScanModel: null,
};
@ -31,3 +33,11 @@ export const { setSearchFolder, setAdvancedAddScanModel } =
export default modelManagerSlice.reducer;
export const selectModelManagerSlice = (state: RootState) => state.modelmanager;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateModelManagerState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -74,11 +74,11 @@ import {
} from 'reactflow';
import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import {
appSocketGeneratorProgress,
appSocketInvocationComplete,
appSocketInvocationError,
appSocketInvocationStarted,
appSocketQueueItemStatusChanged,
socketGeneratorProgress,
socketInvocationComplete,
socketInvocationError,
socketInvocationStarted,
socketQueueItemStatusChanged,
} from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import type { z } from 'zod';
@ -96,6 +96,7 @@ const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
};
export const initialNodesState: NodesState = {
_version: 1,
nodes: [],
edges: [],
isReady: false,
@ -838,14 +839,14 @@ const nodesSlice = createSlice({
}, {});
});
builder.addCase(appSocketInvocationStarted, (state, action) => {
builder.addCase(socketInvocationStarted, (state, action) => {
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.IN_PROGRESS;
}
});
builder.addCase(appSocketInvocationComplete, (state, action) => {
builder.addCase(socketInvocationComplete, (state, action) => {
const { source_node_id, result } = action.payload.data;
const nes = state.nodeExecutionStates[source_node_id];
if (nes) {
@ -856,7 +857,7 @@ const nodesSlice = createSlice({
nes.outputs.push(result);
}
});
builder.addCase(appSocketInvocationError, (state, action) => {
builder.addCase(socketInvocationError, (state, action) => {
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
@ -866,7 +867,7 @@ const nodesSlice = createSlice({
node.progressImage = null;
}
});
builder.addCase(appSocketGeneratorProgress, (state, action) => {
builder.addCase(socketGeneratorProgress, (state, action) => {
const { source_node_id, step, total_steps, progress_image } =
action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
@ -876,7 +877,7 @@ const nodesSlice = createSlice({
node.progressImage = progress_image ?? null;
}
});
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach(state.nodeExecutionStates, (nes) => {
nes.status = zNodeStatus.enum.PENDING;
@ -990,3 +991,11 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
export default nodesSlice.reducer;
export const selectNodesSlice = (state: RootState) => state.nodes;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateNodesState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -14,6 +14,7 @@ import type {
} from 'reactflow';
export type NodesState = {
_version: 1;
nodes: AnyNode[];
edges: InvocationNodeEdge[];
connectionStartParams: OnConnectStartParams | null;
@ -39,6 +40,7 @@ export type NodesState = {
};
export type WorkflowsState = Omit<WorkflowV2, 'nodes' | 'edges'> & {
_version: 1;
isTouched: boolean;
};

View File

@ -12,6 +12,7 @@ import type { FieldIdentifier } from 'features/nodes/types/field';
import { cloneDeep, isEqual, uniqBy } from 'lodash-es';
export const initialWorkflowState: WorkflowState = {
_version: 1,
name: '',
author: '',
description: '',
@ -86,7 +87,7 @@ const workflowSlice = createSlice({
extraReducers: (builder) => {
builder.addCase(workflowLoaded, (state, action) => {
const { nodes: _nodes, edges: _edges, ...workflowExtra } = action.payload;
return { ...cloneDeep(workflowExtra), isTouched: true };
return { ...initialWorkflowState, ...cloneDeep(workflowExtra) };
});
builder.addCase(nodesDeleted, (state, action) => {
@ -123,3 +124,11 @@ export const {
export default workflowSlice.reducer;
export const selectWorkflowSlice = (state: RootState) => state.workflow;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateWorkflowState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -29,6 +29,7 @@ import type { ImageDTO } from 'services/api/types';
import type { GenerationState } from './types';
export const initialGenerationState: GenerationState = {
_version: 1,
cfgScale: 7.5,
cfgRescaleMultiplier: 0,
height: 512,
@ -276,3 +277,12 @@ export const { selectOptimalDimension } = generationSlice.selectors;
export default generationSlice.reducer;
export const selectGenerationSlice = (state: RootState) => state.generation;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateGenerationState = (state: any): GenerationState => {
if (!('_version' in state)) {
state._version = 1;
state.aspectRatio = initialAspectRatioState;
}
return state;
};

View File

@ -14,10 +14,12 @@ export const isParamESRGANModelName = (v: unknown): v is ParamESRGANModelName =>
zParamESRGANModelName.safeParse(v).success;
export interface PostprocessingState {
_version: 1;
esrganModelName: ParamESRGANModelName;
}
export const initialPostprocessingState: PostprocessingState = {
_version: 1,
esrganModelName: 'RealESRGAN_x4plus.pth',
};
@ -40,3 +42,11 @@ export default postprocessingSlice.reducer;
export const selectPostprocessingSlice = (state: RootState) =>
state.postprocessing;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migratePostprocessingState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -19,6 +19,7 @@ import type {
} from 'features/parameters/types/parameterSchemas';
export interface GenerationState {
_version: 1;
cfgScale: ParameterCFGScale;
cfgRescaleMultiplier: ParameterCFGRescaleMultiplier;
height: ParameterHeight;

View File

@ -9,6 +9,7 @@ import type {
} from 'features/parameters/types/parameterSchemas';
type SDXLState = {
_version: 1;
positiveStylePrompt: ParameterPositiveStylePromptSDXL;
negativeStylePrompt: ParameterNegativeStylePromptSDXL;
shouldConcatSDXLStylePrompt: boolean;
@ -22,6 +23,7 @@ type SDXLState = {
};
export const initialSDXLState: SDXLState = {
_version: 1,
positiveStylePrompt: '',
negativeStylePrompt: '',
shouldConcatSDXLStylePrompt: true,
@ -96,3 +98,11 @@ export const {
export default sdxlSlice.reducer;
export const selectSdxlSlice = (state: RootState) => state.sdxl;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateSDXLState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -8,23 +8,24 @@ import { t } from 'i18next';
import { startCase } from 'lodash-es';
import type { LogLevelName } from 'roarr';
import {
appSocketConnected,
appSocketDisconnected,
appSocketGeneratorProgress,
appSocketGraphExecutionStateComplete,
appSocketInvocationComplete,
appSocketInvocationError,
appSocketInvocationRetrievalError,
appSocketInvocationStarted,
appSocketModelLoadCompleted,
appSocketModelLoadStarted,
appSocketQueueItemStatusChanged,
appSocketSessionRetrievalError,
socketConnected,
socketDisconnected,
socketGeneratorProgress,
socketGraphExecutionStateComplete,
socketInvocationComplete,
socketInvocationError,
socketInvocationRetrievalError,
socketInvocationStarted,
socketModelLoadCompleted,
socketModelLoadStarted,
socketQueueItemStatusChanged,
socketSessionRetrievalError,
} from 'services/events/actions';
import type { Language, SystemState } from './types';
export const initialSystemState: SystemState = {
_version: 1,
isInitialized: false,
isConnected: false,
shouldConfirmOnDelete: true,
@ -92,7 +93,7 @@ export const systemSlice = createSlice({
/**
* Socket Connected
*/
builder.addCase(appSocketConnected, (state) => {
builder.addCase(socketConnected, (state) => {
state.isConnected = true;
state.denoiseProgress = null;
state.status = 'CONNECTED';
@ -101,7 +102,7 @@ export const systemSlice = createSlice({
/**
* Socket Disconnected
*/
builder.addCase(appSocketDisconnected, (state) => {
builder.addCase(socketDisconnected, (state) => {
state.isConnected = false;
state.denoiseProgress = null;
state.status = 'DISCONNECTED';
@ -110,7 +111,7 @@ export const systemSlice = createSlice({
/**
* Invocation Started
*/
builder.addCase(appSocketInvocationStarted, (state) => {
builder.addCase(socketInvocationStarted, (state) => {
state.denoiseProgress = null;
state.status = 'PROCESSING';
});
@ -118,7 +119,7 @@ export const systemSlice = createSlice({
/**
* Generator Progress
*/
builder.addCase(appSocketGeneratorProgress, (state, action) => {
builder.addCase(socketGeneratorProgress, (state, action) => {
const {
step,
total_steps,
@ -144,7 +145,7 @@ export const systemSlice = createSlice({
/**
* Invocation Complete
*/
builder.addCase(appSocketInvocationComplete, (state) => {
builder.addCase(socketInvocationComplete, (state) => {
state.denoiseProgress = null;
state.status = 'CONNECTED';
});
@ -152,20 +153,20 @@ export const systemSlice = createSlice({
/**
* Graph Execution State Complete
*/
builder.addCase(appSocketGraphExecutionStateComplete, (state) => {
builder.addCase(socketGraphExecutionStateComplete, (state) => {
state.denoiseProgress = null;
state.status = 'CONNECTED';
});
builder.addCase(appSocketModelLoadStarted, (state) => {
builder.addCase(socketModelLoadStarted, (state) => {
state.status = 'LOADING_MODEL';
});
builder.addCase(appSocketModelLoadCompleted, (state) => {
builder.addCase(socketModelLoadCompleted, (state) => {
state.status = 'CONNECTED';
});
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
if (
['completed', 'canceled', 'failed'].includes(
action.payload.data.queue_item.status
@ -211,9 +212,17 @@ export const {
export default systemSlice.reducer;
const isAnyServerError = isAnyOf(
appSocketInvocationError,
appSocketSessionRetrievalError,
appSocketInvocationRetrievalError
socketInvocationError,
socketSessionRetrievalError,
socketInvocationRetrievalError
);
export const selectSystemSlice = (state: RootState) => state.system;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateSystemState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -43,6 +43,7 @@ export const isLanguage = (v: unknown): v is Language =>
zLanguage.safeParse(v).success;
export interface SystemState {
_version: 1;
isInitialized: boolean;
isConnected: boolean;
shouldConfirmOnDelete: boolean;

View File

@ -7,6 +7,7 @@ import type { InvokeTabName } from './tabMap';
import type { UIState } from './uiTypes';
export const initialUIState: UIState = {
_version: 1,
activeTab: 'txt2img',
shouldShowImageDetails: false,
shouldShowExistingModelsInSearch: false,
@ -63,3 +64,11 @@ export const {
export default uiSlice.reducer;
export const selectUiSlice = (state: RootState) => state.ui;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const migrateUIState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
return state;
};

View File

@ -1,6 +1,7 @@
import type { InvokeTabName } from './tabMap';
export interface UIState {
_version: 1;
activeTab: InvokeTabName;
shouldShowImageDetails: boolean;
shouldShowExistingModelsInSearch: boolean;

View File

@ -15,220 +15,54 @@ import type {
// Create actions for each socket
// Middleware and redux can then respond to them as needed
/**
* Socket.IO Connected
*
* Do not use. Only for use in middleware.
*/
export const socketConnected = createAction('socket/socketConnected');
/**
* App-level Socket.IO Connected
*/
export const appSocketConnected = createAction('socket/appSocketConnected');
/**
* Socket.IO Disconnect
*
* Do not use. Only for use in middleware.
*/
export const socketDisconnected = createAction('socket/socketDisconnected');
/**
* App-level Socket.IO Disconnected
*/
export const appSocketDisconnected = createAction(
'socket/appSocketDisconnected'
);
/**
* Socket.IO Subscribed Session
*
* Do not use. Only for use in middleware.
*/
export const socketSubscribedSession = createAction<{
sessionId: string;
}>('socket/socketSubscribedSession');
/**
* App-level Socket.IO Subscribed Session
*/
export const appSocketSubscribedSession = createAction<{
sessionId: string;
}>('socket/appSocketSubscribedSession');
/**
* Socket.IO Unsubscribed Session
*
* Do not use. Only for use in middleware.
*/
export const socketUnsubscribedSession = createAction<{ sessionId: string }>(
'socket/socketUnsubscribedSession'
);
/**
* App-level Socket.IO Unsubscribed Session
*/
export const appSocketUnsubscribedSession = createAction<{ sessionId: string }>(
'socket/appSocketUnsubscribedSession'
);
/**
* Socket.IO Invocation Started
*
* Do not use. Only for use in middleware.
*/
export const socketInvocationStarted = createAction<{
data: InvocationStartedEvent;
}>('socket/socketInvocationStarted');
/**
* App-level Socket.IO Invocation Started
*/
export const appSocketInvocationStarted = createAction<{
data: InvocationStartedEvent;
}>('socket/appSocketInvocationStarted');
/**
* Socket.IO Invocation Complete
*
* Do not use. Only for use in middleware.
*/
export const socketInvocationComplete = createAction<{
data: InvocationCompleteEvent;
}>('socket/socketInvocationComplete');
/**
* App-level Socket.IO Invocation Complete
*/
export const appSocketInvocationComplete = createAction<{
data: InvocationCompleteEvent;
}>('socket/appSocketInvocationComplete');
/**
* Socket.IO Invocation Error
*
* Do not use. Only for use in middleware.
*/
export const socketInvocationError = createAction<{
data: InvocationErrorEvent;
}>('socket/socketInvocationError');
/**
* App-level Socket.IO Invocation Error
*/
export const appSocketInvocationError = createAction<{
data: InvocationErrorEvent;
}>('socket/appSocketInvocationError');
/**
* Socket.IO Graph Execution State Complete
*
* Do not use. Only for use in middleware.
*/
export const socketGraphExecutionStateComplete = createAction<{
data: GraphExecutionStateCompleteEvent;
}>('socket/socketGraphExecutionStateComplete');
/**
* App-level Socket.IO Graph Execution State Complete
*/
export const appSocketGraphExecutionStateComplete = createAction<{
data: GraphExecutionStateCompleteEvent;
}>('socket/appSocketGraphExecutionStateComplete');
/**
* Socket.IO Generator Progress
*
* Do not use. Only for use in middleware.
*/
export const socketGeneratorProgress = createAction<{
data: GeneratorProgressEvent;
}>('socket/socketGeneratorProgress');
/**
* App-level Socket.IO Generator Progress
*/
export const appSocketGeneratorProgress = createAction<{
data: GeneratorProgressEvent;
}>('socket/appSocketGeneratorProgress');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadStarted = createAction<{
data: ModelLoadStartedEvent;
}>('socket/socketModelLoadStarted');
/**
* App-level Model Load Started
*/
export const appSocketModelLoadStarted = createAction<{
data: ModelLoadStartedEvent;
}>('socket/appSocketModelLoadStarted');
/**
* Socket.IO Model Load Started
*
* Do not use. Only for use in middleware.
*/
export const socketModelLoadCompleted = createAction<{
data: ModelLoadCompletedEvent;
}>('socket/socketModelLoadCompleted');
/**
* App-level Model Load Completed
*/
export const appSocketModelLoadCompleted = createAction<{
data: ModelLoadCompletedEvent;
}>('socket/appSocketModelLoadCompleted');
/**
* Socket.IO Session Retrieval Error
*
* Do not use. Only for use in middleware.
*/
export const socketSessionRetrievalError = createAction<{
data: SessionRetrievalErrorEvent;
}>('socket/socketSessionRetrievalError');
/**
* App-level Session Retrieval Error
*/
export const appSocketSessionRetrievalError = createAction<{
data: SessionRetrievalErrorEvent;
}>('socket/appSocketSessionRetrievalError');
/**
* Socket.IO Invocation Retrieval Error
*
* Do not use. Only for use in middleware.
*/
export const socketInvocationRetrievalError = createAction<{
data: InvocationRetrievalErrorEvent;
}>('socket/socketInvocationRetrievalError');
/**
* App-level Invocation Retrieval Error
*/
export const appSocketInvocationRetrievalError = createAction<{
data: InvocationRetrievalErrorEvent;
}>('socket/appSocketInvocationRetrievalError');
/**
* Socket.IO Quueue Item Status Changed
*
* Do not use. Only for use in middleware.
*/
export const socketQueueItemStatusChanged = createAction<{
data: QueueItemStatusChangedEvent;
}>('socket/socketQueueItemStatusChanged');
/**
* App-level Quueue Item Status Changed
*/
export const appSocketQueueItemStatusChanged = createAction<{
data: QueueItemStatusChangedEvent;
}>('socket/appSocketQueueItemStatusChanged');