Merge branch 'invoke-ai:main' into main

This commit is contained in:
Millun Atluri 2023-09-27 10:10:52 +10:00 committed by GitHub
commit f35dfa06bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 462 additions and 734 deletions

View File

@ -332,6 +332,7 @@ class InvokeAiInstance:
Configure the InvokeAI runtime directory
"""
auto_install = False
# set sys.argv to a consistent state
new_argv = [sys.argv[0]]
for i in range(1, len(sys.argv)):
@ -340,13 +341,17 @@ class InvokeAiInstance:
new_argv.append(el)
new_argv.append(sys.argv[i + 1])
elif el in ["-y", "--yes", "--yes-to-all"]:
new_argv.append(el)
auto_install = True
sys.argv = new_argv
import messages
import requests # to catch download exceptions
from messages import introduction
introduction()
auto_install = auto_install or messages.user_wants_auto_configuration()
if auto_install:
sys.argv.append("--yes")
else:
messages.introduction()
from invokeai.frontend.install.invokeai_configure import invokeai_configure

View File

@ -7,7 +7,7 @@ import os
import platform
from pathlib import Path
from prompt_toolkit import prompt
from prompt_toolkit import HTML, prompt
from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.validation import Validator
from rich import box, print
@ -65,17 +65,50 @@ def confirm_install(dest: Path) -> bool:
if dest.exists():
print(f":exclamation: Directory {dest} already exists :exclamation:")
dest_confirmed = Confirm.ask(
":stop_sign: Are you sure you want to (re)install in this location?",
":stop_sign: (re)install in this location?",
default=False,
)
else:
print(f"InvokeAI will be installed in {dest}")
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
dest_confirmed = Confirm.ask("Use this location?", default=True)
console.line()
return dest_confirmed
def user_wants_auto_configuration() -> bool:
"""Prompt the user to choose between manual and auto configuration."""
console.rule("InvokeAI Configuration Section")
console.print(
Panel(
Group(
"\n".join(
[
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
"",
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
"",
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
]
),
),
box=box.MINIMAL,
padding=(1, 1),
)
)
choice = (
prompt(
HTML("Choose <b>&lt;a&gt;</b>utomatic or <b>&lt;m&gt;</b>anual configuration [a/m] (a): "),
validator=Validator.from_callable(
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
),
)
or "a"
)
return choice.lower().startswith("a")
def dest_path(dest=None) -> Path:
"""
Prompt the user for the destination path and create the path

View File

@ -241,7 +241,7 @@ class InvokeAIAppConfig(InvokeAISettings):
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
# CACHE
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )

View File

@ -1,7 +1,6 @@
from collections import OrderedDict
from dataclasses import dataclass, field
from threading import Lock
from time import time
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
@ -59,7 +58,7 @@ class MemoryInvocationCache(InvocationCacheBase):
# If the cache is full, we need to remove the least used
number_to_delete = len(self._cache) + 1 - self._max_cache_size
self._delete_oldest_access(number_to_delete)
self._cache[key] = CachedItem(time(), invocation_output, invocation_output.json())
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
def _delete_oldest_access(self, number_to_delete: int) -> None:
number_to_delete = min(number_to_delete, len(self._cache))

View File

