mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' of https://github.com/invoke-ai/InvokeAI into responsive-ui
This commit is contained in:
@ -6,3 +6,5 @@ stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
|
@ -3,4 +3,8 @@ dist/
|
||||
node_modules/
|
||||
patches/
|
||||
stats.html
|
||||
index.html
|
||||
.yarn/
|
||||
*.scss
|
||||
src/services/api/
|
||||
src/services/fixtures/*
|
||||
|
87
invokeai/frontend/web/docs/API_CLIENT.md
Normal file
87
invokeai/frontend/web/docs/API_CLIENT.md
Normal file
@ -0,0 +1,87 @@
|
||||
# Generated axios API client
|
||||
|
||||
- [Generated axios API client](#generated-axios-api-client)
|
||||
- [Generation](#generation)
|
||||
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
|
||||
- [Generate the API client from JSON](#generate-the-api-client-from-json)
|
||||
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
|
||||
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
|
||||
- [Generate the API client](#generate-the-api-client)
|
||||
- [The generated client](#the-generated-client)
|
||||
- [API client customisation](#api-client-customisation)
|
||||
|
||||
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
|
||||
|
||||
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
|
||||
|
||||
## Generation
|
||||
|
||||
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
|
||||
|
||||
### Generate the API client from the nodes web server
|
||||
|
||||
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
|
||||
|
||||
1. Start the nodes web server.
|
||||
|
||||
```bash
|
||||
# from the repo root
|
||||
python scripts/invoke-new.py --web
|
||||
```
|
||||
|
||||
2. Generate the API client.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
yarn api:web
|
||||
```
|
||||
|
||||
### Generate the API client from JSON
|
||||
|
||||
The JSON can be acquired from the nodes web server, or with a python script.
|
||||
|
||||
#### Getting the JSON from the nodes web server
|
||||
|
||||
Start the nodes web server as described above, then download the file.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
curl http://localhost:9090/openapi.json -o openapi.json
|
||||
```
|
||||
|
||||
#### Getting the JSON with a python script
|
||||
|
||||
Run this python script from the repo root, so it can access the nodes server modules.
|
||||
|
||||
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
|
||||
|
||||
```bash
|
||||
# from the repo root
|
||||
python invokeai/app/util/generate_openapi_json.py
|
||||
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
|
||||
```
|
||||
|
||||
#### Generate the API client
|
||||
|
||||
Now we can generate the API client from the JSON.
|
||||
|
||||
```bash
|
||||
# from invokeai/frontend/web/
|
||||
yarn api:file
|
||||
```
|
||||
|
||||
## The generated client
|
||||
|
||||
The client will be written to `invokeai/frontend/web/services/api/`:
|
||||
|
||||
- `axios` client
|
||||
- TS types
|
||||
- An easily parseable schema, which we can use to generate UI
|
||||
|
||||
## API client customisation
|
||||
|
||||
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
|
||||
|
||||
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
|
||||
|
||||
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.
|
21
invokeai/frontend/web/docs/EVENTS.md
Normal file
21
invokeai/frontend/web/docs/EVENTS.md
Normal file
@ -0,0 +1,21 @@
|
||||
# Events
|
||||
|
||||
Events via `socket.io`
|
||||
|
||||
## `actions.ts`
|
||||
|
||||
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
|
||||
|
||||
Any reducer (or middleware) can respond to the actions.
|
||||
|
||||
## `middleware.ts`
|
||||
|
||||
Redux middleware for events.
|
||||
|
||||
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
|
||||
|
||||
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
|
||||
|
||||
## `types.ts`
|
||||
|
||||
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.
|
17
invokeai/frontend/web/docs/NODE_EDITOR.md
Normal file
17
invokeai/frontend/web/docs/NODE_EDITOR.md
Normal file
@ -0,0 +1,17 @@
|
||||
# Node Editor Design
|
||||
|
||||
WIP
|
||||
|
||||
nodes
|
||||
|
||||
everything in `src/features/nodes/`
|
||||
|
||||
have a look at `state.nodes.invocation`
|
||||
|
||||
- on socket connect, if no schema saved, fetch `localhost:9090/openapi.json`, save JSON to `state.nodes.schema`
|
||||
- on fulfilled schema fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations` - `Invocation` is like a template for the node
|
||||
- when you add a node, the the `Invocation` template is passed to `InvocationComponent.tsx` to build the UI component for that node
|
||||
- inputs/outputs have field types - and each field type gets an `FieldComponent` which includes a dispatcher to write state changes to redux `nodesSlice`
|
||||
- `reactflow` sends changes to nodes/edges to redux
|
||||
- to invoke, `buildNodesGraph()` state, then send this
|
||||
- changed onClick Invoke button actions to build the schema, then when schema builds it dispatches the actual network request to create the session - see `session.ts`
|
29
invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md
Normal file
29
invokeai/frontend/web/docs/PACKAGE_SCRIPTS.md
Normal file
@ -0,0 +1,29 @@
|
||||
# Package Scripts
|
||||
|
||||
WIP walkthrough of `package.json` scripts.
|
||||
|
||||
## `theme` & `theme:watch`
|
||||
|
||||
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
|
||||
|
||||
The CLI essentially monkeypatches Chakra's files in `node_modules`.
|
||||
|
||||
## `postinstall`
|
||||
|
||||
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
|
||||
|
||||
### Patch `@chakra-ui/cli`
|
||||
|
||||
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
|
||||
|
||||
### Patch `redux-persist`
|
||||
|
||||
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
|
||||
|
||||
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
|
||||
|
||||
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
|
||||
|
||||
### Patch `redux-deep-persist`
|
||||
|
||||
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.
|
@ -1,10 +1,16 @@
|
||||
# InvokeAI Web UI
|
||||
|
||||
- [InvokeAI Web UI](#invokeai-web-ui)
|
||||
- [Stack](#stack)
|
||||
- [Contributing](#contributing)
|
||||
- [Dev Environment](#dev-environment)
|
||||
- [Production builds](#production-builds)
|
||||
|
||||
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
|
||||
|
||||
Code in `invokeai/frontend/web/` if you want to have a look.
|
||||
|
||||
## Details
|
||||
## Stack
|
||||
|
||||
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
|
||||
|
||||
@ -32,7 +38,7 @@ Start everything in dev mode:
|
||||
|
||||
1. Start the dev server: `yarn dev`
|
||||
2. Start the InvokeAI UI per usual: `invokeai --web`
|
||||
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
|
||||
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
|
||||
|
||||
### Production builds
|
||||
|
21
invokeai/frontend/web/index.d.ts
vendored
21
invokeai/frontend/web/index.d.ts
vendored
@ -1,6 +1,7 @@
|
||||
import React, { PropsWithChildren } from 'react';
|
||||
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
|
||||
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
|
||||
export {};
|
||||
|
||||
@ -64,9 +65,25 @@ declare module '@invoke-ai/invoke-ai-ui' {
|
||||
declare class SettingsModal extends React.Component<SettingsModalProps> {
|
||||
public constructor(props: SettingsModalProps);
|
||||
}
|
||||
|
||||
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
|
||||
public constructor(props: StatusIndicatorProps);
|
||||
}
|
||||
|
||||
declare class ModelSelect extends React.Component<ModelSelectProps> {
|
||||
public constructor(props: ModelSelectProps);
|
||||
}
|
||||
}
|
||||
|
||||
declare function Invoke(props: PropsWithChildren): JSX.Element;
|
||||
interface InvokeProps extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
disabledPanels?: string[];
|
||||
disabledTabs?: InvokeTabName[];
|
||||
token?: string;
|
||||
shouldTransformUrls?: boolean;
|
||||
}
|
||||
|
||||
declare function Invoke(props: InvokeProps): JSX.Element;
|
||||
|
||||
export {
|
||||
ThemeChanger,
|
||||
@ -74,5 +91,7 @@ export {
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
StatusIndicator,
|
||||
ModelSelect,
|
||||
};
|
||||
export = Invoke;
|
||||
|
@ -5,7 +5,10 @@
|
||||
"scripts": {
|
||||
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
|
||||
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
|
||||
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
|
||||
"build": "yarn run lint && vite build",
|
||||
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
|
||||
"preview": "vite preview",
|
||||
"lint:madge": "madge --circular src/main.tsx",
|
||||
"lint:eslint": "eslint --max-warnings=0 .",
|
||||
@ -41,9 +44,11 @@
|
||||
"@chakra-ui/react": "^2.5.1",
|
||||
"@chakra-ui/styled-system": "^2.6.1",
|
||||
"@chakra-ui/theme-tools": "^2.0.16",
|
||||
"@dagrejs/graphlib": "^2.1.12",
|
||||
"@emotion/react": "^11.10.6",
|
||||
"@emotion/styled": "^11.10.6",
|
||||
"@reduxjs/toolkit": "^1.9.2",
|
||||
"@fontsource/inter": "^4.5.15",
|
||||
"@reduxjs/toolkit": "^1.9.3",
|
||||
"chakra-ui-contextmenu": "^1.0.5",
|
||||
"dateformat": "^5.0.3",
|
||||
"formik": "^2.2.9",
|
||||
@ -67,15 +72,17 @@
|
||||
"react-redux": "^8.0.5",
|
||||
"react-transition-group": "^4.4.5",
|
||||
"react-zoom-pan-pinch": "^2.6.1",
|
||||
"reactflow": "^11.7.0",
|
||||
"redux-deep-persist": "^1.0.7",
|
||||
"redux-dynamic-middlewares": "^2.2.0",
|
||||
"redux-persist": "^6.0.0",
|
||||
"socket.io-client": "^4.6.0",
|
||||
"use-image": "^1.1.0",
|
||||
"uuid": "^9.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@fontsource/inter": "^4.5.15",
|
||||
"@types/dateformat": "^5.0.0",
|
||||
"@types/lodash": "^4.14.194",
|
||||
"@types/react": "^18.0.28",
|
||||
"@types/react-dom": "^18.0.11",
|
||||
"@types/react-transition-group": "^4.4.5",
|
||||
@ -83,6 +90,7 @@
|
||||
"@typescript-eslint/eslint-plugin": "^5.52.0",
|
||||
"@typescript-eslint/parser": "^5.52.0",
|
||||
"@vitejs/plugin-react-swc": "^3.2.0",
|
||||
"axios": "^1.3.4",
|
||||
"babel-plugin-transform-imports": "^2.0.0",
|
||||
"concurrently": "^7.6.0",
|
||||
"eslint": "^8.34.0",
|
||||
@ -90,13 +98,17 @@
|
||||
"eslint-plugin-prettier": "^4.2.1",
|
||||
"eslint-plugin-react": "^7.32.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"form-data": "^4.0.0",
|
||||
"husky": "^8.0.3",
|
||||
"lint-staged": "^13.1.2",
|
||||
"madge": "^6.0.0",
|
||||
"openapi-types": "^12.1.0",
|
||||
"openapi-typescript-codegen": "^0.23.0",
|
||||
"postinstall-postinstall": "^2.1.0",
|
||||
"prettier": "^2.8.4",
|
||||
"rollup-plugin-visualizer": "^5.9.0",
|
||||
"terser": "^5.16.4",
|
||||
"typescript": "4.9.5",
|
||||
"vite": "^4.1.2",
|
||||
"vite-plugin-eslint": "^1.8.1",
|
||||
"vite-tsconfig-paths": "^4.0.5",
|
||||
|
@ -53,6 +53,7 @@
|
||||
"txt2img": "Text To Image",
|
||||
"img2img": "Image To Image",
|
||||
"unifiedCanvas": "Unified Canvas",
|
||||
"linear": "Linear",
|
||||
"nodes": "Nodes",
|
||||
"postprocessing": "Post Processing",
|
||||
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
|
||||
@ -525,6 +526,10 @@
|
||||
"resetComplete": "Web UI has been reset. Refresh the page to reload."
|
||||
},
|
||||
"toast": {
|
||||
"serverError": "Server Error",
|
||||
"disconnected": "Disconnected from Server",
|
||||
"connected": "Connected to Server",
|
||||
"canceled": "Processing Canceled",
|
||||
"tempFoldersEmptied": "Temp Folder Emptied",
|
||||
"uploadFailed": "Upload failed",
|
||||
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",
|
||||
|
@ -13,16 +13,42 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
|
||||
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
|
||||
import Lightbox from 'features/lightbox/components/Lightbox';
|
||||
import { useAppSelector } from './storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from './storeHooks';
|
||||
import { PropsWithChildren, useEffect } from 'react';
|
||||
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
|
||||
|
||||
keepGUIAlive();
|
||||
|
||||
const App = (props: PropsWithChildren) => {
|
||||
interface Props extends PropsWithChildren {
|
||||
options: {
|
||||
disabledPanels: string[];
|
||||
disabledTabs: InvokeTabName[];
|
||||
shouldTransformUrls?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
const App = (props: Props) => {
|
||||
useToastWatcher();
|
||||
|
||||
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
|
||||
const { setColorMode } = useColorMode();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(setDisabledPanels(props.options.disabledPanels));
|
||||
}, [dispatch, props.options.disabledPanels]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(setDisabledTabs(props.options.disabledTabs));
|
||||
}, [dispatch, props.options.disabledTabs]);
|
||||
|
||||
useEffect(() => {
|
||||
dispatch(
|
||||
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
|
||||
);
|
||||
}, [dispatch, props.options.shouldTransformUrls]);
|
||||
|
||||
useEffect(() => {
|
||||
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');
|
||||
|
22
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
22
invokeai/frontend/web/src/app/invokeai.d.ts
vendored
@ -14,6 +14,8 @@
|
||||
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { ImageMetadata, ImageType } from 'services/api';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
|
||||
/**
|
||||
* TODO:
|
||||
@ -113,7 +115,7 @@ export declare type Metadata = SystemGenerationMetadata & {
|
||||
};
|
||||
|
||||
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
|
||||
export declare type Image = {
|
||||
export declare type _Image = {
|
||||
uuid: string;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
@ -124,11 +126,23 @@ export declare type Image = {
|
||||
category: GalleryCategory;
|
||||
isBase64?: boolean;
|
||||
dreamPrompt?: 'string';
|
||||
name?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* ResultImage
|
||||
*/
|
||||
export declare type Image = {
|
||||
name: string;
|
||||
type: ImageType;
|
||||
url: string;
|
||||
thumbnail: string;
|
||||
metadata: ImageMetadata;
|
||||
};
|
||||
|
||||
// GalleryImages is an array of Image.
|
||||
export declare type GalleryImages = {
|
||||
images: Array<Image>;
|
||||
images: Array<_Image>;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -275,7 +289,7 @@ export declare type SystemStatusResponse = SystemStatus;
|
||||
|
||||
export declare type SystemConfigResponse = SystemConfig;
|
||||
|
||||
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
|
||||
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
|
||||
boundingBox?: IRect;
|
||||
generationMode: InvokeTabName;
|
||||
};
|
||||
@ -296,7 +310,7 @@ export declare type ErrorResponse = {
|
||||
};
|
||||
|
||||
export declare type GalleryImagesResponse = {
|
||||
images: Array<Omit<Image, 'uuid'>>;
|
||||
images: Array<Omit<_Image, 'uuid'>>;
|
||||
areMoreImagesAvailable: boolean;
|
||||
category: GalleryCategory;
|
||||
};
|
||||
|
@ -20,6 +20,7 @@ export const readinessSelector = createSelector(
|
||||
seedWeights,
|
||||
initialImage,
|
||||
seed,
|
||||
isImageToImageEnabled,
|
||||
} = generation;
|
||||
|
||||
const { isProcessing, isConnected } = system;
|
||||
@ -33,7 +34,7 @@ export const readinessSelector = createSelector(
|
||||
reasonsWhyNotReady.push('Missing prompt');
|
||||
}
|
||||
|
||||
if (activeTabName === 'img2img' && !initialImage) {
|
||||
if (isImageToImageEnabled && !initialImage) {
|
||||
isReady = false;
|
||||
reasonsWhyNotReady.push('No initial image selected');
|
||||
}
|
||||
|
@ -13,9 +13,13 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
export const generateImage = createAction<InvokeTabName>(
|
||||
'socketio/generateImage'
|
||||
);
|
||||
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
|
||||
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
|
||||
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
|
||||
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
|
||||
export const runFacetool = createAction<InvokeAI._Image>(
|
||||
'socketio/runFacetool'
|
||||
);
|
||||
export const deleteImage = createAction<InvokeAI._Image>(
|
||||
'socketio/deleteImage'
|
||||
);
|
||||
export const requestImages = createAction<GalleryCategory>(
|
||||
'socketio/requestImages'
|
||||
);
|
||||
|
@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
|
||||
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const {
|
||||
@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
|
||||
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
|
||||
dispatch(setIsProcessing(true));
|
||||
|
||||
const {
|
||||
@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
|
||||
})
|
||||
);
|
||||
},
|
||||
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
|
||||
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
|
||||
const { url, uuid, category, thumbnail } = imageToDelete;
|
||||
dispatch(removeImage(imageToDelete));
|
||||
socketio.emit('deleteImage', url, thumbnail, uuid, category);
|
||||
|
@ -34,8 +34,9 @@ import type { RootState } from 'app/store';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
clearInitialImage,
|
||||
initialImageSelected,
|
||||
setInfillMethod,
|
||||
setInitialImage,
|
||||
// setInitialImage,
|
||||
setMaskPath,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { tabMap } from 'features/ui/store/tabMap';
|
||||
@ -142,15 +143,17 @@ const makeSocketIOListeners = (
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldLoopback) {
|
||||
const activeTabName = tabMap[activeTab];
|
||||
switch (activeTabName) {
|
||||
case 'img2img': {
|
||||
dispatch(setInitialImage(newImage));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: fix
|
||||
// if (shouldLoopback) {
|
||||
// const activeTabName = tabMap[activeTab];
|
||||
// switch (activeTabName) {
|
||||
// case 'img2img': {
|
||||
// dispatch(initialImageSelected(newImage.uuid));
|
||||
// // dispatch(setInitialImage(newImage));
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
dispatch(clearIntermediateImage());
|
||||
|
||||
@ -262,7 +265,7 @@ const makeSocketIOListeners = (
|
||||
*/
|
||||
|
||||
// Generate a UUID for each image
|
||||
const preparedImages = images.map((image): InvokeAI.Image => {
|
||||
const preparedImages = images.map((image): InvokeAI._Image => {
|
||||
return {
|
||||
uuid: uuidv4(),
|
||||
...image,
|
||||
@ -334,7 +337,7 @@ const makeSocketIOListeners = (
|
||||
|
||||
if (
|
||||
initialImage === url ||
|
||||
(initialImage as InvokeAI.Image)?.url === url
|
||||
(initialImage as InvokeAI._Image)?.url === url
|
||||
) {
|
||||
dispatch(clearInitialImage());
|
||||
}
|
||||
|
@ -29,6 +29,8 @@ export const socketioMiddleware = () => {
|
||||
path: `${window.location.pathname}socket.io`,
|
||||
});
|
||||
|
||||
socketio.disconnect();
|
||||
|
||||
let areListenersSet = false;
|
||||
|
||||
const middleware: Middleware = (store) => (next) => (action) => {
|
||||
|
@ -2,18 +2,32 @@ import { combineReducers, configureStore } from '@reduxjs/toolkit';
|
||||
|
||||
import { persistReducer } from 'redux-persist';
|
||||
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
|
||||
|
||||
import dynamicMiddlewares from 'redux-dynamic-middlewares';
|
||||
import { getPersistConfig } from 'redux-deep-persist';
|
||||
|
||||
import canvasReducer from 'features/canvas/store/canvasSlice';
|
||||
import galleryReducer from 'features/gallery/store/gallerySlice';
|
||||
import resultsReducer from 'features/gallery/store/resultsSlice';
|
||||
import uploadsReducer from 'features/gallery/store/uploadsSlice';
|
||||
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
|
||||
import generationReducer from 'features/parameters/store/generationSlice';
|
||||
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
|
||||
import systemReducer from 'features/system/store/systemSlice';
|
||||
import uiReducer from 'features/ui/store/uiSlice';
|
||||
import modelsReducer from 'features/system/store/modelSlice';
|
||||
import nodesReducer from 'features/nodes/store/nodesSlice';
|
||||
|
||||
import { socketioMiddleware } from './socketio/middleware';
|
||||
import { socketMiddleware } from 'services/events/middleware';
|
||||
import { canvasBlacklist } from 'features/canvas/store/canvasPersistBlacklist';
|
||||
import { galleryBlacklist } from 'features/gallery/store/galleryPersistBlacklist';
|
||||
import { generationBlacklist } from 'features/parameters/store/generationPersistBlacklist';
|
||||
import { lightboxBlacklist } from 'features/lightbox/store/lightboxPersistBlacklist';
|
||||
import { modelsBlacklist } from 'features/system/store/modelsPersistBlacklist';
|
||||
import { nodesBlacklist } from 'features/nodes/store/nodesPersistBlacklist';
|
||||
import { postprocessingBlacklist } from 'features/parameters/store/postprocessingPersistBlacklist';
|
||||
import { systemBlacklist } from 'features/system/store/systemPersistsBlacklist';
|
||||
import { uiBlacklist } from 'features/ui/store/uiPersistBlacklist';
|
||||
|
||||
/**
|
||||
* redux-persist provides an easy and reliable way to persist state across reloads.
|
||||
@ -29,49 +43,18 @@ import { socketioMiddleware } from './socketio/middleware';
|
||||
* The necesssary nested persistors with blacklists are configured below.
|
||||
*/
|
||||
|
||||
const canvasBlacklist = [
|
||||
'cursorPosition',
|
||||
'isCanvasInitialized',
|
||||
'doesCanvasNeedScaling',
|
||||
].map((blacklistItem) => `canvas.${blacklistItem}`);
|
||||
|
||||
const systemBlacklist = [
|
||||
'currentIteration',
|
||||
'currentStatus',
|
||||
'currentStep',
|
||||
'isCancelable',
|
||||
'isConnected',
|
||||
'isESRGANAvailable',
|
||||
'isGFPGANAvailable',
|
||||
'isProcessing',
|
||||
'socketId',
|
||||
'totalIterations',
|
||||
'totalSteps',
|
||||
'openModel',
|
||||
'cancelOptions.cancelAfter',
|
||||
].map((blacklistItem) => `system.${blacklistItem}`);
|
||||
|
||||
const galleryBlacklist = [
|
||||
'categories',
|
||||
'currentCategory',
|
||||
'currentImage',
|
||||
'currentImageUuid',
|
||||
'shouldAutoSwitchToNewImages',
|
||||
'intermediateImage',
|
||||
].map((blacklistItem) => `gallery.${blacklistItem}`);
|
||||
|
||||
const lightboxBlacklist = ['isLightboxOpen'].map(
|
||||
(blacklistItem) => `lightbox.${blacklistItem}`
|
||||
);
|
||||
|
||||
const rootReducer = combineReducers({
|
||||
generation: generationReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
gallery: galleryReducer,
|
||||
system: systemReducer,
|
||||
canvas: canvasReducer,
|
||||
ui: uiReducer,
|
||||
gallery: galleryReducer,
|
||||
generation: generationReducer,
|
||||
lightbox: lightboxReducer,
|
||||
models: modelsReducer,
|
||||
nodes: nodesReducer,
|
||||
postprocessing: postprocessingReducer,
|
||||
results: resultsReducer,
|
||||
system: systemReducer,
|
||||
ui: uiReducer,
|
||||
uploads: uploadsReducer,
|
||||
});
|
||||
|
||||
const rootPersistConfig = getPersistConfig({
|
||||
@ -80,23 +63,40 @@ const rootPersistConfig = getPersistConfig({
|
||||
rootReducer,
|
||||
blacklist: [
|
||||
...canvasBlacklist,
|
||||
...systemBlacklist,
|
||||
...galleryBlacklist,
|
||||
...generationBlacklist,
|
||||
...lightboxBlacklist,
|
||||
...modelsBlacklist,
|
||||
...nodesBlacklist,
|
||||
...postprocessingBlacklist,
|
||||
// ...resultsBlacklist,
|
||||
'results',
|
||||
...systemBlacklist,
|
||||
...uiBlacklist,
|
||||
// ...uploadsBlacklist,
|
||||
'uploads',
|
||||
],
|
||||
debounce: 300,
|
||||
});
|
||||
|
||||
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
|
||||
|
||||
// Continue with store setup
|
||||
// TODO: rip the old middleware out when nodes is complete
|
||||
export function buildMiddleware() {
|
||||
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
|
||||
return socketMiddleware();
|
||||
} else {
|
||||
return socketioMiddleware();
|
||||
}
|
||||
}
|
||||
|
||||
export const store = configureStore({
|
||||
reducer: persistedReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
immutableCheck: false,
|
||||
serializableCheck: false,
|
||||
}).concat(socketioMiddleware()),
|
||||
}).concat(dynamicMiddlewares),
|
||||
devTools: {
|
||||
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
|
||||
actionsDenylist: [
|
||||
|
8
invokeai/frontend/web/src/app/storeUtils.ts
Normal file
8
invokeai/frontend/web/src/app/storeUtils.ts
Normal file
@ -0,0 +1,8 @@
|
||||
import { createAsyncThunk } from '@reduxjs/toolkit';
|
||||
import { AppDispatch, RootState } from './store';
|
||||
|
||||
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
|
||||
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
|
||||
state: RootState;
|
||||
dispatch: AppDispatch;
|
||||
}>();
|
@ -44,12 +44,10 @@ export type IAIFullSliderProps = {
|
||||
inputReadOnly?: boolean;
|
||||
withReset?: boolean;
|
||||
handleReset?: () => void;
|
||||
isResetDisabled?: boolean;
|
||||
isSliderDisabled?: boolean;
|
||||
isInputDisabled?: boolean;
|
||||
tooltipSuffix?: string;
|
||||
hideTooltip?: boolean;
|
||||
isCompact?: boolean;
|
||||
isDisabled?: boolean;
|
||||
sliderFormControlProps?: FormControlProps;
|
||||
sliderFormLabelProps?: FormLabelProps;
|
||||
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
|
||||
@ -80,10 +78,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
withReset = false,
|
||||
hideTooltip = false,
|
||||
isCompact = false,
|
||||
isDisabled = false,
|
||||
handleReset,
|
||||
isResetDisabled,
|
||||
isSliderDisabled,
|
||||
isInputDisabled,
|
||||
sliderFormControlProps,
|
||||
sliderFormLabelProps,
|
||||
sliderMarkProps,
|
||||
@ -149,6 +145,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
}
|
||||
: {}
|
||||
}
|
||||
isDisabled={isDisabled}
|
||||
{...sliderFormControlProps}
|
||||
>
|
||||
<FormLabel {...sliderFormLabelProps} mb={-1}>
|
||||
@ -166,15 +163,13 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
onMouseEnter={() => setShowTooltip(true)}
|
||||
onMouseLeave={() => setShowTooltip(false)}
|
||||
focusThumbOnChange={false}
|
||||
isDisabled={isSliderDisabled}
|
||||
// width={width}
|
||||
isDisabled={isDisabled}
|
||||
{...rest}
|
||||
>
|
||||
{withSliderMarks && (
|
||||
<>
|
||||
<SliderMark
|
||||
value={min}
|
||||
// insetInlineStart={0}
|
||||
sx={{
|
||||
insetInlineStart: '0 !important',
|
||||
insetInlineEnd: 'unset !important',
|
||||
@ -185,7 +180,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
</SliderMark>
|
||||
<SliderMark
|
||||
value={max}
|
||||
// insetInlineEnd={0}
|
||||
sx={{
|
||||
insetInlineStart: 'unset !important',
|
||||
insetInlineEnd: '0 !important',
|
||||
@ -221,7 +215,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
value={localInputValue}
|
||||
onChange={handleInputChange}
|
||||
onBlur={handleInputBlur}
|
||||
isDisabled={isInputDisabled}
|
||||
{...sliderNumberInputProps}
|
||||
>
|
||||
<NumberInputField
|
||||
@ -246,8 +239,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
|
||||
aria-label={t('accessibility.reset')}
|
||||
tooltip="Reset"
|
||||
icon={<BiReset />}
|
||||
isDisabled={isDisabled}
|
||||
onClick={handleResetDisable}
|
||||
isDisabled={isResetDisabled}
|
||||
{...sliderIAIIconButtonProps}
|
||||
/>
|
||||
)}
|
||||
|
@ -0,0 +1,79 @@
|
||||
import { Badge, Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { clearInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { useCallback } from 'react';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { FaUndo, FaUpload } from 'react-icons/fa';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Image } from 'app/invokeai';
|
||||
|
||||
type ImageToImageOverlayProps = {
|
||||
setIsLoaded: (isLoaded: boolean) => void;
|
||||
image: Image;
|
||||
};
|
||||
|
||||
const ImageToImageOverlay = ({
|
||||
setIsLoaded,
|
||||
image,
|
||||
}: ImageToImageOverlayProps) => {
|
||||
const isImageToImageEnabled = useAppSelector(
|
||||
(state: RootState) => state.generation.isImageToImageEnabled
|
||||
);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const handleResetInitialImage = useCallback(() => {
|
||||
dispatch(clearInitialImage());
|
||||
setIsLoaded(false);
|
||||
}, [dispatch, setIsLoaded]);
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
top: 0,
|
||||
left: 0,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
position: 'absolute',
|
||||
}}
|
||||
>
|
||||
<ButtonGroup
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
p: 2,
|
||||
}}
|
||||
>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
isDisabled={!isImageToImageEnabled}
|
||||
icon={<FaUndo />}
|
||||
aria-label={t('accessibility.reset')}
|
||||
onClick={handleResetInitialImage}
|
||||
/>
|
||||
<IAIIconButton
|
||||
size="sm"
|
||||
isDisabled={!isImageToImageEnabled}
|
||||
icon={<FaUpload />}
|
||||
aria-label={t('common.upload')}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
<Flex
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
bottom: 0,
|
||||
left: 0,
|
||||
p: 2,
|
||||
alignItems: 'flex-start',
|
||||
}}
|
||||
>
|
||||
<Badge variant="solid" colorScheme="base">
|
||||
{image.metadata?.width} × {image.metadata?.height}
|
||||
</Badge>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default ImageToImageOverlay;
|
@ -2,7 +2,6 @@ import { Box, useToast } from '@chakra-ui/react';
|
||||
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import useImageUploader from 'common/hooks/useImageUploader';
|
||||
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import {
|
||||
@ -15,6 +14,7 @@ import {
|
||||
} from 'react';
|
||||
import { FileRejection, useDropzone } from 'react-dropzone';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import ImageUploadOverlay from './ImageUploadOverlay';
|
||||
|
||||
type ImageUploaderProps = {
|
||||
@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
|
||||
const fileAcceptedCallback = useCallback(
|
||||
async (file: File) => {
|
||||
dispatch(uploadImage({ imageFile: file }));
|
||||
dispatch(imageUploaded({ formData: { file } }));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(uploadImage({ imageFile: file }));
|
||||
dispatch(imageUploaded({ formData: { file } }));
|
||||
};
|
||||
document.addEventListener('paste', pasteImageListener);
|
||||
return () => {
|
||||
|
@ -0,0 +1,12 @@
|
||||
import { Flex, Icon } from '@chakra-ui/react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
|
||||
const SelectImagePlaceholder = () => {
|
||||
return (
|
||||
<Flex sx={{ h: 36, alignItems: 'center', justifyContent: 'center' }}>
|
||||
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default SelectImagePlaceholder;
|
@ -1,27 +1,160 @@
|
||||
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import WorkInProgress from './WorkInProgress';
|
||||
// import WorkInProgress from './WorkInProgress';
|
||||
// import ReactFlow, {
|
||||
// applyEdgeChanges,
|
||||
// applyNodeChanges,
|
||||
// Background,
|
||||
// Controls,
|
||||
// Edge,
|
||||
// Handle,
|
||||
// Node,
|
||||
// NodeTypes,
|
||||
// OnEdgesChange,
|
||||
// OnNodesChange,
|
||||
// Position,
|
||||
// } from 'reactflow';
|
||||
|
||||
export default function NodesWIP() {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<WorkInProgress>
|
||||
<Flex
|
||||
sx={{
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
w: '100%',
|
||||
h: '100%',
|
||||
gap: 4,
|
||||
textAlign: 'center',
|
||||
}}
|
||||
>
|
||||
<Heading>{t('common.nodes')}</Heading>
|
||||
<VStack maxW="50rem" gap={4}>
|
||||
<Text>{t('common.nodesDesc')}</Text>
|
||||
</VStack>
|
||||
</Flex>
|
||||
</WorkInProgress>
|
||||
);
|
||||
}
|
||||
// import 'reactflow/dist/style.css';
|
||||
// import {
|
||||
// Fragment,
|
||||
// FunctionComponent,
|
||||
// ReactNode,
|
||||
// useCallback,
|
||||
// useMemo,
|
||||
// useState,
|
||||
// } from 'react';
|
||||
// import { OpenAPIV3 } from 'openapi-types';
|
||||
// import { filter, map, reduce } from 'lodash';
|
||||
// import {
|
||||
// Box,
|
||||
// Flex,
|
||||
// FormControl,
|
||||
// FormLabel,
|
||||
// Input,
|
||||
// Select,
|
||||
// Switch,
|
||||
// Text,
|
||||
// NumberInput,
|
||||
// NumberInputField,
|
||||
// NumberInputStepper,
|
||||
// NumberIncrementStepper,
|
||||
// NumberDecrementStepper,
|
||||
// Tooltip,
|
||||
// chakra,
|
||||
// Badge,
|
||||
// Heading,
|
||||
// VStack,
|
||||
// HStack,
|
||||
// Menu,
|
||||
// MenuButton,
|
||||
// MenuList,
|
||||
// MenuItem,
|
||||
// MenuItemOption,
|
||||
// MenuGroup,
|
||||
// MenuOptionGroup,
|
||||
// MenuDivider,
|
||||
// IconButton,
|
||||
// } from '@chakra-ui/react';
|
||||
// import { FaPlus } from 'react-icons/fa';
|
||||
// import {
|
||||
// FIELD_NAMES as FIELD_NAMES,
|
||||
// FIELDS,
|
||||
// INVOCATION_NAMES as INVOCATION_NAMES,
|
||||
// INVOCATIONS,
|
||||
// } from 'features/nodeEditor/constants';
|
||||
|
||||
// console.log('invocations', INVOCATIONS);
|
||||
|
||||
// const nodeTypes = reduce(
|
||||
// INVOCATIONS,
|
||||
// (acc, val, key) => {
|
||||
// acc[key] = val.component;
|
||||
// return acc;
|
||||
// },
|
||||
// {} as NodeTypes
|
||||
// );
|
||||
|
||||
// console.log('nodeTypes', nodeTypes);
|
||||
|
||||
// // make initial nodes one of every node for now
|
||||
// let n = 0;
|
||||
// const initialNodes = map(INVOCATIONS, (i) => ({
|
||||
// id: i.type,
|
||||
// type: i.title,
|
||||
// position: { x: (n += 20), y: (n += 20) },
|
||||
// data: {},
|
||||
// }));
|
||||
|
||||
// console.log('initialNodes', initialNodes);
|
||||
|
||||
// export default function NodesWIP() {
|
||||
// const [nodes, setNodes] = useState<Node[]>([]);
|
||||
// const [edges, setEdges] = useState<Edge[]>([]);
|
||||
|
||||
// const onNodesChange: OnNodesChange = useCallback(
|
||||
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
|
||||
// []
|
||||
// );
|
||||
|
||||
// const onEdgesChange: OnEdgesChange = useCallback(
|
||||
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
|
||||
// []
|
||||
// );
|
||||
|
||||
// return (
|
||||
// <Box
|
||||
// sx={{
|
||||
// position: 'relative',
|
||||
// width: 'full',
|
||||
// height: 'full',
|
||||
// borderRadius: 'md',
|
||||
// }}
|
||||
// >
|
||||
// <ReactFlow
|
||||
// nodeTypes={nodeTypes}
|
||||
// nodes={nodes}
|
||||
// edges={edges}
|
||||
// onNodesChange={onNodesChange}
|
||||
// onEdgesChange={onEdgesChange}
|
||||
// >
|
||||
// <Background />
|
||||
// <Controls />
|
||||
// </ReactFlow>
|
||||
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
|
||||
// {FIELD_NAMES.map((field) => (
|
||||
// <Badge
|
||||
// key={field}
|
||||
// colorScheme={FIELDS[field].color}
|
||||
// sx={{ userSelect: 'none' }}
|
||||
// >
|
||||
// {field}
|
||||
// </Badge>
|
||||
// ))}
|
||||
// </HStack>
|
||||
// <Menu>
|
||||
// <MenuButton
|
||||
// as={IconButton}
|
||||
// aria-label="Options"
|
||||
// icon={<FaPlus />}
|
||||
// sx={{ position: 'absolute', top: 2, left: 2 }}
|
||||
// />
|
||||
// <MenuList>
|
||||
// {INVOCATION_NAMES.map((name) => {
|
||||
// const invocation = INVOCATIONS[name];
|
||||
// return (
|
||||
// <Tooltip
|
||||
// key={name}
|
||||
// label={invocation.description}
|
||||
// placement="end"
|
||||
// hasArrow
|
||||
// >
|
||||
// <MenuItem>{invocation.title}</MenuItem>
|
||||
// </Tooltip>
|
||||
// );
|
||||
// })}
|
||||
// </MenuList>
|
||||
// </Menu>
|
||||
// </Box>
|
||||
// );
|
||||
// }
|
||||
|
||||
export default {};
|
||||
|
@ -14,6 +14,8 @@ const WorkInProgress = (props: WorkInProgressProps) => {
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
bg: 'base.850',
|
||||
borderRadius: 'base',
|
||||
position: 'relative',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
|
119
invokeai/frontend/web/src/common/util/_parseMetadataZod.ts
Normal file
119
invokeai/frontend/web/src/common/util/_parseMetadataZod.ts
Normal file
@ -0,0 +1,119 @@
|
||||
/**
|
||||
* PARTIAL ZOD IMPLEMENTATION
|
||||
*
|
||||
* doesn't work well bc like most validators, zod is not built to skip invalid values.
|
||||
* it mostly works but just seems clearer and simpler to manually parse for now.
|
||||
*
|
||||
* in the future it would be really nice if we could use zod for some things:
|
||||
* - zodios (axios + zod): https://github.com/ecyrbe/zodios
|
||||
* - openapi to zodios: https://github.com/astahmer/openapi-zod-client
|
||||
*/
|
||||
|
||||
// import { z } from 'zod';
|
||||
|
||||
// const zMetadataStringField = z.string();
|
||||
// export type MetadataStringField = z.infer<typeof zMetadataStringField>;
|
||||
|
||||
// const zMetadataIntegerField = z.number().int();
|
||||
// export type MetadataIntegerField = z.infer<typeof zMetadataIntegerField>;
|
||||
|
||||
// const zMetadataFloatField = z.number();
|
||||
// export type MetadataFloatField = z.infer<typeof zMetadataFloatField>;
|
||||
|
||||
// const zMetadataBooleanField = z.boolean();
|
||||
// export type MetadataBooleanField = z.infer<typeof zMetadataBooleanField>;
|
||||
|
||||
// const zMetadataImageField = z.object({
|
||||
// image_type: z.union([
|
||||
// z.literal('results'),
|
||||
// z.literal('uploads'),
|
||||
// z.literal('intermediates'),
|
||||
// ]),
|
||||
// image_name: z.string().min(1),
|
||||
// });
|
||||
// export type MetadataImageField = z.infer<typeof zMetadataImageField>;
|
||||
|
||||
// const zMetadataLatentsField = z.object({
|
||||
// latents_name: z.string().min(1),
|
||||
// });
|
||||
// export type MetadataLatentsField = z.infer<typeof zMetadataLatentsField>;
|
||||
|
||||
// /**
|
||||
// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
|
||||
// */
|
||||
// const zAnyMetadataField = z.any().transform((val, ctx) => {
|
||||
// // Grab the field name from the path
|
||||
// const fieldName = String(ctx.path[ctx.path.length - 1]);
|
||||
|
||||
// // `id` and `type` must be strings if they exist
|
||||
// if (['id', 'type'].includes(fieldName)) {
|
||||
// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
|
||||
// if (reservedStringPropertyResult.success) {
|
||||
// return reservedStringPropertyResult.data;
|
||||
// }
|
||||
|
||||
// return;
|
||||
// }
|
||||
|
||||
// // Parse the rest of the fields, only returning the data if the parsing is successful
|
||||
|
||||
// const stringFieldResult = zMetadataStringField.safeParse(val);
|
||||
// if (stringFieldResult.success) {
|
||||
// return stringFieldResult.data;
|
||||
// }
|
||||
|
||||
// const integerFieldResult = zMetadataIntegerField.safeParse(val);
|
||||
// if (integerFieldResult.success) {
|
||||
// return integerFieldResult.data;
|
||||
// }
|
||||
|
||||
// const floatFieldResult = zMetadataFloatField.safeParse(val);
|
||||
// if (floatFieldResult.success) {
|
||||
// return floatFieldResult.data;
|
||||
// }
|
||||
|
||||
// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
|
||||
// if (booleanFieldResult.success) {
|
||||
// return booleanFieldResult.data;
|
||||
// }
|
||||
|
||||
// const imageFieldResult = zMetadataImageField.safeParse(val);
|
||||
// if (imageFieldResult.success) {
|
||||
// return imageFieldResult.data;
|
||||
// }
|
||||
|
||||
// const latentsFieldResult = zMetadataImageField.safeParse(val);
|
||||
// if (latentsFieldResult.success) {
|
||||
// return latentsFieldResult.data;
|
||||
// }
|
||||
// });
|
||||
|
||||
// /**
|
||||
// * The node metadata schema.
|
||||
// */
|
||||
// const zNodeMetadata = z.object({
|
||||
// session_id: z.string().min(1).optional(),
|
||||
// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
|
||||
// });
|
||||
|
||||
// export type NodeMetadata = z.infer<typeof zNodeMetadata>;
|
||||
|
||||
// const zMetadata = z.object({
|
||||
// invokeai: zNodeMetadata.optional(),
|
||||
// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
|
||||
// });
|
||||
// export type Metadata = z.infer<typeof zMetadata>;
|
||||
|
||||
// export const parseMetadata = (
|
||||
// metadata: Record<string, any>
|
||||
// ): Metadata | undefined => {
|
||||
// const result = zMetadata.safeParse(metadata);
|
||||
// if (!result.success) {
|
||||
// console.log(result.error.issues);
|
||||
// return;
|
||||
// }
|
||||
|
||||
// return result.data;
|
||||
// };
|
||||
|
||||
export default {};
|
6
invokeai/frontend/web/src/common/util/getTimestamp.ts
Normal file
6
invokeai/frontend/web/src/common/util/getTimestamp.ts
Normal file
@ -0,0 +1,6 @@
|
||||
import dateFormat from 'dateformat';
|
||||
|
||||
/**
|
||||
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
|
||||
*/
|
||||
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');
|
28
invokeai/frontend/web/src/common/util/getUrl.ts
Normal file
28
invokeai/frontend/web/src/common/util/getUrl.ts
Normal file
@ -0,0 +1,28 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { OpenAPI } from 'services/api';
|
||||
|
||||
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
|
||||
if (OpenAPI.BASE && shouldTransformUrls) {
|
||||
return [OpenAPI.BASE, url].join('/');
|
||||
}
|
||||
|
||||
return url;
|
||||
};
|
||||
|
||||
export const useGetUrl = () => {
|
||||
const shouldTransformUrls = useAppSelector(
|
||||
(state: RootState) => state.system.shouldTransformUrls
|
||||
);
|
||||
|
||||
return {
|
||||
shouldTransformUrls,
|
||||
getUrl: (url?: string) => {
|
||||
if (OpenAPI.BASE && shouldTransformUrls) {
|
||||
return [OpenAPI.BASE, url].join('/');
|
||||
}
|
||||
|
||||
return url;
|
||||
},
|
||||
};
|
||||
};
|
169
invokeai/frontend/web/src/common/util/parseMetadata.ts
Normal file
169
invokeai/frontend/web/src/common/util/parseMetadata.ts
Normal file
@ -0,0 +1,169 @@
|
||||
import { forEach, size } from 'lodash';
|
||||
import { ImageField, LatentsField } from 'services/api';
|
||||
|
||||
const OBJECT_TYPESTRING = '[object Object]';
|
||||
const STRING_TYPESTRING = '[object String]';
|
||||
const NUMBER_TYPESTRING = '[object Number]';
|
||||
const BOOLEAN_TYPESTRING = '[object Boolean]';
|
||||
const ARRAY_TYPESTRING = '[object Array]';
|
||||
|
||||
const isObject = (obj: unknown): obj is Record<string | number, any> =>
|
||||
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
|
||||
|
||||
const isString = (obj: unknown): obj is string =>
|
||||
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
|
||||
|
||||
const isNumber = (obj: unknown): obj is number =>
|
||||
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
|
||||
|
||||
const isBoolean = (obj: unknown): obj is boolean =>
|
||||
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
|
||||
|
||||
const isArray = (obj: unknown): obj is Array<any> =>
|
||||
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
|
||||
|
||||
const parseImageField = (imageField: unknown): ImageField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(imageField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// An ImageField must have both `image_name` and `image_type`
|
||||
if (!('image_name' in imageField && 'image_type' in imageField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// An ImageField's `image_type` must be one of the allowed values
|
||||
if (
|
||||
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
// An ImageField's `image_name` must be a string
|
||||
if (typeof imageField.image_name !== 'string') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a valid ImageField
|
||||
return {
|
||||
image_type: imageField.image_type,
|
||||
image_name: imageField.image_name,
|
||||
};
|
||||
};
|
||||
|
||||
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
|
||||
// Must be an object
|
||||
if (!isObject(latentsField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A LatentsField must have a `latents_name`
|
||||
if (!('latents_name' in latentsField)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// A LatentsField's `latents_name` must be a string
|
||||
if (typeof latentsField.latents_name !== 'string') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a valid LatentsField
|
||||
return {
|
||||
latents_name: latentsField.latents_name,
|
||||
};
|
||||
};
|
||||
|
||||
type NodeMetadata = {
|
||||
[key: string]: string | number | boolean | ImageField | LatentsField;
|
||||
};
|
||||
|
||||
type InvokeAIMetadata = {
|
||||
session_id?: string;
|
||||
node?: NodeMetadata;
|
||||
};
|
||||
|
||||
export const parseNodeMetadata = (
|
||||
nodeMetadata: Record<string | number, any>
|
||||
): NodeMetadata | undefined => {
|
||||
if (!isObject(nodeMetadata)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed: NodeMetadata = {};
|
||||
|
||||
forEach(nodeMetadata, (nodeItem, nodeKey) => {
|
||||
// `id` and `type` must be strings if they are present
|
||||
if (['id', 'type'].includes(nodeKey)) {
|
||||
if (isString(nodeItem)) {
|
||||
parsed[nodeKey] = nodeItem;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// the only valid object types are ImageField and LatentsField
|
||||
if (isObject(nodeItem)) {
|
||||
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
|
||||
const imageField = parseImageField(nodeItem);
|
||||
if (imageField) {
|
||||
parsed[nodeKey] = imageField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if ('latents_name' in nodeItem) {
|
||||
const latentsField = parseLatentsField(nodeItem);
|
||||
if (latentsField) {
|
||||
parsed[nodeKey] = latentsField;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise we accept any string, number or boolean
|
||||
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
|
||||
parsed[nodeKey] = nodeItem;
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
if (size(parsed) === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
return parsed;
|
||||
};
|
||||
|
||||
export const parseInvokeAIMetadata = (
|
||||
metadata: Record<string | number, any> | undefined
|
||||
): InvokeAIMetadata | undefined => {
|
||||
if (metadata === undefined) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!isObject(metadata)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed: InvokeAIMetadata = {};
|
||||
|
||||
forEach(metadata, (item, key) => {
|
||||
if (key === 'session_id' && isString(item)) {
|
||||
parsed['session_id'] = item;
|
||||
}
|
||||
|
||||
if (key === 'node' && isObject(item)) {
|
||||
const nodeMetadata = parseNodeMetadata(item);
|
||||
|
||||
if (nodeMetadata) {
|
||||
parsed['node'] = nodeMetadata;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (size(parsed) === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
return parsed;
|
||||
};
|
@ -1,8 +1,10 @@
|
||||
import React, { lazy, PropsWithChildren } from 'react';
|
||||
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
|
||||
import { Provider } from 'react-redux';
|
||||
import { PersistGate } from 'redux-persist/integration/react';
|
||||
import { store } from './app/store';
|
||||
import { buildMiddleware, store } from './app/store';
|
||||
import { persistor } from './persistor';
|
||||
import { OpenAPI } from 'services/api';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import '@fontsource/inter/100.css';
|
||||
import '@fontsource/inter/200.css';
|
||||
import '@fontsource/inter/300.css';
|
||||
@ -17,18 +19,61 @@ import Loading from './Loading';
|
||||
|
||||
// Localization
|
||||
import './i18n';
|
||||
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
|
||||
|
||||
const App = lazy(() => import('./app/App'));
|
||||
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
|
||||
|
||||
export default function Component(props: PropsWithChildren) {
|
||||
interface Props extends PropsWithChildren {
|
||||
apiUrl?: string;
|
||||
disabledPanels?: string[];
|
||||
disabledTabs?: InvokeTabName[];
|
||||
token?: string;
|
||||
shouldTransformUrls?: boolean;
|
||||
}
|
||||
|
||||
export default function Component({
|
||||
apiUrl,
|
||||
disabledPanels = [],
|
||||
disabledTabs = [],
|
||||
token,
|
||||
children,
|
||||
shouldTransformUrls,
|
||||
}: Props) {
|
||||
useEffect(() => {
|
||||
// configure API client token
|
||||
if (token) {
|
||||
OpenAPI.TOKEN = token;
|
||||
}
|
||||
|
||||
// configure API client base url
|
||||
if (apiUrl) {
|
||||
OpenAPI.BASE = apiUrl;
|
||||
}
|
||||
|
||||
// reset dynamically added middlewares
|
||||
resetMiddlewares();
|
||||
|
||||
// TODO: at this point, after resetting the middleware, we really ought to clean up the socket
|
||||
// stuff by calling `dispatch(socketReset())`. but we cannot dispatch from here as we are
|
||||
// outside the provider. it's not needed until there is the possibility that we will change
|
||||
// the `apiUrl`/`token` dynamically.
|
||||
|
||||
// rebuild socket middleware with token and apiUrl
|
||||
addMiddleware(buildMiddleware());
|
||||
}, [apiUrl, token]);
|
||||
|
||||
return (
|
||||
<React.StrictMode>
|
||||
<Provider store={store}>
|
||||
<PersistGate loading={<Loading />} persistor={persistor}>
|
||||
<React.Suspense fallback={<Loading showText />}>
|
||||
<ThemeLocaleProvider>
|
||||
<App>{props.children}</App>
|
||||
<App
|
||||
options={{ disabledPanels, disabledTabs, shouldTransformUrls }}
|
||||
>
|
||||
{children}
|
||||
</App>
|
||||
</ThemeLocaleProvider>
|
||||
</React.Suspense>
|
||||
</PersistGate>
|
||||
|
@ -5,6 +5,8 @@ import ThemeChanger from './features/system/components/ThemeChanger';
|
||||
import IAIPopover from './common/components/IAIPopover';
|
||||
import IAIIconButton from './common/components/IAIIconButton';
|
||||
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
|
||||
import StatusIndicator from './features/system/components/StatusIndicator';
|
||||
import ModelSelect from 'features/system/components/ModelSelect';
|
||||
|
||||
export default Component;
|
||||
export {
|
||||
@ -13,4 +15,6 @@ export {
|
||||
IAIPopover,
|
||||
IAIIconButton,
|
||||
SettingsModal,
|
||||
StatusIndicator,
|
||||
ModelSelect,
|
||||
};
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { ImageConfig } from 'konva/lib/shapes/Image';
|
||||
import { isEqual } from 'lodash';
|
||||
@ -25,7 +26,7 @@ type Props = Omit<ImageConfig, 'image'>;
|
||||
const IAICanvasIntermediateImage = (props: Props) => {
|
||||
const { ...rest } = props;
|
||||
const intermediateImage = useAppSelector(selector);
|
||||
|
||||
const { getUrl } = useGetUrl();
|
||||
const [loadedImageElement, setLoadedImageElement] =
|
||||
useState<HTMLImageElement | null>(null);
|
||||
|
||||
@ -36,8 +37,8 @@ const IAICanvasIntermediateImage = (props: Props) => {
|
||||
tempImage.onload = () => {
|
||||
setLoadedImageElement(tempImage);
|
||||
};
|
||||
tempImage.src = intermediateImage.url;
|
||||
}, [intermediateImage]);
|
||||
tempImage.src = getUrl(intermediateImage.url);
|
||||
}, [intermediateImage, getUrl]);
|
||||
|
||||
if (!intermediateImage?.boundingBox) return null;
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { rgbaColorToString } from 'features/canvas/util/colorToString';
|
||||
import { isEqual } from 'lodash';
|
||||
@ -32,6 +33,7 @@ const selector = createSelector(
|
||||
|
||||
const IAICanvasObjectRenderer = () => {
|
||||
const { objects } = useAppSelector(selector);
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
if (!objects) return null;
|
||||
|
||||
@ -40,7 +42,12 @@ const IAICanvasObjectRenderer = () => {
|
||||
{objects.map((obj, i) => {
|
||||
if (isCanvasBaseImage(obj)) {
|
||||
return (
|
||||
<IAICanvasImage key={i} x={obj.x} y={obj.y} url={obj.image.url} />
|
||||
<IAICanvasImage
|
||||
key={i}
|
||||
x={obj.x}
|
||||
y={obj.y}
|
||||
url={getUrl(obj.image.url)}
|
||||
/>
|
||||
);
|
||||
} else if (isCanvasBaseLine(obj)) {
|
||||
const line = (
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { GroupConfig } from 'konva/lib/Group';
|
||||
import { isEqual } from 'lodash';
|
||||
@ -53,11 +54,16 @@ const IAICanvasStagingArea = (props: Props) => {
|
||||
width,
|
||||
height,
|
||||
} = useAppSelector(selector);
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
return (
|
||||
<Group {...rest}>
|
||||
{shouldShowStagingImage && currentStagingAreaImage && (
|
||||
<IAICanvasImage url={currentStagingAreaImage.image.url} x={x} y={y} />
|
||||
<IAICanvasImage
|
||||
url={getUrl(currentStagingAreaImage.image.url)}
|
||||
x={x}
|
||||
y={y}
|
||||
/>
|
||||
)}
|
||||
{shouldShowStagingOutline && (
|
||||
<Group>
|
||||
|
@ -0,0 +1,14 @@
|
||||
import { CanvasState } from './canvasTypes';
|
||||
|
||||
/**
|
||||
* Canvas slice persist blacklist
|
||||
*/
|
||||
const itemsToBlacklist: (keyof CanvasState)[] = [
|
||||
'cursorPosition',
|
||||
'isCanvasInitialized',
|
||||
'doesCanvasNeedScaling',
|
||||
];
|
||||
|
||||
export const canvasBlacklist = itemsToBlacklist.map(
|
||||
(blacklistItem) => `canvas.${blacklistItem}`
|
||||
);
|
@ -156,7 +156,7 @@ export const canvasSlice = createSlice({
|
||||
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
|
||||
state.cursorPosition = action.payload;
|
||||
},
|
||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
||||
const image = action.payload;
|
||||
const { stageDimensions } = state;
|
||||
|
||||
@ -291,7 +291,7 @@ export const canvasSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
boundingBox: IRect;
|
||||
image: InvokeAI.Image;
|
||||
image: InvokeAI._Image;
|
||||
}>
|
||||
) => {
|
||||
const { boundingBox, image } = action.payload;
|
||||
|
@ -37,7 +37,7 @@ export type CanvasImage = {
|
||||
y: number;
|
||||
width: number;
|
||||
height: number;
|
||||
image: InvokeAI.Image;
|
||||
image: InvokeAI._Image;
|
||||
};
|
||||
|
||||
export type CanvasMaskLine = {
|
||||
@ -125,7 +125,7 @@ export interface CanvasState {
|
||||
cursorPosition: Vector2d | null;
|
||||
doesCanvasNeedScaling: boolean;
|
||||
futureLayerStates: CanvasLayerState[];
|
||||
intermediateImage?: InvokeAI.Image;
|
||||
intermediateImage?: InvokeAI._Image;
|
||||
isCanvasInitialized: boolean;
|
||||
isDrawing: boolean;
|
||||
isMaskEnabled: boolean;
|
||||
|
@ -105,7 +105,7 @@ export const mergeAndUploadCanvas =
|
||||
|
||||
const { url, width, height } = image;
|
||||
|
||||
const newImage: InvokeAI.Image = {
|
||||
const newImage: InvokeAI._Image = {
|
||||
uuid: uuidv4(),
|
||||
category: shouldSaveToGallery ? 'result' : 'user',
|
||||
...image,
|
||||
|
@ -14,8 +14,9 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
||||
import {
|
||||
initialImageSelected,
|
||||
setAllParameters,
|
||||
setInitialImage,
|
||||
// setInitialImage,
|
||||
setSeed,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
|
||||
@ -48,11 +49,15 @@ import {
|
||||
FaShareAlt,
|
||||
FaTrash,
|
||||
} from 'react-icons/fa';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import {
|
||||
gallerySelector,
|
||||
selectedImageSelector,
|
||||
} from '../store/gallerySelectors';
|
||||
import DeleteImageModal from './DeleteImageModal';
|
||||
import { useCallback } from 'react';
|
||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
|
||||
const currentImageButtonsSelector = createSelector(
|
||||
[
|
||||
@ -62,6 +67,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
uiSelector,
|
||||
lightboxSelector,
|
||||
activeTabNameSelector,
|
||||
selectedImageSelector,
|
||||
],
|
||||
(
|
||||
system: SystemState,
|
||||
@ -69,7 +75,8 @@ const currentImageButtonsSelector = createSelector(
|
||||
postprocessing,
|
||||
ui,
|
||||
lightbox,
|
||||
activeTabName
|
||||
activeTabName,
|
||||
selectedImage
|
||||
) => {
|
||||
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
|
||||
system;
|
||||
@ -95,6 +102,7 @@ const currentImageButtonsSelector = createSelector(
|
||||
activeTabName,
|
||||
isLightboxOpen,
|
||||
shouldHidePreview,
|
||||
selectedImage,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -121,27 +129,33 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
facetoolStrength,
|
||||
shouldDisableToolbarButtons,
|
||||
shouldShowImageDetails,
|
||||
currentImage,
|
||||
// currentImage,
|
||||
isLightboxOpen,
|
||||
activeTabName,
|
||||
shouldHidePreview,
|
||||
selectedImage,
|
||||
} = useAppSelector(currentImageButtonsSelector);
|
||||
const { getUrl, shouldTransformUrls } = useGetUrl();
|
||||
|
||||
const toast = useToast();
|
||||
const { t } = useTranslation();
|
||||
const setBothPrompts = useSetBothPrompts();
|
||||
|
||||
const handleClickUseAsInitialImage = () => {
|
||||
if (!currentImage) return;
|
||||
if (!selectedImage) return;
|
||||
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
||||
dispatch(setInitialImage(currentImage));
|
||||
dispatch(setActiveTab('img2img'));
|
||||
dispatch(initialImageSelected(selectedImage.name));
|
||||
// dispatch(setInitialImage(currentImage));
|
||||
|
||||
// dispatch(setActiveTab('img2img'));
|
||||
};
|
||||
|
||||
const handleCopyImage = async () => {
|
||||
if (!currentImage) return;
|
||||
if (!selectedImage) return;
|
||||
|
||||
const blob = await fetch(currentImage.url).then((res) => res.blob());
|
||||
const blob = await fetch(getUrl(selectedImage.url)).then((res) =>
|
||||
res.blob()
|
||||
);
|
||||
const data = [new ClipboardItem({ [blob.type]: blob })];
|
||||
|
||||
await navigator.clipboard.write(data);
|
||||
@ -155,24 +169,26 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
};
|
||||
|
||||
const handleCopyImageLink = () => {
|
||||
navigator.clipboard
|
||||
.writeText(
|
||||
currentImage ? window.location.toString() + currentImage.url : ''
|
||||
)
|
||||
.then(() => {
|
||||
toast({
|
||||
title: t('toast.imageLinkCopied'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
const url = selectedImage
|
||||
? shouldTransformUrls
|
||||
? getUrl(selectedImage.url)
|
||||
: window.location.toString() + selectedImage.url
|
||||
: '';
|
||||
|
||||
navigator.clipboard.writeText(url).then(() => {
|
||||
toast({
|
||||
title: t('toast.imageLinkCopied'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
'shift+i',
|
||||
() => {
|
||||
if (currentImage) {
|
||||
if (selectedImage) {
|
||||
handleClickUseAsInitialImage();
|
||||
toast({
|
||||
title: t('toast.sentToImageToImage'),
|
||||
@ -190,7 +206,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[currentImage]
|
||||
[selectedImage]
|
||||
);
|
||||
|
||||
const handlePreviewVisibility = () => {
|
||||
@ -198,20 +214,23 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
};
|
||||
|
||||
const handleClickUseAllParameters = () => {
|
||||
if (!currentImage) return;
|
||||
currentImage.metadata && dispatch(setAllParameters(currentImage.metadata));
|
||||
if (currentImage.metadata?.image.type === 'img2img') {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
} else if (currentImage.metadata?.image.type === 'txt2img') {
|
||||
dispatch(setActiveTab('txt2img'));
|
||||
}
|
||||
if (!selectedImage) return;
|
||||
// selectedImage.metadata &&
|
||||
// dispatch(setAllParameters(selectedImage.metadata));
|
||||
// if (selectedImage.metadata?.image.type === 'img2img') {
|
||||
// dispatch(setActiveTab('img2img'));
|
||||
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
|
||||
// dispatch(setActiveTab('txt2img'));
|
||||
// }
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
'a',
|
||||
() => {
|
||||
if (
|
||||
['txt2img', 'img2img'].includes(currentImage?.metadata?.image?.type)
|
||||
['txt2img', 'img2img'].includes(
|
||||
selectedImage?.metadata?.sd_metadata?.type
|
||||
)
|
||||
) {
|
||||
handleClickUseAllParameters();
|
||||
toast({
|
||||
@ -230,18 +249,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[currentImage]
|
||||
[selectedImage]
|
||||
);
|
||||
|
||||
const handleClickUseSeed = () => {
|
||||
currentImage?.metadata &&
|
||||
dispatch(setSeed(currentImage.metadata.image.seed));
|
||||
selectedImage?.metadata &&
|
||||
dispatch(setSeed(selectedImage.metadata.sd_metadata.seed));
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
's',
|
||||
() => {
|
||||
if (currentImage?.metadata?.image?.seed) {
|
||||
if (selectedImage?.metadata?.sd_metadata?.seed) {
|
||||
handleClickUseSeed();
|
||||
toast({
|
||||
title: t('toast.seedSet'),
|
||||
@ -259,19 +278,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[currentImage]
|
||||
[selectedImage]
|
||||
);
|
||||
|
||||
const handleClickUsePrompt = useCallback(() => {
|
||||
if (currentImage?.metadata?.image?.prompt) {
|
||||
setBothPrompts(currentImage?.metadata?.image?.prompt);
|
||||
if (selectedImage?.metadata?.sd_metadata?.prompt) {
|
||||
setBothPrompts(selectedImage?.metadata?.sd_metadata?.prompt);
|
||||
}
|
||||
}, [currentImage?.metadata?.image?.prompt, setBothPrompts]);
|
||||
}, [selectedImage?.metadata?.sd_metadata?.prompt, setBothPrompts]);
|
||||
|
||||
useHotkeys(
|
||||
'p',
|
||||
() => {
|
||||
if (currentImage?.metadata?.image?.prompt) {
|
||||
if (selectedImage?.metadata?.sd_metadata?.prompt) {
|
||||
handleClickUsePrompt();
|
||||
toast({
|
||||
title: t('toast.promptSet'),
|
||||
@ -289,11 +308,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[currentImage]
|
||||
[selectedImage]
|
||||
);
|
||||
|
||||
const handleClickUpscale = () => {
|
||||
currentImage && dispatch(runESRGAN(currentImage));
|
||||
// selectedImage && dispatch(runESRGAN(selectedImage));
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
@ -317,7 +336,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}
|
||||
},
|
||||
[
|
||||
currentImage,
|
||||
selectedImage,
|
||||
isESRGANAvailable,
|
||||
shouldDisableToolbarButtons,
|
||||
isConnected,
|
||||
@ -327,7 +346,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
);
|
||||
|
||||
const handleClickFixFaces = () => {
|
||||
currentImage && dispatch(runFacetool(currentImage));
|
||||
// selectedImage && dispatch(runFacetool(selectedImage));
|
||||
};
|
||||
|
||||
useHotkeys(
|
||||
@ -351,7 +370,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
}
|
||||
},
|
||||
[
|
||||
currentImage,
|
||||
selectedImage,
|
||||
isGFPGANAvailable,
|
||||
shouldDisableToolbarButtons,
|
||||
isConnected,
|
||||
@ -364,10 +383,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
dispatch(setShouldShowImageDetails(!shouldShowImageDetails));
|
||||
|
||||
const handleSendToCanvas = () => {
|
||||
if (!currentImage) return;
|
||||
if (!selectedImage) return;
|
||||
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
|
||||
|
||||
dispatch(setInitialCanvasImage(currentImage));
|
||||
// dispatch(setInitialCanvasImage(selectedImage));
|
||||
dispatch(requestCanvasRescale());
|
||||
|
||||
if (activeTabName !== 'unifiedCanvas') {
|
||||
@ -385,7 +404,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
useHotkeys(
|
||||
'i',
|
||||
() => {
|
||||
if (currentImage) {
|
||||
if (selectedImage) {
|
||||
handleClickShowImageDetails();
|
||||
} else {
|
||||
toast({
|
||||
@ -396,7 +415,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
});
|
||||
}
|
||||
},
|
||||
[currentImage, shouldShowImageDetails]
|
||||
[selectedImage, shouldShowImageDetails]
|
||||
);
|
||||
|
||||
const handleLightBox = () => {
|
||||
@ -458,7 +477,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
{t('parameters.copyImageToLink')}
|
||||
</IAIButton>
|
||||
|
||||
<Link download={true} href={currentImage?.url}>
|
||||
<Link download={true} href={getUrl(selectedImage!.url)}>
|
||||
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
|
||||
{t('parameters.downloadImage')}
|
||||
</IAIButton>
|
||||
@ -502,7 +521,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
icon={<FaQuoteRight />}
|
||||
tooltip={`${t('parameters.usePrompt')} (P)`}
|
||||
aria-label={`${t('parameters.usePrompt')} (P)`}
|
||||
isDisabled={!currentImage?.metadata?.image?.prompt}
|
||||
isDisabled={!selectedImage?.metadata?.sd_metadata?.prompt}
|
||||
onClick={handleClickUsePrompt}
|
||||
/>
|
||||
|
||||
@ -510,7 +529,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
icon={<FaSeedling />}
|
||||
tooltip={`${t('parameters.useSeed')} (S)`}
|
||||
aria-label={`${t('parameters.useSeed')} (S)`}
|
||||
isDisabled={!currentImage?.metadata?.image?.seed}
|
||||
isDisabled={!selectedImage?.metadata?.sd_metadata?.seed}
|
||||
onClick={handleClickUseSeed}
|
||||
/>
|
||||
|
||||
@ -520,7 +539,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
aria-label={`${t('parameters.useAll')} (A)`}
|
||||
isDisabled={
|
||||
!['txt2img', 'img2img'].includes(
|
||||
currentImage?.metadata?.image?.type
|
||||
selectedImage?.metadata?.sd_metadata?.type
|
||||
)
|
||||
}
|
||||
onClick={handleClickUseAllParameters}
|
||||
@ -546,7 +565,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
<IAIButton
|
||||
isDisabled={
|
||||
!isGFPGANAvailable ||
|
||||
!currentImage ||
|
||||
!selectedImage ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!facetoolStrength
|
||||
}
|
||||
@ -575,7 +594,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
<IAIButton
|
||||
isDisabled={
|
||||
!isESRGANAvailable ||
|
||||
!currentImage ||
|
||||
!selectedImage ||
|
||||
!(isConnected && !isProcessing) ||
|
||||
!upscalingLevel
|
||||
}
|
||||
@ -597,15 +616,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
<DeleteImageModal image={currentImage}>
|
||||
{/* <DeleteImageModal image={selectedImage}>
|
||||
<IAIIconButton
|
||||
icon={<FaTrash />}
|
||||
tooltip={`${t('parameters.deleteImage')} (Del)`}
|
||||
aria-label={`${t('parameters.deleteImage')} (Del)`}
|
||||
isDisabled={!currentImage || !isConnected || isProcessing}
|
||||
isDisabled={!selectedImage || !isConnected || isProcessing}
|
||||
colorScheme="error"
|
||||
/>
|
||||
</DeleteImageModal>
|
||||
</DeleteImageModal> */}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -4,17 +4,20 @@ import { useAppSelector } from 'app/storeHooks';
|
||||
import { isEqual } from 'lodash';
|
||||
|
||||
import { MdPhoto } from 'react-icons/md';
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import {
|
||||
gallerySelector,
|
||||
selectedImageSelector,
|
||||
} from '../store/gallerySelectors';
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import CurrentImagePreview from './CurrentImagePreview';
|
||||
|
||||
export const currentImageDisplaySelector = createSelector(
|
||||
[gallerySelector],
|
||||
(gallery) => {
|
||||
[gallerySelector, selectedImageSelector],
|
||||
(gallery, selectedImage) => {
|
||||
const { currentImage, intermediateImage } = gallery;
|
||||
|
||||
return {
|
||||
hasAnImageToDisplay: currentImage || intermediateImage,
|
||||
hasAnImageToDisplay: selectedImage || intermediateImage,
|
||||
};
|
||||
},
|
||||
{
|
||||
|
@ -1,28 +1,48 @@
|
||||
import { Box, Flex, Image } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { GalleryState } from 'features/gallery/store/gallerySlice';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import { systemSelector } from 'features/system/store/systemSelectors';
|
||||
import { uiSelector } from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash';
|
||||
import { ReactEventHandler } from 'react';
|
||||
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
|
||||
|
||||
import { gallerySelector } from '../store/gallerySelectors';
|
||||
import { selectedImageSelector } from '../store/gallerySelectors';
|
||||
import CurrentImageFallback from './CurrentImageFallback';
|
||||
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from './NextPrevImageButtons';
|
||||
import CurrentImageHidden from './CurrentImageHidden';
|
||||
|
||||
export const imagesSelector = createSelector(
|
||||
[gallerySelector, uiSelector],
|
||||
(gallery: GalleryState, ui) => {
|
||||
const { currentImage, intermediateImage } = gallery;
|
||||
[uiSelector, selectedImageSelector, systemSelector],
|
||||
(ui, selectedImage, system) => {
|
||||
const { shouldShowImageDetails, shouldHidePreview } = ui;
|
||||
const { progressImage } = system;
|
||||
|
||||
// TODO: Clean this up, this is really gross
|
||||
const imageToDisplay = progressImage
|
||||
? {
|
||||
url: progressImage.dataURL,
|
||||
width: progressImage.width,
|
||||
height: progressImage.height,
|
||||
isProgressImage: true,
|
||||
image: progressImage,
|
||||
}
|
||||
: selectedImage
|
||||
? {
|
||||
url: selectedImage.url,
|
||||
width: selectedImage.metadata.width,
|
||||
height: selectedImage.metadata.height,
|
||||
isProgressImage: false,
|
||||
image: selectedImage,
|
||||
}
|
||||
: null;
|
||||
|
||||
return {
|
||||
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
|
||||
isIntermediate: Boolean(intermediateImage),
|
||||
shouldShowImageDetails,
|
||||
shouldHidePreview,
|
||||
imageToDisplay,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -33,12 +53,9 @@ export const imagesSelector = createSelector(
|
||||
);
|
||||
|
||||
export default function CurrentImagePreview() {
|
||||
const {
|
||||
shouldShowImageDetails,
|
||||
imageToDisplay,
|
||||
isIntermediate,
|
||||
shouldHidePreview,
|
||||
} = useAppSelector(imagesSelector);
|
||||
const { shouldShowImageDetails, imageToDisplay, shouldHidePreview } =
|
||||
useAppSelector(imagesSelector);
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -52,13 +69,19 @@ export default function CurrentImagePreview() {
|
||||
>
|
||||
{imageToDisplay && (
|
||||
<Image
|
||||
src={shouldHidePreview ? undefined : imageToDisplay.url}
|
||||
src={
|
||||
shouldHidePreview
|
||||
? undefined
|
||||
: imageToDisplay.isProgressImage
|
||||
? imageToDisplay.url
|
||||
: getUrl(imageToDisplay.url)
|
||||
}
|
||||
width={imageToDisplay.width}
|
||||
height={imageToDisplay.height}
|
||||
fallback={
|
||||
shouldHidePreview ? (
|
||||
<CurrentImageHidden />
|
||||
) : !isIntermediate ? (
|
||||
) : !imageToDisplay.isProgressImage ? (
|
||||
<CurrentImageFallback />
|
||||
) : undefined
|
||||
}
|
||||
@ -68,27 +91,31 @@ export default function CurrentImagePreview() {
|
||||
maxHeight: '100%',
|
||||
height: 'auto',
|
||||
position: 'absolute',
|
||||
imageRendering: isIntermediate ? 'pixelated' : 'initial',
|
||||
imageRendering: imageToDisplay.isProgressImage
|
||||
? 'pixelated'
|
||||
: 'initial',
|
||||
borderRadius: 'base',
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{!shouldShowImageDetails && <NextPrevImageButtons />}
|
||||
{shouldShowImageDetails && imageToDisplay && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
borderRadius: 'base',
|
||||
overflow: 'scroll',
|
||||
maxHeight: APP_METADATA_HEIGHT,
|
||||
}}
|
||||
>
|
||||
<ImageMetadataViewer image={imageToDisplay} />
|
||||
</Box>
|
||||
)}
|
||||
{shouldShowImageDetails &&
|
||||
imageToDisplay &&
|
||||
'metadata' in imageToDisplay.image && (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
top: '0',
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
borderRadius: 'base',
|
||||
overflow: 'scroll',
|
||||
maxHeight: APP_METADATA_HEIGHT,
|
||||
}}
|
||||
>
|
||||
<ImageMetadataViewer image={imageToDisplay.image} />
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ interface DeleteImageModalProps {
|
||||
/**
|
||||
* The image to delete.
|
||||
*/
|
||||
image?: InvokeAI.Image;
|
||||
image?: InvokeAI._Image;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -9,11 +9,14 @@ import {
|
||||
useToast,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageSelected,
|
||||
setCurrentImage,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
initialImageSelected,
|
||||
setAllImageToImageParameters,
|
||||
setAllParameters,
|
||||
setInitialImage,
|
||||
setSeed,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { DragEvent, memo, useState } from 'react';
|
||||
@ -31,6 +34,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
|
||||
interface HoverableImageProps {
|
||||
image: InvokeAI.Image;
|
||||
@ -40,7 +44,7 @@ interface HoverableImageProps {
|
||||
const memoEqualityCheck = (
|
||||
prev: HoverableImageProps,
|
||||
next: HoverableImageProps
|
||||
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
|
||||
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
|
||||
|
||||
/**
|
||||
* Gallery image component with delete/use all/use seed buttons on hover.
|
||||
@ -55,7 +59,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
shouldUseSingleGalleryColumn,
|
||||
} = useAppSelector(hoverableImageSelector);
|
||||
const { image, isSelected } = props;
|
||||
const { url, thumbnail, uuid, metadata } = image;
|
||||
const { url, thumbnail, name, metadata } = image;
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
const [isHovered, setIsHovered] = useState<boolean>(false);
|
||||
|
||||
@ -69,10 +74,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
const handleMouseOut = () => setIsHovered(false);
|
||||
|
||||
const handleUsePrompt = () => {
|
||||
if (image.metadata?.image?.prompt) {
|
||||
setBothPrompts(image.metadata?.image?.prompt);
|
||||
if (image.metadata?.sd_metadata?.prompt) {
|
||||
setBothPrompts(image.metadata?.sd_metadata?.prompt);
|
||||
}
|
||||
|
||||
toast({
|
||||
title: t('toast.promptSet'),
|
||||
status: 'success',
|
||||
@ -82,7 +86,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseSeed = () => {
|
||||
image.metadata && dispatch(setSeed(image.metadata.image.seed));
|
||||
image.metadata.sd_metadata &&
|
||||
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
|
||||
toast({
|
||||
title: t('toast.seedSet'),
|
||||
status: 'success',
|
||||
@ -92,20 +97,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleSendToImageToImage = () => {
|
||||
dispatch(setInitialImage(image));
|
||||
if (activeTabName !== 'img2img') {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
}
|
||||
toast({
|
||||
title: t('toast.sentToImageToImage'),
|
||||
status: 'success',
|
||||
duration: 2500,
|
||||
isClosable: true,
|
||||
});
|
||||
dispatch(initialImageSelected(image.name));
|
||||
};
|
||||
|
||||
const handleSendToCanvas = () => {
|
||||
dispatch(setInitialCanvasImage(image));
|
||||
// dispatch(setInitialCanvasImage(image));
|
||||
|
||||
dispatch(resizeAndScaleCanvas());
|
||||
|
||||
@ -122,7 +118,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseAllParameters = () => {
|
||||
metadata && dispatch(setAllParameters(metadata));
|
||||
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
|
||||
toast({
|
||||
title: t('toast.parametersSet'),
|
||||
status: 'success',
|
||||
@ -132,11 +128,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
};
|
||||
|
||||
const handleUseInitialImage = async () => {
|
||||
if (metadata?.image?.init_image_path) {
|
||||
const response = await fetch(metadata.image.init_image_path);
|
||||
if (metadata.sd_metadata?.image?.init_image_path) {
|
||||
const response = await fetch(
|
||||
metadata.sd_metadata?.image?.init_image_path
|
||||
);
|
||||
if (response.ok) {
|
||||
dispatch(setActiveTab('img2img'));
|
||||
dispatch(setAllImageToImageParameters(metadata));
|
||||
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
|
||||
toast({
|
||||
title: t('toast.initialImageSet'),
|
||||
status: 'success',
|
||||
@ -155,16 +153,20 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
});
|
||||
};
|
||||
|
||||
const handleSelectImage = () => dispatch(setCurrentImage(image));
|
||||
const handleSelectImage = () => {
|
||||
dispatch(imageSelected(image.name));
|
||||
};
|
||||
|
||||
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
|
||||
e.dataTransfer.setData('invokeai/imageUuid', uuid);
|
||||
console.log('drag started');
|
||||
e.dataTransfer.setData('invokeai/imageName', image.name);
|
||||
e.dataTransfer.setData('invokeai/imageType', image.type);
|
||||
e.dataTransfer.effectAllowed = 'move';
|
||||
};
|
||||
|
||||
const handleLightBox = () => {
|
||||
dispatch(setCurrentImage(image));
|
||||
dispatch(setIsLightboxOpen(true));
|
||||
// dispatch(setCurrentImage(image));
|
||||
// dispatch(setIsLightboxOpen(true));
|
||||
};
|
||||
|
||||
return (
|
||||
@ -177,28 +179,30 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUsePrompt}
|
||||
isDisabled={image?.metadata?.image?.prompt === undefined}
|
||||
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
|
||||
>
|
||||
{t('parameters.usePrompt')}
|
||||
</MenuItem>
|
||||
|
||||
<MenuItem
|
||||
onClickCapture={handleUseSeed}
|
||||
isDisabled={image?.metadata?.image?.seed === undefined}
|
||||
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
|
||||
>
|
||||
{t('parameters.useSeed')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUseAllParameters}
|
||||
isDisabled={
|
||||
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
|
||||
!['txt2img', 'img2img'].includes(
|
||||
image?.metadata?.sd_metadata?.type
|
||||
)
|
||||
}
|
||||
>
|
||||
{t('parameters.useAll')}
|
||||
</MenuItem>
|
||||
<MenuItem
|
||||
onClickCapture={handleUseInitialImage}
|
||||
isDisabled={image?.metadata?.image?.type !== 'img2img'}
|
||||
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
|
||||
>
|
||||
{t('parameters.useInitImg')}
|
||||
</MenuItem>
|
||||
@ -209,9 +213,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
{t('parameters.sendToUnifiedCanvas')}
|
||||
</MenuItem>
|
||||
<MenuItem data-warning>
|
||||
<DeleteImageModal image={image}>
|
||||
{/* <DeleteImageModal image={image}>
|
||||
<p>{t('parameters.deleteImage')}</p>
|
||||
</DeleteImageModal>
|
||||
</DeleteImageModal> */}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
)}
|
||||
@ -219,7 +223,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
{(ref) => (
|
||||
<Box
|
||||
position="relative"
|
||||
key={uuid}
|
||||
key={name}
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
userSelect="none"
|
||||
@ -244,7 +248,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
|
||||
}
|
||||
rounded="md"
|
||||
src={thumbnail || url}
|
||||
src={getUrl(thumbnail || url)}
|
||||
loading="lazy"
|
||||
sx={{
|
||||
position: 'absolute',
|
||||
@ -290,7 +294,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
insetInlineEnd: 1,
|
||||
}}
|
||||
>
|
||||
<DeleteImageModal image={image}>
|
||||
{/* <DeleteImageModal image={image}>
|
||||
<IAIIconButton
|
||||
aria-label={t('parameters.deleteImage')}
|
||||
icon={<FaTrashAlt />}
|
||||
@ -298,7 +302,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
|
||||
fontSize={14}
|
||||
isDisabled={!mayDeleteImage}
|
||||
/>
|
||||
</DeleteImageModal>
|
||||
</DeleteImageModal> */}
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, Flex, Grid, Icon, Text } from '@chakra-ui/react';
|
||||
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
|
||||
import { requestImages } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
@ -25,9 +25,44 @@ import HoverableImage from './HoverableImage';
|
||||
|
||||
import Scrollable from 'features/ui/components/common/Scrollable';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
import {
|
||||
resultsAdapter,
|
||||
selectResultsAll,
|
||||
selectResultsTotal,
|
||||
} from '../store/resultsSlice';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
receivedUploadImagesPage,
|
||||
} from 'services/thunks/gallery';
|
||||
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
|
||||
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
|
||||
|
||||
const gallerySelector = createSelector(
|
||||
[
|
||||
(state: RootState) => state.uploads,
|
||||
(state: RootState) => state.results,
|
||||
(state: RootState) => state.gallery,
|
||||
],
|
||||
(uploads, results, gallery) => {
|
||||
const { currentCategory } = gallery;
|
||||
|
||||
return currentCategory === 'result'
|
||||
? {
|
||||
images: resultsAdapter.getSelectors().selectAll(results),
|
||||
isLoading: results.isLoading,
|
||||
areMoreImagesAvailable: results.page < results.pages - 1,
|
||||
}
|
||||
: {
|
||||
images: uploadsAdapter.getSelectors().selectAll(uploads),
|
||||
isLoading: uploads.isLoading,
|
||||
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
const ImageGalleryContent = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
@ -35,7 +70,7 @@ const ImageGalleryContent = () => {
|
||||
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
|
||||
|
||||
const {
|
||||
images,
|
||||
// images,
|
||||
currentCategory,
|
||||
currentImageUuid,
|
||||
shouldPinGallery,
|
||||
@ -43,12 +78,24 @@ const ImageGalleryContent = () => {
|
||||
galleryGridTemplateColumns,
|
||||
galleryImageObjectFit,
|
||||
shouldAutoSwitchToNewImages,
|
||||
areMoreImagesAvailable,
|
||||
// areMoreImagesAvailable,
|
||||
shouldUseSingleGalleryColumn,
|
||||
} = useAppSelector(imageGallerySelector);
|
||||
|
||||
const { images, areMoreImagesAvailable, isLoading } =
|
||||
useAppSelector(gallerySelector);
|
||||
|
||||
// const handleClickLoadMore = () => {
|
||||
// dispatch(requestImages(currentCategory));
|
||||
// };
|
||||
const handleClickLoadMore = () => {
|
||||
dispatch(requestImages(currentCategory));
|
||||
if (currentCategory === 'result') {
|
||||
dispatch(receivedResultImagesPage());
|
||||
}
|
||||
|
||||
if (currentCategory === 'user') {
|
||||
dispatch(receivedUploadImagesPage());
|
||||
}
|
||||
};
|
||||
|
||||
const handleChangeGalleryImageMinimumWidth = (v: number) => {
|
||||
@ -203,11 +250,11 @@ const ImageGalleryContent = () => {
|
||||
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
|
||||
>
|
||||
{images.map((image) => {
|
||||
const { uuid } = image;
|
||||
const isSelected = currentImageUuid === uuid;
|
||||
const { name } = image;
|
||||
const isSelected = currentImageUuid === name;
|
||||
return (
|
||||
<HoverableImage
|
||||
key={uuid}
|
||||
key={name}
|
||||
image={image}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
@ -217,6 +264,7 @@ const ImageGalleryContent = () => {
|
||||
<IAIButton
|
||||
onClick={handleClickLoadMore}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isLoading}
|
||||
flexShrink={0}
|
||||
>
|
||||
{areMoreImagesAvailable
|
||||
|
@ -33,12 +33,13 @@ const GALLERY_TAB_WIDTHS: Record<
|
||||
InvokeTabName,
|
||||
{ galleryMinWidth: number; galleryMaxWidth: number }
|
||||
> = {
|
||||
txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
linear: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
|
||||
nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
|
||||
};
|
||||
|
||||
const galleryPanelSelector = createSelector(
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import promptToString from 'common/util/promptToString';
|
||||
import { seedWeightsToString } from 'common/util/seedWeightPairs';
|
||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||
@ -18,7 +19,7 @@ import {
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
setImg2imgStrength,
|
||||
setInitialImage,
|
||||
// setInitialImage,
|
||||
setMaskPath,
|
||||
setPerlin,
|
||||
setSampler,
|
||||
@ -120,7 +121,7 @@ type ImageMetadataViewerProps = {
|
||||
const memoEqualityCheck = (
|
||||
prev: ImageMetadataViewerProps,
|
||||
next: ImageMetadataViewerProps
|
||||
) => prev.image.uuid === next.image.uuid;
|
||||
) => prev.image.name === next.image.name;
|
||||
|
||||
// TODO: Show more interesting information in this component.
|
||||
|
||||
@ -137,34 +138,13 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
dispatch(setShouldShowImageDetails(false));
|
||||
});
|
||||
|
||||
const metadata = image?.metadata?.image || {};
|
||||
const dreamPrompt = image?.dreamPrompt;
|
||||
|
||||
const {
|
||||
cfg_scale,
|
||||
fit,
|
||||
height,
|
||||
hires_fix,
|
||||
init_image_path,
|
||||
mask_image_path,
|
||||
orig_path,
|
||||
perlin,
|
||||
postprocessing,
|
||||
prompt,
|
||||
sampler,
|
||||
seamless,
|
||||
seed,
|
||||
steps,
|
||||
strength,
|
||||
threshold,
|
||||
type,
|
||||
variations,
|
||||
width,
|
||||
} = metadata;
|
||||
const sessionId = image.metadata.invokeai?.session_id;
|
||||
const node = image.metadata.invokeai?.node as Record<string, any>;
|
||||
|
||||
const { t } = useTranslation();
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
const metadataJSON = JSON.stringify(image.metadata, null, 2);
|
||||
const metadataJSON = JSON.stringify(image, null, 2);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@ -183,262 +163,134 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
>
|
||||
<Flex gap={2}>
|
||||
<Text fontWeight="semibold">File:</Text>
|
||||
<Link href={image.url} isExternal maxW="calc(100% - 3rem)">
|
||||
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
|
||||
{image.url.length > 64
|
||||
? image.url.substring(0, 64).concat('...')
|
||||
: image.url}
|
||||
<ExternalLinkIcon mx="2px" />
|
||||
</Link>
|
||||
</Flex>
|
||||
{Object.keys(metadata).length > 0 ? (
|
||||
{node && Object.keys(node).length > 0 ? (
|
||||
<>
|
||||
{type && <MetadataItem label="Generation type" value={type} />}
|
||||
{image.metadata?.model_weights && (
|
||||
<MetadataItem label="Model" value={image.metadata.model_weights} />
|
||||
{node.type && (
|
||||
<MetadataItem label="Invocation type" value={node.type} />
|
||||
)}
|
||||
{['esrgan', 'gfpgan'].includes(type) && (
|
||||
<MetadataItem label="Original image" value={orig_path} />
|
||||
)}
|
||||
{prompt && (
|
||||
{node.model && <MetadataItem label="Model" value={node.model} />}
|
||||
{node.prompt && (
|
||||
<MetadataItem
|
||||
label="Prompt"
|
||||
labelPosition="top"
|
||||
value={
|
||||
typeof prompt === 'string' ? prompt : promptToString(prompt)
|
||||
typeof node.prompt === 'string'
|
||||
? node.prompt
|
||||
: promptToString(node.prompt)
|
||||
}
|
||||
onClick={() => setBothPrompts(prompt)}
|
||||
onClick={() => setBothPrompts(node.prompt)}
|
||||
/>
|
||||
)}
|
||||
{seed !== undefined && (
|
||||
{node.seed !== undefined && (
|
||||
<MetadataItem
|
||||
label="Seed"
|
||||
value={seed}
|
||||
onClick={() => dispatch(setSeed(seed))}
|
||||
value={node.seed}
|
||||
onClick={() => dispatch(setSeed(Number(node.seed)))}
|
||||
/>
|
||||
)}
|
||||
{threshold !== undefined && (
|
||||
{node.threshold !== undefined && (
|
||||
<MetadataItem
|
||||
label="Noise Threshold"
|
||||
value={threshold}
|
||||
onClick={() => dispatch(setThreshold(threshold))}
|
||||
value={node.threshold}
|
||||
onClick={() => dispatch(setThreshold(Number(node.threshold)))}
|
||||
/>
|
||||
)}
|
||||
{perlin !== undefined && (
|
||||
{node.perlin !== undefined && (
|
||||
<MetadataItem
|
||||
label="Perlin Noise"
|
||||
value={perlin}
|
||||
onClick={() => dispatch(setPerlin(perlin))}
|
||||
value={node.perlin}
|
||||
onClick={() => dispatch(setPerlin(Number(node.perlin)))}
|
||||
/>
|
||||
)}
|
||||
{sampler && (
|
||||
{node.scheduler && (
|
||||
<MetadataItem
|
||||
label="Sampler"
|
||||
value={sampler}
|
||||
onClick={() => dispatch(setSampler(sampler))}
|
||||
value={node.scheduler}
|
||||
onClick={() => dispatch(setSampler(node.scheduler))}
|
||||
/>
|
||||
)}
|
||||
{steps && (
|
||||
{node.steps && (
|
||||
<MetadataItem
|
||||
label="Steps"
|
||||
value={steps}
|
||||
onClick={() => dispatch(setSteps(steps))}
|
||||
value={node.steps}
|
||||
onClick={() => dispatch(setSteps(Number(node.steps)))}
|
||||
/>
|
||||
)}
|
||||
{cfg_scale !== undefined && (
|
||||
{node.cfg_scale !== undefined && (
|
||||
<MetadataItem
|
||||
label="CFG scale"
|
||||
value={cfg_scale}
|
||||
onClick={() => dispatch(setCfgScale(cfg_scale))}
|
||||
value={node.cfg_scale}
|
||||
onClick={() => dispatch(setCfgScale(Number(node.cfg_scale)))}
|
||||
/>
|
||||
)}
|
||||
{variations && variations.length > 0 && (
|
||||
{node.variations && node.variations.length > 0 && (
|
||||
<MetadataItem
|
||||
label="Seed-weight pairs"
|
||||
value={seedWeightsToString(variations)}
|
||||
value={seedWeightsToString(node.variations)}
|
||||
onClick={() =>
|
||||
dispatch(setSeedWeights(seedWeightsToString(variations)))
|
||||
dispatch(setSeedWeights(seedWeightsToString(node.variations)))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{seamless && (
|
||||
{node.seamless && (
|
||||
<MetadataItem
|
||||
label="Seamless"
|
||||
value={seamless}
|
||||
onClick={() => dispatch(setSeamless(seamless))}
|
||||
value={node.seamless}
|
||||
onClick={() => dispatch(setSeamless(node.seamless))}
|
||||
/>
|
||||
)}
|
||||
{hires_fix && (
|
||||
{node.hires_fix && (
|
||||
<MetadataItem
|
||||
label="High Resolution Optimization"
|
||||
value={hires_fix}
|
||||
onClick={() => dispatch(setHiresFix(hires_fix))}
|
||||
value={node.hires_fix}
|
||||
onClick={() => dispatch(setHiresFix(node.hires_fix))}
|
||||
/>
|
||||
)}
|
||||
{width && (
|
||||
{node.width && (
|
||||
<MetadataItem
|
||||
label="Width"
|
||||
value={width}
|
||||
onClick={() => dispatch(setWidth(width))}
|
||||
value={node.width}
|
||||
onClick={() => dispatch(setWidth(Number(node.width)))}
|
||||
/>
|
||||
)}
|
||||
{height && (
|
||||
{node.height && (
|
||||
<MetadataItem
|
||||
label="Height"
|
||||
value={height}
|
||||
onClick={() => dispatch(setHeight(height))}
|
||||
value={node.height}
|
||||
onClick={() => dispatch(setHeight(Number(node.height)))}
|
||||
/>
|
||||
)}
|
||||
{init_image_path && (
|
||||
{/* {init_image_path && (
|
||||
<MetadataItem
|
||||
label="Initial image"
|
||||
value={init_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)}
|
||||
{mask_image_path && (
|
||||
<MetadataItem
|
||||
label="Mask image"
|
||||
value={mask_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setMaskPath(mask_image_path))}
|
||||
/>
|
||||
)}
|
||||
{type === 'img2img' && strength && (
|
||||
)} */}
|
||||
{node.strength && (
|
||||
<MetadataItem
|
||||
label="Image to image strength"
|
||||
value={strength}
|
||||
onClick={() => dispatch(setImg2imgStrength(strength))}
|
||||
value={node.strength}
|
||||
onClick={() =>
|
||||
dispatch(setImg2imgStrength(Number(node.strength)))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{fit && (
|
||||
{node.fit && (
|
||||
<MetadataItem
|
||||
label="Image to image fit"
|
||||
value={fit}
|
||||
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
|
||||
value={node.fit}
|
||||
onClick={() => dispatch(setShouldFitToWidthHeight(node.fit))}
|
||||
/>
|
||||
)}
|
||||
{postprocessing && postprocessing.length > 0 && (
|
||||
<>
|
||||
<Heading size="sm">Postprocessing</Heading>
|
||||
{postprocessing.map(
|
||||
(
|
||||
postprocess: InvokeAI.PostProcessedImageMetadata,
|
||||
i: number
|
||||
) => {
|
||||
if (postprocess.type === 'esrgan') {
|
||||
const { scale, strength, denoise_str } = postprocess;
|
||||
return (
|
||||
<Flex key={i} pl={8} gap={1} direction="column">
|
||||
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
|
||||
<MetadataItem
|
||||
label="Scale"
|
||||
value={scale}
|
||||
onClick={() => dispatch(setUpscalingLevel(scale))}
|
||||
/>
|
||||
<MetadataItem
|
||||
label="Strength"
|
||||
value={strength}
|
||||
onClick={() =>
|
||||
dispatch(setUpscalingStrength(strength))
|
||||
}
|
||||
/>
|
||||
{denoise_str !== undefined && (
|
||||
<MetadataItem
|
||||
label="Denoising strength"
|
||||
value={denoise_str}
|
||||
onClick={() =>
|
||||
dispatch(setUpscalingDenoising(denoise_str))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
} else if (postprocess.type === 'gfpgan') {
|
||||
const { strength } = postprocess;
|
||||
return (
|
||||
<Flex key={i} pl={8} gap={1} direction="column">
|
||||
<Text size="md">{`${
|
||||
i + 1
|
||||
}: Face restoration (GFPGAN)`}</Text>
|
||||
|
||||
<MetadataItem
|
||||
label="Strength"
|
||||
value={strength}
|
||||
onClick={() => {
|
||||
dispatch(setFacetoolStrength(strength));
|
||||
dispatch(setFacetoolType('gfpgan'));
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
} else if (postprocess.type === 'codeformer') {
|
||||
const { strength, fidelity } = postprocess;
|
||||
return (
|
||||
<Flex key={i} pl={8} gap={1} direction="column">
|
||||
<Text size="md">{`${
|
||||
i + 1
|
||||
}: Face restoration (Codeformer)`}</Text>
|
||||
|
||||
<MetadataItem
|
||||
label="Strength"
|
||||
value={strength}
|
||||
onClick={() => {
|
||||
dispatch(setFacetoolStrength(strength));
|
||||
dispatch(setFacetoolType('codeformer'));
|
||||
}}
|
||||
/>
|
||||
{fidelity && (
|
||||
<MetadataItem
|
||||
label="Fidelity"
|
||||
value={fidelity}
|
||||
onClick={() => {
|
||||
dispatch(setCodeformerFidelity(fidelity));
|
||||
dispatch(setFacetoolType('codeformer'));
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
}
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{dreamPrompt && (
|
||||
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
|
||||
)}
|
||||
<Flex gap={2} direction="column">
|
||||
<Flex gap={2}>
|
||||
<Tooltip label="Copy metadata JSON">
|
||||
<IconButton
|
||||
aria-label={t('accessibility.copyMetadataJson')}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||
</Flex>
|
||||
<Box
|
||||
sx={{
|
||||
mt: 0,
|
||||
mr: 2,
|
||||
mb: 4,
|
||||
ml: 2,
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
overflowX: 'scroll',
|
||||
wordBreak: 'break-all',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
}}
|
||||
>
|
||||
<pre>{metadataJSON}</pre>
|
||||
</Box>
|
||||
</Flex>
|
||||
</>
|
||||
) : (
|
||||
<Center width="100%" pt={10}>
|
||||
@ -447,6 +299,37 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
</Text>
|
||||
</Center>
|
||||
)}
|
||||
<Flex gap={2} direction="column">
|
||||
<Flex gap={2}>
|
||||
<Tooltip label="Copy metadata JSON">
|
||||
<IconButton
|
||||
aria-label={t('accessibility.copyMetadataJson')}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||
</Flex>
|
||||
<Box
|
||||
sx={{
|
||||
mt: 0,
|
||||
mr: 2,
|
||||
mb: 4,
|
||||
ml: 2,
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
overflowX: 'scroll',
|
||||
wordBreak: 'break-all',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
}}
|
||||
>
|
||||
<pre>{metadataJSON}</pre>
|
||||
</Box>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}, memoEqualityCheck);
|
||||
|
@ -0,0 +1,470 @@
|
||||
import { ExternalLinkIcon } from '@chakra-ui/icons';
|
||||
import {
|
||||
Box,
|
||||
Center,
|
||||
Flex,
|
||||
Heading,
|
||||
IconButton,
|
||||
Link,
|
||||
Text,
|
||||
Tooltip,
|
||||
} from '@chakra-ui/react';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import promptToString from 'common/util/promptToString';
|
||||
import { seedWeightsToString } from 'common/util/seedWeightPairs';
|
||||
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
|
||||
import {
|
||||
setCfgScale,
|
||||
setHeight,
|
||||
setImg2imgStrength,
|
||||
// setInitialImage,
|
||||
setMaskPath,
|
||||
setPerlin,
|
||||
setSampler,
|
||||
setSeamless,
|
||||
setSeed,
|
||||
setSeedWeights,
|
||||
setShouldFitToWidthHeight,
|
||||
setSteps,
|
||||
setThreshold,
|
||||
setWidth,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
setCodeformerFidelity,
|
||||
setFacetoolStrength,
|
||||
setFacetoolType,
|
||||
setHiresFix,
|
||||
setUpscalingDenoising,
|
||||
setUpscalingLevel,
|
||||
setUpscalingStrength,
|
||||
} from 'features/parameters/store/postprocessingSlice';
|
||||
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
|
||||
import { memo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaCopy } from 'react-icons/fa';
|
||||
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
|
||||
import * as png from '@stevebel/png';
|
||||
|
||||
type MetadataItemProps = {
|
||||
isLink?: boolean;
|
||||
label: string;
|
||||
onClick?: () => void;
|
||||
value: number | string | boolean;
|
||||
labelPosition?: string;
|
||||
withCopy?: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Component to display an individual metadata item or parameter.
|
||||
*/
|
||||
const MetadataItem = ({
|
||||
label,
|
||||
value,
|
||||
onClick,
|
||||
isLink,
|
||||
labelPosition,
|
||||
withCopy = false,
|
||||
}: MetadataItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex gap={2}>
|
||||
{onClick && (
|
||||
<Tooltip label={`Recall ${label}`}>
|
||||
<IconButton
|
||||
aria-label={t('accessibility.useThisParameter')}
|
||||
icon={<IoArrowUndoCircleOutline />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={20}
|
||||
onClick={onClick}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
{withCopy && (
|
||||
<Tooltip label={`Copy ${label}`}>
|
||||
<IconButton
|
||||
aria-label={`Copy ${label}`}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(value.toString())}
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
<Flex direction={labelPosition ? 'column' : 'row'}>
|
||||
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
|
||||
{label}:
|
||||
</Text>
|
||||
{isLink ? (
|
||||
<Link href={value.toString()} isExternal wordBreak="break-all">
|
||||
{value.toString()} <ExternalLinkIcon mx="2px" />
|
||||
</Link>
|
||||
) : (
|
||||
<Text overflowY="scroll" wordBreak="break-all">
|
||||
{value.toString()}
|
||||
</Text>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
type ImageMetadataViewerProps = {
|
||||
image: InvokeAI.Image;
|
||||
};
|
||||
|
||||
// TODO: I don't know if this is needed.
|
||||
const memoEqualityCheck = (
|
||||
prev: ImageMetadataViewerProps,
|
||||
next: ImageMetadataViewerProps
|
||||
) => prev.image.name === next.image.name;
|
||||
|
||||
// TODO: Show more interesting information in this component.
|
||||
|
||||
/**
|
||||
* Image metadata viewer overlays currently selected image and provides
|
||||
* access to any of its metadata for use in processing.
|
||||
*/
|
||||
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const setBothPrompts = useSetBothPrompts();
|
||||
|
||||
useHotkeys('esc', () => {
|
||||
dispatch(setShouldShowImageDetails(false));
|
||||
});
|
||||
|
||||
const metadata = image?.metadata.sd_metadata || {};
|
||||
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
|
||||
|
||||
const {
|
||||
cfg_scale,
|
||||
fit,
|
||||
height,
|
||||
hires_fix,
|
||||
init_image_path,
|
||||
mask_image_path,
|
||||
orig_path,
|
||||
perlin,
|
||||
postprocessing,
|
||||
prompt,
|
||||
sampler,
|
||||
seamless,
|
||||
seed,
|
||||
steps,
|
||||
strength,
|
||||
threshold,
|
||||
type,
|
||||
variations,
|
||||
width,
|
||||
model_weights,
|
||||
} = metadata;
|
||||
|
||||
const { t } = useTranslation();
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
const metadataJSON = JSON.stringify(image, null, 2);
|
||||
|
||||
// fetch(getUrl(image.url))
|
||||
// .then((r) => r.arrayBuffer())
|
||||
// .then((buffer) => {
|
||||
// const { text } = png.decode(buffer);
|
||||
// const metadata = text?.['sd-metadata']
|
||||
// ? JSON.parse(text['sd-metadata'] ?? {})
|
||||
// : {};
|
||||
// console.log(metadata);
|
||||
// });
|
||||
|
||||
return (
|
||||
<Flex
|
||||
sx={{
|
||||
padding: 4,
|
||||
gap: 1,
|
||||
flexDirection: 'column',
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
backdropFilter: 'blur(20px)',
|
||||
bg: 'whiteAlpha.600',
|
||||
_dark: {
|
||||
bg: 'blackAlpha.600',
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Flex gap={2}>
|
||||
<Text fontWeight="semibold">File:</Text>
|
||||
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
|
||||
{image.url.length > 64
|
||||
? image.url.substring(0, 64).concat('...')
|
||||
: image.url}
|
||||
<ExternalLinkIcon mx="2px" />
|
||||
</Link>
|
||||
</Flex>
|
||||
<Flex gap={2} direction="column">
|
||||
<Flex gap={2}>
|
||||
<Tooltip label="Copy metadata JSON">
|
||||
<IconButton
|
||||
aria-label={t('accessibility.copyMetadataJson')}
|
||||
icon={<FaCopy />}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
fontSize={14}
|
||||
onClick={() => navigator.clipboard.writeText(metadataJSON)}
|
||||
/>
|
||||
</Tooltip>
|
||||
<Text fontWeight="semibold">Metadata JSON:</Text>
|
||||
</Flex>
|
||||
<Box
|
||||
sx={{
|
||||
mt: 0,
|
||||
mr: 2,
|
||||
mb: 4,
|
||||
ml: 2,
|
||||
padding: 4,
|
||||
borderRadius: 'base',
|
||||
overflowX: 'scroll',
|
||||
wordBreak: 'break-all',
|
||||
bg: 'whiteAlpha.500',
|
||||
_dark: { bg: 'blackAlpha.500' },
|
||||
}}
|
||||
>
|
||||
<pre>{metadataJSON}</pre>
|
||||
</Box>
|
||||
</Flex>
|
||||
{Object.keys(metadata).length > 0 ? (
|
||||
<>
|
||||
{type && <MetadataItem label="Generation type" value={type} />}
|
||||
{model_weights && (
|
||||
<MetadataItem label="Model" value={model_weights} />
|
||||
)}
|
||||
{['esrgan', 'gfpgan'].includes(type) && (
|
||||
<MetadataItem label="Original image" value={orig_path} />
|
||||
)}
|
||||
{prompt && (
|
||||
<MetadataItem
|
||||
label="Prompt"
|
||||
labelPosition="top"
|
||||
value={
|
||||
typeof prompt === 'string' ? prompt : promptToString(prompt)
|
||||
}
|
||||
onClick={() => setBothPrompts(prompt)}
|
||||
/>
|
||||
)}
|
||||
{seed !== undefined && (
|
||||
<MetadataItem
|
||||
label="Seed"
|
||||
value={seed}
|
||||
onClick={() => dispatch(setSeed(seed))}
|
||||
/>
|
||||
)}
|
||||
{threshold !== undefined && (
|
||||
<MetadataItem
|
||||
label="Noise Threshold"
|
||||
value={threshold}
|
||||
onClick={() => dispatch(setThreshold(threshold))}
|
||||
/>
|
||||
)}
|
||||
{perlin !== undefined && (
|
||||
<MetadataItem
|
||||
label="Perlin Noise"
|
||||
value={perlin}
|
||||
onClick={() => dispatch(setPerlin(perlin))}
|
||||
/>
|
||||
)}
|
||||
{sampler && (
|
||||
<MetadataItem
|
||||
label="Sampler"
|
||||
value={sampler}
|
||||
onClick={() => dispatch(setSampler(sampler))}
|
||||
/>
|
||||
)}
|
||||
{steps && (
|
||||
<MetadataItem
|
||||
label="Steps"
|
||||
value={steps}
|
||||
onClick={() => dispatch(setSteps(steps))}
|
||||
/>
|
||||
)}
|
||||
{cfg_scale !== undefined && (
|
||||
<MetadataItem
|
||||
label="CFG scale"
|
||||
value={cfg_scale}
|
||||
onClick={() => dispatch(setCfgScale(cfg_scale))}
|
||||
/>
|
||||
)}
|
||||
{variations && variations.length > 0 && (
|
||||
<MetadataItem
|
||||
label="Seed-weight pairs"
|
||||
value={seedWeightsToString(variations)}
|
||||
onClick={() =>
|
||||
dispatch(setSeedWeights(seedWeightsToString(variations)))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{seamless && (
|
||||
<MetadataItem
|
||||
label="Seamless"
|
||||
value={seamless}
|
||||
onClick={() => dispatch(setSeamless(seamless))}
|
||||
/>
|
||||
)}
|
||||
{hires_fix && (
|
||||
<MetadataItem
|
||||
label="High Resolution Optimization"
|
||||
value={hires_fix}
|
||||
onClick={() => dispatch(setHiresFix(hires_fix))}
|
||||
/>
|
||||
)}
|
||||
{width && (
|
||||
<MetadataItem
|
||||
label="Width"
|
||||
value={width}
|
||||
onClick={() => dispatch(setWidth(width))}
|
||||
/>
|
||||
)}
|
||||
{height && (
|
||||
<MetadataItem
|
||||
label="Height"
|
||||
value={height}
|
||||
onClick={() => dispatch(setHeight(height))}
|
||||
/>
|
||||
)}
|
||||
{/* {init_image_path && (
|
||||
<MetadataItem
|
||||
label="Initial image"
|
||||
value={init_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setInitialImage(init_image_path))}
|
||||
/>
|
||||
)} */}
|
||||
{mask_image_path && (
|
||||
<MetadataItem
|
||||
label="Mask image"
|
||||
value={mask_image_path}
|
||||
isLink
|
||||
onClick={() => dispatch(setMaskPath(mask_image_path))}
|
||||
/>
|
||||
)}
|
||||
{type === 'img2img' && strength && (
|
||||
<MetadataItem
|
||||
label="Image to image strength"
|
||||
value={strength}
|
||||
onClick={() => dispatch(setImg2imgStrength(strength))}
|
||||
/>
|
||||
)}
|
||||
{fit && (
|
||||
<MetadataItem
|
||||
label="Image to image fit"
|
||||
value={fit}
|
||||
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
|
||||
/>
|
||||
)}
|
||||
{postprocessing && postprocessing.length > 0 && (
|
||||
<>
|
||||
<Heading size="sm">Postprocessing</Heading>
|
||||
{postprocessing.map(
|
||||
(
|
||||
postprocess: InvokeAI.PostProcessedImageMetadata,
|
||||
i: number
|
||||
) => {
|
||||
if (postprocess.type === 'esrgan') {
|
||||
const { scale, strength, denoise_str } = postprocess;
|
||||
return (
|
||||
<Flex key={i} pl={8} gap={1} direction="column">
|
||||
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
|
||||
<MetadataItem
|
||||
label="Scale"
|
||||
value={scale}
|
||||
onClick={() => dispatch(setUpscalingLevel(scale))}
|
||||
/>
|
||||
<MetadataItem
|
||||
label="Strength"
|
||||
value={strength}
|
||||
onClick={() =>
|
||||
dispatch(setUpscalingStrength(strength))
|
||||
}
|
||||
/>
|
||||
{denoise_str !== undefined && (
|
||||
<MetadataItem
|
||||
label="Denoising strength"
|
||||
value={denoise_str}
|
||||
onClick={() =>
|
||||
dispatch(setUpscalingDenoising(denoise_str))
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
} else if (postprocess.type === 'gfpgan') {
|
||||
const { strength } = postprocess;
|
||||
return (
|
||||
<Flex key={i} pl={8} gap={1} direction="column">
|
||||
<Text size="md">{`${
|
||||
i + 1
|
||||
}: Face restoration (GFPGAN)`}</Text>
|
||||
|
||||
<MetadataItem
|
||||
label="Strength"
|
||||
value={strength}
|
||||
onClick={() => {
|
||||
dispatch(setFacetoolStrength(strength));
|
||||
dispatch(setFacetoolType('gfpgan'));
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
} else if (postprocess.type === 'codeformer') {
|
||||
const { strength, fidelity } = postprocess;
|
||||
return (
|
||||
<Flex key={i} pl={8} gap={1} direction="column">
|
||||
<Text size="md">{`${
|
||||
i + 1
|
||||
}: Face restoration (Codeformer)`}</Text>
|
||||
|
||||
<MetadataItem
|
||||
label="Strength"
|
||||
value={strength}
|
||||
onClick={() => {
|
||||
dispatch(setFacetoolStrength(strength));
|
||||
dispatch(setFacetoolType('codeformer'));
|
||||
}}
|
||||
/>
|
||||
{fidelity && (
|
||||
<MetadataItem
|
||||
label="Fidelity"
|
||||
value={fidelity}
|
||||
onClick={() => {
|
||||
dispatch(setCodeformerFidelity(fidelity));
|
||||
dispatch(setFacetoolType('codeformer'));
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
}
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{dreamPrompt && (
|
||||
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<Center width="100%" pt={10}>
|
||||
<Text fontSize="lg" fontWeight="semibold">
|
||||
No metadata available
|
||||
</Text>
|
||||
</Center>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}, memoEqualityCheck);
|
||||
|
||||
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
|
||||
|
||||
export default ImageMetadataViewer;
|
@ -0,0 +1,35 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { ImageType } from 'services/api';
|
||||
import { selectResultsEntities } from '../store/resultsSlice';
|
||||
import { selectUploadsEntities } from '../store/uploadsSlice';
|
||||
|
||||
const useGetImageByNameSelector = createSelector(
|
||||
[selectResultsEntities, selectUploadsEntities],
|
||||
(allResults, allUploads) => {
|
||||
return { allResults, allUploads };
|
||||
}
|
||||
);
|
||||
|
||||
const useGetImageByNameAndType = () => {
|
||||
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
|
||||
|
||||
return (name: string, type: ImageType) => {
|
||||
if (type === 'results') {
|
||||
const resultImagesResult = allResults[name];
|
||||
|
||||
if (resultImagesResult) {
|
||||
return resultImagesResult;
|
||||
}
|
||||
}
|
||||
|
||||
if (type === 'uploads') {
|
||||
const userImagesResult = allUploads[name];
|
||||
if (userImagesResult) {
|
||||
return userImagesResult;
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
export default useGetImageByNameAndType;
|
@ -0,0 +1,17 @@
|
||||
import { GalleryState } from './gallerySlice';
|
||||
|
||||
/**
|
||||
* Gallery slice persist blacklist
|
||||
*/
|
||||
const itemsToBlacklist: (keyof GalleryState)[] = [
|
||||
'categories',
|
||||
'currentCategory',
|
||||
'currentImage',
|
||||
'currentImageUuid',
|
||||
'shouldAutoSwitchToNewImages',
|
||||
'intermediateImage',
|
||||
];
|
||||
|
||||
export const galleryBlacklist = itemsToBlacklist.map(
|
||||
(blacklistItem) => `gallery.${blacklistItem}`
|
||||
);
|
@ -7,6 +7,16 @@ import {
|
||||
uiSelector,
|
||||
} from 'features/ui/store/uiSelectors';
|
||||
import { isEqual } from 'lodash';
|
||||
import {
|
||||
selectResultsAll,
|
||||
selectResultsById,
|
||||
selectResultsEntities,
|
||||
} from './resultsSlice';
|
||||
import {
|
||||
selectUploadsAll,
|
||||
selectUploadsById,
|
||||
selectUploadsEntities,
|
||||
} from './uploadsSlice';
|
||||
|
||||
export const gallerySelector = (state: RootState) => state.gallery;
|
||||
|
||||
@ -75,3 +85,18 @@ export const hoverableImageSelector = createSelector(
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const selectedImageSelector = createSelector(
|
||||
[gallerySelector, selectResultsEntities, selectUploadsEntities],
|
||||
(gallery, allResults, allUploads) => {
|
||||
const selectedImageName = gallery.selectedImageName;
|
||||
|
||||
if (selectedImageName in allResults) {
|
||||
return allResults[selectedImageName];
|
||||
}
|
||||
|
||||
if (selectedImageName in allUploads) {
|
||||
return allUploads[selectedImageName];
|
||||
}
|
||||
}
|
||||
);
|
||||
|
@ -1,14 +1,17 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { IRect } from 'konva/lib/types';
|
||||
import { clamp } from 'lodash';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
|
||||
export type GalleryCategory = 'user' | 'result';
|
||||
|
||||
export type AddImagesPayload = {
|
||||
images: Array<InvokeAI.Image>;
|
||||
images: Array<InvokeAI._Image>;
|
||||
areMoreImagesAvailable: boolean;
|
||||
category: GalleryCategory;
|
||||
};
|
||||
@ -16,16 +19,33 @@ export type AddImagesPayload = {
|
||||
type GalleryImageObjectFitType = 'contain' | 'cover';
|
||||
|
||||
export type Gallery = {
|
||||
images: InvokeAI.Image[];
|
||||
images: InvokeAI._Image[];
|
||||
latest_mtime?: number;
|
||||
earliest_mtime?: number;
|
||||
areMoreImagesAvailable: boolean;
|
||||
};
|
||||
|
||||
export interface GalleryState {
|
||||
currentImage?: InvokeAI.Image;
|
||||
/**
|
||||
* The selected image's unique name
|
||||
* Use `selectedImageSelector` to access the image
|
||||
*/
|
||||
selectedImageName: string;
|
||||
/**
|
||||
* The currently selected image
|
||||
* @deprecated See `state.gallery.selectedImageName`
|
||||
*/
|
||||
currentImage?: InvokeAI._Image;
|
||||
/**
|
||||
* The currently selected image's uuid.
|
||||
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
|
||||
*/
|
||||
currentImageUuid: string;
|
||||
intermediateImage?: InvokeAI.Image & {
|
||||
/**
|
||||
* The current progress image
|
||||
* @deprecated See `state.system.progressImage`
|
||||
*/
|
||||
intermediateImage?: InvokeAI._Image & {
|
||||
boundingBox?: IRect;
|
||||
generationMode?: InvokeTabName;
|
||||
};
|
||||
@ -42,6 +62,7 @@ export interface GalleryState {
|
||||
}
|
||||
|
||||
const initialState: GalleryState = {
|
||||
selectedImageName: '',
|
||||
currentImageUuid: '',
|
||||
galleryImageMinimumWidth: 64,
|
||||
galleryImageObjectFit: 'cover',
|
||||
@ -69,7 +90,10 @@ export const gallerySlice = createSlice({
|
||||
name: 'gallery',
|
||||
initialState,
|
||||
reducers: {
|
||||
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
|
||||
imageSelected: (state, action: PayloadAction<string>) => {
|
||||
state.selectedImageName = action.payload;
|
||||
},
|
||||
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
|
||||
state.currentImage = action.payload;
|
||||
state.currentImageUuid = action.payload.uuid;
|
||||
},
|
||||
@ -124,7 +148,7 @@ export const gallerySlice = createSlice({
|
||||
addImage: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
image: InvokeAI.Image;
|
||||
image: InvokeAI._Image;
|
||||
category: GalleryCategory;
|
||||
}>
|
||||
) => {
|
||||
@ -150,7 +174,10 @@ export const gallerySlice = createSlice({
|
||||
setIntermediateImage: (
|
||||
state,
|
||||
action: PayloadAction<
|
||||
InvokeAI.Image & { boundingBox?: IRect; generationMode?: InvokeTabName }
|
||||
InvokeAI._Image & {
|
||||
boundingBox?: IRect;
|
||||
generationMode?: InvokeTabName;
|
||||
}
|
||||
>
|
||||
) => {
|
||||
state.intermediateImage = action.payload;
|
||||
@ -252,9 +279,31 @@ export const gallerySlice = createSlice({
|
||||
state.shouldUseSingleGalleryColumn = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data } = action.payload;
|
||||
if (isImageOutput(data.result)) {
|
||||
state.selectedImageName = data.result.image.image_name;
|
||||
state.intermediateImage = undefined;
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Upload Image - FULFILLED
|
||||
*/
|
||||
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
||||
const { location } = action.payload;
|
||||
const imageName = location.split('/').pop() || '';
|
||||
state.selectedImageName = imageName;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
imageSelected,
|
||||
addImage,
|
||||
clearIntermediateImage,
|
||||
removeImage,
|
||||
|
@ -0,0 +1,12 @@
|
||||
import { ResultsState } from './resultsSlice';
|
||||
|
||||
/**
|
||||
* Results slice persist blacklist
|
||||
*
|
||||
* Currently blacklisting results slice entirely, see persist config in store.ts
|
||||
*/
|
||||
const itemsToBlacklist: (keyof ResultsState)[] = [];
|
||||
|
||||
export const resultsBlacklist = itemsToBlacklist.map(
|
||||
(blacklistItem) => `results.${blacklistItem}`
|
||||
);
|
139
invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts
Normal file
139
invokeai/frontend/web/src/features/gallery/store/resultsSlice.ts
Normal file
@ -0,0 +1,139 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { Image } from 'app/invokeai';
|
||||
import { invocationComplete } from 'services/events/actions';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
receivedResultImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { isImageOutput } from 'services/types/guards';
|
||||
import {
|
||||
buildImageUrls,
|
||||
extractTimestampFromImageName,
|
||||
} from 'services/util/deserializeImageField';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
|
||||
// use `createEntityAdapter` to create a slice for results images
|
||||
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
|
||||
|
||||
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
|
||||
export const resultsAdapter = createEntityAdapter<Image>({
|
||||
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
|
||||
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
|
||||
selectId: (image) => image.name,
|
||||
// Order all images by their time (in descending order)
|
||||
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
|
||||
});
|
||||
|
||||
// This type is intersected with the Entity type to create the shape of the state
|
||||
type AdditionalResultsState = {
|
||||
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
|
||||
// to list all images directly at this time...
|
||||
page: number; // current page we are on
|
||||
pages: number; // the total number of pages available
|
||||
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
|
||||
nextPage: number; // the next page to request
|
||||
};
|
||||
|
||||
export const initialResultsState =
|
||||
resultsAdapter.getInitialState<AdditionalResultsState>({
|
||||
// provide the additional initial state
|
||||
page: 0,
|
||||
pages: 0,
|
||||
isLoading: false,
|
||||
nextPage: 0,
|
||||
});
|
||||
|
||||
export type ResultsState = typeof initialResultsState;
|
||||
|
||||
const resultsSlice = createSlice({
|
||||
name: 'results',
|
||||
initialState: initialResultsState,
|
||||
reducers: {
|
||||
// the adapter provides some helper reducers; see the docs for all of them
|
||||
// can use them as helper functions within a reducer, or use the function itself as a reducer
|
||||
|
||||
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
|
||||
// to add a single result
|
||||
resultAdded: resultsAdapter.upsertOne,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
|
||||
// because we pass in the fulfilled thunk action creator, everything is typed
|
||||
|
||||
/**
|
||||
* Received Result Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Result Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
|
||||
const { items, page, pages } = action.payload;
|
||||
|
||||
const resultImages = items.map((image) =>
|
||||
deserializeImageResponse(image)
|
||||
);
|
||||
|
||||
// use the adapter reducer to append all the results to state
|
||||
resultsAdapter.addMany(state, resultImages);
|
||||
|
||||
state.page = page;
|
||||
state.pages = pages;
|
||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||
state.isLoading = false;
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation Complete
|
||||
*/
|
||||
builder.addCase(invocationComplete, (state, action) => {
|
||||
const { data } = action.payload;
|
||||
const { result, node, graph_execution_state_id } = data;
|
||||
|
||||
if (isImageOutput(result)) {
|
||||
const name = result.image.image_name;
|
||||
const type = result.image.image_type;
|
||||
const { url, thumbnail } = buildImageUrls(type, name);
|
||||
|
||||
const timestamp = extractTimestampFromImageName(name);
|
||||
|
||||
const image: Image = {
|
||||
name,
|
||||
type,
|
||||
url,
|
||||
thumbnail,
|
||||
metadata: {
|
||||
created: timestamp,
|
||||
width: result.width, // TODO: add tese dimensions
|
||||
height: result.height,
|
||||
invokeai: {
|
||||
session_id: graph_execution_state_id,
|
||||
...(node ? { node } : {}),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
resultsAdapter.addOne(state, image);
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
// Create a set of memoized selectors based on the location of this entity state
|
||||
// to be used as selectors in a `useAppSelector()` call
|
||||
export const {
|
||||
selectAll: selectResultsAll,
|
||||
selectById: selectResultsById,
|
||||
selectEntities: selectResultsEntities,
|
||||
selectIds: selectResultsIds,
|
||||
selectTotal: selectResultsTotal,
|
||||
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
|
||||
|
||||
export const { resultAdded } = resultsSlice.actions;
|
||||
|
||||
export default resultsSlice.reducer;
|
@ -1,54 +0,0 @@
|
||||
import { AnyAction, ThunkAction } from '@reduxjs/toolkit';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { RootState } from 'app/store';
|
||||
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
|
||||
import { setInitialImage } from 'features/parameters/store/generationSlice';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { addImage } from '../gallerySlice';
|
||||
|
||||
type UploadImageConfig = {
|
||||
imageFile: File;
|
||||
};
|
||||
|
||||
export const uploadImage =
|
||||
(
|
||||
config: UploadImageConfig
|
||||
): ThunkAction<void, RootState, unknown, AnyAction> =>
|
||||
async (dispatch, getState) => {
|
||||
const { imageFile } = config;
|
||||
|
||||
const state = getState() as RootState;
|
||||
|
||||
const activeTabName = activeTabNameSelector(state);
|
||||
|
||||
const formData = new FormData();
|
||||
|
||||
formData.append('file', imageFile, imageFile.name);
|
||||
formData.append(
|
||||
'data',
|
||||
JSON.stringify({
|
||||
kind: 'init',
|
||||
})
|
||||
);
|
||||
|
||||
const response = await fetch(`${window.location.origin}/upload`, {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
});
|
||||
|
||||
const image = (await response.json()) as InvokeAI.ImageUploadResponse;
|
||||
const newImage: InvokeAI.Image = {
|
||||
uuid: uuidv4(),
|
||||
category: 'user',
|
||||
...image,
|
||||
};
|
||||
|
||||
dispatch(addImage({ image: newImage, category: 'user' }));
|
||||
|
||||
if (activeTabName === 'unifiedCanvas') {
|
||||
dispatch(setInitialCanvasImage(newImage));
|
||||
} else if (activeTabName === 'img2img') {
|
||||
dispatch(setInitialImage(newImage));
|
||||
}
|
||||
};
|
@ -0,0 +1,12 @@
|
||||
import { UploadsState } from './uploadsSlice';
|
||||
|
||||
/**
|
||||
* Uploads slice persist blacklist
|
||||
*
|
||||
* Currently blacklisting uploads slice entirely, see persist config in store.ts
|
||||
*/
|
||||
const itemsToBlacklist: (keyof UploadsState)[] = [];
|
||||
|
||||
export const uploadsBlacklist = itemsToBlacklist.map(
|
||||
(blacklistItem) => `uploads.${blacklistItem}`
|
||||
);
|
@ -0,0 +1,87 @@
|
||||
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
|
||||
import { Image } from 'app/invokeai';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
receivedUploadImagesPage,
|
||||
IMAGES_PER_PAGE,
|
||||
} from 'services/thunks/gallery';
|
||||
import { imageUploaded } from 'services/thunks/image';
|
||||
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
|
||||
|
||||
export const uploadsAdapter = createEntityAdapter<Image>({
|
||||
selectId: (image) => image.name,
|
||||
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
|
||||
});
|
||||
|
||||
type AdditionalUploadsState = {
|
||||
page: number;
|
||||
pages: number;
|
||||
isLoading: boolean;
|
||||
nextPage: number;
|
||||
};
|
||||
|
||||
const initialUploadsState =
|
||||
uploadsAdapter.getInitialState<AdditionalUploadsState>({
|
||||
page: 0,
|
||||
pages: 0,
|
||||
nextPage: 0,
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
export type UploadsState = typeof initialUploadsState;
|
||||
|
||||
const uploadsSlice = createSlice({
|
||||
name: 'uploads',
|
||||
initialState: initialUploadsState,
|
||||
reducers: {
|
||||
uploadAdded: uploadsAdapter.addOne,
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
/**
|
||||
* Received Upload Images Page - PENDING
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.pending, (state) => {
|
||||
state.isLoading = true;
|
||||
});
|
||||
|
||||
/**
|
||||
* Received Upload Images Page - FULFILLED
|
||||
*/
|
||||
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
|
||||
const { items, page, pages } = action.payload;
|
||||
|
||||
const images = items.map((image) => deserializeImageResponse(image));
|
||||
|
||||
uploadsAdapter.addMany(state, images);
|
||||
|
||||
state.page = page;
|
||||
state.pages = pages;
|
||||
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
|
||||
state.isLoading = false;
|
||||
});
|
||||
|
||||
/**
|
||||
* Upload Image - FULFILLED
|
||||
*/
|
||||
builder.addCase(imageUploaded.fulfilled, (state, action) => {
|
||||
const { location, response } = action.payload;
|
||||
|
||||
const uploadedImage = deserializeImageResponse(response);
|
||||
|
||||
uploadsAdapter.addOne(state, uploadedImage);
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
selectAll: selectUploadsAll,
|
||||
selectById: selectUploadsById,
|
||||
selectEntities: selectUploadsEntities,
|
||||
selectIds: selectUploadsIds,
|
||||
selectTotal: selectUploadsTotal,
|
||||
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
|
||||
|
||||
export const { uploadAdded } = uploadsSlice.actions;
|
||||
|
||||
export default uploadsSlice.reducer;
|
@ -1,9 +1,10 @@
|
||||
import * as React from 'react';
|
||||
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
|
||||
type ReactPanZoomProps = {
|
||||
image: InvokeAI.Image;
|
||||
image: InvokeAI._Image;
|
||||
styleClass?: string;
|
||||
alt?: string;
|
||||
ref?: React.Ref<HTMLImageElement>;
|
||||
@ -22,6 +23,7 @@ export default function ReactPanZoomImage({
|
||||
scaleY,
|
||||
}: ReactPanZoomProps) {
|
||||
const { centerView } = useTransformContext();
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
return (
|
||||
<TransformComponent
|
||||
@ -35,7 +37,7 @@ export default function ReactPanZoomImage({
|
||||
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
|
||||
width: '100%',
|
||||
}}
|
||||
src={image.url}
|
||||
src={getUrl(image.url)}
|
||||
alt={alt}
|
||||
ref={ref}
|
||||
className={styleClass ? styleClass : ''}
|
||||
|
@ -0,0 +1,10 @@
|
||||
import { LightboxState } from './lightboxSlice';
|
||||
|
||||
/**
|
||||
* Lightbox slice persist blacklist
|
||||
*/
|
||||
const itemsToBlacklist: (keyof LightboxState)[] = ['isLightboxOpen'];
|
||||
|
||||
export const lightboxBlacklist = itemsToBlacklist.map(
|
||||
(blacklistItem) => `lightbox.${blacklistItem}`
|
||||
);
|
@ -0,0 +1,63 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import 'reactflow/dist/style.css';
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
Tooltip,
|
||||
Menu,
|
||||
MenuButton,
|
||||
MenuList,
|
||||
MenuItem,
|
||||
IconButton,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaPlus } from 'react-icons/fa';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { nodeAdded } from '../store/nodesSlice';
|
||||
import { cloneDeep, map } from 'lodash';
|
||||
import { RootState } from 'app/store';
|
||||
import { useBuildInvocation } from '../hooks/useBuildInvocation';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/hooks/useToastWatcher';
|
||||
|
||||
export const AddNodeMenu = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const invocationTemplates = useAppSelector(
|
||||
(state: RootState) => state.nodes.invocationTemplates
|
||||
);
|
||||
|
||||
const buildInvocation = useBuildInvocation();
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string) => {
|
||||
const invocation = buildInvocation(nodeType);
|
||||
|
||||
if (!invocation) {
|
||||
const toast = makeToast({
|
||||
status: 'error',
|
||||
title: `Unknown Invocation type ${nodeType}`,
|
||||
});
|
||||
dispatch(addToast(toast));
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(nodeAdded(invocation));
|
||||
},
|
||||
[dispatch, buildInvocation]
|
||||
);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
|
||||
<MenuList>
|
||||
{map(invocationTemplates, ({ title, description, type }, key) => {
|
||||
return (
|
||||
<Tooltip key={key} label={description} placement="end" hasArrow>
|
||||
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</MenuList>
|
||||
</Menu>
|
||||
);
|
||||
};
|
@ -0,0 +1,69 @@
|
||||
import { Tooltip } from '@chakra-ui/react';
|
||||
import { CSSProperties, useMemo } from 'react';
|
||||
import {
|
||||
Handle,
|
||||
Position,
|
||||
Connection,
|
||||
HandleType,
|
||||
useReactFlow,
|
||||
} from 'reactflow';
|
||||
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
|
||||
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
|
||||
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
|
||||
|
||||
const handleBaseStyles: CSSProperties = {
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
borderWidth: 0,
|
||||
};
|
||||
|
||||
const inputHandleStyles: CSSProperties = {
|
||||
left: '-1.7rem',
|
||||
};
|
||||
|
||||
const outputHandleStyles: CSSProperties = {
|
||||
right: '-1.7rem',
|
||||
};
|
||||
|
||||
const requiredConnectionStyles: CSSProperties = {
|
||||
boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
|
||||
};
|
||||
|
||||
type FieldHandleProps = {
|
||||
nodeId: string;
|
||||
field: InputFieldTemplate | OutputFieldTemplate;
|
||||
isValidConnection: (connection: Connection) => boolean;
|
||||
handleType: HandleType;
|
||||
styles?: CSSProperties;
|
||||
};
|
||||
|
||||
export const FieldHandle = (props: FieldHandleProps) => {
|
||||
const { nodeId, field, isValidConnection, handleType, styles } = props;
|
||||
const { name, title, type, description } = field;
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
key={name}
|
||||
label={type}
|
||||
placement={handleType === 'target' ? 'start' : 'end'}
|
||||
hasArrow
|
||||
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
|
||||
>
|
||||
<Handle
|
||||
type={handleType}
|
||||
id={name}
|
||||
isValidConnection={isValidConnection}
|
||||
position={handleType === 'target' ? Position.Left : Position.Right}
|
||||
style={{
|
||||
backgroundColor: FIELDS[type].colorCssVar,
|
||||
...styles,
|
||||
...handleBaseStyles,
|
||||
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
|
||||
// ...(inputRequirement === 'always' ? requiredConnectionStyles : {}),
|
||||
// ...connectionEventStyles,
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
@ -0,0 +1,18 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
import { Tooltip, Badge, HStack } from '@chakra-ui/react';
|
||||
import { map } from 'lodash';
|
||||
import { FIELDS } from '../types/constants';
|
||||
|
||||
export const FieldTypeLegend = () => {
|
||||
return (
|
||||
<HStack>
|
||||
{map(FIELDS, ({ title, description, color }, key) => (
|
||||
<Tooltip key={key} label={description}>
|
||||
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
|
||||
{title}
|
||||
</Badge>
|
||||
</Tooltip>
|
||||
))}
|
||||
</HStack>
|
||||
);
|
||||
};
|
104
invokeai/frontend/web/src/features/nodes/components/Flow.tsx
Normal file
104
invokeai/frontend/web/src/features/nodes/components/Flow.tsx
Normal file
@ -0,0 +1,104 @@
|
||||
import {
|
||||
Background,
|
||||
Controls,
|
||||
MiniMap,
|
||||
OnConnect,
|
||||
OnEdgesChange,
|
||||
OnNodesChange,
|
||||
ReactFlow,
|
||||
ConnectionLineType,
|
||||
OnConnectStart,
|
||||
OnConnectEnd,
|
||||
Panel,
|
||||
} from 'reactflow';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgesChanged,
|
||||
nodesChanged,
|
||||
} from '../store/nodesSlice';
|
||||
import { useCallback } from 'react';
|
||||
import { InvocationComponent } from './InvocationComponent';
|
||||
import { AddNodeMenu } from './AddNodeMenu';
|
||||
import { FieldTypeLegend } from './FieldTypeLegend';
|
||||
import { Button } from '@chakra-ui/react';
|
||||
import { nodesGraphBuilt } from 'services/thunks/session';
|
||||
|
||||
const nodeTypes = { invocation: InvocationComponent };
|
||||
|
||||
export const Flow = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
|
||||
const edges = useAppSelector((state: RootState) => state.nodes.edges);
|
||||
|
||||
const onNodesChange: OnNodesChange = useCallback(
|
||||
(changes) => {
|
||||
dispatch(nodesChanged(changes));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgesChange: OnEdgesChange = useCallback(
|
||||
(changes) => {
|
||||
dispatch(edgesChanged(changes));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnectStart: OnConnectStart = useCallback(
|
||||
(event, params) => {
|
||||
dispatch(connectionStarted(params));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnect: OnConnect = useCallback(
|
||||
(connection) => {
|
||||
dispatch(connectionMade(connection));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onConnectEnd: OnConnectEnd = useCallback(
|
||||
(event) => {
|
||||
dispatch(connectionEnded());
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleInvoke = useCallback(() => {
|
||||
dispatch(nodesGraphBuilt());
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<ReactFlow
|
||||
nodeTypes={nodeTypes}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onConnectStart={onConnectStart}
|
||||
onConnect={onConnect}
|
||||
onConnectEnd={onConnectEnd}
|
||||
defaultEdgeOptions={{
|
||||
style: { strokeWidth: 2 },
|
||||
}}
|
||||
>
|
||||
<Panel position="top-left">
|
||||
<AddNodeMenu />
|
||||
</Panel>
|
||||
<Panel position="top-center">
|
||||
<Button onClick={handleInvoke}>Will it blend?</Button>
|
||||
</Panel>
|
||||
<Panel position="top-right">
|
||||
<FieldTypeLegend />
|
||||
</Panel>
|
||||
<Background />
|
||||
<Controls />
|
||||
<MiniMap nodeStrokeWidth={3} zoomable pannable />
|
||||
</ReactFlow>
|
||||
);
|
||||
};
|
@ -0,0 +1,107 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
import { ArrayInputFieldComponent } from './fields/ArrayInputField.tsx';
|
||||
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
|
||||
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
|
||||
import { ImageInputFieldComponent } from './fields/ImageInputFieldComponent';
|
||||
import { LatentsInputFieldComponent } from './fields/LatentsInputFieldComponent';
|
||||
import { ModelInputFieldComponent } from './fields/ModelInputFieldComponent';
|
||||
import { NumberInputFieldComponent } from './fields/NumberInputFieldComponent';
|
||||
import { StringInputFieldComponent } from './fields/StringInputFieldComponent';
|
||||
|
||||
type InputFieldComponentProps = {
|
||||
nodeId: string;
|
||||
field: InputFieldValue;
|
||||
template: InputFieldTemplate;
|
||||
};
|
||||
|
||||
// build an individual input element based on the schema
|
||||
export const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
const { nodeId, field, template } = props;
|
||||
const { type, value } = field;
|
||||
|
||||
if (type === 'string' && template.type === 'string') {
|
||||
return (
|
||||
<StringInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'boolean' && template.type === 'boolean') {
|
||||
return (
|
||||
<BooleanInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(type === 'integer' && template.type === 'integer') ||
|
||||
(type === 'float' && template.type === 'float')
|
||||
) {
|
||||
return (
|
||||
<NumberInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'enum' && template.type === 'enum') {
|
||||
return (
|
||||
<EnumInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'image' && template.type === 'image') {
|
||||
return (
|
||||
<ImageInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'latents' && template.type === 'latents') {
|
||||
return (
|
||||
<LatentsInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'model' && template.type === 'model') {
|
||||
return (
|
||||
<ModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'array' && template.type === 'array') {
|
||||
return (
|
||||
<ArrayInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return <Box p={2}>Unknown field type: {type}</Box>;
|
||||
};
|
@ -0,0 +1,243 @@
|
||||
import { NodeProps, useReactFlow } from 'reactflow';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Heading,
|
||||
HStack,
|
||||
Tooltip,
|
||||
Icon,
|
||||
Code,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import { FaExclamationCircle, FaInfoCircle } from 'react-icons/fa';
|
||||
import { InvocationValue } from '../types/types';
|
||||
import { InputFieldComponent } from './InputFieldComponent';
|
||||
import { FieldHandle } from './FieldHandle';
|
||||
import { isEqual, map, size } from 'lodash';
|
||||
import { memo, useMemo, useRef } from 'react';
|
||||
import { useIsValidConnection } from '../hooks/useIsValidConnection';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { useGetInvocationTemplate } from '../hooks/useInvocationTemplate';
|
||||
|
||||
const connectedInputFieldsSelector = createSelector(
|
||||
[(state: RootState) => state.nodes.edges],
|
||||
(edges) => {
|
||||
// return edges.map((e) => e.targetHandle);
|
||||
return edges;
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
|
||||
const { id: nodeId, data, selected } = props;
|
||||
const { type, inputs, outputs } = data;
|
||||
|
||||
const isValidConnection = useIsValidConnection();
|
||||
|
||||
const connectedInputs = useAppSelector(connectedInputFieldsSelector);
|
||||
const getInvocationTemplate = useGetInvocationTemplate();
|
||||
// TODO: determine if a field/handle is connected and disable the input if so
|
||||
|
||||
const template = useRef(getInvocationTemplate(type));
|
||||
|
||||
if (!template.current) {
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
padding: 4,
|
||||
bg: 'base.800',
|
||||
borderRadius: 'md',
|
||||
boxShadow: 'dark-lg',
|
||||
borderWidth: 2,
|
||||
borderColor: selected ? 'base.400' : 'transparent',
|
||||
}}
|
||||
>
|
||||
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}>
|
||||
<Icon color="base.400" boxSize={32} as={FaExclamationCircle}></Icon>
|
||||
</Flex>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
padding: 4,
|
||||
bg: 'base.800',
|
||||
borderRadius: 'md',
|
||||
boxShadow: 'dark-lg',
|
||||
borderWidth: 2,
|
||||
borderColor: selected ? 'base.400' : 'transparent',
|
||||
}}
|
||||
>
|
||||
<Flex flexDirection="column" gap={2}>
|
||||
<>
|
||||
<Code>{nodeId}</Code>
|
||||
<HStack justifyContent="space-between">
|
||||
<Heading size="sm" fontWeight={500} color="base.100">
|
||||
{template.current.title}
|
||||
</Heading>
|
||||
<Tooltip
|
||||
label={template.current.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
>
|
||||
<Icon color="base.300" as={FaInfoCircle} />
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
{map(inputs, (input, i) => {
|
||||
const { id: fieldId } = input;
|
||||
const inputTemplate = template.current?.inputs[input.name];
|
||||
|
||||
if (!inputTemplate) {
|
||||
return (
|
||||
<Box
|
||||
key={fieldId}
|
||||
position="relative"
|
||||
p={2}
|
||||
borderWidth={1}
|
||||
borderRadius="md"
|
||||
sx={{
|
||||
borderColor: 'error.400',
|
||||
}}
|
||||
>
|
||||
<FormControl isDisabled={true}>
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel>Unknown input: {input.name}</FormLabel>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
const isConnected = Boolean(
|
||||
connectedInputs.filter((connectedInput) => {
|
||||
return (
|
||||
connectedInput.target === nodeId &&
|
||||
connectedInput.targetHandle === input.name
|
||||
);
|
||||
}).length
|
||||
);
|
||||
|
||||
return (
|
||||
<Box
|
||||
key={fieldId}
|
||||
position="relative"
|
||||
p={2}
|
||||
borderWidth={1}
|
||||
borderRadius="md"
|
||||
sx={{
|
||||
borderColor:
|
||||
!isConnected &&
|
||||
['always', 'connectionOnly'].includes(
|
||||
String(inputTemplate?.inputRequirement)
|
||||
) &&
|
||||
input.value === undefined
|
||||
? 'warning.400'
|
||||
: undefined,
|
||||
}}
|
||||
>
|
||||
<FormControl isDisabled={isConnected}>
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel>{inputTemplate?.title}</FormLabel>
|
||||
<Tooltip
|
||||
label={inputTemplate?.description}
|
||||
placement="top"
|
||||
hasArrow
|
||||
shouldWrapChildren
|
||||
>
|
||||
<Icon color="base.400" as={FaInfoCircle} />
|
||||
</Tooltip>
|
||||
</HStack>
|
||||
<InputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={input}
|
||||
template={inputTemplate}
|
||||
/>
|
||||
</FormControl>
|
||||
{!['never', 'directOnly'].includes(
|
||||
inputTemplate?.inputRequirement ?? ''
|
||||
) && (
|
||||
<FieldHandle
|
||||
nodeId={nodeId}
|
||||
field={inputTemplate}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="target"
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
{map(outputs).map((output, i) => {
|
||||
const outputTemplate = template.current?.outputs[output.name];
|
||||
|
||||
const isConnected = Boolean(
|
||||
connectedInputs.filter((connectedInput) => {
|
||||
return (
|
||||
connectedInput.source === nodeId &&
|
||||
connectedInput.sourceHandle === output.name
|
||||
);
|
||||
}).length
|
||||
);
|
||||
|
||||
if (!outputTemplate) {
|
||||
return (
|
||||
<Box
|
||||
key={output.id}
|
||||
position="relative"
|
||||
p={2}
|
||||
borderWidth={1}
|
||||
borderRadius="md"
|
||||
sx={{
|
||||
borderColor: 'error.400',
|
||||
}}
|
||||
>
|
||||
<FormControl isDisabled={true}>
|
||||
<HStack justifyContent="space-between" alignItems="center">
|
||||
<FormLabel>Unknown output: {output.name}</FormLabel>
|
||||
</HStack>
|
||||
</FormControl>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Box
|
||||
key={output.id}
|
||||
position="relative"
|
||||
p={2}
|
||||
borderWidth={1}
|
||||
borderRadius="md"
|
||||
>
|
||||
<FormControl isDisabled={isConnected}>
|
||||
<FormLabel textAlign="end">
|
||||
{outputTemplate?.title} Output
|
||||
</FormLabel>
|
||||
</FormControl>
|
||||
<FieldHandle
|
||||
key={output.id}
|
||||
nodeId={nodeId}
|
||||
field={outputTemplate}
|
||||
isValidConnection={isValidConnection}
|
||||
handleType="source"
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
</Flex>
|
||||
<Flex></Flex>
|
||||
</Box>
|
||||
);
|
||||
});
|
||||
|
||||
InvocationComponent.displayName = 'InvocationComponent';
|
@ -0,0 +1,46 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { ReactFlowProvider } from 'reactflow';
|
||||
|
||||
import { Flow } from './Flow';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { RootState } from 'app/store';
|
||||
import { buildNodesGraph } from '../util/nodesGraphBuilder/buildNodesGraph';
|
||||
|
||||
const NodeEditor = () => {
|
||||
const state = useAppSelector((state: RootState) => state);
|
||||
|
||||
const graph = buildNodesGraph(state);
|
||||
|
||||
return (
|
||||
<Box
|
||||
sx={{
|
||||
position: 'relative',
|
||||
width: 'full',
|
||||
height: 'full',
|
||||
borderRadius: 'md',
|
||||
bg: 'base.850',
|
||||
}}
|
||||
>
|
||||
<ReactFlowProvider>
|
||||
<Flow />
|
||||
</ReactFlowProvider>
|
||||
<Box
|
||||
as="pre"
|
||||
fontFamily="monospace"
|
||||
position="absolute"
|
||||
top={2}
|
||||
left={2}
|
||||
width="full"
|
||||
height="full"
|
||||
userSelect="none"
|
||||
pointerEvents="none"
|
||||
opacity={0.7}
|
||||
>
|
||||
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
export default NodeEditor;
|
@ -0,0 +1,14 @@
|
||||
import {
|
||||
ArrayInputFieldTemplate,
|
||||
ArrayInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { FaImage, FaList } from 'react-icons/fa';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const ArrayInputFieldComponent = (
|
||||
props: FieldComponentProps<ArrayInputFieldValue, ArrayInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
return <FaList />;
|
||||
};
|
@ -0,0 +1,31 @@
|
||||
import { Switch } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
BooleanInputFieldTemplate,
|
||||
BooleanInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const BooleanInputFieldComponent = (
|
||||
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: e.target.checked,
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Switch onChange={handleValueChanged} isChecked={field.value}></Switch>
|
||||
);
|
||||
};
|
@ -0,0 +1,35 @@
|
||||
import { Select } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
EnumInputFieldTemplate,
|
||||
EnumInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const EnumInputFieldComponent = (
|
||||
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field, template } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: e.target.value,
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Select onChange={handleValueChanged} value={field.value}>
|
||||
{template.options.map((option) => (
|
||||
<option key={option}>{option}</option>
|
||||
))}
|
||||
</Select>
|
||||
);
|
||||
};
|
@ -0,0 +1,64 @@
|
||||
import { Box, Image, Icon, Flex } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder';
|
||||
import { useGetUrl } from 'common/util/getUrl';
|
||||
import useGetImageByNameAndType from 'features/gallery/hooks/useGetImageByName';
|
||||
import useGetImageByUuid from 'features/gallery/hooks/useGetImageByUuid';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ImageInputFieldTemplate,
|
||||
ImageInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { DragEvent, useCallback, useState } from 'react';
|
||||
import { FaImage } from 'react-icons/fa';
|
||||
import { ImageType } from 'services/api';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const ImageInputFieldComponent = (
|
||||
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const { value } = field;
|
||||
|
||||
const getImageByNameAndType = useGetImageByNameAndType();
|
||||
const dispatch = useAppDispatch();
|
||||
const [url, setUrl] = useState<string>();
|
||||
const { getUrl } = useGetUrl();
|
||||
|
||||
const handleDrop = useCallback(
|
||||
(e: DragEvent<HTMLDivElement>) => {
|
||||
const name = e.dataTransfer.getData('invokeai/imageName');
|
||||
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
|
||||
|
||||
if (!name || !type) {
|
||||
return;
|
||||
}
|
||||
|
||||
const image = getImageByNameAndType(name, type);
|
||||
|
||||
if (!image) {
|
||||
return;
|
||||
}
|
||||
|
||||
setUrl(image.url);
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: {
|
||||
image_name: name,
|
||||
image_type: type,
|
||||
},
|
||||
})
|
||||
);
|
||||
},
|
||||
[getImageByNameAndType, dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<Box onDrop={handleDrop}>
|
||||
<Image src={getUrl(url)} fallback={<SelectImagePlaceholder />} />
|
||||
</Box>
|
||||
);
|
||||
};
|
@ -0,0 +1,13 @@
|
||||
import {
|
||||
LatentsInputFieldTemplate,
|
||||
LatentsInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const LatentsInputFieldComponent = (
|
||||
props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
return null;
|
||||
};
|
@ -0,0 +1,57 @@
|
||||
import { Select } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ModelInputFieldTemplate,
|
||||
ModelInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import {
|
||||
selectModelsById,
|
||||
selectModelsIds,
|
||||
} from 'features/system/store/modelSlice';
|
||||
import { isEqual, map } from 'lodash';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const availableModelsSelector = createSelector(
|
||||
[selectModelsIds],
|
||||
(allModelNames) => {
|
||||
return { allModelNames };
|
||||
// return map(modelList, (_, name) => name);
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export const ModelInputFieldComponent = (
|
||||
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { allModelNames } = useAppSelector(availableModelsSelector);
|
||||
|
||||
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: e.target.value,
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<Select onChange={handleValueChanged} value={field.value}>
|
||||
{allModelNames.map((option) => (
|
||||
<option key={option}>{option}</option>
|
||||
))}
|
||||
</Select>
|
||||
);
|
||||
};
|
@ -0,0 +1,41 @@
|
||||
import {
|
||||
NumberDecrementStepper,
|
||||
NumberIncrementStepper,
|
||||
NumberInput,
|
||||
NumberInputField,
|
||||
NumberInputStepper,
|
||||
} from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
FloatInputFieldTemplate,
|
||||
FloatInputFieldValue,
|
||||
IntegerInputFieldTemplate,
|
||||
IntegerInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const NumberInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
IntegerInputFieldValue | FloatInputFieldValue,
|
||||
IntegerInputFieldTemplate | FloatInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleValueChanged = (_: string, value: number) => {
|
||||
dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value }));
|
||||
};
|
||||
|
||||
return (
|
||||
<NumberInput onChange={handleValueChanged} value={field.value}>
|
||||
<NumberInputField />
|
||||
<NumberInputStepper>
|
||||
<NumberIncrementStepper />
|
||||
<NumberDecrementStepper />
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
);
|
||||
};
|
@ -0,0 +1,29 @@
|
||||
import { Input } from '@chakra-ui/react';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
StringInputFieldTemplate,
|
||||
StringInputFieldValue,
|
||||
} from 'features/nodes/types';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
export const StringInputFieldComponent = (
|
||||
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: e.target.value,
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
return <Input onChange={handleValueChanged} value={field.value}></Input>;
|
||||
};
|
@ -0,0 +1,10 @@
|
||||
import { InputFieldTemplate, InputFieldValue } from 'features/nodes/types';
|
||||
|
||||
export type FieldComponentProps<
|
||||
V extends InputFieldValue,
|
||||
T extends InputFieldTemplate
|
||||
> = {
|
||||
nodeId: string;
|
||||
field: V;
|
||||
template: T;
|
||||
};
|
@ -0,0 +1,52 @@
|
||||
export const iterationGraph = {
|
||||
nodes: {
|
||||
'0': {
|
||||
id: '0',
|
||||
type: 'range',
|
||||
start: 0,
|
||||
stop: 5,
|
||||
step: 1,
|
||||
},
|
||||
'1': {
|
||||
collection: [],
|
||||
id: '1',
|
||||
index: 0,
|
||||
type: 'iterate',
|
||||
},
|
||||
'2': {
|
||||
cfg_scale: 7.5,
|
||||
height: 512,
|
||||
id: '2',
|
||||
model: '',
|
||||
progress_images: false,
|
||||
prompt: 'dog',
|
||||
sampler_name: 'k_lms',
|
||||
seamless: false,
|
||||
steps: 11,
|
||||
type: 'txt2img',
|
||||
width: 512,
|
||||
},
|
||||
},
|
||||
edges: [
|
||||
{
|
||||
source: {
|
||||
field: 'collection',
|
||||
node_id: '0',
|
||||
},
|
||||
destination: {
|
||||
field: 'collection',
|
||||
node_id: '1',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
field: 'item',
|
||||
node_id: '1',
|
||||
},
|
||||
destination: {
|
||||
field: 'seed',
|
||||
node_id: '2',
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
@ -0,0 +1,78 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { reduce } from 'lodash';
|
||||
import { Node } from 'reactflow';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import {
|
||||
InputFieldValue,
|
||||
InvocationValue,
|
||||
OutputFieldValue,
|
||||
} from '../types/types';
|
||||
import { buildInputFieldValue } from '../util/fieldValueBuilders';
|
||||
|
||||
export const useBuildInvocation = () => {
|
||||
const invocationTemplates = useAppSelector(
|
||||
(state: RootState) => state.nodes.invocationTemplates
|
||||
);
|
||||
|
||||
return (type: AnyInvocationType) => {
|
||||
const template = invocationTemplates[type];
|
||||
|
||||
if (template === undefined) {
|
||||
console.error(`Unable to find template ${type}.`);
|
||||
return;
|
||||
}
|
||||
|
||||
const nodeId = uuidv4();
|
||||
|
||||
const inputs = reduce(
|
||||
template.inputs,
|
||||
(inputsAccumulator, inputTemplate, inputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const inputFieldValue: InputFieldValue = buildInputFieldValue(
|
||||
fieldId,
|
||||
inputTemplate
|
||||
);
|
||||
|
||||
inputsAccumulator[inputName] = inputFieldValue;
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<string, InputFieldValue>
|
||||
);
|
||||
|
||||
const outputs = reduce(
|
||||
template.outputs,
|
||||
(outputsAccumulator, outputTemplate, outputName) => {
|
||||
const fieldId = uuidv4();
|
||||
|
||||
const outputFieldValue: OutputFieldValue = {
|
||||
id: fieldId,
|
||||
name: outputName,
|
||||
type: outputTemplate.type,
|
||||
};
|
||||
|
||||
outputsAccumulator[outputName] = outputFieldValue;
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, OutputFieldValue>
|
||||
);
|
||||
|
||||
const invocation: Node<InvocationValue> = {
|
||||
id: nodeId,
|
||||
type: 'invocation',
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
id: nodeId,
|
||||
type,
|
||||
inputs,
|
||||
outputs,
|
||||
},
|
||||
};
|
||||
|
||||
return invocation;
|
||||
};
|
||||
};
|
@ -0,0 +1,16 @@
|
||||
import { useAppSelector } from 'app/storeHooks';
|
||||
import { invocationTemplatesSelector } from '../store/selectors/invocationTemplatesSelector';
|
||||
|
||||
export const useGetInvocationTemplate = () => {
|
||||
const invocationTemplates = useAppSelector(invocationTemplatesSelector);
|
||||
|
||||
return (invocationType: string) => {
|
||||
const template = invocationTemplates[invocationType];
|
||||
|
||||
if (!template) {
|
||||
return;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
};
|
@ -0,0 +1,93 @@
|
||||
import { useCallback } from 'react';
|
||||
import { Connection, Node, useReactFlow } from 'reactflow';
|
||||
import graphlib from '@dagrejs/graphlib';
|
||||
import { InvocationValue } from '../types/types';
|
||||
|
||||
export const useIsValidConnection = () => {
|
||||
const flow = useReactFlow();
|
||||
|
||||
// Check if an in-progress connection is valid
|
||||
const isValidConnection = useCallback(
|
||||
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
|
||||
const edges = flow.getEdges();
|
||||
const nodes = flow.getNodes();
|
||||
|
||||
return true;
|
||||
|
||||
// Connection must have valid targets
|
||||
if (!(source && sourceHandle && target && targetHandle)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connection is invalid if target already has a connection
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
})
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Find the source and target nodes
|
||||
const sourceNode = flow.getNode(source) as Node<InvocationValue>;
|
||||
|
||||
const targetNode = flow.getNode(target) as Node<InvocationValue>;
|
||||
|
||||
// Conditional guards against undefined nodes/handles
|
||||
if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Connection types must be the same for a connection
|
||||
if (
|
||||
sourceNode.data.outputs[sourceHandle].type !==
|
||||
targetNode.data.inputs[targetHandle].type
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Graphs much be acyclic (no loops!)
|
||||
|
||||
/**
|
||||
* TODO: use `graphlib.alg.findCycles()` to identify strong connections
|
||||
*
|
||||
* this validation func only runs when the cursor hits the second handle of the connection,
|
||||
* and only on that second handle - so it cannot tell us exhaustively which connections
|
||||
* are valid.
|
||||
*
|
||||
* ideally, we check when the connection starts to calculate all invalid handles at once.
|
||||
*
|
||||
* requires making a new graphlib graph - and calling `findCycles()` - for each potential
|
||||
* handle. instead of using the `isValidConnection` prop, it would use the `onConnectStart`
|
||||
* prop.
|
||||
*
|
||||
* the strong connections should be stored in global state.
|
||||
*
|
||||
* then, `isValidConnection` would simple loop through the strong connections and if the
|
||||
* source and target are in a single strong connection, return false.
|
||||
*
|
||||
* and also, we can use this knowledge to style every handle when a connection starts,
|
||||
* which is otherwise not possible.
|
||||
*/
|
||||
|
||||
// build a graphlib graph
|
||||
const g = new graphlib.Graph();
|
||||
|
||||
nodes.forEach((n) => {
|
||||
g.setNode(n.id);
|
||||
});
|
||||
|
||||
edges.forEach((e) => {
|
||||
g.setEdge(e.source, e.target);
|
||||
});
|
||||
|
||||
// Add the candidate edge to the graph
|
||||
g.setEdge(source, target);
|
||||
|
||||
return graphlib.alg.isAcyclic(g);
|
||||
},
|
||||
[flow]
|
||||
);
|
||||
|
||||
return isValidConnection;
|
||||
};
|
@ -0,0 +1,10 @@
|
||||
import { NodesState } from './nodesSlice';
|
||||
|
||||
/**
|
||||
* Nodes slice persist blacklist
|
||||
*/
|
||||
const itemsToBlacklist: (keyof NodesState)[] = ['schema', 'invocations'];
|
||||
|
||||
export const nodesBlacklist = itemsToBlacklist.map(
|
||||
(blacklistItem) => `nodes.${blacklistItem}`
|
||||
);
|
103
invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
Normal file
103
invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
Normal file
@ -0,0 +1,103 @@
|
||||
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import {
|
||||
addEdge,
|
||||
applyEdgeChanges,
|
||||
applyNodeChanges,
|
||||
Connection,
|
||||
Edge,
|
||||
EdgeChange,
|
||||
Node,
|
||||
NodeChange,
|
||||
OnConnectStartParams,
|
||||
} from 'reactflow';
|
||||
import { Graph, ImageField } from 'services/api';
|
||||
import { receivedOpenAPISchema } from 'services/thunks/schema';
|
||||
import { isFulfilledAnyGraphBuilt } from 'services/thunks/session';
|
||||
import { InvocationTemplate, InvocationValue } from '../types/types';
|
||||
import { parseSchema } from '../util/parseSchema';
|
||||
|
||||
export type NodesState = {
|
||||
nodes: Node<InvocationValue>[];
|
||||
edges: Edge[];
|
||||
schema: OpenAPIV3.Document | null;
|
||||
invocationTemplates: Record<string, InvocationTemplate>;
|
||||
connectionStartParams: OnConnectStartParams | null;
|
||||
lastGraph: Graph | null;
|
||||
};
|
||||
|
||||
export const initialNodesState: NodesState = {
|
||||
nodes: [],
|
||||
edges: [],
|
||||
schema: null,
|
||||
invocationTemplates: {},
|
||||
connectionStartParams: null,
|
||||
lastGraph: null,
|
||||
};
|
||||
|
||||
const nodesSlice = createSlice({
|
||||
name: 'nodes',
|
||||
initialState: initialNodesState,
|
||||
reducers: {
|
||||
nodesChanged: (state, action: PayloadAction<NodeChange[]>) => {
|
||||
state.nodes = applyNodeChanges(action.payload, state.nodes);
|
||||
},
|
||||
nodeAdded: (state, action: PayloadAction<Node<InvocationValue>>) => {
|
||||
state.nodes.push(action.payload);
|
||||
},
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
},
|
||||
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
||||
state.connectionStartParams = action.payload;
|
||||
},
|
||||
connectionMade: (state, action: PayloadAction<Connection>) => {
|
||||
state.edges = addEdge(action.payload, state.edges);
|
||||
},
|
||||
connectionEnded: (state) => {
|
||||
state.connectionStartParams = null;
|
||||
},
|
||||
fieldValueChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value:
|
||||
| string
|
||||
| number
|
||||
| boolean
|
||||
| Pick<ImageField, 'image_name' | 'image_type'>
|
||||
| undefined;
|
||||
}>
|
||||
) => {
|
||||
const { nodeId, fieldName, value } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
|
||||
if (nodeIndex > -1) {
|
||||
state.nodes[nodeIndex].data.inputs[fieldName].value = value;
|
||||
}
|
||||
},
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
|
||||
state.schema = action.payload;
|
||||
state.invocationTemplates = parseSchema(action.payload);
|
||||
});
|
||||
|
||||
builder.addMatcher(isFulfilledAnyGraphBuilt, (state, action) => {
|
||||
state.lastGraph = action.payload;
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const {
|
||||
nodesChanged,
|
||||
edgesChanged,
|
||||
nodeAdded,
|
||||
fieldValueChanged,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
connectionEnded,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
@ -0,0 +1,7 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store';
|
||||
|
||||
export const invocationTemplatesSelector = createSelector(
|
||||
(state: RootState) => state.nodes,
|
||||
(nodes) => nodes.invocationTemplates
|
||||
);
|
69
invokeai/frontend/web/src/features/nodes/types/constants.ts
Normal file
69
invokeai/frontend/web/src/features/nodes/types/constants.ts
Normal file
@ -0,0 +1,69 @@
|
||||
import { getCSSVar } from '@chakra-ui/utils';
|
||||
import { FieldType, FieldUIConfig } from './types';
|
||||
|
||||
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
|
||||
|
||||
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
integer: 'integer',
|
||||
number: 'float',
|
||||
string: 'string',
|
||||
boolean: 'boolean',
|
||||
enum: 'enum',
|
||||
ImageField: 'image',
|
||||
LatentsField: 'latents',
|
||||
model: 'model',
|
||||
array: 'array',
|
||||
};
|
||||
|
||||
const COLOR_TOKEN_VALUE = 500;
|
||||
|
||||
const getColorTokenCssVariable = (color: string) =>
|
||||
`var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`;
|
||||
|
||||
export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
integer: {
|
||||
colorCssVar: getColorTokenCssVariable('red'),
|
||||
title: 'Integer',
|
||||
description: 'Integers are whole numbers, without a decimal point.',
|
||||
},
|
||||
float: {
|
||||
colorCssVar: getColorTokenCssVariable('orange'),
|
||||
title: 'Float',
|
||||
description: 'Floats are numbers with a decimal point.',
|
||||
},
|
||||
string: {
|
||||
colorCssVar: getColorTokenCssVariable('yellow'),
|
||||
title: 'String',
|
||||
description: 'Strings are text.',
|
||||
},
|
||||
boolean: {
|
||||
colorCssVar: getColorTokenCssVariable('green'),
|
||||
title: 'Boolean',
|
||||
description: 'Booleans are true or false.',
|
||||
},
|
||||
enum: {
|
||||
colorCssVar: getColorTokenCssVariable('blue'),
|
||||
title: 'Enum',
|
||||
description: 'Enums are values that may be one of a number of options.',
|
||||
},
|
||||
image: {
|
||||
colorCssVar: getColorTokenCssVariable('purple'),
|
||||
title: 'Image',
|
||||
description: 'Images may be passed between nodes.',
|
||||
},
|
||||
latents: {
|
||||
colorCssVar: getColorTokenCssVariable('pink'),
|
||||
title: 'Latents',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
},
|
||||
model: {
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'Model',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
array: {
|
||||
colorCssVar: getColorTokenCssVariable('gray'),
|
||||
title: 'Array',
|
||||
description: 'TODO: Array type description.',
|
||||
},
|
||||
};
|
@ -0,0 +1,9 @@
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
|
||||
export const isReferenceObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
|
||||
): obj is OpenAPIV3.ReferenceObject => '$ref' in obj;
|
||||
|
||||
export const isSchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
|
||||
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);
|
296
invokeai/frontend/web/src/features/nodes/types/types.ts
Normal file
296
invokeai/frontend/web/src/features/nodes/types/types.ts
Normal file
@ -0,0 +1,296 @@
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { ImageField } from 'services/api';
|
||||
import { AnyInvocationType } from 'services/events/types';
|
||||
|
||||
export type InvocationValue = {
|
||||
id: string;
|
||||
type: AnyInvocationType;
|
||||
inputs: Record<string, InputFieldValue>;
|
||||
outputs: Record<string, OutputFieldValue>;
|
||||
};
|
||||
|
||||
export type InvocationTemplate = {
|
||||
/**
|
||||
* Unique type of the invocation
|
||||
*/
|
||||
type: AnyInvocationType;
|
||||
/**
|
||||
* Display name of the invocation
|
||||
*/
|
||||
title: string;
|
||||
/**
|
||||
* Description of the invocation
|
||||
*/
|
||||
description: string;
|
||||
/**
|
||||
* Invocation tags
|
||||
*/
|
||||
tags: string[];
|
||||
/**
|
||||
* Array of invocation inputs
|
||||
*/
|
||||
inputs: Record<string, InputFieldTemplate>;
|
||||
// inputs: InputField[];
|
||||
/**
|
||||
* Array of the invocation outputs
|
||||
*/
|
||||
outputs: Record<string, OutputFieldTemplate>;
|
||||
// outputs: OutputField[];
|
||||
};
|
||||
|
||||
export type FieldUIConfig = {
|
||||
colorCssVar: string;
|
||||
title: string;
|
||||
description: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* The valid invocation field types
|
||||
*/
|
||||
export type FieldType =
|
||||
| 'integer'
|
||||
| 'float'
|
||||
| 'string'
|
||||
| 'boolean'
|
||||
| 'enum'
|
||||
| 'image'
|
||||
| 'latents'
|
||||
| 'model'
|
||||
| 'array';
|
||||
|
||||
/**
|
||||
* An input field is persisted across reloads as part of the user's local state.
|
||||
*
|
||||
* An input field has three properties:
|
||||
* - `id` a unique identifier
|
||||
* - `name` the name of the field, which comes from the python dataclass
|
||||
* - `value` the field's value
|
||||
*/
|
||||
export type InputFieldValue =
|
||||
| IntegerInputFieldValue
|
||||
| FloatInputFieldValue
|
||||
| StringInputFieldValue
|
||||
| BooleanInputFieldValue
|
||||
| ImageInputFieldValue
|
||||
| LatentsInputFieldValue
|
||||
| EnumInputFieldValue
|
||||
| ModelInputFieldValue
|
||||
| ArrayInputFieldValue;
|
||||
|
||||
/**
|
||||
* An input field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
* The template provides the field type and other field metadata (e.g. title, description,
|
||||
* maximum length, pattern to match, etc).
|
||||
*/
|
||||
export type InputFieldTemplate =
|
||||
| IntegerInputFieldTemplate
|
||||
| FloatInputFieldTemplate
|
||||
| StringInputFieldTemplate
|
||||
| BooleanInputFieldTemplate
|
||||
| ImageInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| EnumInputFieldTemplate
|
||||
| ModelInputFieldTemplate
|
||||
| ArrayInputFieldTemplate;
|
||||
|
||||
/**
|
||||
* An output field is persisted across as part of the user's local state.
|
||||
*
|
||||
* An output field has two properties:
|
||||
* - `id` a unique identifier
|
||||
* - `name` the name of the field, which comes from the python dataclass
|
||||
*/
|
||||
export type OutputFieldValue = FieldValueBase;
|
||||
|
||||
/**
|
||||
* An output field template is generated on each page load from the OpenAPI schema.
|
||||
*
|
||||
* The template provides the output field's name, type, title, and description.
|
||||
*/
|
||||
export type OutputFieldTemplate = {
|
||||
name: string;
|
||||
type: FieldType;
|
||||
title: string;
|
||||
description: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Indicates when/if this field needs an input.
|
||||
*/
|
||||
export type InputRequirement = 'always' | 'never' | 'optional';
|
||||
|
||||
/**
|
||||
* Indicates the kind of input(s) this field may have.
|
||||
*/
|
||||
export type InputKind = 'connection' | 'direct' | 'any';
|
||||
|
||||
export type FieldValueBase = {
|
||||
id: string;
|
||||
name: string;
|
||||
type: FieldType;
|
||||
};
|
||||
|
||||
export type IntegerInputFieldValue = FieldValueBase & {
|
||||
type: 'integer';
|
||||
value?: number;
|
||||
};
|
||||
|
||||
export type FloatInputFieldValue = FieldValueBase & {
|
||||
type: 'float';
|
||||
value?: number;
|
||||
};
|
||||
|
||||
export type StringInputFieldValue = FieldValueBase & {
|
||||
type: 'string';
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type BooleanInputFieldValue = FieldValueBase & {
|
||||
type: 'boolean';
|
||||
value?: boolean;
|
||||
};
|
||||
|
||||
export type EnumInputFieldValue = FieldValueBase & {
|
||||
type: 'enum';
|
||||
value?: number | string;
|
||||
};
|
||||
|
||||
export type LatentsInputFieldValue = FieldValueBase & {
|
||||
type: 'latents';
|
||||
value?: undefined;
|
||||
};
|
||||
|
||||
export type ImageInputFieldValue = FieldValueBase & {
|
||||
type: 'image';
|
||||
value?: Pick<ImageField, 'image_name' | 'image_type'>;
|
||||
};
|
||||
|
||||
export type ModelInputFieldValue = FieldValueBase & {
|
||||
type: 'model';
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type ArrayInputFieldValue = FieldValueBase & {
|
||||
type: 'array';
|
||||
value?: (string | number)[];
|
||||
};
|
||||
|
||||
export type InputFieldTemplateBase = {
|
||||
name: string;
|
||||
title: string;
|
||||
description: string;
|
||||
type: FieldType;
|
||||
inputRequirement: InputRequirement;
|
||||
inputKind: InputKind;
|
||||
};
|
||||
|
||||
export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'integer';
|
||||
default: number;
|
||||
multipleOf?: number;
|
||||
maximum?: number;
|
||||
exclusiveMaximum?: boolean;
|
||||
minimum?: number;
|
||||
exclusiveMinimum?: boolean;
|
||||
};
|
||||
|
||||
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'float';
|
||||
default: number;
|
||||
multipleOf?: number;
|
||||
maximum?: number;
|
||||
exclusiveMaximum?: boolean;
|
||||
minimum?: number;
|
||||
exclusiveMinimum?: boolean;
|
||||
};
|
||||
|
||||
export type StringInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'string';
|
||||
default: string;
|
||||
maxLength?: number;
|
||||
minLength?: number;
|
||||
pattern?: string;
|
||||
};
|
||||
|
||||
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: boolean;
|
||||
type: 'boolean';
|
||||
};
|
||||
|
||||
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: Pick<ImageField, 'image_name' | 'image_type'>;
|
||||
type: 'image';
|
||||
};
|
||||
|
||||
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'latents';
|
||||
};
|
||||
|
||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string | number;
|
||||
type: 'enum';
|
||||
enumType: 'string' | 'number';
|
||||
options: Array<string | number>;
|
||||
};
|
||||
|
||||
export type ModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'model';
|
||||
};
|
||||
|
||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: (string | number)[];
|
||||
type: 'array';
|
||||
};
|
||||
|
||||
/**
|
||||
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
|
||||
*/
|
||||
|
||||
export type TypeHints = {
|
||||
[fieldName: string]: FieldType;
|
||||
};
|
||||
|
||||
export type InvocationSchemaExtra = {
|
||||
output: OpenAPIV3.ReferenceObject; // the output of the invocation
|
||||
ui?: {
|
||||
tags?: string[];
|
||||
type_hints?: TypeHints;
|
||||
title?: string;
|
||||
};
|
||||
title: string;
|
||||
properties: Omit<
|
||||
NonNullable<OpenAPIV3.SchemaObject['properties']>,
|
||||
'type'
|
||||
> & {
|
||||
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
|
||||
default: AnyInvocationType;
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
export type InvocationSchemaType = {
|
||||
default: string; // the type of the invocation
|
||||
};
|
||||
|
||||
export type InvocationBaseSchemaObject = Omit<
|
||||
OpenAPIV3.BaseSchemaObject,
|
||||
'title' | 'type' | 'properties'
|
||||
> &
|
||||
InvocationSchemaExtra;
|
||||
|
||||
export interface ArraySchemaObject extends InvocationBaseSchemaObject {
|
||||
type: OpenAPIV3.ArraySchemaObjectType;
|
||||
items: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject;
|
||||
}
|
||||
export interface NonArraySchemaObject extends InvocationBaseSchemaObject {
|
||||
type?: OpenAPIV3.NonArraySchemaObjectType;
|
||||
}
|
||||
|
||||
export type InvocationSchemaObject = ArraySchemaObject | NonArraySchemaObject;
|
||||
|
||||
export const isInvocationSchemaObject = (
|
||||
obj: OpenAPIV3.ReferenceObject | InvocationSchemaObject
|
||||
): obj is InvocationSchemaObject => !('$ref' in obj);
|
@ -0,0 +1,336 @@
|
||||
import { reduce } from 'lodash';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { FIELD_TYPE_MAP } from '../types/constants';
|
||||
import { isSchemaObject } from '../types/typeGuards';
|
||||
import {
|
||||
BooleanInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FloatInputFieldTemplate,
|
||||
ImageInputFieldTemplate,
|
||||
IntegerInputFieldTemplate,
|
||||
LatentsInputFieldTemplate,
|
||||
StringInputFieldTemplate,
|
||||
ModelInputFieldTemplate,
|
||||
InputFieldTemplateBase,
|
||||
OutputFieldTemplate,
|
||||
TypeHints,
|
||||
FieldType,
|
||||
} from '../types/types';
|
||||
|
||||
export type BaseFieldProperties = 'name' | 'title' | 'description';
|
||||
|
||||
export type BuildInputFieldArg = {
|
||||
schemaObject: OpenAPIV3.SchemaObject;
|
||||
baseField: Omit<
|
||||
InputFieldTemplateBase,
|
||||
'type' | 'inputRequirement' | 'inputKind'
|
||||
>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Transforms an invocation output ref object to field type.
|
||||
* @param ref The ref string to transform
|
||||
* @returns The field type.
|
||||
*
|
||||
* @example
|
||||
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
|
||||
*/
|
||||
export const refObjectToFieldType = (
|
||||
refObject: OpenAPIV3.ReferenceObject
|
||||
): keyof typeof FIELD_TYPE_MAP => refObject.$ref.split('/').slice(-1)[0];
|
||||
|
||||
const buildIntegerInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IntegerInputFieldTemplate => {
|
||||
const template: IntegerInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'integer',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'any',
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): FloatInputFieldTemplate => {
|
||||
const template: FloatInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'float',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'any',
|
||||
default: schemaObject.default ?? 0,
|
||||
};
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): StringInputFieldTemplate => {
|
||||
const template: StringInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'string',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'any',
|
||||
default: schemaObject.default ?? '',
|
||||
};
|
||||
|
||||
if (schemaObject.minLength !== undefined) {
|
||||
template.minLength = schemaObject.minLength;
|
||||
}
|
||||
|
||||
if (schemaObject.maxLength !== undefined) {
|
||||
template.maxLength = schemaObject.maxLength;
|
||||
}
|
||||
|
||||
if (schemaObject.pattern !== undefined) {
|
||||
template.pattern = schemaObject.pattern;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): BooleanInputFieldTemplate => {
|
||||
const template: BooleanInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'boolean',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'any',
|
||||
default: schemaObject.default ?? false,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ModelInputFieldTemplate => {
|
||||
const template: ModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ImageInputFieldTemplate => {
|
||||
const template: ImageInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'image',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'any',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildLatentsInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): LatentsInputFieldTemplate => {
|
||||
const template: LatentsInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'latents',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'connection',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildEnumInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): EnumInputFieldTemplate => {
|
||||
const options = schemaObject.enum ?? [];
|
||||
const template: EnumInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'enum',
|
||||
enumType: (schemaObject.type as 'string' | 'number') ?? 'string', // TODO: dangerous?
|
||||
options: options,
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? options[0],
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
export const getFieldType = (
|
||||
schemaObject: OpenAPIV3.SchemaObject,
|
||||
name: string,
|
||||
typeHints?: TypeHints
|
||||
): FieldType => {
|
||||
let rawFieldType = '';
|
||||
|
||||
if (typeHints && name in typeHints) {
|
||||
rawFieldType = typeHints[name];
|
||||
} else if (!schemaObject.type) {
|
||||
rawFieldType = refObjectToFieldType(
|
||||
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
|
||||
);
|
||||
} else if (schemaObject.enum) {
|
||||
rawFieldType = 'enum';
|
||||
} else if (schemaObject.type) {
|
||||
rawFieldType = schemaObject.type;
|
||||
}
|
||||
|
||||
const fieldType = FIELD_TYPE_MAP[rawFieldType];
|
||||
|
||||
if (!fieldType) {
|
||||
throw `Field type "${rawFieldType}" is unknown!`;
|
||||
}
|
||||
|
||||
return fieldType;
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds an input field from an invocation schema property.
|
||||
* @param schemaObject The schema object
|
||||
* @returns An input field
|
||||
*/
|
||||
export const buildInputFieldTemplate = (
|
||||
schemaObject: OpenAPIV3.SchemaObject,
|
||||
name: string,
|
||||
typeHints?: TypeHints
|
||||
) => {
|
||||
const fieldType = getFieldType(schemaObject, name, typeHints);
|
||||
|
||||
const baseField = {
|
||||
name,
|
||||
title: schemaObject.title ?? '',
|
||||
description: schemaObject.description ?? '',
|
||||
};
|
||||
|
||||
if (['image'].includes(fieldType)) {
|
||||
return buildImageInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['latents'].includes(fieldType)) {
|
||||
return buildLatentsInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['model'].includes(fieldType)) {
|
||||
return buildModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['enum'].includes(fieldType)) {
|
||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['integer'].includes(fieldType)) {
|
||||
return buildIntegerInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['number', 'float'].includes(fieldType)) {
|
||||
return buildFloatInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['string'].includes(fieldType)) {
|
||||
return buildStringInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['boolean'].includes(fieldType)) {
|
||||
return buildBooleanInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
||||
return;
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds invocation output fields from an invocation's output reference object.
|
||||
* @param openAPI The OpenAPI schema
|
||||
* @param refObject The output reference object
|
||||
* @returns A record of outputs
|
||||
*/
|
||||
export const buildOutputFieldTemplates = (
|
||||
refObject: OpenAPIV3.ReferenceObject,
|
||||
openAPI: OpenAPIV3.Document,
|
||||
typeHints?: TypeHints
|
||||
): Record<string, OutputFieldTemplate> => {
|
||||
// extract output schema name from ref
|
||||
const outputSchemaName = refObject.$ref.split('/').slice(-1)[0];
|
||||
|
||||
// get the output schema itself
|
||||
const outputSchema = openAPI.components!.schemas![outputSchemaName];
|
||||
|
||||
if (isSchemaObject(outputSchema)) {
|
||||
const outputFields = reduce(
|
||||
outputSchema.properties as OpenAPIV3.SchemaObject,
|
||||
(outputsAccumulator, property, propertyName) => {
|
||||
if (
|
||||
!['type', 'id'].includes(propertyName) &&
|
||||
!['object'].includes(property.type) && // TODO: handle objects?
|
||||
isSchemaObject(property)
|
||||
) {
|
||||
const fieldType = getFieldType(property, propertyName, typeHints);
|
||||
|
||||
outputsAccumulator[propertyName] = {
|
||||
name: propertyName,
|
||||
title: property.title ?? '',
|
||||
description: property.description ?? '',
|
||||
type: fieldType,
|
||||
};
|
||||
}
|
||||
|
||||
return outputsAccumulator;
|
||||
},
|
||||
{} as Record<string, OutputFieldTemplate>
|
||||
);
|
||||
|
||||
return outputFields;
|
||||
}
|
||||
|
||||
return {};
|
||||
};
|
@ -0,0 +1,57 @@
|
||||
import { InputFieldTemplate, InputFieldValue } from '../types/types';
|
||||
|
||||
export const buildInputFieldValue = (
|
||||
id: string,
|
||||
template: InputFieldTemplate
|
||||
): InputFieldValue => {
|
||||
const fieldValue: InputFieldValue = {
|
||||
id,
|
||||
name: template.name,
|
||||
type: template.type,
|
||||
};
|
||||
|
||||
if (template.inputRequirement !== 'never') {
|
||||
if (template.type === 'string') {
|
||||
fieldValue.value = template.default ?? '';
|
||||
}
|
||||
|
||||
if (template.type === 'integer') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
|
||||
if (template.type === 'float') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
|
||||
if (template.type === 'boolean') {
|
||||
fieldValue.value = template.default ?? false;
|
||||
}
|
||||
|
||||
if (template.type === 'enum') {
|
||||
if (template.enumType === 'number') {
|
||||
fieldValue.value = template.default ?? 0;
|
||||
}
|
||||
if (template.enumType === 'string') {
|
||||
fieldValue.value = template.default ?? '';
|
||||
}
|
||||
}
|
||||
|
||||
if (template.type === 'array') {
|
||||
fieldValue.value = template.default ?? 1;
|
||||
}
|
||||
|
||||
if (template.type === 'image') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'latents') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
};
|
@ -0,0 +1,39 @@
|
||||
import {
|
||||
Edge,
|
||||
ImageToImageInvocation,
|
||||
IterateInvocation,
|
||||
RandomRangeInvocation,
|
||||
RangeInvocation,
|
||||
TextToImageInvocation,
|
||||
} from 'services/api';
|
||||
|
||||
export const buildEdges = (
|
||||
baseNode: TextToImageInvocation | ImageToImageInvocation,
|
||||
rangeNode: RangeInvocation | RandomRangeInvocation,
|
||||
iterateNode: IterateInvocation
|
||||
): Edge[] => {
|
||||
const edges: Edge[] = [
|
||||
{
|
||||
source: {
|
||||
node_id: rangeNode.id,
|
||||
field: 'collection',
|
||||
},
|
||||
destination: {
|
||||
node_id: iterateNode.id,
|
||||
field: 'collection',
|
||||
},
|
||||
},
|
||||
{
|
||||
source: {
|
||||
node_id: iterateNode.id,
|
||||
field: 'item',
|
||||
},
|
||||
destination: {
|
||||
node_id: baseNode.id,
|
||||
field: 'seed',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
return edges;
|
||||
};
|
@ -0,0 +1,99 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { RootState } from 'app/store';
|
||||
import {
|
||||
Edge,
|
||||
ImageToImageInvocation,
|
||||
TextToImageInvocation,
|
||||
} from 'services/api';
|
||||
import { _Image } from 'app/invokeai';
|
||||
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
|
||||
|
||||
export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => {
|
||||
const nodeId = uuidv4();
|
||||
const { generation, system, models } = state;
|
||||
|
||||
const { selectedModelName } = models;
|
||||
|
||||
const {
|
||||
prompt,
|
||||
seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfgScale,
|
||||
sampler,
|
||||
seamless,
|
||||
img2imgStrength: strength,
|
||||
shouldFitToWidthHeight: fit,
|
||||
shouldRandomizeSeed,
|
||||
} = generation;
|
||||
|
||||
const initialImage = initialImageSelector(state);
|
||||
|
||||
if (!initialImage) {
|
||||
// TODO: handle this
|
||||
throw 'no initial image';
|
||||
}
|
||||
|
||||
const imageToImageNode: ImageToImageInvocation = {
|
||||
id: nodeId,
|
||||
type: 'img2img',
|
||||
prompt,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale: cfgScale,
|
||||
scheduler: sampler as ImageToImageInvocation['scheduler'],
|
||||
seamless,
|
||||
model: selectedModelName,
|
||||
progress_images: true,
|
||||
image: {
|
||||
image_name: initialImage.name,
|
||||
image_type: initialImage.type,
|
||||
},
|
||||
strength,
|
||||
fit,
|
||||
};
|
||||
|
||||
if (!shouldRandomizeSeed) {
|
||||
imageToImageNode.seed = seed;
|
||||
}
|
||||
|
||||
return imageToImageNode;
|
||||
};
|
||||
|
||||
type hiresReturnType = {
|
||||
node: Record<string, ImageToImageInvocation>;
|
||||
edge: Edge;
|
||||
};
|
||||
|
||||
export const buildHiResNode = (
|
||||
baseNode: Record<string, TextToImageInvocation>,
|
||||
strength?: number
|
||||
): hiresReturnType => {
|
||||
const nodeId = uuidv4();
|
||||
const baseNodeId = Object.keys(baseNode)[0];
|
||||
const baseNodeValues = Object.values(baseNode)[0];
|
||||
|
||||
return {
|
||||
node: {
|
||||
[nodeId]: {
|
||||
...baseNodeValues,
|
||||
id: nodeId,
|
||||
type: 'img2img',
|
||||
strength,
|
||||
fit: true,
|
||||
},
|
||||
},
|
||||
edge: {
|
||||
source: {
|
||||
field: 'image',
|
||||
node_id: baseNodeId,
|
||||
},
|
||||
destination: {
|
||||
field: 'image',
|
||||
node_id: nodeId,
|
||||
},
|
||||
},
|
||||
};
|
||||
};
|
@ -0,0 +1,13 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import { IterateInvocation } from 'services/api';
|
||||
|
||||
export const buildIterateNode = (): IterateInvocation => {
|
||||
const nodeId = uuidv4();
|
||||
return {
|
||||
id: nodeId,
|
||||
type: 'iterate',
|
||||
collection: [],
|
||||
index: 0,
|
||||
};
|
||||
};
|
@ -0,0 +1,39 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { Graph } from 'services/api';
|
||||
import { buildImg2ImgNode } from './buildImageToImageNode';
|
||||
import { buildTxt2ImgNode } from './buildTextToImageNode';
|
||||
import { buildRangeNode } from './buildRangeNode';
|
||||
import { buildIterateNode } from './buildIterateNode';
|
||||
import { buildEdges } from './buildEdges';
|
||||
|
||||
/**
|
||||
* Builds the Linear workflow graph.
|
||||
*/
|
||||
export const buildLinearGraph = (state: RootState): Graph => {
|
||||
// The base node is either a txt2img or img2img node
|
||||
const baseNode = state.generation.isImageToImageEnabled
|
||||
? buildImg2ImgNode(state)
|
||||
: buildTxt2ImgNode(state);
|
||||
|
||||
// We always range and iterate nodes, no matter the iteration count
|
||||
// This is required to provide the correct seeds to the backend engine
|
||||
const rangeNode = buildRangeNode(state);
|
||||
const iterateNode = buildIterateNode();
|
||||
|
||||
// Build the edges for the nodes selected.
|
||||
const edges = buildEdges(baseNode, rangeNode, iterateNode);
|
||||
|
||||
// Assemble!
|
||||
const graph = {
|
||||
nodes: {
|
||||
[rangeNode.id]: rangeNode,
|
||||
[iterateNode.id]: iterateNode,
|
||||
[baseNode.id]: baseNode,
|
||||
},
|
||||
edges,
|
||||
};
|
||||
|
||||
// TODO: hires fix requires latent space upscaling; we don't have nodes for this yet
|
||||
|
||||
return graph;
|
||||
};
|
@ -0,0 +1,26 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
import { RootState } from 'app/store';
|
||||
import { RandomRangeInvocation, RangeInvocation } from 'services/api';
|
||||
|
||||
export const buildRangeNode = (
|
||||
state: RootState
|
||||
): RangeInvocation | RandomRangeInvocation => {
|
||||
const nodeId = uuidv4();
|
||||
const { shouldRandomizeSeed, iterations, seed } = state.generation;
|
||||
|
||||
if (shouldRandomizeSeed) {
|
||||
return {
|
||||
id: nodeId,
|
||||
type: 'random_range',
|
||||
size: iterations,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
id: nodeId,
|
||||
type: 'range',
|
||||
start: seed,
|
||||
stop: seed + iterations,
|
||||
};
|
||||
};
|
@ -0,0 +1,42 @@
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { RootState } from 'app/store';
|
||||
import { TextToImageInvocation } from 'services/api';
|
||||
|
||||
export const buildTxt2ImgNode = (state: RootState): TextToImageInvocation => {
|
||||
const nodeId = uuidv4();
|
||||
const { generation, models } = state;
|
||||
|
||||
const { selectedModelName } = models;
|
||||
|
||||
const {
|
||||
prompt,
|
||||
seed,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfgScale: cfg_scale,
|
||||
sampler,
|
||||
seamless,
|
||||
shouldRandomizeSeed,
|
||||
} = generation;
|
||||
|
||||
const textToImageNode: NonNullable<TextToImageInvocation> = {
|
||||
id: nodeId,
|
||||
type: 'txt2img',
|
||||
prompt,
|
||||
steps,
|
||||
width,
|
||||
height,
|
||||
cfg_scale,
|
||||
scheduler: sampler as TextToImageInvocation['scheduler'],
|
||||
seamless,
|
||||
model: selectedModelName,
|
||||
progress_images: true,
|
||||
};
|
||||
|
||||
if (!shouldRandomizeSeed) {
|
||||
textToImageNode.seed = seed;
|
||||
}
|
||||
|
||||
return textToImageNode;
|
||||
};
|
@ -0,0 +1,77 @@
|
||||
import { Graph } from 'services/api';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { reduce } from 'lodash';
|
||||
import { RootState } from 'app/store';
|
||||
import { AnyInvocation } from 'services/events/types';
|
||||
|
||||
/**
|
||||
* Builds a graph from the node editor state.
|
||||
*/
|
||||
export const buildNodesGraph = (state: RootState): Graph => {
|
||||
const { nodes, edges } = state.nodes;
|
||||
|
||||
// Reduce the node editor nodes into invocation graph nodes
|
||||
const parsedNodes = nodes.reduce<NonNullable<Graph['nodes']>>(
|
||||
(nodesAccumulator, node, nodeIndex) => {
|
||||
const { id, data } = node;
|
||||
const { type, inputs } = data;
|
||||
|
||||
// Transform each node's inputs to simple key-value pairs
|
||||
const transformedInputs = reduce(
|
||||
inputs,
|
||||
(inputsAccumulator, input, name) => {
|
||||
inputsAccumulator[name] = input.value;
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<Exclude<string, 'id' | 'type'>, any>
|
||||
);
|
||||
|
||||
// Build this specific node
|
||||
const graphNode = {
|
||||
type,
|
||||
id,
|
||||
...transformedInputs,
|
||||
};
|
||||
|
||||
// Add it to the nodes object
|
||||
Object.assign(nodesAccumulator, {
|
||||
[id]: graphNode,
|
||||
});
|
||||
|
||||
return nodesAccumulator;
|
||||
},
|
||||
{}
|
||||
);
|
||||
|
||||
// Reduce the node editor edges into invocation graph edges
|
||||
const parsedEdges = edges.reduce<NonNullable<Graph['edges']>>(
|
||||
(edgesAccumulator, edge, edgeIndex) => {
|
||||
const { source, target, sourceHandle, targetHandle } = edge;
|
||||
|
||||
// Format the edges and add to the edges array
|
||||
edgesAccumulator.push({
|
||||
source: {
|
||||
node_id: source,
|
||||
field: sourceHandle as string,
|
||||
},
|
||||
destination: {
|
||||
node_id: target,
|
||||
field: targetHandle as string,
|
||||
},
|
||||
});
|
||||
|
||||
return edgesAccumulator;
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
// Assemble!
|
||||
const graph = {
|
||||
id: uuidv4(),
|
||||
nodes: parsedNodes,
|
||||
edges: parsedEdges,
|
||||
};
|
||||
|
||||
return graph;
|
||||
};
|
120
invokeai/frontend/web/src/features/nodes/util/parseSchema.ts
Normal file
120
invokeai/frontend/web/src/features/nodes/util/parseSchema.ts
Normal file
@ -0,0 +1,120 @@
|
||||
import { filter, reduce } from 'lodash';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { isSchemaObject } from '../types/typeGuards';
|
||||
import {
|
||||
InputFieldTemplate,
|
||||
InvocationSchemaObject,
|
||||
InvocationTemplate,
|
||||
isInvocationSchemaObject,
|
||||
OutputFieldTemplate,
|
||||
} from '../types/types';
|
||||
import {
|
||||
buildInputFieldTemplate,
|
||||
buildOutputFieldTemplates,
|
||||
} from './fieldTemplateBuilders';
|
||||
|
||||
const invocationBlacklist = ['Graph', 'Collect', 'LoadImage'];
|
||||
|
||||
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
|
||||
// filter out non-invocation schemas, plus some tricky invocations for now
|
||||
const filteredSchemas = filter(
|
||||
openAPI.components!.schemas,
|
||||
(schema, key) =>
|
||||
key.includes('Invocation') &&
|
||||
!key.includes('InvocationOutput') &&
|
||||
!invocationBlacklist.some((blacklistItem) => key.includes(blacklistItem))
|
||||
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
|
||||
|
||||
const invocations = filteredSchemas.reduce<
|
||||
Record<string, InvocationTemplate>
|
||||
>((acc, schema) => {
|
||||
// only want SchemaObjects
|
||||
if (isInvocationSchemaObject(schema)) {
|
||||
const type = schema.properties.type.default;
|
||||
|
||||
const title =
|
||||
schema.ui?.title ??
|
||||
schema.title
|
||||
.replace('Invocation', '')
|
||||
.split(/(?=[A-Z])/) // split PascalCase into array
|
||||
.join(' ');
|
||||
|
||||
const typeHints = schema.ui?.type_hints;
|
||||
|
||||
const inputs = reduce(
|
||||
schema.properties,
|
||||
(inputsAccumulator, property, propertyName) => {
|
||||
if (
|
||||
// `type` and `id` are not valid inputs/outputs
|
||||
!['type', 'id'].includes(propertyName) &&
|
||||
isSchemaObject(property)
|
||||
) {
|
||||
let field: InputFieldTemplate | undefined;
|
||||
if (propertyName === 'collection') {
|
||||
field = {
|
||||
default: property.default ?? [],
|
||||
name: 'collection',
|
||||
title: property.title ?? '',
|
||||
description: property.description ?? '',
|
||||
type: 'array',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'connection',
|
||||
};
|
||||
} else {
|
||||
field = buildInputFieldTemplate(
|
||||
property,
|
||||
propertyName,
|
||||
typeHints
|
||||
);
|
||||
}
|
||||
if (field) {
|
||||
inputsAccumulator[propertyName] = field;
|
||||
}
|
||||
}
|
||||
return inputsAccumulator;
|
||||
},
|
||||
{} as Record<string, InputFieldTemplate>
|
||||
);
|
||||
|
||||
const rawOutput = (schema as InvocationSchemaObject).output;
|
||||
|
||||
let outputs: Record<string, OutputFieldTemplate>;
|
||||
|
||||
// some special handling is needed for collect, iterate and range nodes
|
||||
if (type === 'iterate') {
|
||||
// this is guaranteed to be a SchemaObject
|
||||
const iterationOutput = openAPI.components!.schemas![
|
||||
'IterateInvocationOutput'
|
||||
] as OpenAPIV3.SchemaObject;
|
||||
|
||||
outputs = {
|
||||
item: {
|
||||
name: 'item',
|
||||
title: iterationOutput.title ?? '',
|
||||
description: iterationOutput.description ?? '',
|
||||
type: 'array',
|
||||
},
|
||||
};
|
||||
} else {
|
||||
outputs = buildOutputFieldTemplates(rawOutput, openAPI, typeHints);
|
||||
}
|
||||
|
||||
const invocation: InvocationTemplate = {
|
||||
title,
|
||||
type,
|
||||
tags: schema.ui?.tags ?? [],
|
||||
description: schema.description ?? '',
|
||||
inputs,
|
||||
outputs,
|
||||
};
|
||||
|
||||
Object.assign(acc, { [type]: invocation });
|
||||
}
|
||||
|
||||
return acc;
|
||||
}, {});
|
||||
|
||||
console.debug('Generated invocations: ', invocations);
|
||||
|
||||
return invocations;
|
||||
};
|
@ -6,19 +6,18 @@ import {
|
||||
Box,
|
||||
Flex,
|
||||
} from '@chakra-ui/react';
|
||||
import { Feature } from 'app/features';
|
||||
import GuideIcon from 'common/components/GuideIcon';
|
||||
import { ReactNode } from 'react';
|
||||
import { ParametersAccordionItem } from '../ParametersAccordion';
|
||||
|
||||
export interface InvokeAccordionItemProps {
|
||||
header: string;
|
||||
content: ReactNode;
|
||||
feature?: Feature;
|
||||
additionalHeaderComponents?: ReactNode;
|
||||
}
|
||||
type InvokeAccordionItemProps = {
|
||||
accordionItem: ParametersAccordionItem;
|
||||
};
|
||||
|
||||
export default function InvokeAccordionItem(props: InvokeAccordionItemProps) {
|
||||
const { header, feature, content, additionalHeaderComponents } = props;
|
||||
export default function InvokeAccordionItem({
|
||||
accordionItem,
|
||||
}: InvokeAccordionItemProps) {
|
||||
const { header, feature, content, additionalHeaderComponents } =
|
||||
accordionItem;
|
||||
|
||||
return (
|
||||
<AccordionItem>
|
||||
@ -32,7 +31,7 @@ export default function InvokeAccordionItem(props: InvokeAccordionItemProps) {
|
||||
<AccordionIcon />
|
||||
</Flex>
|
||||
</AccordionButton>
|
||||
<AccordionPanel>{content}</AccordionPanel>
|
||||
<AccordionPanel p={4}>{content}</AccordionPanel>
|
||||
</AccordionItem>
|
||||
);
|
||||
}
|
||||
|
@ -115,9 +115,7 @@ const InfillAndScalingSettings = () => {
|
||||
onChange={handleChangeBoundingBoxScaleMethod}
|
||||
/>
|
||||
<IAISlider
|
||||
isInputDisabled={!isManual}
|
||||
isResetDisabled={!isManual}
|
||||
isSliderDisabled={!isManual}
|
||||
isDisabled={!isManual}
|
||||
label={t('parameters.scaledWidth')}
|
||||
min={64}
|
||||
max={1024}
|
||||
@ -132,9 +130,7 @@ const InfillAndScalingSettings = () => {
|
||||
handleReset={handleResetScaledWidth}
|
||||
/>
|
||||
<IAISlider
|
||||
isInputDisabled={!isManual}
|
||||
isResetDisabled={!isManual}
|
||||
isSliderDisabled={!isManual}
|
||||
isDisabled={!isManual}
|
||||
label={t('parameters.scaledHeight')}
|
||||
min={64}
|
||||
max={1024}
|
||||
@ -155,9 +151,7 @@ const InfillAndScalingSettings = () => {
|
||||
onChange={(e) => dispatch(setInfillMethod(e.target.value))}
|
||||
/>
|
||||
<IAISlider
|
||||
isInputDisabled={infillMethod !== 'tile'}
|
||||
isResetDisabled={infillMethod !== 'tile'}
|
||||
isSliderDisabled={infillMethod !== 'tile'}
|
||||
isDisabled={infillMethod !== 'tile'}
|
||||
label={t('parameters.tileSize')}
|
||||
min={16}
|
||||
max={64}
|
||||
|
@ -18,9 +18,7 @@ export default function CodeformerFidelity() {
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
isSliderDisabled={!isGFPGANAvailable}
|
||||
isInputDisabled={!isGFPGANAvailable}
|
||||
isResetDisabled={!isGFPGANAvailable}
|
||||
isDisabled={!isGFPGANAvailable}
|
||||
label={t('parameters.codeformerFidelity')}
|
||||
step={0.05}
|
||||
min={0}
|
||||
|
@ -18,9 +18,7 @@ export default function FaceRestoreStrength() {
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
isSliderDisabled={!isGFPGANAvailable}
|
||||
isInputDisabled={!isGFPGANAvailable}
|
||||
isResetDisabled={!isGFPGANAvailable}
|
||||
isDisabled={!isGFPGANAvailable}
|
||||
label={t('parameters.strength')}
|
||||
step={0.05}
|
||||
min={0}
|
||||
|
@ -12,6 +12,10 @@ export default function ImageFit() {
|
||||
(state: RootState) => state.generation.shouldFitToWidthHeight
|
||||
);
|
||||
|
||||
const isImageToImageEnabled = useAppSelector(
|
||||
(state: RootState) => state.generation.isImageToImageEnabled
|
||||
);
|
||||
|
||||
const handleChangeFit = (e: ChangeEvent<HTMLInputElement>) =>
|
||||
dispatch(setShouldFitToWidthHeight(e.target.checked));
|
||||
|
||||
@ -19,6 +23,7 @@ export default function ImageFit() {
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
isDisabled={!isImageToImageEnabled}
|
||||
label={t('parameters.imageFit')}
|
||||
isChecked={shouldFitToWidthHeight}
|
||||
onChange={handleChangeFit}
|
||||
|
@ -1,13 +1,15 @@
|
||||
import { VStack } from '@chakra-ui/react';
|
||||
import { Flex, Image, VStack } from '@chakra-ui/react';
|
||||
import ImageFit from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageFit';
|
||||
import ImageToImageStrength from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageToImageStrength';
|
||||
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import InitialImagePreview from './InitialImagePreview';
|
||||
|
||||
export default function ImageToImageSettings() {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<VStack gap={2} alignItems="stretch">
|
||||
<InitialImagePreview />
|
||||
<ImageToImageStrength label={t('parameters.img2imgStrength')} />
|
||||
<ImageFit />
|
||||
</VStack>
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user