Partial migration of UI to nodes API (#3195)

* feat(ui): add axios client generator and simple example

* fix(ui): update client & nodes test code w/ new Edge type

* chore(ui): organize generated files

* chore(ui): update .eslintignore, .prettierignore

* chore(ui): update openapi.json

* feat(backend): fixes for nodes/generator

* feat(ui): generate object args for api client

* feat(ui): more nodes api prototyping

* feat(ui): nodes cancel

* chore(ui): regenerate api client

* fix(ui): disable OG web server socket connection

* fix(ui): fix scrollbar styles typing and prop

just noticed the typo, and made the types stronger.

* feat(ui): add socketio types

* feat(ui): wip nodes

- extract api client method arg types instead of manually declaring them
- update example to display images
- general tidy up

* start building out node translations from frontend state and add notes about missing features

* use reference to sampler_name

* use reference to sampler_name

* add optional apiUrl prop

* feat(ui): start hooking up dynamic txt2img node generation, create middleware for session invocation

* feat(ui): write separate nodes socket layer, txt2img generating and rendering w single node

* feat(ui): img2img implementation

* feat(ui): get intermediate images working but types are stubbed out

* chore(ui): add support for package mode

* feat(ui): add nodes mode script

* feat(ui): handle random seeds

* fix(ui): fix middleware types

* feat(ui): add rtk action type guard

* feat(ui): disable NodeAPITest

This was polluting the network/socket logs.

* feat(ui): fix parameters panel border color

This commit should be elsewhere but I don't want to break my flow

* feat(ui): make thunk types more consistent

* feat(ui): add type guards for outputs

* feat(ui): load images on socket connect

Rudimentary

* chore(ui): bump redux-toolkit

* docs(ui): update readme

* chore(ui): regenerate api client

* chore(ui): add typescript as dev dependency

I am having trouble with TS versions after vscode updated and now uses TS 5. `madge` has installed 3.9.10 and for whatever reason my vscode wants to use that. Manually specifying 4.9.5 and then setting vscode to use that as the workspace TS fixes the issue.

* feat(ui): begin migrating gallery to nodes

Along the way, migrate to use RTK `createEntityAdapter` for gallery images, and separate `results` and `uploads` into separate slices. Much cleaner this way.

* feat(ui): clean up & comment results slice

* fix(ui): separate thunk for initial gallery load so it properly gets index 0

* feat(ui): POST upload working

* fix(ui): restore removed type

* feat(ui): patch api generation for headers access

* chore(ui): regenerate api

* feat(ui): wip gallery migration

* feat(ui): wip gallery migration

* chore(ui): regenerate api

* feat(ui): wip refactor socket events

* feat(ui): disable panels based on app props

* feat(ui): invert logic to be disabled

* disable panels when app mounts

* feat(ui): add support to disableTabs

* docs(ui): organise and update docs

* lang(ui): add toast strings

* feat(ui): wip events, comments, and general refactoring

* feat(ui): add optional token for auth

* feat(ui): export StatusIndicator and ModelSelect for header use

* feat(ui) working on making socket URL dynamic

* feat(ui): dynamic middleware loading

* feat(ui): prep for socket jwt

* feat(ui): migrate cancelation

also updated action names to be event-like instead of declaration-like

sorry, i was scattered and this commit has a lot of unrelated stuff in it.

* fix(ui): fix img2img type

* chore(ui): regenerate api client

* feat(ui): improve InvocationCompleteEvent types

* feat(ui): increase StatusIndicator font size

* fix(ui): fix middleware order for multi-node graphs

* feat(ui): add exampleGraphs object w/ iterations example

* feat(ui): generate iterations graph

* feat(ui): update ModelSelect for nodes API

* feat(ui): add hi-res functionality for txt2img generations

* feat(ui): "subscribe" to particular nodes

feels like a dirty hack but oh well it works

* feat(ui): first steps to node editor ui

* fix(ui): disable event subscription

it is not fully baked just yet

* feat(ui): wip node editor

* feat(ui): remove extraneous field types

* feat(ui): nodes before deleting stuff

* feat(ui): cleanup nodes ui stuff

* feat(ui): hook up nodes to redux

* fix(ui): fix handle

* fix(ui): add basic node edges & connection validation

* feat(ui): add connection validation styling

* feat(ui): increase edge width

* feat(ui): it blends

* feat(ui): wip model handling and graph topology validation

* feat(ui): validation connections w/ graphlib

* docs(ui): update nodes doc

* feat(ui): wip node editor

* chore(ui): rebuild api, update types

* add redux-dynamic-middlewares as a dependency

* feat(ui): add url host transformation

* feat(ui): handle already-connected fields

* feat(ui): rewrite SqliteItemStore in sqlalchemy

* fix(ui): fix sqlalchemy dynamic model instantiation

* feat(ui, nodes): metadata wip

* feat(ui, nodes): models

* feat(ui, nodes): more metadata wip

* feat(ui): wip range/iterate

* fix(nodes): fix sqlite typing

* feat(ui): export new type for invoke component

* tests(nodes): fix test instantiation of ImageField

* feat(nodes): fix LoadImageInvocation

* feat(nodes): add `title` ui hint

* feat(nodes): make ImageField attrs optional

* feat(ui): wip nodes etc

* feat(nodes): roll back sqlalchemy

* fix(nodes): partially address feedback

* fix(backend): roll back changes to pngwriter

* feat(nodes): wip address metadata feedback

* feat(nodes): add seeded rng to RandomRange

* feat(nodes): address feedback

* feat(nodes): move GET images error handling to DiskImageStorage

* feat(nodes): move GET images error handling to DiskImageStorage

* fix(nodes): fix image output schema customization

* feat(ui): img2img/txt2img -> linear

- remove txt2img and img2img tabs
- add linear tab
- add initial image selection to linear parameters accordion

* feat(ui): tidy graph builders

* feat(ui): tidy misc

* feat(ui): improve invocation union types

* feat(ui): wip metadata viewer recall

* feat(ui): move fonts to normal deps

* feat(nodes): fix broken upload

* feat(nodes): add metadata module + tests, thumbnails

- `MetadataModule` is stateless and needed in places where the `InvocationContext` is not available, so have not made it a `service`
- Handles loading/parsing/building metadata, and creating png info objects
- added tests for MetadataModule
- Lifted thumbnail stuff to util

* fix(nodes): revert change to RandomRangeInvocation

* feat(nodes): address feedback

- make metadata a service
- rip out pydantic validation, implement metadata parsing as simple functions
- update tests
- address other minor feedback items

* fix(nodes): fix other tests

* fix(nodes): add metadata service to cli

* fix(nodes): fix latents/image field parsing

* feat(nodes): customise LatentsField schema

* feat(nodes): move metadata parsing to frontend

* fix(nodes): fix metadata test

---------

Co-authored-by: maryhipp <maryhipp@gmail.com>
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
This commit is contained in:
psychedelicious
2023-04-22 13:10:20 +10:00
committed by GitHub
parent fdad62e88b
commit 5f498e10bd
324 changed files with 13051 additions and 1400 deletions

View File

@ -6,3 +6,5 @@ stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@ -3,4 +3,8 @@ dist/
node_modules/
patches/
stats.html
index.html
.yarn/
*.scss
src/services/api/
src/services/fixtures/*

View File

@ -0,0 +1,87 @@
# Generated axios API client
- [Generated axios API client](#generated-axios-api-client)
- [Generation](#generation)
- [Generate the API client from the nodes web server](#generate-the-api-client-from-the-nodes-web-server)
- [Generate the API client from JSON](#generate-the-api-client-from-json)
- [Getting the JSON from the nodes web server](#getting-the-json-from-the-nodes-web-server)
- [Getting the JSON with a python script](#getting-the-json-with-a-python-script)
- [Generate the API client](#generate-the-api-client)
- [The generated client](#the-generated-client)
- [API client customisation](#api-client-customisation)
This API client is generated by an [openapi code generator](https://github.com/ferdikoomen/openapi-typescript-codegen).
All files in `invokeai/frontend/web/src/services/api/` are made by the generator.
## Generation
The axios client may be generated by from the OpenAPI schema from the nodes web server, or from JSON.
### Generate the API client from the nodes web server
We need to start the nodes web server, which serves the OpenAPI schema to the generator.
1. Start the nodes web server.
```bash
# from the repo root
python scripts/invoke-new.py --web
```
2. Generate the API client.
```bash
# from invokeai/frontend/web/
yarn api:web
```
### Generate the API client from JSON
The JSON can be acquired from the nodes web server, or with a python script.
#### Getting the JSON from the nodes web server
Start the nodes web server as described above, then download the file.
```bash
# from invokeai/frontend/web/
curl http://localhost:9090/openapi.json -o openapi.json
```
#### Getting the JSON with a python script
Run this python script from the repo root, so it can access the nodes server modules.
The script will output `openapi.json` in the repo root. Then we need to move it to `invokeai/frontend/web/`.
```bash
# from the repo root
python invokeai/app/util/generate_openapi_json.py
mv invokeai/app/util/openapi.json invokeai/frontend/web/services/fixtures/
```
#### Generate the API client
Now we can generate the API client from the JSON.
```bash
# from invokeai/frontend/web/
yarn api:file
```
## The generated client
The client will be written to `invokeai/frontend/web/services/api/`:
- `axios` client
- TS types
- An easily parseable schema, which we can use to generate UI
## API client customisation
The generator has a default `request.ts` file that implements a base `axios` client. The generated client uses this base client.
One shortcoming of this is base client is it does not provide response headers unless the response body is empty. To fix this, we provide our own lightly-patched `request.ts`.
To access the headers, call `getHeaders(response)` on any response from the generated api client. This function is exported from `invokeai/frontend/web/src/services/util/getHeaders.ts`.

View File

@ -0,0 +1,21 @@
# Events
Events via `socket.io`
## `actions.ts`
Redux actions for all socket events. Payloads all include a timestamp, and optionally some other data.
Any reducer (or middleware) can respond to the actions.
## `middleware.ts`
Redux middleware for events.
Handles dispatching the event actions. Only put logic here if it can't really go anywhere else.
For example, on connect we want to load images to the gallery if it's not populated. This requires dispatching a thunk, so we need to directly dispatch this in the middleware.
## `types.ts`
Hand-written types for the socket events. Cannot generate these from the server, but fortunately they are few and simple.

View File

@ -0,0 +1,17 @@
# Node Editor Design
WIP
nodes
everything in `src/features/nodes/`
have a look at `state.nodes.invocation`
- on socket connect, if no schema saved, fetch `localhost:9090/openapi.json`, save JSON to `state.nodes.schema`
- on fulfilled schema fetch, `parseSchema()` the schema. this outputs a `Record<string, Invocation>` which is saved to `state.nodes.invocations` - `Invocation` is like a template for the node
- when you add a node, the the `Invocation` template is passed to `InvocationComponent.tsx` to build the UI component for that node
- inputs/outputs have field types - and each field type gets an `FieldComponent` which includes a dispatcher to write state changes to redux `nodesSlice`
- `reactflow` sends changes to nodes/edges to redux
- to invoke, `buildNodesGraph()` state, then send this
- changed onClick Invoke button actions to build the schema, then when schema builds it dispatches the actual network request to create the session - see `session.ts`

View File

@ -0,0 +1,29 @@
# Package Scripts
WIP walkthrough of `package.json` scripts.
## `theme` & `theme:watch`
These run the Chakra CLI to generate types for the theme, or watch for code change and re-generate the types.
The CLI essentially monkeypatches Chakra's files in `node_modules`.
## `postinstall`
The `postinstall` script patches a few packages and runs the Chakra CLI to generate types for the theme.
### Patch `@chakra-ui/cli`
See: <https://github.com/chakra-ui/chakra-ui/issues/7394>
### Patch `redux-persist`
We want to persist the canvas state to `localStorage` but many canvas operations change data very quickly, so we need to debounce the writes to `localStorage`.
`redux-persist` is unfortunately unmaintained. The repo's current code is nonfunctional, but the last release's code depends on a package that was removed from `npm` for being malware, so we cannot just fork it.
So, we have to patch it directly. Perhaps a better way would be to write a debounced storage adapter, but I couldn't figure out how to do that.
### Patch `redux-deep-persist`
This package makes blacklisting and whitelisting persist configs very simple, but we have to patch it to match `redux-persist` for the types to work.

View File

@ -1,10 +1,16 @@
# InvokeAI Web UI
- [InvokeAI Web UI](#invokeai-web-ui)
- [Stack](#stack)
- [Contributing](#contributing)
- [Dev Environment](#dev-environment)
- [Production builds](#production-builds)
The UI is a fairly straightforward Typescript React app. The only really fancy stuff is the Unified Canvas.
Code in `invokeai/frontend/web/` if you want to have a look.
## Details
## Stack
State management is Redux via [Redux Toolkit](https://github.com/reduxjs/redux-toolkit). Communication with server is a mix of HTTP and [socket.io](https://github.com/socketio/socket.io-client) (with a custom redux middleware to help).
@ -32,7 +38,7 @@ Start everything in dev mode:
1. Start the dev server: `yarn dev`
2. Start the InvokeAI UI per usual: `invokeai --web`
3. Point your browser to the dev server address e.g. `http://localhost:5173/`
3. Point your browser to the dev server address e.g. <http://localhost:5173/>
### Production builds

View File

@ -1,6 +1,7 @@
import React, { PropsWithChildren } from 'react';
import { IAIPopoverProps } from '../web/src/common/components/IAIPopover';
import { IAIIconButtonProps } from '../web/src/common/components/IAIIconButton';
import { InvokeTabName } from 'features/ui/store/tabMap';
export {};
@ -64,9 +65,25 @@ declare module '@invoke-ai/invoke-ai-ui' {
declare class SettingsModal extends React.Component<SettingsModalProps> {
public constructor(props: SettingsModalProps);
}
declare class StatusIndicator extends React.Component<StatusIndicatorProps> {
public constructor(props: StatusIndicatorProps);
}
declare class ModelSelect extends React.Component<ModelSelectProps> {
public constructor(props: ModelSelectProps);
}
}
declare function Invoke(props: PropsWithChildren): JSX.Element;
interface InvokeProps extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
shouldTransformUrls?: boolean;
}
declare function Invoke(props: InvokeProps): JSX.Element;
export {
ThemeChanger,
@ -74,5 +91,7 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};
export = Invoke;

View File

@ -5,7 +5,10 @@
"scripts": {
"prepare": "cd ../../../ && husky install invokeai/frontend/web/.husky",
"dev": "concurrently \"vite dev\" \"yarn run theme:watch\"",
"dev:nodes": "concurrently \"vite dev --mode nodes\" \"yarn run theme:watch\"",
"build": "yarn run lint && vite build",
"api:web": "openapi -i http://localhost:9090/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"api:file": "openapi -i src/services/fixtures/openapi.json -o src/services/api --client axios --useOptions --useUnionTypes --exportSchemas true --indent 2 --request src/services/fixtures/request.ts",
"preview": "vite preview",
"lint:madge": "madge --circular src/main.tsx",
"lint:eslint": "eslint --max-warnings=0 .",
@ -41,9 +44,11 @@
"@chakra-ui/react": "^2.5.1",
"@chakra-ui/styled-system": "^2.6.1",
"@chakra-ui/theme-tools": "^2.0.16",
"@dagrejs/graphlib": "^2.1.12",
"@emotion/react": "^11.10.6",
"@emotion/styled": "^11.10.6",
"@reduxjs/toolkit": "^1.9.2",
"@fontsource/inter": "^4.5.15",
"@reduxjs/toolkit": "^1.9.3",
"chakra-ui-contextmenu": "^1.0.5",
"dateformat": "^5.0.3",
"formik": "^2.2.9",
@ -67,15 +72,17 @@
"react-redux": "^8.0.5",
"react-transition-group": "^4.4.5",
"react-zoom-pan-pinch": "^2.6.1",
"reactflow": "^11.7.0",
"redux-deep-persist": "^1.0.7",
"redux-dynamic-middlewares": "^2.2.0",
"redux-persist": "^6.0.0",
"socket.io-client": "^4.6.0",
"use-image": "^1.1.0",
"uuid": "^9.0.0"
},
"devDependencies": {
"@fontsource/inter": "^4.5.15",
"@types/dateformat": "^5.0.0",
"@types/lodash": "^4.14.194",
"@types/react": "^18.0.28",
"@types/react-dom": "^18.0.11",
"@types/react-transition-group": "^4.4.5",
@ -83,6 +90,7 @@
"@typescript-eslint/eslint-plugin": "^5.52.0",
"@typescript-eslint/parser": "^5.52.0",
"@vitejs/plugin-react-swc": "^3.2.0",
"axios": "^1.3.4",
"babel-plugin-transform-imports": "^2.0.0",
"concurrently": "^7.6.0",
"eslint": "^8.34.0",
@ -90,13 +98,17 @@
"eslint-plugin-prettier": "^4.2.1",
"eslint-plugin-react": "^7.32.2",
"eslint-plugin-react-hooks": "^4.6.0",
"form-data": "^4.0.0",
"husky": "^8.0.3",
"lint-staged": "^13.1.2",
"madge": "^6.0.0",
"openapi-types": "^12.1.0",
"openapi-typescript-codegen": "^0.23.0",
"postinstall-postinstall": "^2.1.0",
"prettier": "^2.8.4",
"rollup-plugin-visualizer": "^5.9.0",
"terser": "^5.16.4",
"typescript": "4.9.5",
"vite": "^4.1.2",
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.0.5",

View File

@ -52,6 +52,7 @@
"txt2img": "Text To Image",
"img2img": "Image To Image",
"unifiedCanvas": "Unified Canvas",
"linear": "Linear",
"nodes": "Nodes",
"postprocessing": "Post Processing",
"nodesDesc": "A node based system for the generation of images is under development currently. Stay tuned for updates about this amazing feature.",
@ -524,6 +525,10 @@
"resetComplete": "Web UI has been reset. Refresh the page to reload."
},
"toast": {
"serverError": "Server Error",
"disconnected": "Disconnected from Server",
"connected": "Connected to Server",
"canceled": "Processing Canceled",
"tempFoldersEmptied": "Temp Folder Emptied",
"uploadFailed": "Upload failed",
"uploadFailedMultipleImagesDesc": "Multiple images pasted, may only upload one image at a time",

View File

@ -13,16 +13,42 @@ import { Box, Flex, Grid, Portal, useColorMode } from '@chakra-ui/react';
import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import ImageGalleryPanel from 'features/gallery/components/ImageGalleryPanel';
import Lightbox from 'features/lightbox/components/Lightbox';
import { useAppSelector } from './storeHooks';
import { useAppDispatch, useAppSelector } from './storeHooks';
import { PropsWithChildren, useEffect } from 'react';
import { setDisabledPanels, setDisabledTabs } from 'features/ui/store/uiSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { shouldTransformUrlsChanged } from 'features/system/store/systemSlice';
keepGUIAlive();
const App = (props: PropsWithChildren) => {
interface Props extends PropsWithChildren {
options: {
disabledPanels: string[];
disabledTabs: InvokeTabName[];
shouldTransformUrls?: boolean;
};
}
const App = (props: Props) => {
useToastWatcher();
const currentTheme = useAppSelector((state) => state.ui.currentTheme);
const { setColorMode } = useColorMode();
const dispatch = useAppDispatch();
useEffect(() => {
dispatch(setDisabledPanels(props.options.disabledPanels));
}, [dispatch, props.options.disabledPanels]);
useEffect(() => {
dispatch(setDisabledTabs(props.options.disabledTabs));
}, [dispatch, props.options.disabledTabs]);
useEffect(() => {
dispatch(
shouldTransformUrlsChanged(Boolean(props.options.shouldTransformUrls))
);
}, [dispatch, props.options.shouldTransformUrls]);
useEffect(() => {
setColorMode(['light'].includes(currentTheme) ? 'light' : 'dark');

View File

@ -14,6 +14,8 @@
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { ImageMetadata, ImageType } from 'services/api';
import { AnyInvocation } from 'services/events/types';
/**
* TODO:
@ -113,7 +115,7 @@ export declare type Metadata = SystemGenerationMetadata & {
};
// An Image has a UUID, url, modified timestamp, width, height and maybe metadata
export declare type Image = {
export declare type _Image = {
uuid: string;
url: string;
thumbnail: string;
@ -124,11 +126,23 @@ export declare type Image = {
category: GalleryCategory;
isBase64?: boolean;
dreamPrompt?: 'string';
name?: string;
};
/**
* ResultImage
*/
export declare type Image = {
name: string;
type: ImageType;
url: string;
thumbnail: string;
metadata: ImageMetadata;
};
// GalleryImages is an array of Image.
export declare type GalleryImages = {
images: Array<Image>;
images: Array<_Image>;
};
/**
@ -275,7 +289,7 @@ export declare type SystemStatusResponse = SystemStatus;
export declare type SystemConfigResponse = SystemConfig;
export declare type ImageResultResponse = Omit<Image, 'uuid'> & {
export declare type ImageResultResponse = Omit<_Image, 'uuid'> & {
boundingBox?: IRect;
generationMode: InvokeTabName;
};
@ -296,7 +310,7 @@ export declare type ErrorResponse = {
};
export declare type GalleryImagesResponse = {
images: Array<Omit<Image, 'uuid'>>;
images: Array<Omit<_Image, 'uuid'>>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};

View File

@ -20,6 +20,7 @@ export const readinessSelector = createSelector(
seedWeights,
initialImage,
seed,
isImageToImageEnabled,
} = generation;
const { isProcessing, isConnected } = system;
@ -33,7 +34,7 @@ export const readinessSelector = createSelector(
reasonsWhyNotReady.push('Missing prompt');
}
if (activeTabName === 'img2img' && !initialImage) {
if (isImageToImageEnabled && !initialImage) {
isReady = false;
reasonsWhyNotReady.push('No initial image selected');
}

View File

@ -13,9 +13,13 @@ import { InvokeTabName } from 'features/ui/store/tabMap';
export const generateImage = createAction<InvokeTabName>(
'socketio/generateImage'
);
export const runESRGAN = createAction<InvokeAI.Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI.Image>('socketio/runFacetool');
export const deleteImage = createAction<InvokeAI.Image>('socketio/deleteImage');
export const runESRGAN = createAction<InvokeAI._Image>('socketio/runESRGAN');
export const runFacetool = createAction<InvokeAI._Image>(
'socketio/runFacetool'
);
export const deleteImage = createAction<InvokeAI._Image>(
'socketio/deleteImage'
);
export const requestImages = createAction<GalleryCategory>(
'socketio/requestImages'
);

View File

@ -91,7 +91,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunESRGAN: (imageToProcess: InvokeAI.Image) => {
emitRunESRGAN: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
const {
@ -119,7 +119,7 @@ const makeSocketIOEmitters = (
})
);
},
emitRunFacetool: (imageToProcess: InvokeAI.Image) => {
emitRunFacetool: (imageToProcess: InvokeAI._Image) => {
dispatch(setIsProcessing(true));
const {
@ -150,7 +150,7 @@ const makeSocketIOEmitters = (
})
);
},
emitDeleteImage: (imageToDelete: InvokeAI.Image) => {
emitDeleteImage: (imageToDelete: InvokeAI._Image) => {
const { url, uuid, category, thumbnail } = imageToDelete;
dispatch(removeImage(imageToDelete));
socketio.emit('deleteImage', url, thumbnail, uuid, category);

View File

@ -34,8 +34,9 @@ import type { RootState } from 'app/store';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
clearInitialImage,
initialImageSelected,
setInfillMethod,
setInitialImage,
// setInitialImage,
setMaskPath,
} from 'features/parameters/store/generationSlice';
import { tabMap } from 'features/ui/store/tabMap';
@ -142,15 +143,17 @@ const makeSocketIOListeners = (
}
}
if (shouldLoopback) {
const activeTabName = tabMap[activeTab];
switch (activeTabName) {
case 'img2img': {
dispatch(setInitialImage(newImage));
break;
}
}
}
// TODO: fix
// if (shouldLoopback) {
// const activeTabName = tabMap[activeTab];
// switch (activeTabName) {
// case 'img2img': {
// dispatch(initialImageSelected(newImage.uuid));
// // dispatch(setInitialImage(newImage));
// break;
// }
// }
// }
dispatch(clearIntermediateImage());
@ -262,7 +265,7 @@ const makeSocketIOListeners = (
*/
// Generate a UUID for each image
const preparedImages = images.map((image): InvokeAI.Image => {
const preparedImages = images.map((image): InvokeAI._Image => {
return {
uuid: uuidv4(),
...image,
@ -334,7 +337,7 @@ const makeSocketIOListeners = (
if (
initialImage === url ||
(initialImage as InvokeAI.Image)?.url === url
(initialImage as InvokeAI._Image)?.url === url
) {
dispatch(clearInitialImage());
}

View File

@ -29,6 +29,8 @@ export const socketioMiddleware = () => {
path: `${window.location.pathname}socket.io`,
});
socketio.disconnect();
let areListenersSet = false;
const middleware: Middleware = (store) => (next) => (action) => {

View File

@ -2,18 +2,32 @@ import { combineReducers, configureStore } from '@reduxjs/toolkit';
import { persistReducer } from 'redux-persist';
import storage from 'redux-persist/lib/storage'; // defaults to localStorage for web
import dynamicMiddlewares from 'redux-dynamic-middlewares';
import { getPersistConfig } from 'redux-deep-persist';
import canvasReducer from 'features/canvas/store/canvasSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
import resultsReducer from 'features/gallery/store/resultsSlice';
import uploadsReducer from 'features/gallery/store/uploadsSlice';
import lightboxReducer from 'features/lightbox/store/lightboxSlice';
import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
import systemReducer from 'features/system/store/systemSlice';
import uiReducer from 'features/ui/store/uiSlice';
import modelsReducer from 'features/system/store/modelSlice';
import nodesReducer from 'features/nodes/store/nodesSlice';
import { socketioMiddleware } from './socketio/middleware';
import { socketMiddleware } from 'services/events/middleware';
import { canvasBlacklist } from 'features/canvas/store/canvasPersistBlacklist';
import { galleryBlacklist } from 'features/gallery/store/galleryPersistBlacklist';
import { generationBlacklist } from 'features/parameters/store/generationPersistBlacklist';
import { lightboxBlacklist } from 'features/lightbox/store/lightboxPersistBlacklist';
import { modelsBlacklist } from 'features/system/store/modelsPersistBlacklist';
import { nodesBlacklist } from 'features/nodes/store/nodesPersistBlacklist';
import { postprocessingBlacklist } from 'features/parameters/store/postprocessingPersistBlacklist';
import { systemBlacklist } from 'features/system/store/systemPersistsBlacklist';
import { uiBlacklist } from 'features/ui/store/uiPersistBlacklist';
/**
* redux-persist provides an easy and reliable way to persist state across reloads.
@ -29,49 +43,18 @@ import { socketioMiddleware } from './socketio/middleware';
* The necesssary nested persistors with blacklists are configured below.
*/
const canvasBlacklist = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
].map((blacklistItem) => `canvas.${blacklistItem}`);
const systemBlacklist = [
'currentIteration',
'currentStatus',
'currentStep',
'isCancelable',
'isConnected',
'isESRGANAvailable',
'isGFPGANAvailable',
'isProcessing',
'socketId',
'totalIterations',
'totalSteps',
'openModel',
'cancelOptions.cancelAfter',
].map((blacklistItem) => `system.${blacklistItem}`);
const galleryBlacklist = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
].map((blacklistItem) => `gallery.${blacklistItem}`);
const lightboxBlacklist = ['isLightboxOpen'].map(
(blacklistItem) => `lightbox.${blacklistItem}`
);
const rootReducer = combineReducers({
generation: generationReducer,
postprocessing: postprocessingReducer,
gallery: galleryReducer,
system: systemReducer,
canvas: canvasReducer,
ui: uiReducer,
gallery: galleryReducer,
generation: generationReducer,
lightbox: lightboxReducer,
models: modelsReducer,
nodes: nodesReducer,
postprocessing: postprocessingReducer,
results: resultsReducer,
system: systemReducer,
ui: uiReducer,
uploads: uploadsReducer,
});
const rootPersistConfig = getPersistConfig({
@ -80,23 +63,40 @@ const rootPersistConfig = getPersistConfig({
rootReducer,
blacklist: [
...canvasBlacklist,
...systemBlacklist,
...galleryBlacklist,
...generationBlacklist,
...lightboxBlacklist,
...modelsBlacklist,
...nodesBlacklist,
...postprocessingBlacklist,
// ...resultsBlacklist,
'results',
...systemBlacklist,
...uiBlacklist,
// ...uploadsBlacklist,
'uploads',
],
debounce: 300,
});
const persistedReducer = persistReducer(rootPersistConfig, rootReducer);
// Continue with store setup
// TODO: rip the old middleware out when nodes is complete
export function buildMiddleware() {
if (import.meta.env.MODE === 'nodes' || import.meta.env.MODE === 'package') {
return socketMiddleware();
} else {
return socketioMiddleware();
}
}
export const store = configureStore({
reducer: persistedReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
immutableCheck: false,
serializableCheck: false,
}).concat(socketioMiddleware()),
}).concat(dynamicMiddlewares),
devTools: {
// Uncommenting these very rapidly called actions makes the redux dev tools output much more readable
actionsDenylist: [

View File

@ -0,0 +1,8 @@
import { createAsyncThunk } from '@reduxjs/toolkit';
import { AppDispatch, RootState } from './store';
// https://redux-toolkit.js.org/usage/usage-with-typescript#defining-a-pre-typed-createasyncthunk
export const createAppAsyncThunk = createAsyncThunk.withTypes<{
state: RootState;
dispatch: AppDispatch;
}>();

View File

@ -44,12 +44,10 @@ export type IAIFullSliderProps = {
inputReadOnly?: boolean;
withReset?: boolean;
handleReset?: () => void;
isResetDisabled?: boolean;
isSliderDisabled?: boolean;
isInputDisabled?: boolean;
tooltipSuffix?: string;
hideTooltip?: boolean;
isCompact?: boolean;
isDisabled?: boolean;
sliderFormControlProps?: FormControlProps;
sliderFormLabelProps?: FormLabelProps;
sliderMarkProps?: Omit<SliderMarkProps, 'value'>;
@ -80,10 +78,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
withReset = false,
hideTooltip = false,
isCompact = false,
isDisabled = false,
handleReset,
isResetDisabled,
isSliderDisabled,
isInputDisabled,
sliderFormControlProps,
sliderFormLabelProps,
sliderMarkProps,
@ -149,6 +145,7 @@ const IAISlider = (props: IAIFullSliderProps) => {
}
: {}
}
isDisabled={isDisabled}
{...sliderFormControlProps}
>
<FormLabel {...sliderFormLabelProps} mb={-1}>
@ -166,15 +163,13 @@ const IAISlider = (props: IAIFullSliderProps) => {
onMouseEnter={() => setShowTooltip(true)}
onMouseLeave={() => setShowTooltip(false)}
focusThumbOnChange={false}
isDisabled={isSliderDisabled}
// width={width}
isDisabled={isDisabled}
{...rest}
>
{withSliderMarks && (
<>
<SliderMark
value={min}
// insetInlineStart={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
@ -185,7 +180,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
</SliderMark>
<SliderMark
value={max}
// insetInlineEnd={0}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
@ -221,7 +215,6 @@ const IAISlider = (props: IAIFullSliderProps) => {
value={localInputValue}
onChange={handleInputChange}
onBlur={handleInputBlur}
isDisabled={isInputDisabled}
{...sliderNumberInputProps}
>
<NumberInputField
@ -246,8 +239,8 @@ const IAISlider = (props: IAIFullSliderProps) => {
aria-label={t('accessibility.reset')}
tooltip="Reset"
icon={<BiReset />}
isDisabled={isDisabled}
onClick={handleResetDisable}
isDisabled={isResetDisabled}
{...sliderIAIIconButtonProps}
/>
)}

View File

@ -0,0 +1,79 @@
import { Badge, Box, ButtonGroup, Flex } from '@chakra-ui/react';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
import { useCallback } from 'react';
import IAIIconButton from 'common/components/IAIIconButton';
import { FaUndo, FaUpload } from 'react-icons/fa';
import { useTranslation } from 'react-i18next';
import { Image } from 'app/invokeai';
type ImageToImageOverlayProps = {
setIsLoaded: (isLoaded: boolean) => void;
image: Image;
};
const ImageToImageOverlay = ({
setIsLoaded,
image,
}: ImageToImageOverlayProps) => {
const isImageToImageEnabled = useAppSelector(
(state: RootState) => state.generation.isImageToImageEnabled
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleResetInitialImage = useCallback(() => {
dispatch(clearInitialImage());
setIsLoaded(false);
}, [dispatch, setIsLoaded]);
return (
<Box
sx={{
top: 0,
left: 0,
w: 'full',
h: 'full',
position: 'absolute',
}}
>
<ButtonGroup
sx={{
position: 'absolute',
top: 0,
right: 0,
p: 2,
}}
>
<IAIIconButton
size="sm"
isDisabled={!isImageToImageEnabled}
icon={<FaUndo />}
aria-label={t('accessibility.reset')}
onClick={handleResetInitialImage}
/>
<IAIIconButton
size="sm"
isDisabled={!isImageToImageEnabled}
icon={<FaUpload />}
aria-label={t('common.upload')}
/>
</ButtonGroup>
<Flex
sx={{
position: 'absolute',
bottom: 0,
left: 0,
p: 2,
alignItems: 'flex-start',
}}
>
<Badge variant="solid" colorScheme="base">
{image.metadata?.width} × {image.metadata?.height}
</Badge>
</Flex>
</Box>
);
};
export default ImageToImageOverlay;

View File

@ -2,7 +2,6 @@ import { Box, useToast } from '@chakra-ui/react';
import { ImageUploaderTriggerContext } from 'app/contexts/ImageUploaderTriggerContext';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import useImageUploader from 'common/hooks/useImageUploader';
import { uploadImage } from 'features/gallery/store/thunks/uploadImage';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ResourceKey } from 'i18next';
import {
@ -15,6 +14,7 @@ import {
} from 'react';
import { FileRejection, useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { imageUploaded } from 'services/thunks/image';
import ImageUploadOverlay from './ImageUploadOverlay';
type ImageUploaderProps = {
@ -49,7 +49,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
const fileAcceptedCallback = useCallback(
async (file: File) => {
dispatch(uploadImage({ imageFile: file }));
dispatch(imageUploaded({ formData: { file } }));
},
[dispatch]
);
@ -124,7 +124,7 @@ const ImageUploader = (props: ImageUploaderProps) => {
return;
}
dispatch(uploadImage({ imageFile: file }));
dispatch(imageUploaded({ formData: { file } }));
};
document.addEventListener('paste', pasteImageListener);
return () => {

View File

@ -0,0 +1,12 @@
import { Flex, Icon } from '@chakra-ui/react';
import { FaImage } from 'react-icons/fa';
const SelectImagePlaceholder = () => {
return (
<Flex sx={{ h: 36, alignItems: 'center', justifyContent: 'center' }}>
<Icon color="base.400" boxSize={32} as={FaImage}></Icon>
</Flex>
);
};
export default SelectImagePlaceholder;

View File

@ -1,27 +1,160 @@
import { Flex, Heading, Text, VStack } from '@chakra-ui/react';
import { useTranslation } from 'react-i18next';
import WorkInProgress from './WorkInProgress';
// import WorkInProgress from './WorkInProgress';
// import ReactFlow, {
// applyEdgeChanges,
// applyNodeChanges,
// Background,
// Controls,
// Edge,
// Handle,
// Node,
// NodeTypes,
// OnEdgesChange,
// OnNodesChange,
// Position,
// } from 'reactflow';
export default function NodesWIP() {
const { t } = useTranslation();
return (
<WorkInProgress>
<Flex
sx={{
flexDirection: 'column',
alignItems: 'center',
justifyContent: 'center',
w: '100%',
h: '100%',
gap: 4,
textAlign: 'center',
}}
>
<Heading>{t('common.nodes')}</Heading>
<VStack maxW="50rem" gap={4}>
<Text>{t('common.nodesDesc')}</Text>
</VStack>
</Flex>
</WorkInProgress>
);
}
// import 'reactflow/dist/style.css';
// import {
// Fragment,
// FunctionComponent,
// ReactNode,
// useCallback,
// useMemo,
// useState,
// } from 'react';
// import { OpenAPIV3 } from 'openapi-types';
// import { filter, map, reduce } from 'lodash';
// import {
// Box,
// Flex,
// FormControl,
// FormLabel,
// Input,
// Select,
// Switch,
// Text,
// NumberInput,
// NumberInputField,
// NumberInputStepper,
// NumberIncrementStepper,
// NumberDecrementStepper,
// Tooltip,
// chakra,
// Badge,
// Heading,
// VStack,
// HStack,
// Menu,
// MenuButton,
// MenuList,
// MenuItem,
// MenuItemOption,
// MenuGroup,
// MenuOptionGroup,
// MenuDivider,
// IconButton,
// } from '@chakra-ui/react';
// import { FaPlus } from 'react-icons/fa';
// import {
// FIELD_NAMES as FIELD_NAMES,
// FIELDS,
// INVOCATION_NAMES as INVOCATION_NAMES,
// INVOCATIONS,
// } from 'features/nodeEditor/constants';
// console.log('invocations', INVOCATIONS);
// const nodeTypes = reduce(
// INVOCATIONS,
// (acc, val, key) => {
// acc[key] = val.component;
// return acc;
// },
// {} as NodeTypes
// );
// console.log('nodeTypes', nodeTypes);
// // make initial nodes one of every node for now
// let n = 0;
// const initialNodes = map(INVOCATIONS, (i) => ({
// id: i.type,
// type: i.title,
// position: { x: (n += 20), y: (n += 20) },
// data: {},
// }));
// console.log('initialNodes', initialNodes);
// export default function NodesWIP() {
// const [nodes, setNodes] = useState<Node[]>([]);
// const [edges, setEdges] = useState<Edge[]>([]);
// const onNodesChange: OnNodesChange = useCallback(
// (changes) => setNodes((nds) => applyNodeChanges(changes, nds)),
// []
// );
// const onEdgesChange: OnEdgesChange = useCallback(
// (changes) => setEdges((eds: Edge[]) => applyEdgeChanges(changes, eds)),
// []
// );
// return (
// <Box
// sx={{
// position: 'relative',
// width: 'full',
// height: 'full',
// borderRadius: 'md',
// }}
// >
// <ReactFlow
// nodeTypes={nodeTypes}
// nodes={nodes}
// edges={edges}
// onNodesChange={onNodesChange}
// onEdgesChange={onEdgesChange}
// >
// <Background />
// <Controls />
// </ReactFlow>
// <HStack sx={{ position: 'absolute', top: 2, right: 2 }}>
// {FIELD_NAMES.map((field) => (
// <Badge
// key={field}
// colorScheme={FIELDS[field].color}
// sx={{ userSelect: 'none' }}
// >
// {field}
// </Badge>
// ))}
// </HStack>
// <Menu>
// <MenuButton
// as={IconButton}
// aria-label="Options"
// icon={<FaPlus />}
// sx={{ position: 'absolute', top: 2, left: 2 }}
// />
// <MenuList>
// {INVOCATION_NAMES.map((name) => {
// const invocation = INVOCATIONS[name];
// return (
// <Tooltip
// key={name}
// label={invocation.description}
// placement="end"
// hasArrow
// >
// <MenuItem>{invocation.title}</MenuItem>
// </Tooltip>
// );
// })}
// </MenuList>
// </Menu>
// </Box>
// );
// }
export default {};

View File

@ -14,6 +14,8 @@ const WorkInProgress = (props: WorkInProgressProps) => {
width: '100%',
height: '100%',
bg: 'base.850',
borderRadius: 'base',
position: 'relative',
}}
>
{children}

View File

@ -0,0 +1,119 @@
/**
* PARTIAL ZOD IMPLEMENTATION
*
* doesn't work well bc like most validators, zod is not built to skip invalid values.
* it mostly works but just seems clearer and simpler to manually parse for now.
*
* in the future it would be really nice if we could use zod for some things:
* - zodios (axios + zod): https://github.com/ecyrbe/zodios
* - openapi to zodios: https://github.com/astahmer/openapi-zod-client
*/
// import { z } from 'zod';
// const zMetadataStringField = z.string();
// export type MetadataStringField = z.infer<typeof zMetadataStringField>;
// const zMetadataIntegerField = z.number().int();
// export type MetadataIntegerField = z.infer<typeof zMetadataIntegerField>;
// const zMetadataFloatField = z.number();
// export type MetadataFloatField = z.infer<typeof zMetadataFloatField>;
// const zMetadataBooleanField = z.boolean();
// export type MetadataBooleanField = z.infer<typeof zMetadataBooleanField>;
// const zMetadataImageField = z.object({
// image_type: z.union([
// z.literal('results'),
// z.literal('uploads'),
// z.literal('intermediates'),
// ]),
// image_name: z.string().min(1),
// });
// export type MetadataImageField = z.infer<typeof zMetadataImageField>;
// const zMetadataLatentsField = z.object({
// latents_name: z.string().min(1),
// });
// export type MetadataLatentsField = z.infer<typeof zMetadataLatentsField>;
// /**
// * zod Schema for any node field. Use a `transform()` to manually parse, skipping invalid values.
// */
// const zAnyMetadataField = z.any().transform((val, ctx) => {
// // Grab the field name from the path
// const fieldName = String(ctx.path[ctx.path.length - 1]);
// // `id` and `type` must be strings if they exist
// if (['id', 'type'].includes(fieldName)) {
// const reservedStringPropertyResult = zMetadataStringField.safeParse(val);
// if (reservedStringPropertyResult.success) {
// return reservedStringPropertyResult.data;
// }
// return;
// }
// // Parse the rest of the fields, only returning the data if the parsing is successful
// const stringFieldResult = zMetadataStringField.safeParse(val);
// if (stringFieldResult.success) {
// return stringFieldResult.data;
// }
// const integerFieldResult = zMetadataIntegerField.safeParse(val);
// if (integerFieldResult.success) {
// return integerFieldResult.data;
// }
// const floatFieldResult = zMetadataFloatField.safeParse(val);
// if (floatFieldResult.success) {
// return floatFieldResult.data;
// }
// const booleanFieldResult = zMetadataBooleanField.safeParse(val);
// if (booleanFieldResult.success) {
// return booleanFieldResult.data;
// }
// const imageFieldResult = zMetadataImageField.safeParse(val);
// if (imageFieldResult.success) {
// return imageFieldResult.data;
// }
// const latentsFieldResult = zMetadataImageField.safeParse(val);
// if (latentsFieldResult.success) {
// return latentsFieldResult.data;
// }
// });
// /**
// * The node metadata schema.
// */
// const zNodeMetadata = z.object({
// session_id: z.string().min(1).optional(),
// node: z.record(z.string().min(1), zAnyMetadataField).optional(),
// });
// export type NodeMetadata = z.infer<typeof zNodeMetadata>;
// const zMetadata = z.object({
// invokeai: zNodeMetadata.optional(),
// 'sd-metadata': z.record(z.string().min(1), z.any()).optional(),
// });
// export type Metadata = z.infer<typeof zMetadata>;
// export const parseMetadata = (
// metadata: Record<string, any>
// ): Metadata | undefined => {
// const result = zMetadata.safeParse(metadata);
// if (!result.success) {
// console.log(result.error.issues);
// return;
// }
// return result.data;
// };
export default {};

View File

@ -0,0 +1,6 @@
import dateFormat from 'dateformat';
/**
* Get a `now` timestamp with 1s precision, formatted as ISO datetime.
*/
export const getTimestamp = () => dateFormat(new Date(), 'isoDateTime');

View File

@ -0,0 +1,28 @@
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { OpenAPI } from 'services/api';
export const getUrlAlt = (url: string, shouldTransformUrls: boolean) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
};
export const useGetUrl = () => {
const shouldTransformUrls = useAppSelector(
(state: RootState) => state.system.shouldTransformUrls
);
return {
shouldTransformUrls,
getUrl: (url?: string) => {
if (OpenAPI.BASE && shouldTransformUrls) {
return [OpenAPI.BASE, url].join('/');
}
return url;
},
};
};

View File

@ -0,0 +1,169 @@
import { forEach, size } from 'lodash';
import { ImageField, LatentsField } from 'services/api';
const OBJECT_TYPESTRING = '[object Object]';
const STRING_TYPESTRING = '[object String]';
const NUMBER_TYPESTRING = '[object Number]';
const BOOLEAN_TYPESTRING = '[object Boolean]';
const ARRAY_TYPESTRING = '[object Array]';
const isObject = (obj: unknown): obj is Record<string | number, any> =>
Object.prototype.toString.call(obj) === OBJECT_TYPESTRING;
const isString = (obj: unknown): obj is string =>
Object.prototype.toString.call(obj) === STRING_TYPESTRING;
const isNumber = (obj: unknown): obj is number =>
Object.prototype.toString.call(obj) === NUMBER_TYPESTRING;
const isBoolean = (obj: unknown): obj is boolean =>
Object.prototype.toString.call(obj) === BOOLEAN_TYPESTRING;
const isArray = (obj: unknown): obj is Array<any> =>
Object.prototype.toString.call(obj) === ARRAY_TYPESTRING;
const parseImageField = (imageField: unknown): ImageField | undefined => {
// Must be an object
if (!isObject(imageField)) {
return;
}
// An ImageField must have both `image_name` and `image_type`
if (!('image_name' in imageField && 'image_type' in imageField)) {
return;
}
// An ImageField's `image_type` must be one of the allowed values
if (
!['results', 'uploads', 'intermediates'].includes(imageField.image_type)
) {
return;
}
// An ImageField's `image_name` must be a string
if (typeof imageField.image_name !== 'string') {
return;
}
// Build a valid ImageField
return {
image_type: imageField.image_type,
image_name: imageField.image_name,
};
};
const parseLatentsField = (latentsField: unknown): LatentsField | undefined => {
// Must be an object
if (!isObject(latentsField)) {
return;
}
// A LatentsField must have a `latents_name`
if (!('latents_name' in latentsField)) {
return;
}
// A LatentsField's `latents_name` must be a string
if (typeof latentsField.latents_name !== 'string') {
return;
}
// Build a valid LatentsField
return {
latents_name: latentsField.latents_name,
};
};
type NodeMetadata = {
[key: string]: string | number | boolean | ImageField | LatentsField;
};
type InvokeAIMetadata = {
session_id?: string;
node?: NodeMetadata;
};
export const parseNodeMetadata = (
nodeMetadata: Record<string | number, any>
): NodeMetadata | undefined => {
if (!isObject(nodeMetadata)) {
return;
}
const parsed: NodeMetadata = {};
forEach(nodeMetadata, (nodeItem, nodeKey) => {
// `id` and `type` must be strings if they are present
if (['id', 'type'].includes(nodeKey)) {
if (isString(nodeItem)) {
parsed[nodeKey] = nodeItem;
}
return;
}
// the only valid object types are ImageField and LatentsField
if (isObject(nodeItem)) {
if ('image_name' in nodeItem || 'image_type' in nodeItem) {
const imageField = parseImageField(nodeItem);
if (imageField) {
parsed[nodeKey] = imageField;
}
return;
}
if ('latents_name' in nodeItem) {
const latentsField = parseLatentsField(nodeItem);
if (latentsField) {
parsed[nodeKey] = latentsField;
}
return;
}
}
// otherwise we accept any string, number or boolean
if (isString(nodeItem) || isNumber(nodeItem) || isBoolean(nodeItem)) {
parsed[nodeKey] = nodeItem;
return;
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};
export const parseInvokeAIMetadata = (
metadata: Record<string | number, any> | undefined
): InvokeAIMetadata | undefined => {
if (metadata === undefined) {
return;
}
if (!isObject(metadata)) {
return;
}
const parsed: InvokeAIMetadata = {};
forEach(metadata, (item, key) => {
if (key === 'session_id' && isString(item)) {
parsed['session_id'] = item;
}
if (key === 'node' && isObject(item)) {
const nodeMetadata = parseNodeMetadata(item);
if (nodeMetadata) {
parsed['node'] = nodeMetadata;
}
}
});
if (size(parsed) === 0) {
return;
}
return parsed;
};

View File

@ -1,8 +1,10 @@
import React, { lazy, PropsWithChildren } from 'react';
import React, { lazy, PropsWithChildren, useEffect, useState } from 'react';
import { Provider } from 'react-redux';
import { PersistGate } from 'redux-persist/integration/react';
import { store } from './app/store';
import { buildMiddleware, store } from './app/store';
import { persistor } from './persistor';
import { OpenAPI } from 'services/api';
import { InvokeTabName } from 'features/ui/store/tabMap';
import '@fontsource/inter/100.css';
import '@fontsource/inter/200.css';
import '@fontsource/inter/300.css';
@ -17,18 +19,61 @@ import Loading from './Loading';
// Localization
import './i18n';
import { addMiddleware, resetMiddlewares } from 'redux-dynamic-middlewares';
const App = lazy(() => import('./app/App'));
const ThemeLocaleProvider = lazy(() => import('./app/ThemeLocaleProvider'));
export default function Component(props: PropsWithChildren) {
interface Props extends PropsWithChildren {
apiUrl?: string;
disabledPanels?: string[];
disabledTabs?: InvokeTabName[];
token?: string;
shouldTransformUrls?: boolean;
}
export default function Component({
apiUrl,
disabledPanels = [],
disabledTabs = [],
token,
children,
shouldTransformUrls,
}: Props) {
useEffect(() => {
// configure API client token
if (token) {
OpenAPI.TOKEN = token;
}
// configure API client base url
if (apiUrl) {
OpenAPI.BASE = apiUrl;
}
// reset dynamically added middlewares
resetMiddlewares();
// TODO: at this point, after resetting the middleware, we really ought to clean up the socket
// stuff by calling `dispatch(socketReset())`. but we cannot dispatch from here as we are
// outside the provider. it's not needed until there is the possibility that we will change
// the `apiUrl`/`token` dynamically.
// rebuild socket middleware with token and apiUrl
addMiddleware(buildMiddleware());
}, [apiUrl, token]);
return (
<React.StrictMode>
<Provider store={store}>
<PersistGate loading={<Loading />} persistor={persistor}>
<React.Suspense fallback={<Loading showText />}>
<ThemeLocaleProvider>
<App>{props.children}</App>
<App
options={{ disabledPanels, disabledTabs, shouldTransformUrls }}
>
{children}
</App>
</ThemeLocaleProvider>
</React.Suspense>
</PersistGate>

View File

@ -5,6 +5,8 @@ import ThemeChanger from './features/system/components/ThemeChanger';
import IAIPopover from './common/components/IAIPopover';
import IAIIconButton from './common/components/IAIIconButton';
import SettingsModal from './features/system/components/SettingsModal/SettingsModal';
import StatusIndicator from './features/system/components/StatusIndicator';
import ModelSelect from 'features/system/components/ModelSelect';
export default Component;
export {
@ -13,4 +15,6 @@ export {
IAIPopover,
IAIIconButton,
SettingsModal,
StatusIndicator,
ModelSelect,
};

View File

@ -1,6 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { ImageConfig } from 'konva/lib/shapes/Image';
import { isEqual } from 'lodash';
@ -25,7 +26,7 @@ type Props = Omit<ImageConfig, 'image'>;
const IAICanvasIntermediateImage = (props: Props) => {
const { ...rest } = props;
const intermediateImage = useAppSelector(selector);
const { getUrl } = useGetUrl();
const [loadedImageElement, setLoadedImageElement] =
useState<HTMLImageElement | null>(null);
@ -36,8 +37,8 @@ const IAICanvasIntermediateImage = (props: Props) => {
tempImage.onload = () => {
setLoadedImageElement(tempImage);
};
tempImage.src = intermediateImage.url;
}, [intermediateImage]);
tempImage.src = getUrl(intermediateImage.url);
}, [intermediateImage, getUrl]);
if (!intermediateImage?.boundingBox) return null;

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { rgbaColorToString } from 'features/canvas/util/colorToString';
import { isEqual } from 'lodash';
@ -32,6 +33,7 @@ const selector = createSelector(
const IAICanvasObjectRenderer = () => {
const { objects } = useAppSelector(selector);
const { getUrl } = useGetUrl();
if (!objects) return null;
@ -40,7 +42,12 @@ const IAICanvasObjectRenderer = () => {
{objects.map((obj, i) => {
if (isCanvasBaseImage(obj)) {
return (
<IAICanvasImage key={i} x={obj.x} y={obj.y} url={obj.image.url} />
<IAICanvasImage
key={i}
x={obj.x}
y={obj.y}
url={getUrl(obj.image.url)}
/>
);
} else if (isCanvasBaseLine(obj)) {
const line = (

View File

@ -1,5 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import { canvasSelector } from 'features/canvas/store/canvasSelectors';
import { GroupConfig } from 'konva/lib/Group';
import { isEqual } from 'lodash';
@ -53,11 +54,16 @@ const IAICanvasStagingArea = (props: Props) => {
width,
height,
} = useAppSelector(selector);
const { getUrl } = useGetUrl();
return (
<Group {...rest}>
{shouldShowStagingImage && currentStagingAreaImage && (
<IAICanvasImage url={currentStagingAreaImage.image.url} x={x} y={y} />
<IAICanvasImage
url={getUrl(currentStagingAreaImage.image.url)}
x={x}
y={y}
/>
)}
{shouldShowStagingOutline && (
<Group>

View File

@ -0,0 +1,14 @@
import { CanvasState } from './canvasTypes';
/**
* Canvas slice persist blacklist
*/
const itemsToBlacklist: (keyof CanvasState)[] = [
'cursorPosition',
'isCanvasInitialized',
'doesCanvasNeedScaling',
];
export const canvasBlacklist = itemsToBlacklist.map(
(blacklistItem) => `canvas.${blacklistItem}`
);

View File

@ -156,7 +156,7 @@ export const canvasSlice = createSlice({
setCursorPosition: (state, action: PayloadAction<Vector2d | null>) => {
state.cursorPosition = action.payload;
},
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI.Image>) => {
setInitialCanvasImage: (state, action: PayloadAction<InvokeAI._Image>) => {
const image = action.payload;
const { stageDimensions } = state;
@ -291,7 +291,7 @@ export const canvasSlice = createSlice({
state,
action: PayloadAction<{
boundingBox: IRect;
image: InvokeAI.Image;
image: InvokeAI._Image;
}>
) => {
const { boundingBox, image } = action.payload;

View File

@ -37,7 +37,7 @@ export type CanvasImage = {
y: number;
width: number;
height: number;
image: InvokeAI.Image;
image: InvokeAI._Image;
};
export type CanvasMaskLine = {
@ -125,7 +125,7 @@ export interface CanvasState {
cursorPosition: Vector2d | null;
doesCanvasNeedScaling: boolean;
futureLayerStates: CanvasLayerState[];
intermediateImage?: InvokeAI.Image;
intermediateImage?: InvokeAI._Image;
isCanvasInitialized: boolean;
isDrawing: boolean;
isMaskEnabled: boolean;

View File

@ -105,7 +105,7 @@ export const mergeAndUploadCanvas =
const { url, width, height } = image;
const newImage: InvokeAI.Image = {
const newImage: InvokeAI._Image = {
uuid: uuidv4(),
category: shouldSaveToGallery ? 'result' : 'user',
...image,

View File

@ -14,8 +14,9 @@ import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
import {
initialImageSelected,
setAllParameters,
setInitialImage,
// setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
@ -48,11 +49,15 @@ import {
FaShareAlt,
FaTrash,
} from 'react-icons/fa';
import { gallerySelector } from '../store/gallerySelectors';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import DeleteImageModal from './DeleteImageModal';
import { useCallback } from 'react';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import { useGetUrl } from 'common/util/getUrl';
const currentImageButtonsSelector = createSelector(
[
@ -62,6 +67,7 @@ const currentImageButtonsSelector = createSelector(
uiSelector,
lightboxSelector,
activeTabNameSelector,
selectedImageSelector,
],
(
system: SystemState,
@ -69,7 +75,8 @@ const currentImageButtonsSelector = createSelector(
postprocessing,
ui,
lightbox,
activeTabName
activeTabName,
selectedImage
) => {
const { isProcessing, isConnected, isGFPGANAvailable, isESRGANAvailable } =
system;
@ -95,6 +102,7 @@ const currentImageButtonsSelector = createSelector(
activeTabName,
isLightboxOpen,
shouldHidePreview,
selectedImage,
};
},
{
@ -121,27 +129,33 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
facetoolStrength,
shouldDisableToolbarButtons,
shouldShowImageDetails,
currentImage,
// currentImage,
isLightboxOpen,
activeTabName,
shouldHidePreview,
selectedImage,
} = useAppSelector(currentImageButtonsSelector);
const { getUrl, shouldTransformUrls } = useGetUrl();
const toast = useToast();
const { t } = useTranslation();
const setBothPrompts = useSetBothPrompts();
const handleClickUseAsInitialImage = () => {
if (!currentImage) return;
if (!selectedImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(setInitialImage(currentImage));
dispatch(setActiveTab('img2img'));
dispatch(initialImageSelected(selectedImage.name));
// dispatch(setInitialImage(currentImage));
// dispatch(setActiveTab('img2img'));
};
const handleCopyImage = async () => {
if (!currentImage) return;
if (!selectedImage) return;
const blob = await fetch(currentImage.url).then((res) => res.blob());
const blob = await fetch(getUrl(selectedImage.url)).then((res) =>
res.blob()
);
const data = [new ClipboardItem({ [blob.type]: blob })];
await navigator.clipboard.write(data);
@ -155,24 +169,26 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
};
const handleCopyImageLink = () => {
navigator.clipboard
.writeText(
currentImage ? window.location.toString() + currentImage.url : ''
)
.then(() => {
toast({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
const url = selectedImage
? shouldTransformUrls
? getUrl(selectedImage.url)
: window.location.toString() + selectedImage.url
: '';
navigator.clipboard.writeText(url).then(() => {
toast({
title: t('toast.imageLinkCopied'),
status: 'success',
duration: 2500,
isClosable: true,
});
});
};
useHotkeys(
'shift+i',
() => {
if (currentImage) {
if (selectedImage) {
handleClickUseAsInitialImage();
toast({
title: t('toast.sentToImageToImage'),
@ -190,7 +206,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handlePreviewVisibility = () => {
@ -198,20 +214,23 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
};
const handleClickUseAllParameters = () => {
if (!currentImage) return;
currentImage.metadata && dispatch(setAllParameters(currentImage.metadata));
if (currentImage.metadata?.image.type === 'img2img') {
dispatch(setActiveTab('img2img'));
} else if (currentImage.metadata?.image.type === 'txt2img') {
dispatch(setActiveTab('txt2img'));
}
if (!selectedImage) return;
// selectedImage.metadata &&
// dispatch(setAllParameters(selectedImage.metadata));
// if (selectedImage.metadata?.image.type === 'img2img') {
// dispatch(setActiveTab('img2img'));
// } else if (selectedImage.metadata?.image.type === 'txt2img') {
// dispatch(setActiveTab('txt2img'));
// }
};
useHotkeys(
'a',
() => {
if (
['txt2img', 'img2img'].includes(currentImage?.metadata?.image?.type)
['txt2img', 'img2img'].includes(
selectedImage?.metadata?.sd_metadata?.type
)
) {
handleClickUseAllParameters();
toast({
@ -230,18 +249,18 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUseSeed = () => {
currentImage?.metadata &&
dispatch(setSeed(currentImage.metadata.image.seed));
selectedImage?.metadata &&
dispatch(setSeed(selectedImage.metadata.sd_metadata.seed));
};
useHotkeys(
's',
() => {
if (currentImage?.metadata?.image?.seed) {
if (selectedImage?.metadata?.sd_metadata?.seed) {
handleClickUseSeed();
toast({
title: t('toast.seedSet'),
@ -259,19 +278,19 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUsePrompt = useCallback(() => {
if (currentImage?.metadata?.image?.prompt) {
setBothPrompts(currentImage?.metadata?.image?.prompt);
if (selectedImage?.metadata?.sd_metadata?.prompt) {
setBothPrompts(selectedImage?.metadata?.sd_metadata?.prompt);
}
}, [currentImage?.metadata?.image?.prompt, setBothPrompts]);
}, [selectedImage?.metadata?.sd_metadata?.prompt, setBothPrompts]);
useHotkeys(
'p',
() => {
if (currentImage?.metadata?.image?.prompt) {
if (selectedImage?.metadata?.sd_metadata?.prompt) {
handleClickUsePrompt();
toast({
title: t('toast.promptSet'),
@ -289,11 +308,11 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage]
[selectedImage]
);
const handleClickUpscale = () => {
currentImage && dispatch(runESRGAN(currentImage));
// selectedImage && dispatch(runESRGAN(selectedImage));
};
useHotkeys(
@ -317,7 +336,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
currentImage,
selectedImage,
isESRGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@ -327,7 +346,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
);
const handleClickFixFaces = () => {
currentImage && dispatch(runFacetool(currentImage));
// selectedImage && dispatch(runFacetool(selectedImage));
};
useHotkeys(
@ -351,7 +370,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
}
},
[
currentImage,
selectedImage,
isGFPGANAvailable,
shouldDisableToolbarButtons,
isConnected,
@ -364,10 +383,10 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
dispatch(setShouldShowImageDetails(!shouldShowImageDetails));
const handleSendToCanvas = () => {
if (!currentImage) return;
if (!selectedImage) return;
if (isLightboxOpen) dispatch(setIsLightboxOpen(false));
dispatch(setInitialCanvasImage(currentImage));
// dispatch(setInitialCanvasImage(selectedImage));
dispatch(requestCanvasRescale());
if (activeTabName !== 'unifiedCanvas') {
@ -385,7 +404,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
useHotkeys(
'i',
() => {
if (currentImage) {
if (selectedImage) {
handleClickShowImageDetails();
} else {
toast({
@ -396,7 +415,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
});
}
},
[currentImage, shouldShowImageDetails]
[selectedImage, shouldShowImageDetails]
);
const handleLightBox = () => {
@ -457,7 +476,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
{t('parameters.copyImageToLink')}
</IAIButton>
<Link download={true} href={currentImage?.url}>
<Link download={true} href={getUrl(selectedImage!.url)}>
<IAIButton leftIcon={<FaDownload />} size="sm" w="100%">
{t('parameters.downloadImage')}
</IAIButton>
@ -501,7 +520,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaQuoteRight />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!currentImage?.metadata?.image?.prompt}
isDisabled={!selectedImage?.metadata?.sd_metadata?.prompt}
onClick={handleClickUsePrompt}
/>
@ -509,7 +528,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
icon={<FaSeedling />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!currentImage?.metadata?.image?.seed}
isDisabled={!selectedImage?.metadata?.sd_metadata?.seed}
onClick={handleClickUseSeed}
/>
@ -519,7 +538,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={
!['txt2img', 'img2img'].includes(
currentImage?.metadata?.image?.type
selectedImage?.metadata?.sd_metadata?.type
)
}
onClick={handleClickUseAllParameters}
@ -545,7 +564,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isGFPGANAvailable ||
!currentImage ||
!selectedImage ||
!(isConnected && !isProcessing) ||
!facetoolStrength
}
@ -574,7 +593,7 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
<IAIButton
isDisabled={
!isESRGANAvailable ||
!currentImage ||
!selectedImage ||
!(isConnected && !isProcessing) ||
!upscalingLevel
}
@ -596,15 +615,15 @@ const CurrentImageButtons = (props: CurrentImageButtonsProps) => {
/>
</ButtonGroup>
<DeleteImageModal image={currentImage}>
{/* <DeleteImageModal image={selectedImage}>
<IAIIconButton
icon={<FaTrash />}
tooltip={`${t('parameters.deleteImage')} (Del)`}
aria-label={`${t('parameters.deleteImage')} (Del)`}
isDisabled={!currentImage || !isConnected || isProcessing}
isDisabled={!selectedImage || !isConnected || isProcessing}
colorScheme="error"
/>
</DeleteImageModal>
</DeleteImageModal> */}
</Flex>
);
};

View File

@ -4,17 +4,20 @@ import { useAppSelector } from 'app/storeHooks';
import { isEqual } from 'lodash';
import { MdPhoto } from 'react-icons/md';
import { gallerySelector } from '../store/gallerySelectors';
import {
gallerySelector,
selectedImageSelector,
} from '../store/gallerySelectors';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
export const currentImageDisplaySelector = createSelector(
[gallerySelector],
(gallery) => {
[gallerySelector, selectedImageSelector],
(gallery, selectedImage) => {
const { currentImage, intermediateImage } = gallery;
return {
hasAnImageToDisplay: currentImage || intermediateImage,
hasAnImageToDisplay: selectedImage || intermediateImage,
};
},
{

View File

@ -1,28 +1,48 @@
import { Box, Flex, Image } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { GalleryState } from 'features/gallery/store/gallerySlice';
import { useGetUrl } from 'common/util/getUrl';
import { systemSelector } from 'features/system/store/systemSelectors';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import { ReactEventHandler } from 'react';
import { APP_METADATA_HEIGHT } from 'theme/util/constants';
import { gallerySelector } from '../store/gallerySelectors';
import { selectedImageSelector } from '../store/gallerySelectors';
import CurrentImageFallback from './CurrentImageFallback';
import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer';
import NextPrevImageButtons from './NextPrevImageButtons';
import CurrentImageHidden from './CurrentImageHidden';
export const imagesSelector = createSelector(
[gallerySelector, uiSelector],
(gallery: GalleryState, ui) => {
const { currentImage, intermediateImage } = gallery;
[uiSelector, selectedImageSelector, systemSelector],
(ui, selectedImage, system) => {
const { shouldShowImageDetails, shouldHidePreview } = ui;
const { progressImage } = system;
// TODO: Clean this up, this is really gross
const imageToDisplay = progressImage
? {
url: progressImage.dataURL,
width: progressImage.width,
height: progressImage.height,
isProgressImage: true,
image: progressImage,
}
: selectedImage
? {
url: selectedImage.url,
width: selectedImage.metadata.width,
height: selectedImage.metadata.height,
isProgressImage: false,
image: selectedImage,
}
: null;
return {
imageToDisplay: intermediateImage ? intermediateImage : currentImage,
isIntermediate: Boolean(intermediateImage),
shouldShowImageDetails,
shouldHidePreview,
imageToDisplay,
};
},
{
@ -33,12 +53,9 @@ export const imagesSelector = createSelector(
);
export default function CurrentImagePreview() {
const {
shouldShowImageDetails,
imageToDisplay,
isIntermediate,
shouldHidePreview,
} = useAppSelector(imagesSelector);
const { shouldShowImageDetails, imageToDisplay, shouldHidePreview } =
useAppSelector(imagesSelector);
const { getUrl } = useGetUrl();
return (
<Flex
@ -52,13 +69,19 @@ export default function CurrentImagePreview() {
>
{imageToDisplay && (
<Image
src={shouldHidePreview ? undefined : imageToDisplay.url}
src={
shouldHidePreview
? undefined
: imageToDisplay.isProgressImage
? imageToDisplay.url
: getUrl(imageToDisplay.url)
}
width={imageToDisplay.width}
height={imageToDisplay.height}
fallback={
shouldHidePreview ? (
<CurrentImageHidden />
) : !isIntermediate ? (
) : !imageToDisplay.isProgressImage ? (
<CurrentImageFallback />
) : undefined
}
@ -68,27 +91,31 @@ export default function CurrentImagePreview() {
maxHeight: '100%',
height: 'auto',
position: 'absolute',
imageRendering: isIntermediate ? 'pixelated' : 'initial',
imageRendering: imageToDisplay.isProgressImage
? 'pixelated'
: 'initial',
borderRadius: 'base',
}}
/>
)}
{!shouldShowImageDetails && <NextPrevImageButtons />}
{shouldShowImageDetails && imageToDisplay && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay} />
</Box>
)}
{shouldShowImageDetails &&
imageToDisplay &&
'metadata' in imageToDisplay.image && (
<Box
sx={{
position: 'absolute',
top: '0',
width: '100%',
height: '100%',
borderRadius: 'base',
overflow: 'scroll',
maxHeight: APP_METADATA_HEIGHT,
}}
>
<ImageMetadataViewer image={imageToDisplay.image} />
</Box>
)}
</Flex>
);
}

View File

@ -52,7 +52,7 @@ interface DeleteImageModalProps {
/**
* The image to delete.
*/
image?: InvokeAI.Image;
image?: InvokeAI._Image;
}
/**

View File

@ -9,11 +9,14 @@ import {
useToast,
} from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { setCurrentImage } from 'features/gallery/store/gallerySlice';
import {
imageSelected,
setCurrentImage,
} from 'features/gallery/store/gallerySlice';
import {
initialImageSelected,
setAllImageToImageParameters,
setAllParameters,
setInitialImage,
setSeed,
} from 'features/parameters/store/generationSlice';
import { DragEvent, memo, useState } from 'react';
@ -31,6 +34,7 @@ import { useTranslation } from 'react-i18next';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import IAIIconButton from 'common/components/IAIIconButton';
import { useGetUrl } from 'common/util/getUrl';
interface HoverableImageProps {
image: InvokeAI.Image;
@ -40,7 +44,7 @@ interface HoverableImageProps {
const memoEqualityCheck = (
prev: HoverableImageProps,
next: HoverableImageProps
) => prev.image.uuid === next.image.uuid && prev.isSelected === next.isSelected;
) => prev.image.name === next.image.name && prev.isSelected === next.isSelected;
/**
* Gallery image component with delete/use all/use seed buttons on hover.
@ -55,7 +59,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn,
} = useAppSelector(hoverableImageSelector);
const { image, isSelected } = props;
const { url, thumbnail, uuid, metadata } = image;
const { url, thumbnail, name, metadata } = image;
const { getUrl } = useGetUrl();
const [isHovered, setIsHovered] = useState<boolean>(false);
@ -69,10 +74,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
const handleMouseOut = () => setIsHovered(false);
const handleUsePrompt = () => {
if (image.metadata?.image?.prompt) {
setBothPrompts(image.metadata?.image?.prompt);
if (image.metadata?.sd_metadata?.prompt) {
setBothPrompts(image.metadata?.sd_metadata?.prompt);
}
toast({
title: t('toast.promptSet'),
status: 'success',
@ -82,7 +86,8 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseSeed = () => {
image.metadata && dispatch(setSeed(image.metadata.image.seed));
image.metadata.sd_metadata &&
dispatch(setSeed(image.metadata.sd_metadata.image.seed));
toast({
title: t('toast.seedSet'),
status: 'success',
@ -92,20 +97,11 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleSendToImageToImage = () => {
dispatch(setInitialImage(image));
if (activeTabName !== 'img2img') {
dispatch(setActiveTab('img2img'));
}
toast({
title: t('toast.sentToImageToImage'),
status: 'success',
duration: 2500,
isClosable: true,
});
dispatch(initialImageSelected(image.name));
};
const handleSendToCanvas = () => {
dispatch(setInitialCanvasImage(image));
// dispatch(setInitialCanvasImage(image));
dispatch(resizeAndScaleCanvas());
@ -122,7 +118,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseAllParameters = () => {
metadata && dispatch(setAllParameters(metadata));
metadata.sd_metadata && dispatch(setAllParameters(metadata.sd_metadata));
toast({
title: t('toast.parametersSet'),
status: 'success',
@ -132,11 +128,13 @@ const HoverableImage = memo((props: HoverableImageProps) => {
};
const handleUseInitialImage = async () => {
if (metadata?.image?.init_image_path) {
const response = await fetch(metadata.image.init_image_path);
if (metadata.sd_metadata?.image?.init_image_path) {
const response = await fetch(
metadata.sd_metadata?.image?.init_image_path
);
if (response.ok) {
dispatch(setActiveTab('img2img'));
dispatch(setAllImageToImageParameters(metadata));
dispatch(setAllImageToImageParameters(metadata?.sd_metadata));
toast({
title: t('toast.initialImageSet'),
status: 'success',
@ -155,16 +153,20 @@ const HoverableImage = memo((props: HoverableImageProps) => {
});
};
const handleSelectImage = () => dispatch(setCurrentImage(image));
const handleSelectImage = () => {
dispatch(imageSelected(image.name));
};
const handleDragStart = (e: DragEvent<HTMLDivElement>) => {
e.dataTransfer.setData('invokeai/imageUuid', uuid);
console.log('drag started');
e.dataTransfer.setData('invokeai/imageName', image.name);
e.dataTransfer.setData('invokeai/imageType', image.type);
e.dataTransfer.effectAllowed = 'move';
};
const handleLightBox = () => {
dispatch(setCurrentImage(image));
dispatch(setIsLightboxOpen(true));
// dispatch(setCurrentImage(image));
// dispatch(setIsLightboxOpen(true));
};
return (
@ -177,28 +179,30 @@ const HoverableImage = memo((props: HoverableImageProps) => {
</MenuItem>
<MenuItem
onClickCapture={handleUsePrompt}
isDisabled={image?.metadata?.image?.prompt === undefined}
isDisabled={image?.metadata?.sd_metadata?.prompt === undefined}
>
{t('parameters.usePrompt')}
</MenuItem>
<MenuItem
onClickCapture={handleUseSeed}
isDisabled={image?.metadata?.image?.seed === undefined}
isDisabled={image?.metadata?.sd_metadata?.seed === undefined}
>
{t('parameters.useSeed')}
</MenuItem>
<MenuItem
onClickCapture={handleUseAllParameters}
isDisabled={
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
!['txt2img', 'img2img'].includes(
image?.metadata?.sd_metadata?.type
)
}
>
{t('parameters.useAll')}
</MenuItem>
<MenuItem
onClickCapture={handleUseInitialImage}
isDisabled={image?.metadata?.image?.type !== 'img2img'}
isDisabled={image?.metadata?.sd_metadata?.type !== 'img2img'}
>
{t('parameters.useInitImg')}
</MenuItem>
@ -209,9 +213,9 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
<MenuItem data-warning>
<DeleteImageModal image={image}>
{/* <DeleteImageModal image={image}>
<p>{t('parameters.deleteImage')}</p>
</DeleteImageModal>
</DeleteImageModal> */}
</MenuItem>
</MenuList>
)}
@ -219,7 +223,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
{(ref) => (
<Box
position="relative"
key={uuid}
key={name}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
userSelect="none"
@ -244,7 +248,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
shouldUseSingleGalleryColumn ? 'contain' : galleryImageObjectFit
}
rounded="md"
src={thumbnail || url}
src={getUrl(thumbnail || url)}
loading="lazy"
sx={{
position: 'absolute',
@ -290,7 +294,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
insetInlineEnd: 1,
}}
>
<DeleteImageModal image={image}>
{/* <DeleteImageModal image={image}>
<IAIIconButton
aria-label={t('parameters.deleteImage')}
icon={<FaTrashAlt />}
@ -298,7 +302,7 @@ const HoverableImage = memo((props: HoverableImageProps) => {
fontSize={14}
isDisabled={!mayDeleteImage}
/>
</DeleteImageModal>
</DeleteImageModal> */}
</Box>
)}
</Box>

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex, Grid, Icon, Text } from '@chakra-ui/react';
import { ButtonGroup, Flex, Grid, Icon, Image, Text } from '@chakra-ui/react';
import { requestImages } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAIButton from 'common/components/IAIButton';
@ -25,9 +25,44 @@ import HoverableImage from './HoverableImage';
import Scrollable from 'features/ui/components/common/Scrollable';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
import {
resultsAdapter,
selectResultsAll,
selectResultsTotal,
} from '../store/resultsSlice';
import {
receivedResultImagesPage,
receivedUploadImagesPage,
} from 'services/thunks/gallery';
import { selectUploadsAll, uploadsAdapter } from '../store/uploadsSlice';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
const GALLERY_SHOW_BUTTONS_MIN_WIDTH = 290;
const gallerySelector = createSelector(
[
(state: RootState) => state.uploads,
(state: RootState) => state.results,
(state: RootState) => state.gallery,
],
(uploads, results, gallery) => {
const { currentCategory } = gallery;
return currentCategory === 'result'
? {
images: resultsAdapter.getSelectors().selectAll(results),
isLoading: results.isLoading,
areMoreImagesAvailable: results.page < results.pages - 1,
}
: {
images: uploadsAdapter.getSelectors().selectAll(uploads),
isLoading: uploads.isLoading,
areMoreImagesAvailable: uploads.page < uploads.pages - 1,
};
}
);
const ImageGalleryContent = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
@ -35,7 +70,7 @@ const ImageGalleryContent = () => {
const [shouldShouldIconButtons, setShouldShouldIconButtons] = useState(true);
const {
images,
// images,
currentCategory,
currentImageUuid,
shouldPinGallery,
@ -43,12 +78,24 @@ const ImageGalleryContent = () => {
galleryGridTemplateColumns,
galleryImageObjectFit,
shouldAutoSwitchToNewImages,
areMoreImagesAvailable,
// areMoreImagesAvailable,
shouldUseSingleGalleryColumn,
} = useAppSelector(imageGallerySelector);
const { images, areMoreImagesAvailable, isLoading } =
useAppSelector(gallerySelector);
// const handleClickLoadMore = () => {
// dispatch(requestImages(currentCategory));
// };
const handleClickLoadMore = () => {
dispatch(requestImages(currentCategory));
if (currentCategory === 'result') {
dispatch(receivedResultImagesPage());
}
if (currentCategory === 'user') {
dispatch(receivedUploadImagesPage());
}
};
const handleChangeGalleryImageMinimumWidth = (v: number) => {
@ -203,11 +250,11 @@ const ImageGalleryContent = () => {
style={{ gridTemplateColumns: galleryGridTemplateColumns }}
>
{images.map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
const { name } = image;
const isSelected = currentImageUuid === name;
return (
<HoverableImage
key={uuid}
key={name}
image={image}
isSelected={isSelected}
/>
@ -217,6 +264,7 @@ const ImageGalleryContent = () => {
<IAIButton
onClick={handleClickLoadMore}
isDisabled={!areMoreImagesAvailable}
isLoading={isLoading}
flexShrink={0}
>
{areMoreImagesAvailable

View File

@ -31,12 +31,13 @@ const GALLERY_TAB_WIDTHS: Record<
InvokeTabName,
{ galleryMinWidth: number; galleryMaxWidth: number }
> = {
txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// txt2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// img2img: { galleryMinWidth: 200, galleryMaxWidth: 500 },
linear: { galleryMinWidth: 200, galleryMaxWidth: 500 },
unifiedCanvas: { galleryMinWidth: 200, galleryMaxWidth: 200 },
nodes: { galleryMinWidth: 200, galleryMaxWidth: 500 },
postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// postprocessing: { galleryMinWidth: 200, galleryMaxWidth: 500 },
// training: { galleryMinWidth: 200, galleryMaxWidth: 500 },
};
const galleryPanelSelector = createSelector(

View File

@ -11,6 +11,7 @@ import {
} from '@chakra-ui/react';
import * as InvokeAI from 'app/invokeai';
import { useAppDispatch } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
@ -18,7 +19,7 @@ import {
setCfgScale,
setHeight,
setImg2imgStrength,
setInitialImage,
// setInitialImage,
setMaskPath,
setPerlin,
setSampler,
@ -120,7 +121,7 @@ type ImageMetadataViewerProps = {
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.uuid === next.image.uuid;
) => prev.image.name === next.image.name;
// TODO: Show more interesting information in this component.
@ -137,34 +138,13 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
dispatch(setShouldShowImageDetails(false));
});
const metadata = image?.metadata?.image || {};
const dreamPrompt = image?.dreamPrompt;
const {
cfg_scale,
fit,
height,
hires_fix,
init_image_path,
mask_image_path,
orig_path,
perlin,
postprocessing,
prompt,
sampler,
seamless,
seed,
steps,
strength,
threshold,
type,
variations,
width,
} = metadata;
const sessionId = image.metadata.invokeai?.session_id;
const node = image.metadata.invokeai?.node as Record<string, any>;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
const metadataJSON = JSON.stringify(image.metadata, null, 2);
const metadataJSON = JSON.stringify(image, null, 2);
return (
<Flex
@ -183,262 +163,134 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={image.url} isExternal maxW="calc(100% - 3rem)">
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
{image.url.length > 64
? image.url.substring(0, 64).concat('...')
: image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
{Object.keys(metadata).length > 0 ? (
{node && Object.keys(node).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
{image.metadata?.model_weights && (
<MetadataItem label="Model" value={image.metadata.model_weights} />
{node.type && (
<MetadataItem label="Invocation type" value={node.type} />
)}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} />
)}
{prompt && (
{node.model && <MetadataItem label="Model" value={node.model} />}
{node.prompt && (
<MetadataItem
label="Prompt"
labelPosition="top"
value={
typeof prompt === 'string' ? prompt : promptToString(prompt)
typeof node.prompt === 'string'
? node.prompt
: promptToString(node.prompt)
}
onClick={() => setBothPrompts(prompt)}
onClick={() => setBothPrompts(node.prompt)}
/>
)}
{seed !== undefined && (
{node.seed !== undefined && (
<MetadataItem
label="Seed"
value={seed}
onClick={() => dispatch(setSeed(seed))}
value={node.seed}
onClick={() => dispatch(setSeed(Number(node.seed)))}
/>
)}
{threshold !== undefined && (
{node.threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={threshold}
onClick={() => dispatch(setThreshold(threshold))}
value={node.threshold}
onClick={() => dispatch(setThreshold(Number(node.threshold)))}
/>
)}
{perlin !== undefined && (
{node.perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={perlin}
onClick={() => dispatch(setPerlin(perlin))}
value={node.perlin}
onClick={() => dispatch(setPerlin(Number(node.perlin)))}
/>
)}
{sampler && (
{node.scheduler && (
<MetadataItem
label="Sampler"
value={sampler}
onClick={() => dispatch(setSampler(sampler))}
value={node.scheduler}
onClick={() => dispatch(setSampler(node.scheduler))}
/>
)}
{steps && (
{node.steps && (
<MetadataItem
label="Steps"
value={steps}
onClick={() => dispatch(setSteps(steps))}
value={node.steps}
onClick={() => dispatch(setSteps(Number(node.steps)))}
/>
)}
{cfg_scale !== undefined && (
{node.cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={cfg_scale}
onClick={() => dispatch(setCfgScale(cfg_scale))}
value={node.cfg_scale}
onClick={() => dispatch(setCfgScale(Number(node.cfg_scale)))}
/>
)}
{variations && variations.length > 0 && (
{node.variations && node.variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(variations)}
value={seedWeightsToString(node.variations)}
onClick={() =>
dispatch(setSeedWeights(seedWeightsToString(variations)))
dispatch(setSeedWeights(seedWeightsToString(node.variations)))
}
/>
)}
{seamless && (
{node.seamless && (
<MetadataItem
label="Seamless"
value={seamless}
onClick={() => dispatch(setSeamless(seamless))}
value={node.seamless}
onClick={() => dispatch(setSeamless(node.seamless))}
/>
)}
{hires_fix && (
{node.hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={hires_fix}
onClick={() => dispatch(setHiresFix(hires_fix))}
value={node.hires_fix}
onClick={() => dispatch(setHiresFix(node.hires_fix))}
/>
)}
{width && (
{node.width && (
<MetadataItem
label="Width"
value={width}
onClick={() => dispatch(setWidth(width))}
value={node.width}
onClick={() => dispatch(setWidth(Number(node.width)))}
/>
)}
{height && (
{node.height && (
<MetadataItem
label="Height"
value={height}
onClick={() => dispatch(setHeight(height))}
value={node.height}
onClick={() => dispatch(setHeight(Number(node.height)))}
/>
)}
{init_image_path && (
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)}
{mask_image_path && (
<MetadataItem
label="Mask image"
value={mask_image_path}
isLink
onClick={() => dispatch(setMaskPath(mask_image_path))}
/>
)}
{type === 'img2img' && strength && (
)} */}
{node.strength && (
<MetadataItem
label="Image to image strength"
value={strength}
onClick={() => dispatch(setImg2imgStrength(strength))}
value={node.strength}
onClick={() =>
dispatch(setImg2imgStrength(Number(node.strength)))
}
/>
)}
{fit && (
{node.fit && (
<MetadataItem
label="Image to image fit"
value={fit}
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
value={node.fit}
onClick={() => dispatch(setShouldFitToWidthHeight(node.fit))}
/>
)}
{postprocessing && postprocessing.length > 0 && (
<>
<Heading size="sm">Postprocessing</Heading>
{postprocessing.map(
(
postprocess: InvokeAI.PostProcessedImageMetadata,
i: number
) => {
if (postprocess.type === 'esrgan') {
const { scale, strength, denoise_str } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
<MetadataItem
label="Scale"
value={scale}
onClick={() => dispatch(setUpscalingLevel(scale))}
/>
<MetadataItem
label="Strength"
value={strength}
onClick={() =>
dispatch(setUpscalingStrength(strength))
}
/>
{denoise_str !== undefined && (
<MetadataItem
label="Denoising strength"
value={denoise_str}
onClick={() =>
dispatch(setUpscalingDenoising(denoise_str))
}
/>
)}
</Flex>
);
} else if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (GFPGAN)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('gfpgan'));
}}
/>
</Flex>
);
} else if (postprocess.type === 'codeformer') {
const { strength, fidelity } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (Codeformer)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('codeformer'));
}}
/>
{fidelity && (
<MetadataItem
label="Fidelity"
value={fidelity}
onClick={() => {
dispatch(setCodeformerFidelity(fidelity));
dispatch(setFacetoolType('codeformer'));
}}
/>
)}
</Flex>
);
}
}
)}
</>
)}
{dreamPrompt && (
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
)}
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</>
) : (
<Center width="100%" pt={10}>
@ -447,6 +299,37 @@ const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
</Text>
</Center>
)}
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
</Flex>
);
}, memoEqualityCheck);

View File

@ -0,0 +1,470 @@
import { ExternalLinkIcon } from '@chakra-ui/icons';
import {
Box,
Center,
Flex,
Heading,
IconButton,
Link,
Text,
Tooltip,
} from '@chakra-ui/react';
import * as InvokeAI from 'app/invokeai';
import { useAppDispatch } from 'app/storeHooks';
import { useGetUrl } from 'common/util/getUrl';
import promptToString from 'common/util/promptToString';
import { seedWeightsToString } from 'common/util/seedWeightPairs';
import useSetBothPrompts from 'features/parameters/hooks/usePrompt';
import {
setCfgScale,
setHeight,
setImg2imgStrength,
// setInitialImage,
setMaskPath,
setPerlin,
setSampler,
setSeamless,
setSeed,
setSeedWeights,
setShouldFitToWidthHeight,
setSteps,
setThreshold,
setWidth,
} from 'features/parameters/store/generationSlice';
import {
setCodeformerFidelity,
setFacetoolStrength,
setFacetoolType,
setHiresFix,
setUpscalingDenoising,
setUpscalingLevel,
setUpscalingStrength,
} from 'features/parameters/store/postprocessingSlice';
import { setShouldShowImageDetails } from 'features/ui/store/uiSlice';
import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { FaCopy } from 'react-icons/fa';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import * as png from '@stevebel/png';
type MetadataItemProps = {
isLink?: boolean;
label: string;
onClick?: () => void;
value: number | string | boolean;
labelPosition?: string;
withCopy?: boolean;
};
/**
* Component to display an individual metadata item or parameter.
*/
const MetadataItem = ({
label,
value,
onClick,
isLink,
labelPosition,
withCopy = false,
}: MetadataItemProps) => {
const { t } = useTranslation();
return (
<Flex gap={2}>
{onClick && (
<Tooltip label={`Recall ${label}`}>
<IconButton
aria-label={t('accessibility.useThisParameter')}
icon={<IoArrowUndoCircleOutline />}
size="xs"
variant="ghost"
fontSize={20}
onClick={onClick}
/>
</Tooltip>
)}
{withCopy && (
<Tooltip label={`Copy ${label}`}>
<IconButton
aria-label={`Copy ${label}`}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(value.toString())}
/>
</Tooltip>
)}
<Flex direction={labelPosition ? 'column' : 'row'}>
<Text fontWeight="semibold" whiteSpace="pre-wrap" pr={2}>
{label}:
</Text>
{isLink ? (
<Link href={value.toString()} isExternal wordBreak="break-all">
{value.toString()} <ExternalLinkIcon mx="2px" />
</Link>
) : (
<Text overflowY="scroll" wordBreak="break-all">
{value.toString()}
</Text>
)}
</Flex>
</Flex>
);
};
type ImageMetadataViewerProps = {
image: InvokeAI.Image;
};
// TODO: I don't know if this is needed.
const memoEqualityCheck = (
prev: ImageMetadataViewerProps,
next: ImageMetadataViewerProps
) => prev.image.name === next.image.name;
// TODO: Show more interesting information in this component.
/**
* Image metadata viewer overlays currently selected image and provides
* access to any of its metadata for use in processing.
*/
const ImageMetadataViewer = memo(({ image }: ImageMetadataViewerProps) => {
const dispatch = useAppDispatch();
const setBothPrompts = useSetBothPrompts();
useHotkeys('esc', () => {
dispatch(setShouldShowImageDetails(false));
});
const metadata = image?.metadata.sd_metadata || {};
const dreamPrompt = image?.metadata.sd_metadata?.dreamPrompt;
const {
cfg_scale,
fit,
height,
hires_fix,
init_image_path,
mask_image_path,
orig_path,
perlin,
postprocessing,
prompt,
sampler,
seamless,
seed,
steps,
strength,
threshold,
type,
variations,
width,
model_weights,
} = metadata;
const { t } = useTranslation();
const { getUrl } = useGetUrl();
const metadataJSON = JSON.stringify(image, null, 2);
// fetch(getUrl(image.url))
// .then((r) => r.arrayBuffer())
// .then((buffer) => {
// const { text } = png.decode(buffer);
// const metadata = text?.['sd-metadata']
// ? JSON.parse(text['sd-metadata'] ?? {})
// : {};
// console.log(metadata);
// });
return (
<Flex
sx={{
padding: 4,
gap: 1,
flexDirection: 'column',
width: 'full',
height: 'full',
backdropFilter: 'blur(20px)',
bg: 'whiteAlpha.600',
_dark: {
bg: 'blackAlpha.600',
},
}}
>
<Flex gap={2}>
<Text fontWeight="semibold">File:</Text>
<Link href={getUrl(image.url)} isExternal maxW="calc(100% - 3rem)">
{image.url.length > 64
? image.url.substring(0, 64).concat('...')
: image.url}
<ExternalLinkIcon mx="2px" />
</Link>
</Flex>
<Flex gap={2} direction="column">
<Flex gap={2}>
<Tooltip label="Copy metadata JSON">
<IconButton
aria-label={t('accessibility.copyMetadataJson')}
icon={<FaCopy />}
size="xs"
variant="ghost"
fontSize={14}
onClick={() => navigator.clipboard.writeText(metadataJSON)}
/>
</Tooltip>
<Text fontWeight="semibold">Metadata JSON:</Text>
</Flex>
<Box
sx={{
mt: 0,
mr: 2,
mb: 4,
ml: 2,
padding: 4,
borderRadius: 'base',
overflowX: 'scroll',
wordBreak: 'break-all',
bg: 'whiteAlpha.500',
_dark: { bg: 'blackAlpha.500' },
}}
>
<pre>{metadataJSON}</pre>
</Box>
</Flex>
{Object.keys(metadata).length > 0 ? (
<>
{type && <MetadataItem label="Generation type" value={type} />}
{model_weights && (
<MetadataItem label="Model" value={model_weights} />
)}
{['esrgan', 'gfpgan'].includes(type) && (
<MetadataItem label="Original image" value={orig_path} />
)}
{prompt && (
<MetadataItem
label="Prompt"
labelPosition="top"
value={
typeof prompt === 'string' ? prompt : promptToString(prompt)
}
onClick={() => setBothPrompts(prompt)}
/>
)}
{seed !== undefined && (
<MetadataItem
label="Seed"
value={seed}
onClick={() => dispatch(setSeed(seed))}
/>
)}
{threshold !== undefined && (
<MetadataItem
label="Noise Threshold"
value={threshold}
onClick={() => dispatch(setThreshold(threshold))}
/>
)}
{perlin !== undefined && (
<MetadataItem
label="Perlin Noise"
value={perlin}
onClick={() => dispatch(setPerlin(perlin))}
/>
)}
{sampler && (
<MetadataItem
label="Sampler"
value={sampler}
onClick={() => dispatch(setSampler(sampler))}
/>
)}
{steps && (
<MetadataItem
label="Steps"
value={steps}
onClick={() => dispatch(setSteps(steps))}
/>
)}
{cfg_scale !== undefined && (
<MetadataItem
label="CFG scale"
value={cfg_scale}
onClick={() => dispatch(setCfgScale(cfg_scale))}
/>
)}
{variations && variations.length > 0 && (
<MetadataItem
label="Seed-weight pairs"
value={seedWeightsToString(variations)}
onClick={() =>
dispatch(setSeedWeights(seedWeightsToString(variations)))
}
/>
)}
{seamless && (
<MetadataItem
label="Seamless"
value={seamless}
onClick={() => dispatch(setSeamless(seamless))}
/>
)}
{hires_fix && (
<MetadataItem
label="High Resolution Optimization"
value={hires_fix}
onClick={() => dispatch(setHiresFix(hires_fix))}
/>
)}
{width && (
<MetadataItem
label="Width"
value={width}
onClick={() => dispatch(setWidth(width))}
/>
)}
{height && (
<MetadataItem
label="Height"
value={height}
onClick={() => dispatch(setHeight(height))}
/>
)}
{/* {init_image_path && (
<MetadataItem
label="Initial image"
value={init_image_path}
isLink
onClick={() => dispatch(setInitialImage(init_image_path))}
/>
)} */}
{mask_image_path && (
<MetadataItem
label="Mask image"
value={mask_image_path}
isLink
onClick={() => dispatch(setMaskPath(mask_image_path))}
/>
)}
{type === 'img2img' && strength && (
<MetadataItem
label="Image to image strength"
value={strength}
onClick={() => dispatch(setImg2imgStrength(strength))}
/>
)}
{fit && (
<MetadataItem
label="Image to image fit"
value={fit}
onClick={() => dispatch(setShouldFitToWidthHeight(fit))}
/>
)}
{postprocessing && postprocessing.length > 0 && (
<>
<Heading size="sm">Postprocessing</Heading>
{postprocessing.map(
(
postprocess: InvokeAI.PostProcessedImageMetadata,
i: number
) => {
if (postprocess.type === 'esrgan') {
const { scale, strength, denoise_str } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${i + 1}: Upscale (ESRGAN)`}</Text>
<MetadataItem
label="Scale"
value={scale}
onClick={() => dispatch(setUpscalingLevel(scale))}
/>
<MetadataItem
label="Strength"
value={strength}
onClick={() =>
dispatch(setUpscalingStrength(strength))
}
/>
{denoise_str !== undefined && (
<MetadataItem
label="Denoising strength"
value={denoise_str}
onClick={() =>
dispatch(setUpscalingDenoising(denoise_str))
}
/>
)}
</Flex>
);
} else if (postprocess.type === 'gfpgan') {
const { strength } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (GFPGAN)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('gfpgan'));
}}
/>
</Flex>
);
} else if (postprocess.type === 'codeformer') {
const { strength, fidelity } = postprocess;
return (
<Flex key={i} pl={8} gap={1} direction="column">
<Text size="md">{`${
i + 1
}: Face restoration (Codeformer)`}</Text>
<MetadataItem
label="Strength"
value={strength}
onClick={() => {
dispatch(setFacetoolStrength(strength));
dispatch(setFacetoolType('codeformer'));
}}
/>
{fidelity && (
<MetadataItem
label="Fidelity"
value={fidelity}
onClick={() => {
dispatch(setCodeformerFidelity(fidelity));
dispatch(setFacetoolType('codeformer'));
}}
/>
)}
</Flex>
);
}
}
)}
</>
)}
{dreamPrompt && (
<MetadataItem withCopy label="Dream Prompt" value={dreamPrompt} />
)}
</>
) : (
<Center width="100%" pt={10}>
<Text fontSize="lg" fontWeight="semibold">
No metadata available
</Text>
</Center>
)}
</Flex>
);
}, memoEqualityCheck);
ImageMetadataViewer.displayName = 'ImageMetadataViewer';
export default ImageMetadataViewer;

View File

@ -0,0 +1,35 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/storeHooks';
import { ImageType } from 'services/api';
import { selectResultsEntities } from '../store/resultsSlice';
import { selectUploadsEntities } from '../store/uploadsSlice';
const useGetImageByNameSelector = createSelector(
[selectResultsEntities, selectUploadsEntities],
(allResults, allUploads) => {
return { allResults, allUploads };
}
);
const useGetImageByNameAndType = () => {
const { allResults, allUploads } = useAppSelector(useGetImageByNameSelector);
return (name: string, type: ImageType) => {
if (type === 'results') {
const resultImagesResult = allResults[name];
if (resultImagesResult) {
return resultImagesResult;
}
}
if (type === 'uploads') {
const userImagesResult = allUploads[name];
if (userImagesResult) {
return userImagesResult;
}
}
};
};
export default useGetImageByNameAndType;

View File

@ -0,0 +1,17 @@
import { GalleryState } from './gallerySlice';
/**
* Gallery slice persist blacklist
*/
const itemsToBlacklist: (keyof GalleryState)[] = [
'categories',
'currentCategory',
'currentImage',
'currentImageUuid',
'shouldAutoSwitchToNewImages',
'intermediateImage',
];
export const galleryBlacklist = itemsToBlacklist.map(
(blacklistItem) => `gallery.${blacklistItem}`
);

View File

@ -7,6 +7,16 @@ import {
uiSelector,
} from 'features/ui/store/uiSelectors';
import { isEqual } from 'lodash';
import {
selectResultsAll,
selectResultsById,
selectResultsEntities,
} from './resultsSlice';
import {
selectUploadsAll,
selectUploadsById,
selectUploadsEntities,
} from './uploadsSlice';
export const gallerySelector = (state: RootState) => state.gallery;
@ -75,3 +85,18 @@ export const hoverableImageSelector = createSelector(
},
}
);
export const selectedImageSelector = createSelector(
[gallerySelector, selectResultsEntities, selectUploadsEntities],
(gallery, allResults, allUploads) => {
const selectedImageName = gallery.selectedImageName;
if (selectedImageName in allResults) {
return allResults[selectedImageName];
}
if (selectedImageName in allUploads) {
return allUploads[selectedImageName];
}
}
);

View File

@ -1,14 +1,17 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { IRect } from 'konva/lib/types';
import { clamp } from 'lodash';
import { isImageOutput } from 'services/types/guards';
import { imageUploaded } from 'services/thunks/image';
export type GalleryCategory = 'user' | 'result';
export type AddImagesPayload = {
images: Array<InvokeAI.Image>;
images: Array<InvokeAI._Image>;
areMoreImagesAvailable: boolean;
category: GalleryCategory;
};
@ -16,16 +19,33 @@ export type AddImagesPayload = {
type GalleryImageObjectFitType = 'contain' | 'cover';
export type Gallery = {
images: InvokeAI.Image[];
images: InvokeAI._Image[];
latest_mtime?: number;
earliest_mtime?: number;
areMoreImagesAvailable: boolean;
};
export interface GalleryState {
currentImage?: InvokeAI.Image;
/**
* The selected image's unique name
* Use `selectedImageSelector` to access the image
*/
selectedImageName: string;
/**
* The currently selected image
* @deprecated See `state.gallery.selectedImageName`
*/
currentImage?: InvokeAI._Image;
/**
* The currently selected image's uuid.
* @deprecated See `state.gallery.selectedImageName`, use `selectedImageSelector` to access the image
*/
currentImageUuid: string;
intermediateImage?: InvokeAI.Image & {
/**
* The current progress image
* @deprecated See `state.system.progressImage`
*/
intermediateImage?: InvokeAI._Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
};
@ -42,6 +62,7 @@ export interface GalleryState {
}
const initialState: GalleryState = {
selectedImageName: '',
currentImageUuid: '',
galleryImageMinimumWidth: 64,
galleryImageObjectFit: 'cover',
@ -69,7 +90,10 @@ export const gallerySlice = createSlice({
name: 'gallery',
initialState,
reducers: {
setCurrentImage: (state, action: PayloadAction<InvokeAI.Image>) => {
imageSelected: (state, action: PayloadAction<string>) => {
state.selectedImageName = action.payload;
},
setCurrentImage: (state, action: PayloadAction<InvokeAI._Image>) => {
state.currentImage = action.payload;
state.currentImageUuid = action.payload.uuid;
},
@ -124,7 +148,7 @@ export const gallerySlice = createSlice({
addImage: (
state,
action: PayloadAction<{
image: InvokeAI.Image;
image: InvokeAI._Image;
category: GalleryCategory;
}>
) => {
@ -150,7 +174,10 @@ export const gallerySlice = createSlice({
setIntermediateImage: (
state,
action: PayloadAction<
InvokeAI.Image & { boundingBox?: IRect; generationMode?: InvokeTabName }
InvokeAI._Image & {
boundingBox?: IRect;
generationMode?: InvokeTabName;
}
>
) => {
state.intermediateImage = action.payload;
@ -252,9 +279,31 @@ export const gallerySlice = createSlice({
state.shouldUseSingleGalleryColumn = action.payload;
},
},
extraReducers(builder) {
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
if (isImageOutput(data.result)) {
state.selectedImageName = data.result.image.image_name;
state.intermediateImage = undefined;
}
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location } = action.payload;
const imageName = location.split('/').pop() || '';
state.selectedImageName = imageName;
});
},
});
export const {
imageSelected,
addImage,
clearIntermediateImage,
removeImage,

View File

@ -0,0 +1,12 @@
import { ResultsState } from './resultsSlice';
/**
* Results slice persist blacklist
*
* Currently blacklisting results slice entirely, see persist config in store.ts
*/
const itemsToBlacklist: (keyof ResultsState)[] = [];
export const resultsBlacklist = itemsToBlacklist.map(
(blacklistItem) => `results.${blacklistItem}`
);

View File

@ -0,0 +1,139 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { invocationComplete } from 'services/events/actions';
import { RootState } from 'app/store';
import {
receivedResultImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { isImageOutput } from 'services/types/guards';
import {
buildImageUrls,
extractTimestampFromImageName,
} from 'services/util/deserializeImageField';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
// use `createEntityAdapter` to create a slice for results images
// https://redux-toolkit.js.org/api/createEntityAdapter#overview
// the "Entity" is InvokeAI.ResultImage, while the "entities" are instances of that type
export const resultsAdapter = createEntityAdapter<Image>({
// Provide a callback to get a stable, unique identifier for each entity. This defaults to
// `(item) => item.id`, but for our result images, the `name` is the unique identifier.
selectId: (image) => image.name,
// Order all images by their time (in descending order)
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
// This type is intersected with the Entity type to create the shape of the state
type AdditionalResultsState = {
// these are a bit misleading; they refer to sessions, not results, but we don't have a route
// to list all images directly at this time...
page: number; // current page we are on
pages: number; // the total number of pages available
isLoading: boolean; // whether we are loading more images or not, mostly a placeholder
nextPage: number; // the next page to request
};
export const initialResultsState =
resultsAdapter.getInitialState<AdditionalResultsState>({
// provide the additional initial state
page: 0,
pages: 0,
isLoading: false,
nextPage: 0,
});
export type ResultsState = typeof initialResultsState;
const resultsSlice = createSlice({
name: 'results',
initialState: initialResultsState,
reducers: {
// the adapter provides some helper reducers; see the docs for all of them
// can use them as helper functions within a reducer, or use the function itself as a reducer
// here we just use the function itself as the reducer. we'll call this on `invocation_complete`
// to add a single result
resultAdded: resultsAdapter.upsertOne,
},
extraReducers: (builder) => {
// here we can respond to a fulfilled call of the `getNextResultsPage` thunk
// because we pass in the fulfilled thunk action creator, everything is typed
/**
* Received Result Images Page - PENDING
*/
builder.addCase(receivedResultImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Result Images Page - FULFILLED
*/
builder.addCase(receivedResultImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const resultImages = items.map((image) =>
deserializeImageResponse(image)
);
// use the adapter reducer to append all the results to state
resultsAdapter.addMany(state, resultImages);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Invocation Complete
*/
builder.addCase(invocationComplete, (state, action) => {
const { data } = action.payload;
const { result, node, graph_execution_state_id } = data;
if (isImageOutput(result)) {
const name = result.image.image_name;
const type = result.image.image_type;
const { url, thumbnail } = buildImageUrls(type, name);
const timestamp = extractTimestampFromImageName(name);
const image: Image = {
name,
type,
url,
thumbnail,
metadata: {
created: timestamp,
width: result.width, // TODO: add tese dimensions
height: result.height,
invokeai: {
session_id: graph_execution_state_id,
...(node ? { node } : {}),
},
},
};
resultsAdapter.addOne(state, image);
}
});
},
});
// Create a set of memoized selectors based on the location of this entity state
// to be used as selectors in a `useAppSelector()` call
export const {
selectAll: selectResultsAll,
selectById: selectResultsById,
selectEntities: selectResultsEntities,
selectIds: selectResultsIds,
selectTotal: selectResultsTotal,
} = resultsAdapter.getSelectors<RootState>((state) => state.results);
export const { resultAdded } = resultsSlice.actions;
export default resultsSlice.reducer;

View File

@ -1,54 +0,0 @@
import { AnyAction, ThunkAction } from '@reduxjs/toolkit';
import * as InvokeAI from 'app/invokeai';
import { RootState } from 'app/store';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { setInitialImage } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { v4 as uuidv4 } from 'uuid';
import { addImage } from '../gallerySlice';
type UploadImageConfig = {
imageFile: File;
};
export const uploadImage =
(
config: UploadImageConfig
): ThunkAction<void, RootState, unknown, AnyAction> =>
async (dispatch, getState) => {
const { imageFile } = config;
const state = getState() as RootState;
const activeTabName = activeTabNameSelector(state);
const formData = new FormData();
formData.append('file', imageFile, imageFile.name);
formData.append(
'data',
JSON.stringify({
kind: 'init',
})
);
const response = await fetch(`${window.location.origin}/upload`, {
method: 'POST',
body: formData,
});
const image = (await response.json()) as InvokeAI.ImageUploadResponse;
const newImage: InvokeAI.Image = {
uuid: uuidv4(),
category: 'user',
...image,
};
dispatch(addImage({ image: newImage, category: 'user' }));
if (activeTabName === 'unifiedCanvas') {
dispatch(setInitialCanvasImage(newImage));
} else if (activeTabName === 'img2img') {
dispatch(setInitialImage(newImage));
}
};

View File

@ -0,0 +1,12 @@
import { UploadsState } from './uploadsSlice';
/**
* Uploads slice persist blacklist
*
* Currently blacklisting uploads slice entirely, see persist config in store.ts
*/
const itemsToBlacklist: (keyof UploadsState)[] = [];
export const uploadsBlacklist = itemsToBlacklist.map(
(blacklistItem) => `uploads.${blacklistItem}`
);

View File

@ -0,0 +1,87 @@
import { createEntityAdapter, createSlice } from '@reduxjs/toolkit';
import { Image } from 'app/invokeai';
import { RootState } from 'app/store';
import {
receivedUploadImagesPage,
IMAGES_PER_PAGE,
} from 'services/thunks/gallery';
import { imageUploaded } from 'services/thunks/image';
import { deserializeImageResponse } from 'services/util/deserializeImageResponse';
export const uploadsAdapter = createEntityAdapter<Image>({
selectId: (image) => image.name,
sortComparer: (a, b) => b.metadata.created - a.metadata.created,
});
type AdditionalUploadsState = {
page: number;
pages: number;
isLoading: boolean;
nextPage: number;
};
const initialUploadsState =
uploadsAdapter.getInitialState<AdditionalUploadsState>({
page: 0,
pages: 0,
nextPage: 0,
isLoading: false,
});
export type UploadsState = typeof initialUploadsState;
const uploadsSlice = createSlice({
name: 'uploads',
initialState: initialUploadsState,
reducers: {
uploadAdded: uploadsAdapter.addOne,
},
extraReducers: (builder) => {
/**
* Received Upload Images Page - PENDING
*/
builder.addCase(receivedUploadImagesPage.pending, (state) => {
state.isLoading = true;
});
/**
* Received Upload Images Page - FULFILLED
*/
builder.addCase(receivedUploadImagesPage.fulfilled, (state, action) => {
const { items, page, pages } = action.payload;
const images = items.map((image) => deserializeImageResponse(image));
uploadsAdapter.addMany(state, images);
state.page = page;
state.pages = pages;
state.nextPage = items.length < IMAGES_PER_PAGE ? page : page + 1;
state.isLoading = false;
});
/**
* Upload Image - FULFILLED
*/
builder.addCase(imageUploaded.fulfilled, (state, action) => {
const { location, response } = action.payload;
const uploadedImage = deserializeImageResponse(response);
uploadsAdapter.addOne(state, uploadedImage);
});
},
});
export const {
selectAll: selectUploadsAll,
selectById: selectUploadsById,
selectEntities: selectUploadsEntities,
selectIds: selectUploadsIds,
selectTotal: selectUploadsTotal,
} = uploadsAdapter.getSelectors<RootState>((state) => state.uploads);
export const { uploadAdded } = uploadsSlice.actions;
export default uploadsSlice.reducer;

View File

@ -1,9 +1,10 @@
import * as React from 'react';
import { TransformComponent, useTransformContext } from 'react-zoom-pan-pinch';
import * as InvokeAI from 'app/invokeai';
import { useGetUrl } from 'common/util/getUrl';
type ReactPanZoomProps = {
image: InvokeAI.Image;
image: InvokeAI._Image;
styleClass?: string;
alt?: string;
ref?: React.Ref<HTMLImageElement>;
@ -22,6 +23,7 @@ export default function ReactPanZoomImage({
scaleY,
}: ReactPanZoomProps) {
const { centerView } = useTransformContext();
const { getUrl } = useGetUrl();
return (
<TransformComponent
@ -35,7 +37,7 @@ export default function ReactPanZoomImage({
transform: `rotate(${rotation}deg) scaleX(${scaleX}) scaleY(${scaleY})`,
width: '100%',
}}
src={image.url}
src={getUrl(image.url)}
alt={alt}
ref={ref}
className={styleClass ? styleClass : ''}

View File

@ -0,0 +1,10 @@
import { LightboxState } from './lightboxSlice';
/**
* Lightbox slice persist blacklist
*/
const itemsToBlacklist: (keyof LightboxState)[] = ['isLightboxOpen'];
export const lightboxBlacklist = itemsToBlacklist.map(
(blacklistItem) => `lightbox.${blacklistItem}`
);

View File

@ -0,0 +1,63 @@
import { v4 as uuidv4 } from 'uuid';
import 'reactflow/dist/style.css';
import { useCallback } from 'react';
import {
Tooltip,
Menu,
MenuButton,
MenuList,
MenuItem,
IconButton,
} from '@chakra-ui/react';
import { FaPlus } from 'react-icons/fa';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { nodeAdded } from '../store/nodesSlice';
import { cloneDeep, map } from 'lodash';
import { RootState } from 'app/store';
import { useBuildInvocation } from '../hooks/useBuildInvocation';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/hooks/useToastWatcher';
export const AddNodeMenu = () => {
const dispatch = useAppDispatch();
const invocationTemplates = useAppSelector(
(state: RootState) => state.nodes.invocationTemplates
);
const buildInvocation = useBuildInvocation();
const addNode = useCallback(
(nodeType: string) => {
const invocation = buildInvocation(nodeType);
if (!invocation) {
const toast = makeToast({
status: 'error',
title: `Unknown Invocation type ${nodeType}`,
});
dispatch(addToast(toast));
return;
}
dispatch(nodeAdded(invocation));
},
[dispatch, buildInvocation]
);
return (
<Menu>
<MenuButton as={IconButton} aria-label="Add Node" icon={<FaPlus />} />
<MenuList>
{map(invocationTemplates, ({ title, description, type }, key) => {
return (
<Tooltip key={key} label={description} placement="end" hasArrow>
<MenuItem onClick={() => addNode(type)}>{title}</MenuItem>
</Tooltip>
);
})}
</MenuList>
</Menu>
);
};

View File

@ -0,0 +1,69 @@
import { Tooltip } from '@chakra-ui/react';
import { CSSProperties, useMemo } from 'react';
import {
Handle,
Position,
Connection,
HandleType,
useReactFlow,
} from 'reactflow';
import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants';
// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles';
import { InputFieldTemplate, OutputFieldTemplate } from '../types/types';
const handleBaseStyles: CSSProperties = {
position: 'absolute',
width: '1rem',
height: '1rem',
borderWidth: 0,
};
const inputHandleStyles: CSSProperties = {
left: '-1.7rem',
};
const outputHandleStyles: CSSProperties = {
right: '-1.7rem',
};
const requiredConnectionStyles: CSSProperties = {
boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)',
};
type FieldHandleProps = {
nodeId: string;
field: InputFieldTemplate | OutputFieldTemplate;
isValidConnection: (connection: Connection) => boolean;
handleType: HandleType;
styles?: CSSProperties;
};
export const FieldHandle = (props: FieldHandleProps) => {
const { nodeId, field, isValidConnection, handleType, styles } = props;
const { name, title, type, description } = field;
return (
<Tooltip
key={name}
label={type}
placement={handleType === 'target' ? 'start' : 'end'}
hasArrow
openDelay={HANDLE_TOOLTIP_OPEN_DELAY}
>
<Handle
type={handleType}
id={name}
isValidConnection={isValidConnection}
position={handleType === 'target' ? Position.Left : Position.Right}
style={{
backgroundColor: FIELDS[type].colorCssVar,
...styles,
...handleBaseStyles,
...(handleType === 'target' ? inputHandleStyles : outputHandleStyles),
// ...(inputRequirement === 'always' ? requiredConnectionStyles : {}),
// ...connectionEventStyles,
}}
/>
</Tooltip>
);
};

View File

@ -0,0 +1,18 @@
import 'reactflow/dist/style.css';
import { Tooltip, Badge, HStack } from '@chakra-ui/react';
import { map } from 'lodash';
import { FIELDS } from '../types/constants';
export const FieldTypeLegend = () => {
return (
<HStack>
{map(FIELDS, ({ title, description, color }, key) => (
<Tooltip key={key} label={description}>
<Badge colorScheme={color} sx={{ userSelect: 'none' }}>
{title}
</Badge>
</Tooltip>
))}
</HStack>
);
};

View File

@ -0,0 +1,104 @@
import {
Background,
Controls,
MiniMap,
OnConnect,
OnEdgesChange,
OnNodesChange,
ReactFlow,
ConnectionLineType,
OnConnectStart,
OnConnectEnd,
Panel,
} from 'reactflow';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import {
connectionEnded,
connectionMade,
connectionStarted,
edgesChanged,
nodesChanged,
} from '../store/nodesSlice';
import { useCallback } from 'react';
import { InvocationComponent } from './InvocationComponent';
import { AddNodeMenu } from './AddNodeMenu';
import { FieldTypeLegend } from './FieldTypeLegend';
import { Button } from '@chakra-ui/react';
import { nodesGraphBuilt } from 'services/thunks/session';
const nodeTypes = { invocation: InvocationComponent };
export const Flow = () => {
const dispatch = useAppDispatch();
const nodes = useAppSelector((state: RootState) => state.nodes.nodes);
const edges = useAppSelector((state: RootState) => state.nodes.edges);
const onNodesChange: OnNodesChange = useCallback(
(changes) => {
dispatch(nodesChanged(changes));
},
[dispatch]
);
const onEdgesChange: OnEdgesChange = useCallback(
(changes) => {
dispatch(edgesChanged(changes));
},
[dispatch]
);
const onConnectStart: OnConnectStart = useCallback(
(event, params) => {
dispatch(connectionStarted(params));
},
[dispatch]
);
const onConnect: OnConnect = useCallback(
(connection) => {
dispatch(connectionMade(connection));
},
[dispatch]
);
const onConnectEnd: OnConnectEnd = useCallback(
(event) => {
dispatch(connectionEnded());
},
[dispatch]
);
const handleInvoke = useCallback(() => {
dispatch(nodesGraphBuilt());
}, [dispatch]);
return (
<ReactFlow
nodeTypes={nodeTypes}
nodes={nodes}
edges={edges}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
defaultEdgeOptions={{
style: { strokeWidth: 2 },
}}
>
<Panel position="top-left">
<AddNodeMenu />
</Panel>
<Panel position="top-center">
<Button onClick={handleInvoke}>Will it blend?</Button>
</Panel>
<Panel position="top-right">
<FieldTypeLegend />
</Panel>
<Background />
<Controls />
<MiniMap nodeStrokeWidth={3} zoomable pannable />
</ReactFlow>
);
};

View File

@ -0,0 +1,107 @@
import { Box } from '@chakra-ui/react';
import { InputFieldTemplate, InputFieldValue } from '../types/types';
import { ArrayInputFieldComponent } from './fields/ArrayInputField.tsx';
import { BooleanInputFieldComponent } from './fields/BooleanInputFieldComponent';
import { EnumInputFieldComponent } from './fields/EnumInputFieldComponent';
import { ImageInputFieldComponent } from './fields/ImageInputFieldComponent';
import { LatentsInputFieldComponent } from './fields/LatentsInputFieldComponent';
import { ModelInputFieldComponent } from './fields/ModelInputFieldComponent';
import { NumberInputFieldComponent } from './fields/NumberInputFieldComponent';
import { StringInputFieldComponent } from './fields/StringInputFieldComponent';
type InputFieldComponentProps = {
nodeId: string;
field: InputFieldValue;
template: InputFieldTemplate;
};
// build an individual input element based on the schema
export const InputFieldComponent = (props: InputFieldComponentProps) => {
const { nodeId, field, template } = props;
const { type, value } = field;
if (type === 'string' && template.type === 'string') {
return (
<StringInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'boolean' && template.type === 'boolean') {
return (
<BooleanInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (
(type === 'integer' && template.type === 'integer') ||
(type === 'float' && template.type === 'float')
) {
return (
<NumberInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'enum' && template.type === 'enum') {
return (
<EnumInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'image' && template.type === 'image') {
return (
<ImageInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'latents' && template.type === 'latents') {
return (
<LatentsInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'model' && template.type === 'model') {
return (
<ModelInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
if (type === 'array' && template.type === 'array') {
return (
<ArrayInputFieldComponent
nodeId={nodeId}
field={field}
template={template}
/>
);
}
return <Box p={2}>Unknown field type: {type}</Box>;
};

View File

@ -0,0 +1,243 @@
import { NodeProps, useReactFlow } from 'reactflow';
import {
Box,
Flex,
FormControl,
FormLabel,
Heading,
HStack,
Tooltip,
Icon,
Code,
Text,
} from '@chakra-ui/react';
import { FaExclamationCircle, FaInfoCircle } from 'react-icons/fa';
import { InvocationValue } from '../types/types';
import { InputFieldComponent } from './InputFieldComponent';
import { FieldHandle } from './FieldHandle';
import { isEqual, map, size } from 'lodash';
import { memo, useMemo, useRef } from 'react';
import { useIsValidConnection } from '../hooks/useIsValidConnection';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { useGetInvocationTemplate } from '../hooks/useInvocationTemplate';
const connectedInputFieldsSelector = createSelector(
[(state: RootState) => state.nodes.edges],
(edges) => {
// return edges.map((e) => e.targetHandle);
return edges;
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const InvocationComponent = memo((props: NodeProps<InvocationValue>) => {
const { id: nodeId, data, selected } = props;
const { type, inputs, outputs } = data;
const isValidConnection = useIsValidConnection();
const connectedInputs = useAppSelector(connectedInputFieldsSelector);
const getInvocationTemplate = useGetInvocationTemplate();
// TODO: determine if a field/handle is connected and disable the input if so
const template = useRef(getInvocationTemplate(type));
if (!template.current) {
return (
<Box
sx={{
padding: 4,
bg: 'base.800',
borderRadius: 'md',
boxShadow: 'dark-lg',
borderWidth: 2,
borderColor: selected ? 'base.400' : 'transparent',
}}
>
<Flex sx={{ alignItems: 'center', justifyContent: 'center' }}>
<Icon color="base.400" boxSize={32} as={FaExclamationCircle}></Icon>
</Flex>
</Box>
);
}
return (
<Box
sx={{
padding: 4,
bg: 'base.800',
borderRadius: 'md',
boxShadow: 'dark-lg',
borderWidth: 2,
borderColor: selected ? 'base.400' : 'transparent',
}}
>
<Flex flexDirection="column" gap={2}>
<>
<Code>{nodeId}</Code>
<HStack justifyContent="space-between">
<Heading size="sm" fontWeight={500} color="base.100">
{template.current.title}
</Heading>
<Tooltip
label={template.current.description}
placement="top"
hasArrow
shouldWrapChildren
>
<Icon color="base.300" as={FaInfoCircle} />
</Tooltip>
</HStack>
{map(inputs, (input, i) => {
const { id: fieldId } = input;
const inputTemplate = template.current?.inputs[input.name];
if (!inputTemplate) {
return (
<Box
key={fieldId}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
sx={{
borderColor: 'error.400',
}}
>
<FormControl isDisabled={true}>
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>Unknown input: {input.name}</FormLabel>
</HStack>
</FormControl>
</Box>
);
}
const isConnected = Boolean(
connectedInputs.filter((connectedInput) => {
return (
connectedInput.target === nodeId &&
connectedInput.targetHandle === input.name
);
}).length
);
return (
<Box
key={fieldId}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
sx={{
borderColor:
!isConnected &&
['always', 'connectionOnly'].includes(
String(inputTemplate?.inputRequirement)
) &&
input.value === undefined
? 'warning.400'
: undefined,
}}
>
<FormControl isDisabled={isConnected}>
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>{inputTemplate?.title}</FormLabel>
<Tooltip
label={inputTemplate?.description}
placement="top"
hasArrow
shouldWrapChildren
>
<Icon color="base.400" as={FaInfoCircle} />
</Tooltip>
</HStack>
<InputFieldComponent
nodeId={nodeId}
field={input}
template={inputTemplate}
/>
</FormControl>
{!['never', 'directOnly'].includes(
inputTemplate?.inputRequirement ?? ''
) && (
<FieldHandle
nodeId={nodeId}
field={inputTemplate}
isValidConnection={isValidConnection}
handleType="target"
/>
)}
</Box>
);
})}
{map(outputs).map((output, i) => {
const outputTemplate = template.current?.outputs[output.name];
const isConnected = Boolean(
connectedInputs.filter((connectedInput) => {
return (
connectedInput.source === nodeId &&
connectedInput.sourceHandle === output.name
);
}).length
);
if (!outputTemplate) {
return (
<Box
key={output.id}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
sx={{
borderColor: 'error.400',
}}
>
<FormControl isDisabled={true}>
<HStack justifyContent="space-between" alignItems="center">
<FormLabel>Unknown output: {output.name}</FormLabel>
</HStack>
</FormControl>
</Box>
);
}
return (
<Box
key={output.id}
position="relative"
p={2}
borderWidth={1}
borderRadius="md"
>
<FormControl isDisabled={isConnected}>
<FormLabel textAlign="end">
{outputTemplate?.title} Output
</FormLabel>
</FormControl>
<FieldHandle
key={output.id}
nodeId={nodeId}
field={outputTemplate}
isValidConnection={isValidConnection}
handleType="source"
/>
</Box>
);
})}
</>
</Flex>
<Flex></Flex>
</Box>
);
});
InvocationComponent.displayName = 'InvocationComponent';

View File

@ -0,0 +1,46 @@
import 'reactflow/dist/style.css';
import { Box } from '@chakra-ui/react';
import { ReactFlowProvider } from 'reactflow';
import { Flow } from './Flow';
import { useAppSelector } from 'app/storeHooks';
import { RootState } from 'app/store';
import { buildNodesGraph } from '../util/nodesGraphBuilder/buildNodesGraph';
const NodeEditor = () => {
const state = useAppSelector((state: RootState) => state);
const graph = buildNodesGraph(state);
return (
<Box
sx={{
position: 'relative',
width: 'full',
height: 'full',
borderRadius: 'md',
bg: 'base.850',
}}
>
<ReactFlowProvider>
<Flow />
</ReactFlowProvider>
<Box
as="pre"
fontFamily="monospace"
position="absolute"
top={2}
left={2}
width="full"
height="full"
userSelect="none"
pointerEvents="none"
opacity={0.7}
>
<Box w="50%">{JSON.stringify(graph, null, 2)}</Box>
</Box>
</Box>
);
};
export default NodeEditor;

View File

@ -0,0 +1,14 @@
import {
ArrayInputFieldTemplate,
ArrayInputFieldValue,
} from 'features/nodes/types';
import { FaImage, FaList } from 'react-icons/fa';
import { FieldComponentProps } from './types';
export const ArrayInputFieldComponent = (
props: FieldComponentProps<ArrayInputFieldValue, ArrayInputFieldTemplate>
) => {
const { nodeId, field } = props;
return <FaList />;
};

View File

@ -0,0 +1,31 @@
import { Switch } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
BooleanInputFieldTemplate,
BooleanInputFieldValue,
} from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const BooleanInputFieldComponent = (
props: FieldComponentProps<BooleanInputFieldValue, BooleanInputFieldTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.checked,
})
);
};
return (
<Switch onChange={handleValueChanged} isChecked={field.value}></Switch>
);
};

View File

@ -0,0 +1,35 @@
import { Select } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
EnumInputFieldTemplate,
EnumInputFieldValue,
} from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const EnumInputFieldComponent = (
props: FieldComponentProps<EnumInputFieldValue, EnumInputFieldTemplate>
) => {
const { nodeId, field, template } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{template.options.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};

View File

@ -0,0 +1,64 @@
import { Box, Image, Icon, Flex } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import SelectImagePlaceholder from 'common/components/SelectImagePlaceholder';
import { useGetUrl } from 'common/util/getUrl';
import useGetImageByNameAndType from 'features/gallery/hooks/useGetImageByName';
import useGetImageByUuid from 'features/gallery/hooks/useGetImageByUuid';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ImageInputFieldTemplate,
ImageInputFieldValue,
} from 'features/nodes/types';
import { DragEvent, useCallback, useState } from 'react';
import { FaImage } from 'react-icons/fa';
import { ImageType } from 'services/api';
import { FieldComponentProps } from './types';
export const ImageInputFieldComponent = (
props: FieldComponentProps<ImageInputFieldValue, ImageInputFieldTemplate>
) => {
const { nodeId, field } = props;
const { value } = field;
const getImageByNameAndType = useGetImageByNameAndType();
const dispatch = useAppDispatch();
const [url, setUrl] = useState<string>();
const { getUrl } = useGetUrl();
const handleDrop = useCallback(
(e: DragEvent<HTMLDivElement>) => {
const name = e.dataTransfer.getData('invokeai/imageName');
const type = e.dataTransfer.getData('invokeai/imageType') as ImageType;
if (!name || !type) {
return;
}
const image = getImageByNameAndType(name, type);
if (!image) {
return;
}
setUrl(image.url);
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: {
image_name: name,
image_type: type,
},
})
);
},
[getImageByNameAndType, dispatch, field.name, nodeId]
);
return (
<Box onDrop={handleDrop}>
<Image src={getUrl(url)} fallback={<SelectImagePlaceholder />} />
</Box>
);
};

View File

@ -0,0 +1,13 @@
import {
LatentsInputFieldTemplate,
LatentsInputFieldValue,
} from 'features/nodes/types';
import { FieldComponentProps } from './types';
export const LatentsInputFieldComponent = (
props: FieldComponentProps<LatentsInputFieldValue, LatentsInputFieldTemplate>
) => {
const { nodeId, field } = props;
return null;
};

View File

@ -0,0 +1,57 @@
import { Select } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
ModelInputFieldTemplate,
ModelInputFieldValue,
} from 'features/nodes/types';
import {
selectModelsById,
selectModelsIds,
} from 'features/system/store/modelSlice';
import { isEqual, map } from 'lodash';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
const availableModelsSelector = createSelector(
[selectModelsIds],
(allModelNames) => {
return { allModelNames };
// return map(modelList, (_, name) => name);
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export const ModelInputFieldComponent = (
props: FieldComponentProps<ModelInputFieldValue, ModelInputFieldTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { allModelNames } = useAppSelector(availableModelsSelector);
const handleValueChanged = (e: ChangeEvent<HTMLSelectElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.value,
})
);
};
return (
<Select onChange={handleValueChanged} value={field.value}>
{allModelNames.map((option) => (
<option key={option}>{option}</option>
))}
</Select>
);
};

View File

@ -0,0 +1,41 @@
import {
NumberDecrementStepper,
NumberIncrementStepper,
NumberInput,
NumberInputField,
NumberInputStepper,
} from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
FloatInputFieldTemplate,
FloatInputFieldValue,
IntegerInputFieldTemplate,
IntegerInputFieldValue,
} from 'features/nodes/types';
import { FieldComponentProps } from './types';
export const NumberInputFieldComponent = (
props: FieldComponentProps<
IntegerInputFieldValue | FloatInputFieldValue,
IntegerInputFieldTemplate | FloatInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (_: string, value: number) => {
dispatch(fieldValueChanged({ nodeId, fieldName: field.name, value }));
};
return (
<NumberInput onChange={handleValueChanged} value={field.value}>
<NumberInputField />
<NumberInputStepper>
<NumberIncrementStepper />
<NumberDecrementStepper />
</NumberInputStepper>
</NumberInput>
);
};

View File

@ -0,0 +1,29 @@
import { Input } from '@chakra-ui/react';
import { useAppDispatch } from 'app/storeHooks';
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
import {
StringInputFieldTemplate,
StringInputFieldValue,
} from 'features/nodes/types';
import { ChangeEvent } from 'react';
import { FieldComponentProps } from './types';
export const StringInputFieldComponent = (
props: FieldComponentProps<StringInputFieldValue, StringInputFieldTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const handleValueChanged = (e: ChangeEvent<HTMLInputElement>) => {
dispatch(
fieldValueChanged({
nodeId,
fieldName: field.name,
value: e.target.value,
})
);
};
return <Input onChange={handleValueChanged} value={field.value}></Input>;
};

View File

@ -0,0 +1,10 @@
import { InputFieldTemplate, InputFieldValue } from 'features/nodes/types';
export type FieldComponentProps<
V extends InputFieldValue,
T extends InputFieldTemplate
> = {
nodeId: string;
field: V;
template: T;
};

View File

@ -0,0 +1,52 @@
export const iterationGraph = {
nodes: {
'0': {
id: '0',
type: 'range',
start: 0,
stop: 5,
step: 1,
},
'1': {
collection: [],
id: '1',
index: 0,
type: 'iterate',
},
'2': {
cfg_scale: 7.5,
height: 512,
id: '2',
model: '',
progress_images: false,
prompt: 'dog',
sampler_name: 'k_lms',
seamless: false,
steps: 11,
type: 'txt2img',
width: 512,
},
},
edges: [
{
source: {
field: 'collection',
node_id: '0',
},
destination: {
field: 'collection',
node_id: '1',
},
},
{
source: {
field: 'item',
node_id: '1',
},
destination: {
field: 'seed',
node_id: '2',
},
},
],
};

View File

@ -0,0 +1,78 @@
import { RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { reduce } from 'lodash';
import { Node } from 'reactflow';
import { AnyInvocationType } from 'services/events/types';
import { v4 as uuidv4 } from 'uuid';
import {
InputFieldValue,
InvocationValue,
OutputFieldValue,
} from '../types/types';
import { buildInputFieldValue } from '../util/fieldValueBuilders';
export const useBuildInvocation = () => {
const invocationTemplates = useAppSelector(
(state: RootState) => state.nodes.invocationTemplates
);
return (type: AnyInvocationType) => {
const template = invocationTemplates[type];
if (template === undefined) {
console.error(`Unable to find template ${type}.`);
return;
}
const nodeId = uuidv4();
const inputs = reduce(
template.inputs,
(inputsAccumulator, inputTemplate, inputName) => {
const fieldId = uuidv4();
const inputFieldValue: InputFieldValue = buildInputFieldValue(
fieldId,
inputTemplate
);
inputsAccumulator[inputName] = inputFieldValue;
return inputsAccumulator;
},
{} as Record<string, InputFieldValue>
);
const outputs = reduce(
template.outputs,
(outputsAccumulator, outputTemplate, outputName) => {
const fieldId = uuidv4();
const outputFieldValue: OutputFieldValue = {
id: fieldId,
name: outputName,
type: outputTemplate.type,
};
outputsAccumulator[outputName] = outputFieldValue;
return outputsAccumulator;
},
{} as Record<string, OutputFieldValue>
);
const invocation: Node<InvocationValue> = {
id: nodeId,
type: 'invocation',
position: { x: 0, y: 0 },
data: {
id: nodeId,
type,
inputs,
outputs,
},
};
return invocation;
};
};

View File

@ -0,0 +1,16 @@
import { useAppSelector } from 'app/storeHooks';
import { invocationTemplatesSelector } from '../store/selectors/invocationTemplatesSelector';
export const useGetInvocationTemplate = () => {
const invocationTemplates = useAppSelector(invocationTemplatesSelector);
return (invocationType: string) => {
const template = invocationTemplates[invocationType];
if (!template) {
return;
}
return template;
};
};

View File

@ -0,0 +1,93 @@
import { useCallback } from 'react';
import { Connection, Node, useReactFlow } from 'reactflow';
import graphlib from '@dagrejs/graphlib';
import { InvocationValue } from '../types/types';
export const useIsValidConnection = () => {
const flow = useReactFlow();
// Check if an in-progress connection is valid
const isValidConnection = useCallback(
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
const edges = flow.getEdges();
const nodes = flow.getNodes();
return true;
// Connection must have valid targets
if (!(source && sourceHandle && target && targetHandle)) {
return false;
}
// Connection is invalid if target already has a connection
if (
edges.find((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
})
) {
return false;
}
// Find the source and target nodes
const sourceNode = flow.getNode(source) as Node<InvocationValue>;
const targetNode = flow.getNode(target) as Node<InvocationValue>;
// Conditional guards against undefined nodes/handles
if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) {
return false;
}
// Connection types must be the same for a connection
if (
sourceNode.data.outputs[sourceHandle].type !==
targetNode.data.inputs[targetHandle].type
) {
return false;
}
// Graphs much be acyclic (no loops!)
/**
* TODO: use `graphlib.alg.findCycles()` to identify strong connections
*
* this validation func only runs when the cursor hits the second handle of the connection,
* and only on that second handle - so it cannot tell us exhaustively which connections
* are valid.
*
* ideally, we check when the connection starts to calculate all invalid handles at once.
*
* requires making a new graphlib graph - and calling `findCycles()` - for each potential
* handle. instead of using the `isValidConnection` prop, it would use the `onConnectStart`
* prop.
*
* the strong connections should be stored in global state.
*
* then, `isValidConnection` would simple loop through the strong connections and if the
* source and target are in a single strong connection, return false.
*
* and also, we can use this knowledge to style every handle when a connection starts,
* which is otherwise not possible.
*/
// build a graphlib graph
const g = new graphlib.Graph();
nodes.forEach((n) => {
g.setNode(n.id);
});
edges.forEach((e) => {
g.setEdge(e.source, e.target);
});
// Add the candidate edge to the graph
g.setEdge(source, target);
return graphlib.alg.isAcyclic(g);
},
[flow]
);
return isValidConnection;
};

View File

@ -0,0 +1,10 @@
import { NodesState } from './nodesSlice';
/**
* Nodes slice persist blacklist
*/
const itemsToBlacklist: (keyof NodesState)[] = ['schema', 'invocations'];
export const nodesBlacklist = itemsToBlacklist.map(
(blacklistItem) => `nodes.${blacklistItem}`
);

View File

@ -0,0 +1,103 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit';
import { OpenAPIV3 } from 'openapi-types';
import {
addEdge,
applyEdgeChanges,
applyNodeChanges,
Connection,
Edge,
EdgeChange,
Node,
NodeChange,
OnConnectStartParams,
} from 'reactflow';
import { Graph, ImageField } from 'services/api';
import { receivedOpenAPISchema } from 'services/thunks/schema';
import { isFulfilledAnyGraphBuilt } from 'services/thunks/session';
import { InvocationTemplate, InvocationValue } from '../types/types';
import { parseSchema } from '../util/parseSchema';
export type NodesState = {
nodes: Node<InvocationValue>[];
edges: Edge[];
schema: OpenAPIV3.Document | null;
invocationTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
lastGraph: Graph | null;
};
export const initialNodesState: NodesState = {
nodes: [],
edges: [],
schema: null,
invocationTemplates: {},
connectionStartParams: null,
lastGraph: null,
};
const nodesSlice = createSlice({
name: 'nodes',
initialState: initialNodesState,
reducers: {
nodesChanged: (state, action: PayloadAction<NodeChange[]>) => {
state.nodes = applyNodeChanges(action.payload, state.nodes);
},
nodeAdded: (state, action: PayloadAction<Node<InvocationValue>>) => {
state.nodes.push(action.payload);
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
},
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
state.connectionStartParams = action.payload;
},
connectionMade: (state, action: PayloadAction<Connection>) => {
state.edges = addEdge(action.payload, state.edges);
},
connectionEnded: (state) => {
state.connectionStartParams = null;
},
fieldValueChanged: (
state,
action: PayloadAction<{
nodeId: string;
fieldName: string;
value:
| string
| number
| boolean
| Pick<ImageField, 'image_name' | 'image_type'>
| undefined;
}>
) => {
const { nodeId, fieldName, value } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
if (nodeIndex > -1) {
state.nodes[nodeIndex].data.inputs[fieldName].value = value;
}
},
},
extraReducers(builder) {
builder.addCase(receivedOpenAPISchema.fulfilled, (state, action) => {
state.schema = action.payload;
state.invocationTemplates = parseSchema(action.payload);
});
builder.addMatcher(isFulfilledAnyGraphBuilt, (state, action) => {
state.lastGraph = action.payload;
});
},
});
export const {
nodesChanged,
edgesChanged,
nodeAdded,
fieldValueChanged,
connectionMade,
connectionStarted,
connectionEnded,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@ -0,0 +1,7 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store';
export const invocationTemplatesSelector = createSelector(
(state: RootState) => state.nodes,
(nodes) => nodes.invocationTemplates
);

View File

@ -0,0 +1,69 @@
import { getCSSVar } from '@chakra-ui/utils';
import { FieldType, FieldUIConfig } from './types';
export const HANDLE_TOOLTIP_OPEN_DELAY = 500;
export const FIELD_TYPE_MAP: Record<string, FieldType> = {
integer: 'integer',
number: 'float',
string: 'string',
boolean: 'boolean',
enum: 'enum',
ImageField: 'image',
LatentsField: 'latents',
model: 'model',
array: 'array',
};
const COLOR_TOKEN_VALUE = 500;
const getColorTokenCssVariable = (color: string) =>
`var(--invokeai-colors-${color}-${COLOR_TOKEN_VALUE})`;
export const FIELDS: Record<FieldType, FieldUIConfig> = {
integer: {
colorCssVar: getColorTokenCssVariable('red'),
title: 'Integer',
description: 'Integers are whole numbers, without a decimal point.',
},
float: {
colorCssVar: getColorTokenCssVariable('orange'),
title: 'Float',
description: 'Floats are numbers with a decimal point.',
},
string: {
colorCssVar: getColorTokenCssVariable('yellow'),
title: 'String',
description: 'Strings are text.',
},
boolean: {
colorCssVar: getColorTokenCssVariable('green'),
title: 'Boolean',
description: 'Booleans are true or false.',
},
enum: {
colorCssVar: getColorTokenCssVariable('blue'),
title: 'Enum',
description: 'Enums are values that may be one of a number of options.',
},
image: {
colorCssVar: getColorTokenCssVariable('purple'),
title: 'Image',
description: 'Images may be passed between nodes.',
},
latents: {
colorCssVar: getColorTokenCssVariable('pink'),
title: 'Latents',
description: 'Latents may be passed between nodes.',
},
model: {
colorCssVar: getColorTokenCssVariable('teal'),
title: 'Model',
description: 'Models are models.',
},
array: {
colorCssVar: getColorTokenCssVariable('gray'),
title: 'Array',
description: 'TODO: Array type description.',
},
};

View File

@ -0,0 +1,9 @@
import { OpenAPIV3 } from 'openapi-types';
export const isReferenceObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
): obj is OpenAPIV3.ReferenceObject => '$ref' in obj;
export const isSchemaObject = (
obj: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject
): obj is OpenAPIV3.SchemaObject => !('$ref' in obj);

View File

@ -0,0 +1,296 @@
import { OpenAPIV3 } from 'openapi-types';
import { ImageField } from 'services/api';
import { AnyInvocationType } from 'services/events/types';
export type InvocationValue = {
id: string;
type: AnyInvocationType;
inputs: Record<string, InputFieldValue>;
outputs: Record<string, OutputFieldValue>;
};
export type InvocationTemplate = {
/**
* Unique type of the invocation
*/
type: AnyInvocationType;
/**
* Display name of the invocation
*/
title: string;
/**
* Description of the invocation
*/
description: string;
/**
* Invocation tags
*/
tags: string[];
/**
* Array of invocation inputs
*/
inputs: Record<string, InputFieldTemplate>;
// inputs: InputField[];
/**
* Array of the invocation outputs
*/
outputs: Record<string, OutputFieldTemplate>;
// outputs: OutputField[];
};
export type FieldUIConfig = {
colorCssVar: string;
title: string;
description: string;
};
/**
* The valid invocation field types
*/
export type FieldType =
| 'integer'
| 'float'
| 'string'
| 'boolean'
| 'enum'
| 'image'
| 'latents'
| 'model'
| 'array';
/**
* An input field is persisted across reloads as part of the user's local state.
*
* An input field has three properties:
* - `id` a unique identifier
* - `name` the name of the field, which comes from the python dataclass
* - `value` the field's value
*/
export type InputFieldValue =
| IntegerInputFieldValue
| FloatInputFieldValue
| StringInputFieldValue
| BooleanInputFieldValue
| ImageInputFieldValue
| LatentsInputFieldValue
| EnumInputFieldValue
| ModelInputFieldValue
| ArrayInputFieldValue;
/**
* An input field template is generated on each page load from the OpenAPI schema.
*
* The template provides the field type and other field metadata (e.g. title, description,
* maximum length, pattern to match, etc).
*/
export type InputFieldTemplate =
| IntegerInputFieldTemplate
| FloatInputFieldTemplate
| StringInputFieldTemplate
| BooleanInputFieldTemplate
| ImageInputFieldTemplate
| LatentsInputFieldTemplate
| EnumInputFieldTemplate
| ModelInputFieldTemplate
| ArrayInputFieldTemplate;
/**
* An output field is persisted across as part of the user's local state.
*
* An output field has two properties:
* - `id` a unique identifier
* - `name` the name of the field, which comes from the python dataclass
*/
export type OutputFieldValue = FieldValueBase;
/**
* An output field template is generated on each page load from the OpenAPI schema.
*
* The template provides the output field's name, type, title, and description.
*/
export type OutputFieldTemplate = {
name: string;
type: FieldType;
title: string;
description: string;
};
/**
* Indicates when/if this field needs an input.
*/
export type InputRequirement = 'always' | 'never' | 'optional';
/**
* Indicates the kind of input(s) this field may have.
*/
export type InputKind = 'connection' | 'direct' | 'any';
export type FieldValueBase = {
id: string;
name: string;
type: FieldType;
};
export type IntegerInputFieldValue = FieldValueBase & {
type: 'integer';
value?: number;
};
export type FloatInputFieldValue = FieldValueBase & {
type: 'float';
value?: number;
};
export type StringInputFieldValue = FieldValueBase & {
type: 'string';
value?: string;
};
export type BooleanInputFieldValue = FieldValueBase & {
type: 'boolean';
value?: boolean;
};
export type EnumInputFieldValue = FieldValueBase & {
type: 'enum';
value?: number | string;
};
export type LatentsInputFieldValue = FieldValueBase & {
type: 'latents';
value?: undefined;
};
export type ImageInputFieldValue = FieldValueBase & {
type: 'image';
value?: Pick<ImageField, 'image_name' | 'image_type'>;
};
export type ModelInputFieldValue = FieldValueBase & {
type: 'model';
value?: string;
};
export type ArrayInputFieldValue = FieldValueBase & {
type: 'array';
value?: (string | number)[];
};
export type InputFieldTemplateBase = {
name: string;
title: string;
description: string;
type: FieldType;
inputRequirement: InputRequirement;
inputKind: InputKind;
};
export type IntegerInputFieldTemplate = InputFieldTemplateBase & {
type: 'integer';
default: number;
multipleOf?: number;
maximum?: number;
exclusiveMaximum?: boolean;
minimum?: number;
exclusiveMinimum?: boolean;
};
export type FloatInputFieldTemplate = InputFieldTemplateBase & {
type: 'float';
default: number;
multipleOf?: number;
maximum?: number;
exclusiveMaximum?: boolean;
minimum?: number;
exclusiveMinimum?: boolean;
};
export type StringInputFieldTemplate = InputFieldTemplateBase & {
type: 'string';
default: string;
maxLength?: number;
minLength?: number;
pattern?: string;
};
export type BooleanInputFieldTemplate = InputFieldTemplateBase & {
default: boolean;
type: 'boolean';
};
export type ImageInputFieldTemplate = InputFieldTemplateBase & {
default: Pick<ImageField, 'image_name' | 'image_type'>;
type: 'image';
};
export type LatentsInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'latents';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number;
type: 'enum';
enumType: 'string' | 'number';
options: Array<string | number>;
};
export type ModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'model';
};
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: (string | number)[];
type: 'array';
};
/**
* JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES
*/
export type TypeHints = {
[fieldName: string]: FieldType;
};
export type InvocationSchemaExtra = {
output: OpenAPIV3.ReferenceObject; // the output of the invocation
ui?: {
tags?: string[];
type_hints?: TypeHints;
title?: string;
};
title: string;
properties: Omit<
NonNullable<OpenAPIV3.SchemaObject['properties']>,
'type'
> & {
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
default: AnyInvocationType;
};
};
};
export type InvocationSchemaType = {
default: string; // the type of the invocation
};
export type InvocationBaseSchemaObject = Omit<
OpenAPIV3.BaseSchemaObject,
'title' | 'type' | 'properties'
> &
InvocationSchemaExtra;
export interface ArraySchemaObject extends InvocationBaseSchemaObject {
type: OpenAPIV3.ArraySchemaObjectType;
items: OpenAPIV3.ReferenceObject | OpenAPIV3.SchemaObject;
}
export interface NonArraySchemaObject extends InvocationBaseSchemaObject {
type?: OpenAPIV3.NonArraySchemaObjectType;
}
export type InvocationSchemaObject = ArraySchemaObject | NonArraySchemaObject;
export const isInvocationSchemaObject = (
obj: OpenAPIV3.ReferenceObject | InvocationSchemaObject
): obj is InvocationSchemaObject => !('$ref' in obj);

View File

@ -0,0 +1,336 @@
import { reduce } from 'lodash';
import { OpenAPIV3 } from 'openapi-types';
import { FIELD_TYPE_MAP } from '../types/constants';
import { isSchemaObject } from '../types/typeGuards';
import {
BooleanInputFieldTemplate,
EnumInputFieldTemplate,
FloatInputFieldTemplate,
ImageInputFieldTemplate,
IntegerInputFieldTemplate,
LatentsInputFieldTemplate,
StringInputFieldTemplate,
ModelInputFieldTemplate,
InputFieldTemplateBase,
OutputFieldTemplate,
TypeHints,
FieldType,
} from '../types/types';
export type BaseFieldProperties = 'name' | 'title' | 'description';
export type BuildInputFieldArg = {
schemaObject: OpenAPIV3.SchemaObject;
baseField: Omit<
InputFieldTemplateBase,
'type' | 'inputRequirement' | 'inputKind'
>;
};
/**
* Transforms an invocation output ref object to field type.
* @param ref The ref string to transform
* @returns The field type.
*
* @example
* refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField'
*/
export const refObjectToFieldType = (
refObject: OpenAPIV3.ReferenceObject
): keyof typeof FIELD_TYPE_MAP => refObject.$ref.split('/').slice(-1)[0];
const buildIntegerInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IntegerInputFieldTemplate => {
const template: IntegerInputFieldTemplate = {
...baseField,
type: 'integer',
inputRequirement: 'always',
inputKind: 'any',
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildFloatInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): FloatInputFieldTemplate => {
const template: FloatInputFieldTemplate = {
...baseField,
type: 'float',
inputRequirement: 'always',
inputKind: 'any',
default: schemaObject.default ?? 0,
};
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildStringInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): StringInputFieldTemplate => {
const template: StringInputFieldTemplate = {
...baseField,
type: 'string',
inputRequirement: 'always',
inputKind: 'any',
default: schemaObject.default ?? '',
};
if (schemaObject.minLength !== undefined) {
template.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
template.maxLength = schemaObject.maxLength;
}
if (schemaObject.pattern !== undefined) {
template.pattern = schemaObject.pattern;
}
return template;
};
const buildBooleanInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): BooleanInputFieldTemplate => {
const template: BooleanInputFieldTemplate = {
...baseField,
type: 'boolean',
inputRequirement: 'always',
inputKind: 'any',
default: schemaObject.default ?? false,
};
return template;
};
const buildModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ModelInputFieldTemplate => {
const template: ModelInputFieldTemplate = {
...baseField,
type: 'model',
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): ImageInputFieldTemplate => {
const template: ImageInputFieldTemplate = {
...baseField,
type: 'image',
inputRequirement: 'always',
inputKind: 'any',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildLatentsInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): LatentsInputFieldTemplate => {
const template: LatentsInputFieldTemplate = {
...baseField,
type: 'latents',
inputRequirement: 'always',
inputKind: 'connection',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): EnumInputFieldTemplate => {
const options = schemaObject.enum ?? [];
const template: EnumInputFieldTemplate = {
...baseField,
type: 'enum',
enumType: (schemaObject.type as 'string' | 'number') ?? 'string', // TODO: dangerous?
options: options,
inputRequirement: 'always',
inputKind: 'direct',
default: schemaObject.default ?? options[0],
};
return template;
};
export const getFieldType = (
schemaObject: OpenAPIV3.SchemaObject,
name: string,
typeHints?: TypeHints
): FieldType => {
let rawFieldType = '';
if (typeHints && name in typeHints) {
rawFieldType = typeHints[name];
} else if (!schemaObject.type) {
rawFieldType = refObjectToFieldType(
schemaObject.allOf![0] as OpenAPIV3.ReferenceObject
);
} else if (schemaObject.enum) {
rawFieldType = 'enum';
} else if (schemaObject.type) {
rawFieldType = schemaObject.type;
}
const fieldType = FIELD_TYPE_MAP[rawFieldType];
if (!fieldType) {
throw `Field type "${rawFieldType}" is unknown!`;
}
return fieldType;
};
/**
* Builds an input field from an invocation schema property.
* @param schemaObject The schema object
* @returns An input field
*/
export const buildInputFieldTemplate = (
schemaObject: OpenAPIV3.SchemaObject,
name: string,
typeHints?: TypeHints
) => {
const fieldType = getFieldType(schemaObject, name, typeHints);
const baseField = {
name,
title: schemaObject.title ?? '',
description: schemaObject.description ?? '',
};
if (['image'].includes(fieldType)) {
return buildImageInputFieldTemplate({ schemaObject, baseField });
}
if (['latents'].includes(fieldType)) {
return buildLatentsInputFieldTemplate({ schemaObject, baseField });
}
if (['model'].includes(fieldType)) {
return buildModelInputFieldTemplate({ schemaObject, baseField });
}
if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField });
}
if (['integer'].includes(fieldType)) {
return buildIntegerInputFieldTemplate({ schemaObject, baseField });
}
if (['number', 'float'].includes(fieldType)) {
return buildFloatInputFieldTemplate({ schemaObject, baseField });
}
if (['string'].includes(fieldType)) {
return buildStringInputFieldTemplate({ schemaObject, baseField });
}
if (['boolean'].includes(fieldType)) {
return buildBooleanInputFieldTemplate({ schemaObject, baseField });
}
return;
};
/**
* Builds invocation output fields from an invocation's output reference object.
* @param openAPI The OpenAPI schema
* @param refObject The output reference object
* @returns A record of outputs
*/
export const buildOutputFieldTemplates = (
refObject: OpenAPIV3.ReferenceObject,
openAPI: OpenAPIV3.Document,
typeHints?: TypeHints
): Record<string, OutputFieldTemplate> => {
// extract output schema name from ref
const outputSchemaName = refObject.$ref.split('/').slice(-1)[0];
// get the output schema itself
const outputSchema = openAPI.components!.schemas![outputSchemaName];
if (isSchemaObject(outputSchema)) {
const outputFields = reduce(
outputSchema.properties as OpenAPIV3.SchemaObject,
(outputsAccumulator, property, propertyName) => {
if (
!['type', 'id'].includes(propertyName) &&
!['object'].includes(property.type) && // TODO: handle objects?
isSchemaObject(property)
) {
const fieldType = getFieldType(property, propertyName, typeHints);
outputsAccumulator[propertyName] = {
name: propertyName,
title: property.title ?? '',
description: property.description ?? '',
type: fieldType,
};
}
return outputsAccumulator;
},
{} as Record<string, OutputFieldTemplate>
);
return outputFields;
}
return {};
};

View File

@ -0,0 +1,57 @@
import { InputFieldTemplate, InputFieldValue } from '../types/types';
export const buildInputFieldValue = (
id: string,
template: InputFieldTemplate
): InputFieldValue => {
const fieldValue: InputFieldValue = {
id,
name: template.name,
type: template.type,
};
if (template.inputRequirement !== 'never') {
if (template.type === 'string') {
fieldValue.value = template.default ?? '';
}
if (template.type === 'integer') {
fieldValue.value = template.default ?? 0;
}
if (template.type === 'float') {
fieldValue.value = template.default ?? 0;
}
if (template.type === 'boolean') {
fieldValue.value = template.default ?? false;
}
if (template.type === 'enum') {
if (template.enumType === 'number') {
fieldValue.value = template.default ?? 0;
}
if (template.enumType === 'string') {
fieldValue.value = template.default ?? '';
}
}
if (template.type === 'array') {
fieldValue.value = template.default ?? 1;
}
if (template.type === 'image') {
fieldValue.value = undefined;
}
if (template.type === 'latents') {
fieldValue.value = undefined;
}
if (template.type === 'model') {
fieldValue.value = undefined;
}
}
return fieldValue;
};

View File

@ -0,0 +1,39 @@
import {
Edge,
ImageToImageInvocation,
IterateInvocation,
RandomRangeInvocation,
RangeInvocation,
TextToImageInvocation,
} from 'services/api';
export const buildEdges = (
baseNode: TextToImageInvocation | ImageToImageInvocation,
rangeNode: RangeInvocation | RandomRangeInvocation,
iterateNode: IterateInvocation
): Edge[] => {
const edges: Edge[] = [
{
source: {
node_id: rangeNode.id,
field: 'collection',
},
destination: {
node_id: iterateNode.id,
field: 'collection',
},
},
{
source: {
node_id: iterateNode.id,
field: 'item',
},
destination: {
node_id: baseNode.id,
field: 'seed',
},
},
];
return edges;
};

View File

@ -0,0 +1,99 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import {
Edge,
ImageToImageInvocation,
TextToImageInvocation,
} from 'services/api';
import { _Image } from 'app/invokeai';
import { initialImageSelector } from 'features/parameters/store/generationSelectors';
export const buildImg2ImgNode = (state: RootState): ImageToImageInvocation => {
const nodeId = uuidv4();
const { generation, system, models } = state;
const { selectedModelName } = models;
const {
prompt,
seed,
steps,
width,
height,
cfgScale,
sampler,
seamless,
img2imgStrength: strength,
shouldFitToWidthHeight: fit,
shouldRandomizeSeed,
} = generation;
const initialImage = initialImageSelector(state);
if (!initialImage) {
// TODO: handle this
throw 'no initial image';
}
const imageToImageNode: ImageToImageInvocation = {
id: nodeId,
type: 'img2img',
prompt,
steps,
width,
height,
cfg_scale: cfgScale,
scheduler: sampler as ImageToImageInvocation['scheduler'],
seamless,
model: selectedModelName,
progress_images: true,
image: {
image_name: initialImage.name,
image_type: initialImage.type,
},
strength,
fit,
};
if (!shouldRandomizeSeed) {
imageToImageNode.seed = seed;
}
return imageToImageNode;
};
type hiresReturnType = {
node: Record<string, ImageToImageInvocation>;
edge: Edge;
};
export const buildHiResNode = (
baseNode: Record<string, TextToImageInvocation>,
strength?: number
): hiresReturnType => {
const nodeId = uuidv4();
const baseNodeId = Object.keys(baseNode)[0];
const baseNodeValues = Object.values(baseNode)[0];
return {
node: {
[nodeId]: {
...baseNodeValues,
id: nodeId,
type: 'img2img',
strength,
fit: true,
},
},
edge: {
source: {
field: 'image',
node_id: baseNodeId,
},
destination: {
field: 'image',
node_id: nodeId,
},
},
};
};

View File

@ -0,0 +1,13 @@
import { v4 as uuidv4 } from 'uuid';
import { IterateInvocation } from 'services/api';
export const buildIterateNode = (): IterateInvocation => {
const nodeId = uuidv4();
return {
id: nodeId,
type: 'iterate',
collection: [],
index: 0,
};
};

View File

@ -0,0 +1,39 @@
import { RootState } from 'app/store';
import { Graph } from 'services/api';
import { buildImg2ImgNode } from './buildImageToImageNode';
import { buildTxt2ImgNode } from './buildTextToImageNode';
import { buildRangeNode } from './buildRangeNode';
import { buildIterateNode } from './buildIterateNode';
import { buildEdges } from './buildEdges';
/**
* Builds the Linear workflow graph.
*/
export const buildLinearGraph = (state: RootState): Graph => {
// The base node is either a txt2img or img2img node
const baseNode = state.generation.isImageToImageEnabled
? buildImg2ImgNode(state)
: buildTxt2ImgNode(state);
// We always range and iterate nodes, no matter the iteration count
// This is required to provide the correct seeds to the backend engine
const rangeNode = buildRangeNode(state);
const iterateNode = buildIterateNode();
// Build the edges for the nodes selected.
const edges = buildEdges(baseNode, rangeNode, iterateNode);
// Assemble!
const graph = {
nodes: {
[rangeNode.id]: rangeNode,
[iterateNode.id]: iterateNode,
[baseNode.id]: baseNode,
},
edges,
};
// TODO: hires fix requires latent space upscaling; we don't have nodes for this yet
return graph;
};

View File

@ -0,0 +1,26 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import { RandomRangeInvocation, RangeInvocation } from 'services/api';
export const buildRangeNode = (
state: RootState
): RangeInvocation | RandomRangeInvocation => {
const nodeId = uuidv4();
const { shouldRandomizeSeed, iterations, seed } = state.generation;
if (shouldRandomizeSeed) {
return {
id: nodeId,
type: 'random_range',
size: iterations,
};
}
return {
id: nodeId,
type: 'range',
start: seed,
stop: seed + iterations,
};
};

View File

@ -0,0 +1,42 @@
import { v4 as uuidv4 } from 'uuid';
import { RootState } from 'app/store';
import { TextToImageInvocation } from 'services/api';
export const buildTxt2ImgNode = (state: RootState): TextToImageInvocation => {
const nodeId = uuidv4();
const { generation, models } = state;
const { selectedModelName } = models;
const {
prompt,
seed,
steps,
width,
height,
cfgScale: cfg_scale,
sampler,
seamless,
shouldRandomizeSeed,
} = generation;
const textToImageNode: NonNullable<TextToImageInvocation> = {
id: nodeId,
type: 'txt2img',
prompt,
steps,
width,
height,
cfg_scale,
scheduler: sampler as TextToImageInvocation['scheduler'],
seamless,
model: selectedModelName,
progress_images: true,
};
if (!shouldRandomizeSeed) {
textToImageNode.seed = seed;
}
return textToImageNode;
};

View File

@ -0,0 +1,77 @@
import { Graph } from 'services/api';
import { v4 as uuidv4 } from 'uuid';
import { reduce } from 'lodash';
import { RootState } from 'app/store';
import { AnyInvocation } from 'services/events/types';
/**
* Builds a graph from the node editor state.
*/
export const buildNodesGraph = (state: RootState): Graph => {
const { nodes, edges } = state.nodes;
// Reduce the node editor nodes into invocation graph nodes
const parsedNodes = nodes.reduce<NonNullable<Graph['nodes']>>(
(nodesAccumulator, node, nodeIndex) => {
const { id, data } = node;
const { type, inputs } = data;
// Transform each node's inputs to simple key-value pairs
const transformedInputs = reduce(
inputs,
(inputsAccumulator, input, name) => {
inputsAccumulator[name] = input.value;
return inputsAccumulator;
},
{} as Record<Exclude<string, 'id' | 'type'>, any>
);
// Build this specific node
const graphNode = {
type,
id,
...transformedInputs,
};
// Add it to the nodes object
Object.assign(nodesAccumulator, {
[id]: graphNode,
});
return nodesAccumulator;
},
{}
);
// Reduce the node editor edges into invocation graph edges
const parsedEdges = edges.reduce<NonNullable<Graph['edges']>>(
(edgesAccumulator, edge, edgeIndex) => {
const { source, target, sourceHandle, targetHandle } = edge;
// Format the edges and add to the edges array
edgesAccumulator.push({
source: {
node_id: source,
field: sourceHandle as string,
},
destination: {
node_id: target,
field: targetHandle as string,
},
});
return edgesAccumulator;
},
[]
);
// Assemble!
const graph = {
id: uuidv4(),
nodes: parsedNodes,
edges: parsedEdges,
};
return graph;
};

View File

@ -0,0 +1,120 @@
import { filter, reduce } from 'lodash';
import { OpenAPIV3 } from 'openapi-types';
import { isSchemaObject } from '../types/typeGuards';
import {
InputFieldTemplate,
InvocationSchemaObject,
InvocationTemplate,
isInvocationSchemaObject,
OutputFieldTemplate,
} from '../types/types';
import {
buildInputFieldTemplate,
buildOutputFieldTemplates,
} from './fieldTemplateBuilders';
const invocationBlacklist = ['Graph', 'Collect', 'LoadImage'];
export const parseSchema = (openAPI: OpenAPIV3.Document) => {
// filter out non-invocation schemas, plus some tricky invocations for now
const filteredSchemas = filter(
openAPI.components!.schemas,
(schema, key) =>
key.includes('Invocation') &&
!key.includes('InvocationOutput') &&
!invocationBlacklist.some((blacklistItem) => key.includes(blacklistItem))
) as (OpenAPIV3.ReferenceObject | InvocationSchemaObject)[];
const invocations = filteredSchemas.reduce<
Record<string, InvocationTemplate>
>((acc, schema) => {
// only want SchemaObjects
if (isInvocationSchemaObject(schema)) {
const type = schema.properties.type.default;
const title =
schema.ui?.title ??
schema.title
.replace('Invocation', '')
.split(/(?=[A-Z])/) // split PascalCase into array
.join(' ');
const typeHints = schema.ui?.type_hints;
const inputs = reduce(
schema.properties,
(inputsAccumulator, property, propertyName) => {
if (
// `type` and `id` are not valid inputs/outputs
!['type', 'id'].includes(propertyName) &&
isSchemaObject(property)
) {
let field: InputFieldTemplate | undefined;
if (propertyName === 'collection') {
field = {
default: property.default ?? [],
name: 'collection',
title: property.title ?? '',
description: property.description ?? '',
type: 'array',
inputRequirement: 'always',
inputKind: 'connection',
};
} else {
field = buildInputFieldTemplate(
property,
propertyName,
typeHints
);
}
if (field) {
inputsAccumulator[propertyName] = field;
}
}
return inputsAccumulator;
},
{} as Record<string, InputFieldTemplate>
);
const rawOutput = (schema as InvocationSchemaObject).output;
let outputs: Record<string, OutputFieldTemplate>;
// some special handling is needed for collect, iterate and range nodes
if (type === 'iterate') {
// this is guaranteed to be a SchemaObject
const iterationOutput = openAPI.components!.schemas![
'IterateInvocationOutput'
] as OpenAPIV3.SchemaObject;
outputs = {
item: {
name: 'item',
title: iterationOutput.title ?? '',
description: iterationOutput.description ?? '',
type: 'array',
},
};
} else {
outputs = buildOutputFieldTemplates(rawOutput, openAPI, typeHints);
}
const invocation: InvocationTemplate = {
title,
type,
tags: schema.ui?.tags ?? [],
description: schema.description ?? '',
inputs,
outputs,
};
Object.assign(acc, { [type]: invocation });
}
return acc;
}, {});
console.debug('Generated invocations: ', invocations);
return invocations;
};

View File

@ -6,19 +6,18 @@ import {
Box,
Flex,
} from '@chakra-ui/react';
import { Feature } from 'app/features';
import GuideIcon from 'common/components/GuideIcon';
import { ReactNode } from 'react';
import { ParametersAccordionItem } from '../ParametersAccordion';
export interface InvokeAccordionItemProps {
header: string;
content: ReactNode;
feature?: Feature;
additionalHeaderComponents?: ReactNode;
}
type InvokeAccordionItemProps = {
accordionItem: ParametersAccordionItem;
};
export default function InvokeAccordionItem(props: InvokeAccordionItemProps) {
const { header, feature, content, additionalHeaderComponents } = props;
export default function InvokeAccordionItem({
accordionItem,
}: InvokeAccordionItemProps) {
const { header, feature, content, additionalHeaderComponents } =
accordionItem;
return (
<AccordionItem>
@ -32,7 +31,7 @@ export default function InvokeAccordionItem(props: InvokeAccordionItemProps) {
<AccordionIcon />
</Flex>
</AccordionButton>
<AccordionPanel>{content}</AccordionPanel>
<AccordionPanel p={4}>{content}</AccordionPanel>
</AccordionItem>
);
}

View File

@ -115,9 +115,7 @@ const InfillAndScalingSettings = () => {
onChange={handleChangeBoundingBoxScaleMethod}
/>
<IAISlider
isInputDisabled={!isManual}
isResetDisabled={!isManual}
isSliderDisabled={!isManual}
isDisabled={!isManual}
label={t('parameters.scaledWidth')}
min={64}
max={1024}
@ -132,9 +130,7 @@ const InfillAndScalingSettings = () => {
handleReset={handleResetScaledWidth}
/>
<IAISlider
isInputDisabled={!isManual}
isResetDisabled={!isManual}
isSliderDisabled={!isManual}
isDisabled={!isManual}
label={t('parameters.scaledHeight')}
min={64}
max={1024}
@ -155,9 +151,7 @@ const InfillAndScalingSettings = () => {
onChange={(e) => dispatch(setInfillMethod(e.target.value))}
/>
<IAISlider
isInputDisabled={infillMethod !== 'tile'}
isResetDisabled={infillMethod !== 'tile'}
isSliderDisabled={infillMethod !== 'tile'}
isDisabled={infillMethod !== 'tile'}
label={t('parameters.tileSize')}
min={16}
max={64}

View File

@ -18,9 +18,7 @@ export default function CodeformerFidelity() {
return (
<IAISlider
isSliderDisabled={!isGFPGANAvailable}
isInputDisabled={!isGFPGANAvailable}
isResetDisabled={!isGFPGANAvailable}
isDisabled={!isGFPGANAvailable}
label={t('parameters.codeformerFidelity')}
step={0.05}
min={0}

View File

@ -18,9 +18,7 @@ export default function FaceRestoreStrength() {
return (
<IAISlider
isSliderDisabled={!isGFPGANAvailable}
isInputDisabled={!isGFPGANAvailable}
isResetDisabled={!isGFPGANAvailable}
isDisabled={!isGFPGANAvailable}
label={t('parameters.strength')}
step={0.05}
min={0}

View File

@ -12,6 +12,10 @@ export default function ImageFit() {
(state: RootState) => state.generation.shouldFitToWidthHeight
);
const isImageToImageEnabled = useAppSelector(
(state: RootState) => state.generation.isImageToImageEnabled
);
const handleChangeFit = (e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldFitToWidthHeight(e.target.checked));
@ -19,6 +23,7 @@ export default function ImageFit() {
return (
<IAISwitch
isDisabled={!isImageToImageEnabled}
label={t('parameters.imageFit')}
isChecked={shouldFitToWidthHeight}
onChange={handleChangeFit}

View File

@ -1,13 +1,15 @@
import { VStack } from '@chakra-ui/react';
import { Flex, Image, VStack } from '@chakra-ui/react';
import ImageFit from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageFit';
import ImageToImageStrength from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageToImageStrength';
import { useTranslation } from 'react-i18next';
import InitialImagePreview from './InitialImagePreview';
export default function ImageToImageSettings() {
const { t } = useTranslation();
return (
<VStack gap={2} alignItems="stretch">
<InitialImagePreview />
<ImageToImageStrength label={t('parameters.img2imgStrength')} />
<ImageFit />
</VStack>

Some files were not shown because too many files have changed in this diff Show More