@ -70,7 +70,6 @@ def get_literal_fields(field) -> list[Any]:
config = InvokeAIAppConfig.get_config()
Model_dir = "models"
Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path
@ -458,7 +457,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
)
self.add_widget_intelligent(
npyscreen.TitleFixedText,
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.",
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
begin_entry_at=0,
editable=False,
color="CONTROL",
@ -651,8 +650,19 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
return editApp.new_opts()
def default_ramcache() -> float:
"""Run a heuristic for the default RAM cache based on installed RAM."""
# Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
# So we adjust everthing down a bit.
return (
15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
) # 2.1 is just large enough for sd 1.5 ;-)
def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig.get_config()
opts.ram = default_ramcache()
return opts

View File

@ -58,6 +58,7 @@
"githubLabel": "Github",
"hotkeysLabel": "Hotkeys",
"imagePrompt": "Image Prompt",
"imageFailedToLoad": "Unable to Load Image",
"img2img": "Image To Image",
"langArabic": "العربية",
"langBrPortuguese": "Português do Brasil",
@ -716,6 +717,7 @@
"cannotConnectInputToInput": "Cannot connect input to input",
"cannotConnectOutputToOutput": "Cannot connect output to output",
"cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections",
"clipField": "Clip",
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
"collection": "Collection",

View File

@ -5,7 +5,7 @@ import {
} from '@chakra-ui/react';
import { ReactNode, memo, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme } from 'theme/theme';
import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme';
import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core';
@ -39,7 +39,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
return (
<MantineProvider theme={mantineTheme}>
<ChakraProvider theme={theme} colorModeManager={manager}>
<ChakraProvider
theme={theme}
colorModeManager={manager}
toastOptions={TOAST_OPTIONS}
>
{children}
</ChakraProvider>
</MantineProvider>

View File

@ -54,21 +54,6 @@ import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded';
import { addDynamicPromptsListener } from './listeners/promptChanged';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import {
addSessionCanceledFulfilledListener,
addSessionCanceledPendingListener,
addSessionCanceledRejectedListener,
} from './listeners/sessionCanceled';
import {
addSessionCreatedFulfilledListener,
addSessionCreatedPendingListener,
addSessionCreatedRejectedListener,
} from './listeners/sessionCreated';
import {
addSessionInvokedFulfilledListener,
addSessionInvokedPendingListener,
addSessionInvokedRejectedListener,
} from './listeners/sessionInvoked';
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
@ -86,6 +71,7 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
export const listenerMiddleware = createListenerMiddleware();
@ -136,6 +122,7 @@ addEnqueueRequestedCanvasListener();
addEnqueueRequestedNodes();
addEnqueueRequestedLinear();
addAnyEnqueuedListener();
addBatchEnqueuedListener();
// Canvas actions
addCanvasSavedToGalleryListener();
@ -175,21 +162,6 @@ addSessionRetrievalErrorEventListener();
addInvocationRetrievalErrorEventListener();
addSocketQueueItemStatusChangedEventListener();
// Session Created
addSessionCreatedPendingListener();
addSessionCreatedFulfilledListener();
addSessionCreatedRejectedListener();
// Session Invoked
addSessionInvokedPendingListener();
addSessionInvokedFulfilledListener();
addSessionInvokedRejectedListener();
// Session Canceled
addSessionCanceledPendingListener();
addSessionCanceledFulfilledListener();
addSessionCanceledRejectedListener();
// ControlNet
addControlNetImageProcessedListener();
addControlNetAutoProcessListener();

View File

@ -0,0 +1,96 @@
import { createStandaloneToast } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { t } from 'i18next';
import { get, truncate, upperFirst } from 'lodash-es';
import { queueApi } from 'services/api/endpoints/queue';
import { TOAST_OPTIONS, theme } from 'theme/theme';
import { startAppListening } from '..';
const { toast } = createStandaloneToast({
theme: theme,
defaultOptions: TOAST_OPTIONS.defaultOptions,
});
export const addBatchEnqueuedListener = () => {
// success
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: async (action) => {
const response = action.payload;
const arg = action.meta.arg.originalArgs;
logger('queue').debug(
{ enqueueResult: parseify(response) },
'Batch enqueued'
);
if (!toast.isActive('batch-queued')) {
toast({
id: 'batch-queued',
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: response.enqueued,
direction: arg.prepend ? t('queue.front') : t('queue.back'),
}),
duration: 1000,
status: 'success',
});
}
},
});
// error
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
effect: async (action) => {
const response = action.payload;
const arg = action.meta.arg.originalArgs;
if (!response) {
toast({
title: t('queue.batchFailedToQueue'),
status: 'error',
description: 'Unknown Error',
});
logger('queue').error(
{ batchConfig: parseify(arg), error: parseify(response) },
t('queue.batchFailedToQueue')
);
return;
}
const result = zPydanticValidationError.safeParse(response);
if (result.success) {
result.data.data.detail.map((e) => {
toast({
id: 'batch-failed-to-queue',
title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error',
description: truncate(
`Path:
${e.loc.join('.')}`,
{ length: 128 }
),
});
});
} else {
let detail = 'Unknown Error';
if (response.status === 403 && 'body' in response) {
detail = get(response, 'body.detail', 'Unknown Error');
} else if (response.status === 403 && 'error' in response) {
detail = get(response, 'error.detail', 'Unknown Error');
}
toast({
title: t('queue.batchFailedToQueue'),
status: 'error',
description: detail,
});
}
logger('queue').error(
{ batchConfig: parseify(arg), error: parseify(response) },
t('queue.batchFailedToQueue')
);
},
});
};

View File

@ -12,8 +12,6 @@ import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGeneratio
import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import { ImageDTO } from 'services/api/types';
@ -140,8 +138,6 @@ export const addEnqueueRequestedCanvasListener = () => {
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
// Prep the canvas staging area if it is not yet initialized
@ -158,28 +154,8 @@ export const addEnqueueRequestedCanvasListener = () => {
// Associate the session with the canvas session ID
dispatch(canvasBatchIdAdded(batchId));
dispatch(
addToast({
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: enqueueResult.enqueued,
direction: prepend ? t('queue.front') : t('queue.back'),
}),
status: 'success',
})
);
} catch {
log.error(
{ batchConfig: parseify(batchConfig) },
t('queue.batchFailedToQueue')
);
dispatch(
addToast({
title: t('queue.batchFailedToQueue'),
status: 'error',
})
);
// no-op
}
},
});

View File

@ -1,13 +1,9 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import { parseify } from 'common/util/serialize';
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..';
@ -18,7 +14,6 @@ export const addEnqueueRequestedLinear = () => {
(action.payload.tabName === 'txt2img' ||
action.payload.tabName === 'img2img'),
effect: async (action, { getState, dispatch }) => {
const log = logger('queue');
const state = getState();
const model = state.generation.model;
const { prepend } = action.payload;
@ -41,38 +36,12 @@ export const addEnqueueRequestedLinear = () => {
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
dispatch(
addToast({
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: enqueueResult.enqueued,
direction: prepend ? t('queue.front') : t('queue.back'),
}),
status: 'success',
})
);
} catch {
log.error(
{ batchConfig: parseify(batchConfig) },
t('queue.batchFailedToQueue')
);
dispatch(
addToast({
title: t('queue.batchFailedToQueue'),
status: 'error',
})
);
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
req.reset();
},
});
};

View File

@ -1,9 +1,5 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import { parseify } from 'common/util/serialize';
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { BatchConfig } from 'services/api/types';
import { startAppListening } from '..';
@ -13,9 +9,7 @@ export const addEnqueueRequestedNodes = () => {
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'nodes',
effect: async (action, { getState, dispatch }) => {
const log = logger('queue');
const state = getState();
const { prepend } = action.payload;
const graph = buildNodesGraph(state.nodes);
const batchConfig: BatchConfig = {
batch: {
@ -25,38 +19,12 @@ export const addEnqueueRequestedNodes = () => {
prepend: action.payload.prepend,
};
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
dispatch(
addToast({
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: enqueueResult.enqueued,
direction: prepend ? t('queue.front') : t('queue.back'),
}),
status: 'success',
})
);
} catch {
log.error(
{ batchConfig: parseify(batchConfig) },
'Failed to enqueue batch'
);
dispatch(
addToast({
title: t('queue.batchFailedToQueue'),
status: 'error',
})
);
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
req.reset();
},
});
};

View File

@ -1,44 +0,0 @@
import { logger } from 'app/logging/logger';
import { serializeError } from 'serialize-error';
import { sessionCanceled } from 'services/api/thunks/session';
import { startAppListening } from '..';
export const addSessionCanceledPendingListener = () => {
startAppListening({
actionCreator: sessionCanceled.pending,
effect: () => {
//
},
});
};
export const addSessionCanceledFulfilledListener = () => {
startAppListening({
actionCreator: sessionCanceled.fulfilled,
effect: (action) => {
const log = logger('session');
const { session_id } = action.meta.arg;
log.debug({ session_id }, `Session canceled (${session_id})`);
},
});
};
export const addSessionCanceledRejectedListener = () => {
startAppListening({
actionCreator: sessionCanceled.rejected,
effect: (action) => {
const log = logger('session');
const { session_id } = action.meta.arg;
if (action.payload) {
const { error } = action.payload;
log.error(
{
session_id,
error: serializeError(error),
},
`Problem canceling session`
);
}
},
});
};

View File

@ -1,45 +0,0 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { serializeError } from 'serialize-error';
import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..';
export const addSessionCreatedPendingListener = () => {
startAppListening({
actionCreator: sessionCreated.pending,
effect: () => {
//
},
});
};
export const addSessionCreatedFulfilledListener = () => {
startAppListening({
actionCreator: sessionCreated.fulfilled,
effect: (action) => {
const log = logger('session');
const session = action.payload;
log.debug(
{ session: parseify(session) },
`Session created (${session.id})`
);
},
});
};
export const addSessionCreatedRejectedListener = () => {
startAppListening({
actionCreator: sessionCreated.rejected,
effect: (action) => {
const log = logger('session');
if (action.payload) {
const { error, status } = action.payload;
const graph = parseify(action.meta.arg);
log.error(
{ graph, status, error: serializeError(error) },
`Problem creating session`
);
}
},
});
};

View File

@ -1,44 +0,0 @@
import { logger } from 'app/logging/logger';
import { serializeError } from 'serialize-error';
import { sessionInvoked } from 'services/api/thunks/session';
import { startAppListening } from '..';
export const addSessionInvokedPendingListener = () => {
startAppListening({
actionCreator: sessionInvoked.pending,
effect: () => {
//
},
});
};
export const addSessionInvokedFulfilledListener = () => {
startAppListening({
actionCreator: sessionInvoked.fulfilled,
effect: (action) => {
const log = logger('session');
const { session_id } = action.meta.arg;
log.debug({ session_id }, `Session invoked (${session_id})`);
},
});
};
export const addSessionInvokedRejectedListener = () => {
startAppListening({
actionCreator: sessionInvoked.rejected,
effect: (action) => {
const log = logger('session');
const { session_id } = action.meta.arg;
if (action.payload) {
const { error } = action.payload;
log.error(
{
session_id,
error: serializeError(error),
},
`Problem invoking session`
);
}
},
});
};

View File

@ -35,6 +35,7 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
{ type: 'SessionQueueItemDTO', id: item_id },
{ type: 'BatchStatus', id: queue_batch_id },

View File

@ -1,54 +0,0 @@
import { logger } from 'app/logging/logger';
import { AppThunkDispatch } from 'app/store/store';
import { parseify } from 'common/util/serialize';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { BatchConfig } from 'services/api/types';
export const enqueueBatch = async (
batchConfig: BatchConfig,
dispatch: AppThunkDispatch
) => {
const log = logger('session');
const { prepend } = batchConfig;
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'resumeProcessor',
})
);
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
dispatch(
addToast({
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: enqueueResult.enqueued,
direction: prepend ? t('queue.front') : t('queue.back'),
}),
status: 'success',
})
);
} catch {
log.error(
{ batchConfig: parseify(batchConfig) },
t('queue.batchFailedToQueue')
);
dispatch(
addToast({
title: t('queue.batchFailedToQueue'),
status: 'error',
})
);
}
};

View File

@ -1,18 +1,9 @@
import { chakra, ChakraProps } from '@chakra-ui/react';
import { Box, ChakraProps } from '@chakra-ui/react';
import { memo } from 'react';
import { RgbaColorPicker } from 'react-colorful';
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
type IAIColorPickerProps = Omit<ColorPickerBaseProps<RgbaColor>, 'color'> &
ChakraProps & {
pickerColor: RgbaColor;
styleClass?: string;
};
const ChakraRgbaColorPicker = chakra(RgbaColorPicker, {
baseStyle: { paddingInline: 4 },
shouldForwardProp: (prop) => !['pickerColor'].includes(prop),
});
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor>;
const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
width: 6,
@ -20,19 +11,17 @@ const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
borderColor: 'base.100',
};
const IAIColorPicker = (props: IAIColorPickerProps) => {
const { styleClass = '', ...rest } = props;
const sx = {
'.react-colorful__hue-pointer': colorPickerStyles,
'.react-colorful__saturation-pointer': colorPickerStyles,
'.react-colorful__alpha-pointer': colorPickerStyles,
};
const IAIColorPicker = (props: IAIColorPickerProps) => {
return (
<ChakraRgbaColorPicker
sx={{
'.react-colorful__hue-pointer': colorPickerStyles,
'.react-colorful__saturation-pointer': colorPickerStyles,
'.react-colorful__alpha-pointer': colorPickerStyles,
}}
className={styleClass}
{...rest}
/>
<Box sx={sx}>
<RgbaColorPicker {...props} />
</Box>
);
};

View File

@ -1,26 +1,27 @@
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { Image, Rect } from 'react-konva';
import { memo } from 'react';
import { Image } from 'react-konva';
import { $authToken } from 'services/api/client';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import useImage from 'use-image';
import { CanvasImage } from '../store/canvasTypes';
import { $authToken } from 'services/api/client';
import { memo } from 'react';
import IAICanvasImageErrorFallback from './IAICanvasImageErrorFallback';
type IAICanvasImageProps = {
canvasImage: CanvasImage;
};
const IAICanvasImage = (props: IAICanvasImageProps) => {
const { width, height, x, y, imageName } = props.canvasImage;
const { x, y, imageName } = props.canvasImage;
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
imageName ?? skipToken
);
const [image] = useImage(
const [image, status] = useImage(
imageDTO?.image_url ?? '',
$authToken.get() ? 'use-credentials' : 'anonymous'
);
if (isError) {
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
if (isError || status === 'failed') {
return <IAICanvasImageErrorFallback canvasImage={props.canvasImage} />;
}
return <Image x={x} y={y} image={image} listening={false} />;

View File

@ -0,0 +1,44 @@
import { useColorModeValue, useToken } from '@chakra-ui/react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { Group, Rect, Text } from 'react-konva';
import { CanvasImage } from '../store/canvasTypes';
type IAICanvasImageErrorFallbackProps = {
canvasImage: CanvasImage;
};
const IAICanvasImageErrorFallback = ({
canvasImage,
}: IAICanvasImageErrorFallbackProps) => {
const [errorColorLight, errorColorDark, fontColorLight, fontColorDark] =
useToken('colors', ['gray.400', 'gray.500', 'base.700', 'base.900']);
const errorColor = useColorModeValue(errorColorLight, errorColorDark);
const fontColor = useColorModeValue(fontColorLight, fontColorDark);
const { t } = useTranslation();
return (
<Group>
<Rect
x={canvasImage.x}
y={canvasImage.y}
width={canvasImage.width}
height={canvasImage.height}
fill={errorColor}
/>
<Text
x={canvasImage.x}
y={canvasImage.y}
width={canvasImage.width}
height={canvasImage.height}
align="center"
verticalAlign="middle"
fontFamily='"Inter Variable", sans-serif'
fontSize={canvasImage.width / 16}
fontStyle="600"
text={t('common.imageFailedToLoad')}
fill={fontColor}
/>
</Group>
);
};
export default memo(IAICanvasImageErrorFallback);

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
@ -135,11 +135,12 @@ const IAICanvasMaskOptions = () => {
dispatch(setShouldPreserveMaskedArea(e.target.checked))
}
/>
<IAIColorPicker
sx={{ paddingTop: 2, paddingBottom: 2 }}
pickerColor={maskColor}
onChange={(newColor) => dispatch(setMaskColor(newColor))}
/>
<Box sx={{ paddingTop: 2, paddingBottom: 2 }}>
<IAIColorPicker
color={maskColor}
onChange={(newColor) => dispatch(setMaskColor(newColor))}
/>
</Box>
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
Save Mask
</IAIButton>

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex } from '@chakra-ui/react';
import { ButtonGroup, Flex, Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -237,15 +237,18 @@ const IAICanvasToolChooserOptions = () => {
sliderNumberInputProps={{ max: 500 }}
/>
</Flex>
<IAIColorPicker
<Box
sx={{
width: '100%',
paddingTop: 2,
paddingBottom: 2,
}}
pickerColor={brushColor}
onChange={(newColor) => dispatch(setBrushColor(newColor))}
/>
>
<IAIColorPicker
color={brushColor}
onChange={(newColor) => dispatch(setBrushColor(newColor))}
/>
</Box>
</Flex>
</IAIPopover>
</ButtonGroup>

View File

@ -8,7 +8,6 @@ import { setAspectRatio } from 'features/parameters/store/generationSlice';
import { IRect, Vector2d } from 'konva/lib/types';
import { clamp, cloneDeep } from 'lodash-es';
import { RgbaColor } from 'react-colorful';
import { sessionCanceled } from 'services/api/thunks/session';
import { ImageDTO } from 'services/api/types';
import calculateCoordinates from '../util/calculateCoordinates';
import calculateScale from '../util/calculateScale';
@ -786,11 +785,6 @@ export const canvasSlice = createSlice({
},
},
extraReducers: (builder) => {
builder.addCase(sessionCanceled.pending, (state) => {
if (!state.layerState.stagingArea.images.length) {
state.layerState.stagingArea = initialLayerState.stagingArea;
}
});
builder.addCase(setAspectRatio, (state, action) => {
const ratio = action.payload;
if (ratio) {

View File

@ -6,7 +6,6 @@ import {
import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { ImageDTO } from 'services/api/types';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
@ -418,10 +417,6 @@ export const controlNetSlice = createSlice({
state.pendingControlImages = [];
});
builder.addMatcher(isAnySessionRejected, (state) => {
state.pendingControlImages = [];
});
builder.addMatcher(
imagesApi.endpoints.deleteImage.matchFulfilled,
(state, action) => {

View File

@ -12,6 +12,7 @@ import {
OnConnect,
OnConnectEnd,
OnConnectStart,
OnEdgeUpdateFunc,
OnEdgesChange,
OnEdgesDelete,
OnInit,
@ -21,6 +22,7 @@ import {
OnSelectionChangeFunc,
ProOptions,
ReactFlow,
ReactFlowProps,
XYPosition,
} from 'reactflow';
import { useIsValidConnection } from '../../hooks/useIsValidConnection';
@ -28,6 +30,8 @@ import {
connectionEnded,
connectionMade,
connectionStarted,
edgeAdded,
edgeDeleted,
edgesChanged,
edgesDeleted,
nodesChanged,
@ -167,6 +171,63 @@ export const Flow = () => {
}
}, []);
// #region Updatable Edges
/**
* Adapted from https://reactflow.dev/docs/examples/edges/updatable-edge/
* and https://reactflow.dev/docs/examples/edges/delete-edge-on-drop/
*
* - Edges can be dragged from one handle to another.
* - If the user drags the edge away from the node and drops it, delete the edge.
* - Do not delete the edge if the cursor didn't move (resolves annoying behaviour
* where the edge is deleted if you click it accidentally).
*/
// We have a ref for cursor position, but it is the *projected* cursor position.
// Easiest to just keep track of the last mouse event for this particular feature
const edgeUpdateMouseEvent = useRef<MouseEvent>();
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> =
useCallback(
(e, edge, _handleType) => {
// update mouse event
edgeUpdateMouseEvent.current = e;
// always delete the edge when starting an updated
dispatch(edgeDeleted(edge.id));
},
[dispatch]
);
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
(_oldEdge, newConnection) => {
// instead of updating the edge (we deleted it earlier), we instead create
// a new one.
dispatch(connectionMade(newConnection));
},
[dispatch]
);
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> =
useCallback(
(e, edge, _handleType) => {
// Handle the case where user begins a drag but didn't move the cursor -
// bc we deleted the edge, we need to add it back
if (
// ignore touch events
!('touches' in e) &&
edgeUpdateMouseEvent.current?.clientX === e.clientX &&
edgeUpdateMouseEvent.current?.clientY === e.clientY
) {
dispatch(edgeAdded(edge));
}
// reset mouse event
edgeUpdateMouseEvent.current = undefined;
},
[dispatch]
);
// #endregion
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
e.preventDefault();
dispatch(selectionCopied());
@ -196,6 +257,9 @@ export const Flow = () => {
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete}
onEdgeUpdate={onEdgeUpdate}
onEdgeUpdateStart={onEdgeUpdateStart}
onEdgeUpdateEnd={onEdgeUpdateEnd}
onNodesDelete={onNodesDelete}
onConnectStart={onConnectStart}
onConnect={onConnect}

View File

@ -53,13 +53,12 @@ export const useIsValidConnection = () => {
}
if (
edges
.filter((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
})
.find((edge) => {
edge.source === source && edge.sourceHandle === sourceHandle;
})
edges.find((edge) => {
edge.target === target &&
edge.targetHandle === targetHandle &&
edge.source === source &&
edge.sourceHandle === sourceHandle;
})
) {
// We already have a connection from this source to this target
return false;

View File

@ -15,6 +15,7 @@ import {
NodeChange,
OnConnectStartParams,
SelectionMode,
updateEdge,
Viewport,
XYPosition,
} from 'reactflow';
@ -182,6 +183,16 @@ const nodesSlice = createSlice({
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
},
edgeAdded: (state, action: PayloadAction<Edge>) => {
state.edges = addEdge(action.payload, state.edges);
},
edgeUpdated: (
state,
action: PayloadAction<{ oldEdge: Edge; newConnection: Connection }>
) => {
const { oldEdge, newConnection } = action.payload;
state.edges = updateEdge(oldEdge, newConnection, state.edges);
},
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
state.connectionStartParams = action.payload;
const { nodeId, handleId, handleType } = action.payload;
@ -366,6 +377,7 @@ const nodesSlice = createSlice({
target: edge.target,
type: 'collapsed',
data: { count: 1 },
updatable: false,
});
}
}
@ -388,6 +400,7 @@ const nodesSlice = createSlice({
target: edge.target,
type: 'collapsed',
data: { count: 1 },
updatable: false,
});
}
}
@ -400,6 +413,9 @@ const nodesSlice = createSlice({
}
}
},
edgeDeleted: (state, action: PayloadAction<string>) => {
state.edges = state.edges.filter((e) => e.id !== action.payload);
},
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
const edges = action.payload;
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
@ -890,69 +906,72 @@ const nodesSlice = createSlice({
});
export const {
nodesChanged,
edgesChanged,
nodeAdded,
nodesDeleted,
addNodePopoverClosed,
addNodePopoverOpened,
addNodePopoverToggled,
connectionEnded,
connectionMade,
connectionStarted,
connectionEnded,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
nodeTemplatesBuilt,
nodeEditorReset,
imageCollectionFieldValueChanged,
fieldStringValueChanged,
fieldNumberValueChanged,
edgeDeleted,
edgesChanged,
edgesDeleted,
edgeUpdated,
fieldBoardValueChanged,
fieldBooleanValueChanged,
fieldImageValueChanged,
fieldColorValueChanged,
fieldMainModelValueChanged,
fieldVaeModelValueChanged,
fieldLoRAModelValueChanged,
fieldEnumModelValueChanged,
fieldControlNetModelValueChanged,
fieldEnumModelValueChanged,
fieldImageValueChanged,
fieldIPAdapterModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
imageCollectionFieldValueChanged,
mouseOverFieldChanged,
mouseOverNodeChanged,
nodeAdded,
nodeEditorReset,
nodeEmbedWorkflowChanged,
nodeExclusivelySelected,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
nodeLabelChanged,
nodeNotesChanged,
edgesDeleted,
shouldValidateGraphChanged,
shouldAnimateEdgesChanged,
nodeOpacityChanged,
shouldSnapToGridChanged,
shouldColorEdgesChanged,
selectedNodesChanged,
selectedEdgesChanged,
workflowNameChanged,
workflowDescriptionChanged,
workflowTagsChanged,
workflowAuthorChanged,
workflowNotesChanged,
workflowVersionChanged,
workflowContactChanged,
workflowLoaded,
nodesChanged,
nodesDeleted,
nodeTemplatesBuilt,
nodeUseCacheChanged,
notesNodeValueChanged,
selectedAll,
selectedEdgesChanged,
selectedNodesChanged,
selectionCopied,
selectionModeChanged,
selectionPasted,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
viewportChanged,
workflowAuthorChanged,
workflowContactChanged,
workflowDescriptionChanged,
workflowExposedFieldAdded,
workflowExposedFieldRemoved,
fieldLabelChanged,
viewportChanged,
mouseOverFieldChanged,
selectionCopied,
selectionPasted,
selectedAll,
addNodePopoverOpened,
addNodePopoverClosed,
addNodePopoverToggled,
selectionModeChanged,
nodeEmbedWorkflowChanged,
nodeIsIntermediateChanged,
mouseOverNodeChanged,
nodeExclusivelySelected,
nodeUseCacheChanged,
workflowLoaded,
workflowNameChanged,
workflowNotesChanged,
workflowTagsChanged,
workflowVersionChanged,
edgeAdded,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@ -55,9 +55,29 @@ export const makeConnectionErrorSelector = (
return i18n.t('nodes.cannotConnectInputToInput');
}
// we have to figure out which is the target and which is the source
const target = handleType === 'target' ? nodeId : connectionNodeId;
const targetHandle =
handleType === 'target' ? fieldName : connectionFieldName;
const source = handleType === 'source' ? nodeId : connectionNodeId;
const sourceHandle =
handleType === 'source' ? fieldName : connectionFieldName;
if (
edges.find((edge) => {
return edge.target === nodeId && edge.targetHandle === fieldName;
edge.target === target &&
edge.targetHandle === targetHandle &&
edge.source === source &&
edge.sourceHandle === sourceHandle;
})
) {
// We already have a connection from this source to this target
return i18n.t('nodes.cannotDuplicateConnection');
}
if (
edges.find((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
}) &&
// except CollectionItem inputs can have multiples
targetType !== 'CollectionItem'

View File

@ -1,9 +1,7 @@
import { ButtonGroup } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import ClearInvocationCacheButton from './ClearInvocationCacheButton';
import ToggleInvocationCacheButton from './ToggleInvocationCacheButton';
import StatusStatGroup from './common/StatusStatGroup';
@ -11,16 +9,7 @@ import StatusStatItem from './common/StatusStatItem';
const InvocationCacheStatus = () => {
const { t } = useTranslation();
const isConnected = useAppSelector((state) => state.system.isConnected);
const { data: queueStatus } = useGetQueueStatusQuery(undefined);
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined, {
pollingInterval:
isConnected &&
queueStatus?.processor.is_started &&
queueStatus?.queue.pending > 0
? 5000
: 0,
});
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined);
return (
<StatusStatGroup>

View File

@ -1,9 +1,8 @@
import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { t } from 'i18next';
import { get, startCase, truncate, upperFirst } from 'lodash-es';
import { startCase } from 'lodash-es';
import { LogLevelName } from 'roarr';
import { isAnySessionRejected } from 'services/api/thunks/session';
import {
appSocketConnected,
appSocketDisconnected,
@ -20,8 +19,7 @@ import {
} from 'services/events/actions';
import { calculateStepPercentage } from '../util/calculateStepPercentage';
import { makeToast } from '../util/makeToast';
import { SystemState, LANGUAGES } from './types';
import { zPydanticValidationError } from './zodSchemas';
import { LANGUAGES, SystemState } from './types';
export const initialSystemState: SystemState = {
isInitialized: false,
@ -175,50 +173,6 @@ export const systemSlice = createSlice({
// *** Matchers - must be after all cases ***
/**
* Session Invoked - REJECTED
* Session Created - REJECTED
*/
builder.addMatcher(isAnySessionRejected, (state, action) => {
let errorDescription = undefined;
const duration = 5000;
if (action.payload?.status === 422) {
const result = zPydanticValidationError.safeParse(action.payload);
if (result.success) {
result.data.error.detail.map((e) => {
state.toastQueue.push(
makeToast({
title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error',
description: truncate(
`Path:
${e.loc.join('.')}`,
{ length: 128 }
),
duration,
})
);
});
return;
}
} else if (action.payload?.error) {
errorDescription = action.payload?.error;
}
state.toastQueue.push(
makeToast({
title: t('toast.serverError'),
status: 'error',
description: truncate(
get(errorDescription, 'detail', 'Unknown Error'),
{ length: 128 }
),
duration,
})
);
});
/**
* Any server error
*/

View File

@ -2,7 +2,7 @@ import { z } from 'zod';
export const zPydanticValidationError = z.object({
status: z.literal(422),
error: z.object({
data: z.object({
detail: z.array(
z.object({
loc: z.array(z.string()),

View File

@ -14,7 +14,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
import NodeEditorPanelGroup from 'features/nodes/components/sidePanel/NodeEditorPanelGroup';
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { ResourceKey } from 'i18next';
import { isEqual } from 'lodash-es';
@ -110,7 +110,7 @@ export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager', 'queue'];
export const NO_SIDE_PANEL_TABS: InvokeTabName[] = ['modelManager', 'queue'];
const InvokeTabs = () => {
const activeTab = useAppSelector(activeTabIndexSelector);
const activeTabIndex = useAppSelector(activeTabIndexSelector);
const activeTabName = useAppSelector(activeTabNameSelector);
const enabledTabs = useAppSelector(enabledTabsSelector);
const { t } = useTranslation();
@ -150,13 +150,13 @@ const InvokeTabs = () => {
const handleTabChange = useCallback(
(index: number) => {
const activeTabName = tabMap[index];
if (!activeTabName) {
const tab = enabledTabs[index];
if (!tab) {
return;
}
dispatch(setActiveTab(activeTabName));
dispatch(setActiveTab(tab.id));
},
[dispatch]
[dispatch, enabledTabs]
);
const {
@ -216,8 +216,8 @@ const InvokeTabs = () => {
return (
<Tabs
variant="appTabs"
defaultIndex={activeTab}
index={activeTab}
defaultIndex={activeTabIndex}
index={activeTabIndex}
onChange={handleTabChange}
sx={{
flexGrow: 1,

View File

@ -95,26 +95,32 @@ export default function UnifiedCanvasColorPicker() {
>
<Flex minWidth={60} direction="column" gap={4} width="100%">
{layer === 'base' && (
<IAIColorPicker
<Box
sx={{
width: '100%',
paddingTop: 2,
paddingBottom: 2,
}}
pickerColor={brushColor}
onChange={(newColor) => dispatch(setBrushColor(newColor))}
/>
>
<IAIColorPicker
color={brushColor}
onChange={(newColor) => dispatch(setBrushColor(newColor))}
/>
</Box>
)}
{layer === 'mask' && (
<IAIColorPicker
<Box
sx={{
width: '100%',
paddingTop: 2,
paddingBottom: 2,
}}
pickerColor={maskColor}
onChange={(newColor) => dispatch(setMaskColor(newColor))}
/>
>
<IAIColorPicker
color={maskColor}
onChange={(newColor) => dispatch(setMaskColor(newColor))}
/>
</Box>
)}
</Flex>
</IAIPopover>

View File

@ -1,13 +0,0 @@
import { InvokeTabName, tabMap } from './tabMap';
import { UIState } from './uiTypes';
export const setActiveTabReducer = (
state: UIState,
newActiveTab: number | InvokeTabName
) => {
if (typeof newActiveTab === 'number') {
state.activeTab = newActiveTab;
} else {
state.activeTab = tabMap.indexOf(newActiveTab);
}
};

View File

@ -1,27 +1,23 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { isEqual } from 'lodash-es';
import { InvokeTabName, tabMap } from './tabMap';
import { UIState } from './uiTypes';
import { isEqual, isString } from 'lodash-es';
import { tabMap } from './tabMap';
export const activeTabNameSelector = createSelector(
(state: RootState) => state.ui,
(ui: UIState) => tabMap[ui.activeTab] as InvokeTabName,
{
memoizeOptions: {
equalityCheck: isEqual,
},
}
(state: RootState) => state,
/**
* Previously `activeTab` was an integer, but now it's a string.
* Default to first tab in case user has integer.
*/
({ ui }) => (isString(ui.activeTab) ? ui.activeTab : 'txt2img')
);
export const activeTabIndexSelector = createSelector(
(state: RootState) => state.ui,
(ui: UIState) => ui.activeTab,
{
memoizeOptions: {
equalityCheck: isEqual,
},
(state: RootState) => state,
({ ui, config }) => {
const tabs = tabMap.filter((t) => !config.disabledTabs.includes(t));
const idx = tabs.indexOf(ui.activeTab);
return idx === -1 ? 0 : idx;
}
);

View File

@ -2,12 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap';
import { UIState } from './uiTypes';
export const initialUIState: UIState = {
activeTab: 0,
activeTab: 'txt2img',
shouldShowImageDetails: false,
shouldUseCanvasBetaLayout: false,
shouldShowExistingModelsInSearch: false,
@ -26,7 +25,7 @@ export const uiSlice = createSlice({
initialState: initialUIState,
reducers: {
setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
setActiveTabReducer(state, action.payload);
state.activeTab = action.payload;
},
setShouldShowImageDetails: (state, action: PayloadAction<boolean>) => {
state.shouldShowImageDetails = action.payload;
@ -73,7 +72,7 @@ export const uiSlice = createSlice({
},
extraReducers(builder) {
builder.addCase(initialImageChanged, (state) => {
setActiveTabReducer(state, 'img2img');
state.activeTab = 'img2img';
});
},
});

View File

@ -1,4 +1,5 @@
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { InvokeTabName } from './tabMap';
export type Coordinates = {
x: number;
@ -13,7 +14,7 @@ export type Dimensions = {
export type Rect = Coordinates & Dimensions;
export interface UIState {
activeTab: number;
activeTab: InvokeTabName;
shouldShowImageDetails: boolean;
shouldUseCanvasBetaLayout: boolean;
shouldShowExistingModelsInSearch: boolean;

View File

@ -1,184 +0,0 @@
import { createAsyncThunk, isAnyOf } from '@reduxjs/toolkit';
import { $queueId } from 'features/queue/store/queueNanoStore';
import { isObject } from 'lodash-es';
import { $client } from 'services/api/client';
import { paths } from 'services/api/schema';
import { O } from 'ts-toolbelt';
type CreateSessionArg = {
graph: NonNullable<
paths['/api/v1/sessions/']['post']['requestBody']
>['content']['application/json'];
};
type CreateSessionResponse = O.Required<
NonNullable<
paths['/api/v1/sessions/']['post']['requestBody']
>['content']['application/json'],
'id'
>;
type CreateSessionThunkConfig = {
rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
};
/**
* `SessionsService.createSession()` thunk
*/
export const sessionCreated = createAsyncThunk<
CreateSessionResponse,
CreateSessionArg,
CreateSessionThunkConfig
>('api/sessionCreated', async (arg, { rejectWithValue }) => {
const { graph } = arg;
const { POST } = $client.get();
const { data, error, response } = await POST('/api/v1/sessions/', {
body: graph,
params: { query: { queue_id: $queueId.get() } },
});
if (error) {
return rejectWithValue({ arg, status: response.status, error });
}
return data;
});
type InvokedSessionArg = {
session_id: paths['/api/v1/sessions/{session_id}/invoke']['put']['parameters']['path']['session_id'];
};
type InvokedSessionResponse =
paths['/api/v1/sessions/{session_id}/invoke']['put']['responses']['200']['content']['application/json'];
type InvokedSessionThunkConfig = {
rejectValue: {
arg: InvokedSessionArg;
error: unknown;
status: number;
};
};
const isErrorWithStatus = (error: unknown): error is { status: number } =>
isObject(error) && 'status' in error;
const isErrorWithDetail = (error: unknown): error is { detail: string } =>
isObject(error) && 'detail' in error;
/**
* `SessionsService.invokeSession()` thunk
*/
export const sessionInvoked = createAsyncThunk<
InvokedSessionResponse,
InvokedSessionArg,
InvokedSessionThunkConfig
>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
const { session_id } = arg;
const { PUT } = $client.get();
const { error, response } = await PUT(
'/api/v1/sessions/{session_id}/invoke',
{
params: {
query: { queue_id: $queueId.get(), all: true },
path: { session_id },
},
}
);
if (error) {
if (isErrorWithStatus(error) && error.status === 403) {
return rejectWithValue({
arg,
status: response.status,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error: (error as any).body.detail,
});
}
if (isErrorWithDetail(error) && response.status === 403) {
return rejectWithValue({
arg,
status: response.status,
error: error.detail,
});
}
if (error) {
return rejectWithValue({ arg, status: response.status, error });
}
}
});
type CancelSessionArg =
paths['/api/v1/sessions/{session_id}/invoke']['delete']['parameters']['path'];
type CancelSessionResponse =
paths['/api/v1/sessions/{session_id}/invoke']['delete']['responses']['200']['content']['application/json'];
type CancelSessionThunkConfig = {
rejectValue: {
arg: CancelSessionArg;
error: unknown;
};
};
/**
* `SessionsService.cancelSession()` thunk
*/
export const sessionCanceled = createAsyncThunk<
CancelSessionResponse,
CancelSessionArg,
CancelSessionThunkConfig
>('api/sessionCanceled', async (arg, { rejectWithValue }) => {
const { session_id } = arg;
const { DELETE } = $client.get();
const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', {
params: {
path: { session_id },
},
});
if (error) {
return rejectWithValue({ arg, error });
}
return data;
});
type ListSessionsArg = {
params: paths['/api/v1/sessions/']['get']['parameters'];
};
type ListSessionsResponse =
paths['/api/v1/sessions/']['get']['responses']['200']['content']['application/json'];
type ListSessionsThunkConfig = {
rejectValue: {
arg: ListSessionsArg;
error: unknown;
};
};
/**
* `SessionsService.listSessions()` thunk
*/
export const listedSessions = createAsyncThunk<
ListSessionsResponse,
ListSessionsArg,
ListSessionsThunkConfig
>('api/listSessions', async (arg, { rejectWithValue }) => {
const { params } = arg;
const { GET } = $client.get();
const { data, error } = await GET('/api/v1/sessions/', {
params,
});
if (error) {
return rejectWithValue({ arg, error });
}
return data;
});
export const isAnySessionRejected = isAnyOf(
sessionCreated.rejected,
sessionInvoked.rejected
);

View File

@ -1,5 +1,4 @@
import { ThemeOverride } from '@chakra-ui/react';
import { ThemeOverride, ToastProviderProps } from '@chakra-ui/react';
import { InvokeAIColors } from './colors/colors';
import { accordionTheme } from './components/accordion';
import { buttonTheme } from './components/button';
@ -149,3 +148,7 @@ export const theme: ThemeOverride = {
Tooltip: tooltipTheme,
},
};
export const TOAST_OPTIONS: ToastProviderProps = {
defaultOptions: { isClosable: true },
};