mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'invoke-ai:main' into main
This commit is contained in:
commit
f35dfa06bb
@ -332,6 +332,7 @@ class InvokeAiInstance:
|
||||
Configure the InvokeAI runtime directory
|
||||
"""
|
||||
|
||||
auto_install = False
|
||||
# set sys.argv to a consistent state
|
||||
new_argv = [sys.argv[0]]
|
||||
for i in range(1, len(sys.argv)):
|
||||
@ -340,13 +341,17 @@ class InvokeAiInstance:
|
||||
new_argv.append(el)
|
||||
new_argv.append(sys.argv[i + 1])
|
||||
elif el in ["-y", "--yes", "--yes-to-all"]:
|
||||
new_argv.append(el)
|
||||
auto_install = True
|
||||
sys.argv = new_argv
|
||||
|
||||
import messages
|
||||
import requests # to catch download exceptions
|
||||
from messages import introduction
|
||||
|
||||
introduction()
|
||||
auto_install = auto_install or messages.user_wants_auto_configuration()
|
||||
if auto_install:
|
||||
sys.argv.append("--yes")
|
||||
else:
|
||||
messages.introduction()
|
||||
|
||||
from invokeai.frontend.install.invokeai_configure import invokeai_configure
|
||||
|
||||
|
@ -7,7 +7,7 @@ import os
|
||||
import platform
|
||||
from pathlib import Path
|
||||
|
||||
from prompt_toolkit import prompt
|
||||
from prompt_toolkit import HTML, prompt
|
||||
from prompt_toolkit.completion import PathCompleter
|
||||
from prompt_toolkit.validation import Validator
|
||||
from rich import box, print
|
||||
@ -65,17 +65,50 @@ def confirm_install(dest: Path) -> bool:
|
||||
if dest.exists():
|
||||
print(f":exclamation: Directory {dest} already exists :exclamation:")
|
||||
dest_confirmed = Confirm.ask(
|
||||
":stop_sign: Are you sure you want to (re)install in this location?",
|
||||
":stop_sign: (re)install in this location?",
|
||||
default=False,
|
||||
)
|
||||
else:
|
||||
print(f"InvokeAI will be installed in {dest}")
|
||||
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False)
|
||||
dest_confirmed = Confirm.ask("Use this location?", default=True)
|
||||
console.line()
|
||||
|
||||
return dest_confirmed
|
||||
|
||||
|
||||
def user_wants_auto_configuration() -> bool:
|
||||
"""Prompt the user to choose between manual and auto configuration."""
|
||||
console.rule("InvokeAI Configuration Section")
|
||||
console.print(
|
||||
Panel(
|
||||
Group(
|
||||
"\n".join(
|
||||
[
|
||||
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
|
||||
"",
|
||||
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
|
||||
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
|
||||
"",
|
||||
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
|
||||
]
|
||||
),
|
||||
),
|
||||
box=box.MINIMAL,
|
||||
padding=(1, 1),
|
||||
)
|
||||
)
|
||||
choice = (
|
||||
prompt(
|
||||
HTML("Choose <b><a></b>utomatic or <b><m></b>anual configuration [a/m] (a): "),
|
||||
validator=Validator.from_callable(
|
||||
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
|
||||
),
|
||||
)
|
||||
or "a"
|
||||
)
|
||||
return choice.lower().startswith("a")
|
||||
|
||||
|
||||
def dest_path(dest=None) -> Path:
|
||||
"""
|
||||
Prompt the user for the destination path and create the path
|
||||
|
@ -241,7 +241,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
|
||||
|
||||
# CACHE
|
||||
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
|
||||
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Lock
|
||||
from time import time
|
||||
from typing import Optional, Union
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
@ -59,7 +58,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
# If the cache is full, we need to remove the least used
|
||||
number_to_delete = len(self._cache) + 1 - self._max_cache_size
|
||||
self._delete_oldest_access(number_to_delete)
|
||||
self._cache[key] = CachedItem(time(), invocation_output, invocation_output.json())
|
||||
self._cache[key] = CachedItem(invocation_output, invocation_output.json())
|
||||
|
||||
def _delete_oldest_access(self, number_to_delete: int) -> None:
|
||||
number_to_delete = min(number_to_delete, len(self._cache))
|
||||
|
@ -70,7 +70,6 @@ def get_literal_fields(field) -> list[Any]:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
Model_dir = "models"
|
||||
|
||||
Default_config_file = config.model_conf_path
|
||||
SD_Configs = config.legacy_conf_path
|
||||
|
||||
@ -458,7 +457,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.",
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
@ -651,8 +650,19 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
|
||||
return editApp.new_opts()
|
||||
|
||||
|
||||
def default_ramcache() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
# Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
|
||||
# So we adjust everthing down a bit.
|
||||
return (
|
||||
15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
|
||||
) # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
opts.ram = default_ramcache()
|
||||
return opts
|
||||
|
||||
|
||||
|
@ -58,6 +58,7 @@
|
||||
"githubLabel": "Github",
|
||||
"hotkeysLabel": "Hotkeys",
|
||||
"imagePrompt": "Image Prompt",
|
||||
"imageFailedToLoad": "Unable to Load Image",
|
||||
"img2img": "Image To Image",
|
||||
"langArabic": "العربية",
|
||||
"langBrPortuguese": "Português do Brasil",
|
||||
@ -716,6 +717,7 @@
|
||||
"cannotConnectInputToInput": "Cannot connect input to input",
|
||||
"cannotConnectOutputToOutput": "Cannot connect output to output",
|
||||
"cannotConnectToSelf": "Cannot connect to self",
|
||||
"cannotDuplicateConnection": "Cannot create duplicate connections",
|
||||
"clipField": "Clip",
|
||||
"clipFieldDescription": "Tokenizer and text_encoder submodels.",
|
||||
"collection": "Collection",
|
||||
|
@ -5,7 +5,7 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { ReactNode, memo, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { theme as invokeAITheme } from 'theme/theme';
|
||||
import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme';
|
||||
|
||||
import '@fontsource-variable/inter';
|
||||
import { MantineProvider } from '@mantine/core';
|
||||
@ -39,7 +39,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
|
||||
return (
|
||||
<MantineProvider theme={mantineTheme}>
|
||||
<ChakraProvider theme={theme} colorModeManager={manager}>
|
||||
<ChakraProvider
|
||||
theme={theme}
|
||||
colorModeManager={manager}
|
||||
toastOptions={TOAST_OPTIONS}
|
||||
>
|
||||
{children}
|
||||
</ChakraProvider>
|
||||
</MantineProvider>
|
||||
|
@ -54,21 +54,6 @@ import { addModelSelectedListener } from './listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from './listeners/modelsLoaded';
|
||||
import { addDynamicPromptsListener } from './listeners/promptChanged';
|
||||
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
|
||||
import {
|
||||
addSessionCanceledFulfilledListener,
|
||||
addSessionCanceledPendingListener,
|
||||
addSessionCanceledRejectedListener,
|
||||
} from './listeners/sessionCanceled';
|
||||
import {
|
||||
addSessionCreatedFulfilledListener,
|
||||
addSessionCreatedPendingListener,
|
||||
addSessionCreatedRejectedListener,
|
||||
} from './listeners/sessionCreated';
|
||||
import {
|
||||
addSessionInvokedFulfilledListener,
|
||||
addSessionInvokedPendingListener,
|
||||
addSessionInvokedRejectedListener,
|
||||
} from './listeners/sessionInvoked';
|
||||
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
|
||||
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
|
||||
@ -86,6 +71,7 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
|
||||
import { addTabChangedListener } from './listeners/tabChanged';
|
||||
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
|
||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -136,6 +122,7 @@ addEnqueueRequestedCanvasListener();
|
||||
addEnqueueRequestedNodes();
|
||||
addEnqueueRequestedLinear();
|
||||
addAnyEnqueuedListener();
|
||||
addBatchEnqueuedListener();
|
||||
|
||||
// Canvas actions
|
||||
addCanvasSavedToGalleryListener();
|
||||
@ -175,21 +162,6 @@ addSessionRetrievalErrorEventListener();
|
||||
addInvocationRetrievalErrorEventListener();
|
||||
addSocketQueueItemStatusChangedEventListener();
|
||||
|
||||
// Session Created
|
||||
addSessionCreatedPendingListener();
|
||||
addSessionCreatedFulfilledListener();
|
||||
addSessionCreatedRejectedListener();
|
||||
|
||||
// Session Invoked
|
||||
addSessionInvokedPendingListener();
|
||||
addSessionInvokedFulfilledListener();
|
||||
addSessionInvokedRejectedListener();
|
||||
|
||||
// Session Canceled
|
||||
addSessionCanceledPendingListener();
|
||||
addSessionCanceledFulfilledListener();
|
||||
addSessionCanceledRejectedListener();
|
||||
|
||||
// ControlNet
|
||||
addControlNetImageProcessedListener();
|
||||
addControlNetAutoProcessListener();
|
||||
|
@ -0,0 +1,96 @@
|
||||
import { createStandaloneToast } from '@chakra-ui/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
|
||||
import { t } from 'i18next';
|
||||
import { get, truncate, upperFirst } from 'lodash-es';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { TOAST_OPTIONS, theme } from 'theme/theme';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
const { toast } = createStandaloneToast({
|
||||
theme: theme,
|
||||
defaultOptions: TOAST_OPTIONS.defaultOptions,
|
||||
});
|
||||
|
||||
export const addBatchEnqueuedListener = () => {
|
||||
// success
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
|
||||
effect: async (action) => {
|
||||
const response = action.payload;
|
||||
const arg = action.meta.arg.originalArgs;
|
||||
logger('queue').debug(
|
||||
{ enqueueResult: parseify(response) },
|
||||
'Batch enqueued'
|
||||
);
|
||||
|
||||
if (!toast.isActive('batch-queued')) {
|
||||
toast({
|
||||
id: 'batch-queued',
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: response.enqueued,
|
||||
direction: arg.prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
duration: 1000,
|
||||
status: 'success',
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// error
|
||||
startAppListening({
|
||||
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
|
||||
effect: async (action) => {
|
||||
const response = action.payload;
|
||||
const arg = action.meta.arg.originalArgs;
|
||||
|
||||
if (!response) {
|
||||
toast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
description: 'Unknown Error',
|
||||
});
|
||||
logger('queue').error(
|
||||
{ batchConfig: parseify(arg), error: parseify(response) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const result = zPydanticValidationError.safeParse(response);
|
||||
if (result.success) {
|
||||
result.data.data.detail.map((e) => {
|
||||
toast({
|
||||
id: 'batch-failed-to-queue',
|
||||
title: truncate(upperFirst(e.msg), { length: 128 }),
|
||||
status: 'error',
|
||||
description: truncate(
|
||||
`Path:
|
||||
${e.loc.join('.')}`,
|
||||
{ length: 128 }
|
||||
),
|
||||
});
|
||||
});
|
||||
} else {
|
||||
let detail = 'Unknown Error';
|
||||
if (response.status === 403 && 'body' in response) {
|
||||
detail = get(response, 'body.detail', 'Unknown Error');
|
||||
} else if (response.status === 403 && 'error' in response) {
|
||||
detail = get(response, 'error.detail', 'Unknown Error');
|
||||
}
|
||||
toast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
description: detail,
|
||||
});
|
||||
}
|
||||
logger('queue').error(
|
||||
{ batchConfig: parseify(arg), error: parseify(response) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
@ -12,8 +12,6 @@ import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGeneratio
|
||||
import { canvasGraphBuilt } from 'features/nodes/store/actions';
|
||||
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
@ -140,8 +138,6 @@ export const addEnqueueRequestedCanvasListener = () => {
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||
|
||||
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
|
||||
|
||||
// Prep the canvas staging area if it is not yet initialized
|
||||
@ -158,28 +154,8 @@ export const addEnqueueRequestedCanvasListener = () => {
|
||||
|
||||
// Associate the session with the canvas session ID
|
||||
dispatch(canvasBatchIdAdded(batchId));
|
||||
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: enqueueResult.enqueued,
|
||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
log.error(
|
||||
{ batchConfig: parseify(batchConfig) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
});
|
||||
|
@ -1,13 +1,9 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
|
||||
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
|
||||
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
|
||||
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
|
||||
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
@ -18,7 +14,6 @@ export const addEnqueueRequestedLinear = () => {
|
||||
(action.payload.tabName === 'txt2img' ||
|
||||
action.payload.tabName === 'img2img'),
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const log = logger('queue');
|
||||
const state = getState();
|
||||
const model = state.generation.model;
|
||||
const { prepend } = action.payload;
|
||||
@ -41,38 +36,12 @@ export const addEnqueueRequestedLinear = () => {
|
||||
|
||||
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: enqueueResult.enqueued,
|
||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
log.error(
|
||||
{ batchConfig: parseify(batchConfig) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
req.reset();
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,9 +1,5 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { BatchConfig } from 'services/api/types';
|
||||
import { startAppListening } from '..';
|
||||
@ -13,9 +9,7 @@ export const addEnqueueRequestedNodes = () => {
|
||||
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
|
||||
enqueueRequested.match(action) && action.payload.tabName === 'nodes',
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const log = logger('queue');
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
const graph = buildNodesGraph(state.nodes);
|
||||
const batchConfig: BatchConfig = {
|
||||
batch: {
|
||||
@ -25,38 +19,12 @@ export const addEnqueueRequestedNodes = () => {
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: enqueueResult.enqueued,
|
||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
log.error(
|
||||
{ batchConfig: parseify(batchConfig) },
|
||||
'Failed to enqueue batch'
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
req.reset();
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,44 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addSessionCanceledPendingListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCanceled.pending,
|
||||
effect: () => {
|
||||
//
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionCanceledFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCanceled.fulfilled,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const { session_id } = action.meta.arg;
|
||||
log.debug({ session_id }, `Session canceled (${session_id})`);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionCanceledRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCanceled.rejected,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const { session_id } = action.meta.arg;
|
||||
if (action.payload) {
|
||||
const { error } = action.payload;
|
||||
log.error(
|
||||
{
|
||||
session_id,
|
||||
error: serializeError(error),
|
||||
},
|
||||
`Problem canceling session`
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -1,45 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { sessionCreated } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addSessionCreatedPendingListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCreated.pending,
|
||||
effect: () => {
|
||||
//
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionCreatedFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCreated.fulfilled,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const session = action.payload;
|
||||
log.debug(
|
||||
{ session: parseify(session) },
|
||||
`Session created (${session.id})`
|
||||
);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionCreatedRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionCreated.rejected,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
if (action.payload) {
|
||||
const { error, status } = action.payload;
|
||||
const graph = parseify(action.meta.arg);
|
||||
log.error(
|
||||
{ graph, status, error: serializeError(error) },
|
||||
`Problem creating session`
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -1,44 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { sessionInvoked } from 'services/api/thunks/session';
|
||||
import { startAppListening } from '..';
|
||||
|
||||
export const addSessionInvokedPendingListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionInvoked.pending,
|
||||
effect: () => {
|
||||
//
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionInvokedFulfilledListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionInvoked.fulfilled,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const { session_id } = action.meta.arg;
|
||||
log.debug({ session_id }, `Session invoked (${session_id})`);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
export const addSessionInvokedRejectedListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: sessionInvoked.rejected,
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
const { session_id } = action.meta.arg;
|
||||
if (action.payload) {
|
||||
const { error } = action.payload;
|
||||
log.error(
|
||||
{
|
||||
session_id,
|
||||
error: serializeError(error),
|
||||
},
|
||||
`Problem invoking session`
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
@ -35,6 +35,7 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
|
||||
queueApi.util.invalidateTags([
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'InvocationCacheStatus',
|
||||
{ type: 'SessionQueueItem', id: item_id },
|
||||
{ type: 'SessionQueueItemDTO', id: item_id },
|
||||
{ type: 'BatchStatus', id: queue_batch_id },
|
||||
|
@ -1,54 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { AppThunkDispatch } from 'app/store/store';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import { BatchConfig } from 'services/api/types';
|
||||
|
||||
export const enqueueBatch = async (
|
||||
batchConfig: BatchConfig,
|
||||
dispatch: AppThunkDispatch
|
||||
) => {
|
||||
const log = logger('session');
|
||||
const { prepend } = batchConfig;
|
||||
|
||||
try {
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
})
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
|
||||
dispatch(
|
||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||
fixedCacheKey: 'resumeProcessor',
|
||||
})
|
||||
);
|
||||
|
||||
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchQueued'),
|
||||
description: t('queue.batchQueuedDesc', {
|
||||
item_count: enqueueResult.enqueued,
|
||||
direction: prepend ? t('queue.front') : t('queue.back'),
|
||||
}),
|
||||
status: 'success',
|
||||
})
|
||||
);
|
||||
} catch {
|
||||
log.error(
|
||||
{ batchConfig: parseify(batchConfig) },
|
||||
t('queue.batchFailedToQueue')
|
||||
);
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.batchFailedToQueue'),
|
||||
status: 'error',
|
||||
})
|
||||
);
|
||||
}
|
||||
};
|
@ -1,18 +1,9 @@
|
||||
import { chakra, ChakraProps } from '@chakra-ui/react';
|
||||
import { Box, ChakraProps } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { RgbaColorPicker } from 'react-colorful';
|
||||
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
|
||||
|
||||
type IAIColorPickerProps = Omit<ColorPickerBaseProps<RgbaColor>, 'color'> &
|
||||
ChakraProps & {
|
||||
pickerColor: RgbaColor;
|
||||
styleClass?: string;
|
||||
};
|
||||
|
||||
const ChakraRgbaColorPicker = chakra(RgbaColorPicker, {
|
||||
baseStyle: { paddingInline: 4 },
|
||||
shouldForwardProp: (prop) => !['pickerColor'].includes(prop),
|
||||
});
|
||||
type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor>;
|
||||
|
||||
const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
||||
width: 6,
|
||||
@ -20,19 +11,17 @@ const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
|
||||
borderColor: 'base.100',
|
||||
};
|
||||
|
||||
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
const { styleClass = '', ...rest } = props;
|
||||
const sx = {
|
||||
'.react-colorful__hue-pointer': colorPickerStyles,
|
||||
'.react-colorful__saturation-pointer': colorPickerStyles,
|
||||
'.react-colorful__alpha-pointer': colorPickerStyles,
|
||||
};
|
||||
|
||||
const IAIColorPicker = (props: IAIColorPickerProps) => {
|
||||
return (
|
||||
<ChakraRgbaColorPicker
|
||||
sx={{
|
||||
'.react-colorful__hue-pointer': colorPickerStyles,
|
||||
'.react-colorful__saturation-pointer': colorPickerStyles,
|
||||
'.react-colorful__alpha-pointer': colorPickerStyles,
|
||||
}}
|
||||
className={styleClass}
|
||||
{...rest}
|
||||
/>
|
||||
<Box sx={sx}>
|
||||
<RgbaColorPicker {...props} />
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
|
@ -1,26 +1,27 @@
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { Image, Rect } from 'react-konva';
|
||||
import { memo } from 'react';
|
||||
import { Image } from 'react-konva';
|
||||
import { $authToken } from 'services/api/client';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import useImage from 'use-image';
|
||||
import { CanvasImage } from '../store/canvasTypes';
|
||||
import { $authToken } from 'services/api/client';
|
||||
import { memo } from 'react';
|
||||
import IAICanvasImageErrorFallback from './IAICanvasImageErrorFallback';
|
||||
|
||||
type IAICanvasImageProps = {
|
||||
canvasImage: CanvasImage;
|
||||
};
|
||||
const IAICanvasImage = (props: IAICanvasImageProps) => {
|
||||
const { width, height, x, y, imageName } = props.canvasImage;
|
||||
const { x, y, imageName } = props.canvasImage;
|
||||
const { currentData: imageDTO, isError } = useGetImageDTOQuery(
|
||||
imageName ?? skipToken
|
||||
);
|
||||
const [image] = useImage(
|
||||
const [image, status] = useImage(
|
||||
imageDTO?.image_url ?? '',
|
||||
$authToken.get() ? 'use-credentials' : 'anonymous'
|
||||
);
|
||||
|
||||
if (isError) {
|
||||
return <Rect x={x} y={y} width={width} height={height} fill="red" />;
|
||||
if (isError || status === 'failed') {
|
||||
return <IAICanvasImageErrorFallback canvasImage={props.canvasImage} />;
|
||||
}
|
||||
|
||||
return <Image x={x} y={y} image={image} listening={false} />;
|
||||
|
@ -0,0 +1,44 @@
|
||||
import { useColorModeValue, useToken } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Group, Rect, Text } from 'react-konva';
|
||||
import { CanvasImage } from '../store/canvasTypes';
|
||||
|
||||
type IAICanvasImageErrorFallbackProps = {
|
||||
canvasImage: CanvasImage;
|
||||
};
|
||||
const IAICanvasImageErrorFallback = ({
|
||||
canvasImage,
|
||||
}: IAICanvasImageErrorFallbackProps) => {
|
||||
const [errorColorLight, errorColorDark, fontColorLight, fontColorDark] =
|
||||
useToken('colors', ['gray.400', 'gray.500', 'base.700', 'base.900']);
|
||||
const errorColor = useColorModeValue(errorColorLight, errorColorDark);
|
||||
const fontColor = useColorModeValue(fontColorLight, fontColorDark);
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Group>
|
||||
<Rect
|
||||
x={canvasImage.x}
|
||||
y={canvasImage.y}
|
||||
width={canvasImage.width}
|
||||
height={canvasImage.height}
|
||||
fill={errorColor}
|
||||
/>
|
||||
<Text
|
||||
x={canvasImage.x}
|
||||
y={canvasImage.y}
|
||||
width={canvasImage.width}
|
||||
height={canvasImage.height}
|
||||
align="center"
|
||||
verticalAlign="middle"
|
||||
fontFamily='"Inter Variable", sans-serif'
|
||||
fontSize={canvasImage.width / 16}
|
||||
fontStyle="600"
|
||||
text={t('common.imageFailedToLoad')}
|
||||
fill={fontColor}
|
||||
/>
|
||||
</Group>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAICanvasImageErrorFallback);
|
@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
@ -135,11 +135,12 @@ const IAICanvasMaskOptions = () => {
|
||||
dispatch(setShouldPreserveMaskedArea(e.target.checked))
|
||||
}
|
||||
/>
|
||||
<IAIColorPicker
|
||||
sx={{ paddingTop: 2, paddingBottom: 2 }}
|
||||
pickerColor={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
<Box sx={{ paddingTop: 2, paddingBottom: 2 }}>
|
||||
<IAIColorPicker
|
||||
color={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
</Box>
|
||||
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
|
||||
Save Mask
|
||||
</IAIButton>
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { ButtonGroup, Flex } from '@chakra-ui/react';
|
||||
import { ButtonGroup, Flex, Box } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@ -237,15 +237,18 @@ const IAICanvasToolChooserOptions = () => {
|
||||
sliderNumberInputProps={{ max: 500 }}
|
||||
/>
|
||||
</Flex>
|
||||
<IAIColorPicker
|
||||
<Box
|
||||
sx={{
|
||||
width: '100%',
|
||||
paddingTop: 2,
|
||||
paddingBottom: 2,
|
||||
}}
|
||||
pickerColor={brushColor}
|
||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||
/>
|
||||
>
|
||||
<IAIColorPicker
|
||||
color={brushColor}
|
||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||
/>
|
||||
</Box>
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
</ButtonGroup>
|
||||
|
@ -8,7 +8,6 @@ import { setAspectRatio } from 'features/parameters/store/generationSlice';
|
||||
import { IRect, Vector2d } from 'konva/lib/types';
|
||||
import { clamp, cloneDeep } from 'lodash-es';
|
||||
import { RgbaColor } from 'react-colorful';
|
||||
import { sessionCanceled } from 'services/api/thunks/session';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import calculateCoordinates from '../util/calculateCoordinates';
|
||||
import calculateScale from '../util/calculateScale';
|
||||
@ -786,11 +785,6 @@ export const canvasSlice = createSlice({
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addCase(sessionCanceled.pending, (state) => {
|
||||
if (!state.layerState.stagingArea.images.length) {
|
||||
state.layerState.stagingArea = initialLayerState.stagingArea;
|
||||
}
|
||||
});
|
||||
builder.addCase(setAspectRatio, (state, action) => {
|
||||
const ratio = action.payload;
|
||||
if (ratio) {
|
||||
|
@ -6,7 +6,6 @@ import {
|
||||
import { cloneDeep, forEach } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { components } from 'services/api/schema';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import { ImageDTO } from 'services/api/types';
|
||||
import { appSocketInvocationError } from 'services/events/actions';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
@ -418,10 +417,6 @@ export const controlNetSlice = createSlice({
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
|
||||
builder.addMatcher(isAnySessionRejected, (state) => {
|
||||
state.pendingControlImages = [];
|
||||
});
|
||||
|
||||
builder.addMatcher(
|
||||
imagesApi.endpoints.deleteImage.matchFulfilled,
|
||||
(state, action) => {
|
||||
|
@ -12,6 +12,7 @@ import {
|
||||
OnConnect,
|
||||
OnConnectEnd,
|
||||
OnConnectStart,
|
||||
OnEdgeUpdateFunc,
|
||||
OnEdgesChange,
|
||||
OnEdgesDelete,
|
||||
OnInit,
|
||||
@ -21,6 +22,7 @@ import {
|
||||
OnSelectionChangeFunc,
|
||||
ProOptions,
|
||||
ReactFlow,
|
||||
ReactFlowProps,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
import { useIsValidConnection } from '../../hooks/useIsValidConnection';
|
||||
@ -28,6 +30,8 @@ import {
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
edgeAdded,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
nodesChanged,
|
||||
@ -167,6 +171,63 @@ export const Flow = () => {
|
||||
}
|
||||
}, []);
|
||||
|
||||
// #region Updatable Edges
|
||||
|
||||
/**
|
||||
* Adapted from https://reactflow.dev/docs/examples/edges/updatable-edge/
|
||||
* and https://reactflow.dev/docs/examples/edges/delete-edge-on-drop/
|
||||
*
|
||||
* - Edges can be dragged from one handle to another.
|
||||
* - If the user drags the edge away from the node and drops it, delete the edge.
|
||||
* - Do not delete the edge if the cursor didn't move (resolves annoying behaviour
|
||||
* where the edge is deleted if you click it accidentally).
|
||||
*/
|
||||
|
||||
// We have a ref for cursor position, but it is the *projected* cursor position.
|
||||
// Easiest to just keep track of the last mouse event for this particular feature
|
||||
const edgeUpdateMouseEvent = useRef<MouseEvent>();
|
||||
|
||||
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> =
|
||||
useCallback(
|
||||
(e, edge, _handleType) => {
|
||||
// update mouse event
|
||||
edgeUpdateMouseEvent.current = e;
|
||||
// always delete the edge when starting an updated
|
||||
dispatch(edgeDeleted(edge.id));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
|
||||
(_oldEdge, newConnection) => {
|
||||
// instead of updating the edge (we deleted it earlier), we instead create
|
||||
// a new one.
|
||||
dispatch(connectionMade(newConnection));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> =
|
||||
useCallback(
|
||||
(e, edge, _handleType) => {
|
||||
// Handle the case where user begins a drag but didn't move the cursor -
|
||||
// bc we deleted the edge, we need to add it back
|
||||
if (
|
||||
// ignore touch events
|
||||
!('touches' in e) &&
|
||||
edgeUpdateMouseEvent.current?.clientX === e.clientX &&
|
||||
edgeUpdateMouseEvent.current?.clientY === e.clientY
|
||||
) {
|
||||
dispatch(edgeAdded(edge));
|
||||
}
|
||||
// reset mouse event
|
||||
edgeUpdateMouseEvent.current = undefined;
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
// #endregion
|
||||
|
||||
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
|
||||
e.preventDefault();
|
||||
dispatch(selectionCopied());
|
||||
@ -196,6 +257,9 @@ export const Flow = () => {
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onEdgesDelete={onEdgesDelete}
|
||||
onEdgeUpdate={onEdgeUpdate}
|
||||
onEdgeUpdateStart={onEdgeUpdateStart}
|
||||
onEdgeUpdateEnd={onEdgeUpdateEnd}
|
||||
onNodesDelete={onNodesDelete}
|
||||
onConnectStart={onConnectStart}
|
||||
onConnect={onConnect}
|
||||
|
@ -53,13 +53,12 @@ export const useIsValidConnection = () => {
|
||||
}
|
||||
|
||||
if (
|
||||
edges
|
||||
.filter((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
})
|
||||
.find((edge) => {
|
||||
edge.source === source && edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
edges.find((edge) => {
|
||||
edge.target === target &&
|
||||
edge.targetHandle === targetHandle &&
|
||||
edge.source === source &&
|
||||
edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return false;
|
||||
|
@ -15,6 +15,7 @@ import {
|
||||
NodeChange,
|
||||
OnConnectStartParams,
|
||||
SelectionMode,
|
||||
updateEdge,
|
||||
Viewport,
|
||||
XYPosition,
|
||||
} from 'reactflow';
|
||||
@ -182,6 +183,16 @@ const nodesSlice = createSlice({
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
},
|
||||
edgeAdded: (state, action: PayloadAction<Edge>) => {
|
||||
state.edges = addEdge(action.payload, state.edges);
|
||||
},
|
||||
edgeUpdated: (
|
||||
state,
|
||||
action: PayloadAction<{ oldEdge: Edge; newConnection: Connection }>
|
||||
) => {
|
||||
const { oldEdge, newConnection } = action.payload;
|
||||
state.edges = updateEdge(oldEdge, newConnection, state.edges);
|
||||
},
|
||||
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
|
||||
state.connectionStartParams = action.payload;
|
||||
const { nodeId, handleId, handleType } = action.payload;
|
||||
@ -366,6 +377,7 @@ const nodesSlice = createSlice({
|
||||
target: edge.target,
|
||||
type: 'collapsed',
|
||||
data: { count: 1 },
|
||||
updatable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -388,6 +400,7 @@ const nodesSlice = createSlice({
|
||||
target: edge.target,
|
||||
type: 'collapsed',
|
||||
data: { count: 1 },
|
||||
updatable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -400,6 +413,9 @@ const nodesSlice = createSlice({
|
||||
}
|
||||
}
|
||||
},
|
||||
edgeDeleted: (state, action: PayloadAction<string>) => {
|
||||
state.edges = state.edges.filter((e) => e.id !== action.payload);
|
||||
},
|
||||
edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
|
||||
const edges = action.payload;
|
||||
const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
|
||||
@ -890,69 +906,72 @@ const nodesSlice = createSlice({
|
||||
});
|
||||
|
||||
export const {
|
||||
nodesChanged,
|
||||
edgesChanged,
|
||||
nodeAdded,
|
||||
nodesDeleted,
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverOpened,
|
||||
addNodePopoverToggled,
|
||||
connectionEnded,
|
||||
connectionMade,
|
||||
connectionStarted,
|
||||
connectionEnded,
|
||||
shouldShowFieldTypeLegendChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
nodeTemplatesBuilt,
|
||||
nodeEditorReset,
|
||||
imageCollectionFieldValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
edgeDeleted,
|
||||
edgesChanged,
|
||||
edgesDeleted,
|
||||
edgeUpdated,
|
||||
fieldBoardValueChanged,
|
||||
fieldBooleanValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldColorValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldControlNetModelValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
imageCollectionFieldValueChanged,
|
||||
mouseOverFieldChanged,
|
||||
mouseOverNodeChanged,
|
||||
nodeAdded,
|
||||
nodeEditorReset,
|
||||
nodeEmbedWorkflowChanged,
|
||||
nodeExclusivelySelected,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
nodeLabelChanged,
|
||||
nodeNotesChanged,
|
||||
edgesDeleted,
|
||||
shouldValidateGraphChanged,
|
||||
shouldAnimateEdgesChanged,
|
||||
nodeOpacityChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldColorEdgesChanged,
|
||||
selectedNodesChanged,
|
||||
selectedEdgesChanged,
|
||||
workflowNameChanged,
|
||||
workflowDescriptionChanged,
|
||||
workflowTagsChanged,
|
||||
workflowAuthorChanged,
|
||||
workflowNotesChanged,
|
||||
workflowVersionChanged,
|
||||
workflowContactChanged,
|
||||
workflowLoaded,
|
||||
nodesChanged,
|
||||
nodesDeleted,
|
||||
nodeTemplatesBuilt,
|
||||
nodeUseCacheChanged,
|
||||
notesNodeValueChanged,
|
||||
selectedAll,
|
||||
selectedEdgesChanged,
|
||||
selectedNodesChanged,
|
||||
selectionCopied,
|
||||
selectionModeChanged,
|
||||
selectionPasted,
|
||||
shouldAnimateEdgesChanged,
|
||||
shouldColorEdgesChanged,
|
||||
shouldShowFieldTypeLegendChanged,
|
||||
shouldShowMinimapPanelChanged,
|
||||
shouldSnapToGridChanged,
|
||||
shouldValidateGraphChanged,
|
||||
viewportChanged,
|
||||
workflowAuthorChanged,
|
||||
workflowContactChanged,
|
||||
workflowDescriptionChanged,
|
||||
workflowExposedFieldAdded,
|
||||
workflowExposedFieldRemoved,
|
||||
fieldLabelChanged,
|
||||
viewportChanged,
|
||||
mouseOverFieldChanged,
|
||||
selectionCopied,
|
||||
selectionPasted,
|
||||
selectedAll,
|
||||
addNodePopoverOpened,
|
||||
addNodePopoverClosed,
|
||||
addNodePopoverToggled,
|
||||
selectionModeChanged,
|
||||
nodeEmbedWorkflowChanged,
|
||||
nodeIsIntermediateChanged,
|
||||
mouseOverNodeChanged,
|
||||
nodeExclusivelySelected,
|
||||
nodeUseCacheChanged,
|
||||
workflowLoaded,
|
||||
workflowNameChanged,
|
||||
workflowNotesChanged,
|
||||
workflowTagsChanged,
|
||||
workflowVersionChanged,
|
||||
edgeAdded,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export default nodesSlice.reducer;
|
||||
|
@ -55,9 +55,29 @@ export const makeConnectionErrorSelector = (
|
||||
return i18n.t('nodes.cannotConnectInputToInput');
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
const target = handleType === 'target' ? nodeId : connectionNodeId;
|
||||
const targetHandle =
|
||||
handleType === 'target' ? fieldName : connectionFieldName;
|
||||
const source = handleType === 'source' ? nodeId : connectionNodeId;
|
||||
const sourceHandle =
|
||||
handleType === 'source' ? fieldName : connectionFieldName;
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === nodeId && edge.targetHandle === fieldName;
|
||||
edge.target === target &&
|
||||
edge.targetHandle === targetHandle &&
|
||||
edge.source === source &&
|
||||
edge.sourceHandle === sourceHandle;
|
||||
})
|
||||
) {
|
||||
// We already have a connection from this source to this target
|
||||
return i18n.t('nodes.cannotDuplicateConnection');
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find((edge) => {
|
||||
return edge.target === target && edge.targetHandle === targetHandle;
|
||||
}) &&
|
||||
// except CollectionItem inputs can have multiples
|
||||
targetType !== 'CollectionItem'
|
||||
|
@ -1,9 +1,7 @@
|
||||
import { ButtonGroup } from '@chakra-ui/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import ClearInvocationCacheButton from './ClearInvocationCacheButton';
|
||||
import ToggleInvocationCacheButton from './ToggleInvocationCacheButton';
|
||||
import StatusStatGroup from './common/StatusStatGroup';
|
||||
@ -11,16 +9,7 @@ import StatusStatItem from './common/StatusStatItem';
|
||||
|
||||
const InvocationCacheStatus = () => {
|
||||
const { t } = useTranslation();
|
||||
const isConnected = useAppSelector((state) => state.system.isConnected);
|
||||
const { data: queueStatus } = useGetQueueStatusQuery(undefined);
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined, {
|
||||
pollingInterval:
|
||||
isConnected &&
|
||||
queueStatus?.processor.is_started &&
|
||||
queueStatus?.queue.pending > 0
|
||||
? 5000
|
||||
: 0,
|
||||
});
|
||||
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined);
|
||||
|
||||
return (
|
||||
<StatusStatGroup>
|
||||
|
@ -1,9 +1,8 @@
|
||||
import { UseToastOptions } from '@chakra-ui/react';
|
||||
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { t } from 'i18next';
|
||||
import { get, startCase, truncate, upperFirst } from 'lodash-es';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { LogLevelName } from 'roarr';
|
||||
import { isAnySessionRejected } from 'services/api/thunks/session';
|
||||
import {
|
||||
appSocketConnected,
|
||||
appSocketDisconnected,
|
||||
@ -20,8 +19,7 @@ import {
|
||||
} from 'services/events/actions';
|
||||
import { calculateStepPercentage } from '../util/calculateStepPercentage';
|
||||
import { makeToast } from '../util/makeToast';
|
||||
import { SystemState, LANGUAGES } from './types';
|
||||
import { zPydanticValidationError } from './zodSchemas';
|
||||
import { LANGUAGES, SystemState } from './types';
|
||||
|
||||
export const initialSystemState: SystemState = {
|
||||
isInitialized: false,
|
||||
@ -175,50 +173,6 @@ export const systemSlice = createSlice({
|
||||
|
||||
// *** Matchers - must be after all cases ***
|
||||
|
||||
/**
|
||||
* Session Invoked - REJECTED
|
||||
* Session Created - REJECTED
|
||||
*/
|
||||
builder.addMatcher(isAnySessionRejected, (state, action) => {
|
||||
let errorDescription = undefined;
|
||||
const duration = 5000;
|
||||
|
||||
if (action.payload?.status === 422) {
|
||||
const result = zPydanticValidationError.safeParse(action.payload);
|
||||
if (result.success) {
|
||||
result.data.error.detail.map((e) => {
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: truncate(upperFirst(e.msg), { length: 128 }),
|
||||
status: 'error',
|
||||
description: truncate(
|
||||
`Path:
|
||||
${e.loc.join('.')}`,
|
||||
{ length: 128 }
|
||||
),
|
||||
duration,
|
||||
})
|
||||
);
|
||||
});
|
||||
return;
|
||||
}
|
||||
} else if (action.payload?.error) {
|
||||
errorDescription = action.payload?.error;
|
||||
}
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: t('toast.serverError'),
|
||||
status: 'error',
|
||||
description: truncate(
|
||||
get(errorDescription, 'detail', 'Unknown Error'),
|
||||
{ length: 128 }
|
||||
),
|
||||
duration,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Any server error
|
||||
*/
|
||||
|
@ -2,7 +2,7 @@ import { z } from 'zod';
|
||||
|
||||
export const zPydanticValidationError = z.object({
|
||||
status: z.literal(422),
|
||||
error: z.object({
|
||||
data: z.object({
|
||||
detail: z.array(
|
||||
z.object({
|
||||
loc: z.array(z.string()),
|
||||
|
@ -14,7 +14,7 @@ import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
|
||||
import NodeEditorPanelGroup from 'features/nodes/components/sidePanel/NodeEditorPanelGroup';
|
||||
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { ResourceKey } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
@ -110,7 +110,7 @@ export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager', 'queue'];
|
||||
export const NO_SIDE_PANEL_TABS: InvokeTabName[] = ['modelManager', 'queue'];
|
||||
|
||||
const InvokeTabs = () => {
|
||||
const activeTab = useAppSelector(activeTabIndexSelector);
|
||||
const activeTabIndex = useAppSelector(activeTabIndexSelector);
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const enabledTabs = useAppSelector(enabledTabsSelector);
|
||||
const { t } = useTranslation();
|
||||
@ -150,13 +150,13 @@ const InvokeTabs = () => {
|
||||
|
||||
const handleTabChange = useCallback(
|
||||
(index: number) => {
|
||||
const activeTabName = tabMap[index];
|
||||
if (!activeTabName) {
|
||||
const tab = enabledTabs[index];
|
||||
if (!tab) {
|
||||
return;
|
||||
}
|
||||
dispatch(setActiveTab(activeTabName));
|
||||
dispatch(setActiveTab(tab.id));
|
||||
},
|
||||
[dispatch]
|
||||
[dispatch, enabledTabs]
|
||||
);
|
||||
|
||||
const {
|
||||
@ -216,8 +216,8 @@ const InvokeTabs = () => {
|
||||
return (
|
||||
<Tabs
|
||||
variant="appTabs"
|
||||
defaultIndex={activeTab}
|
||||
index={activeTab}
|
||||
defaultIndex={activeTabIndex}
|
||||
index={activeTabIndex}
|
||||
onChange={handleTabChange}
|
||||
sx={{
|
||||
flexGrow: 1,
|
||||
|
@ -95,26 +95,32 @@ export default function UnifiedCanvasColorPicker() {
|
||||
>
|
||||
<Flex minWidth={60} direction="column" gap={4} width="100%">
|
||||
{layer === 'base' && (
|
||||
<IAIColorPicker
|
||||
<Box
|
||||
sx={{
|
||||
width: '100%',
|
||||
paddingTop: 2,
|
||||
paddingBottom: 2,
|
||||
}}
|
||||
pickerColor={brushColor}
|
||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||
/>
|
||||
>
|
||||
<IAIColorPicker
|
||||
color={brushColor}
|
||||
onChange={(newColor) => dispatch(setBrushColor(newColor))}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
{layer === 'mask' && (
|
||||
<IAIColorPicker
|
||||
<Box
|
||||
sx={{
|
||||
width: '100%',
|
||||
paddingTop: 2,
|
||||
paddingBottom: 2,
|
||||
}}
|
||||
pickerColor={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
>
|
||||
<IAIColorPicker
|
||||
color={maskColor}
|
||||
onChange={(newColor) => dispatch(setMaskColor(newColor))}
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
|
@ -1,13 +0,0 @@
|
||||
import { InvokeTabName, tabMap } from './tabMap';
|
||||
import { UIState } from './uiTypes';
|
||||
|
||||
export const setActiveTabReducer = (
|
||||
state: UIState,
|
||||
newActiveTab: number | InvokeTabName
|
||||
) => {
|
||||
if (typeof newActiveTab === 'number') {
|
||||
state.activeTab = newActiveTab;
|
||||
} else {
|
||||
state.activeTab = tabMap.indexOf(newActiveTab);
|
||||
}
|
||||
};
|
@ -1,27 +1,23 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { InvokeTabName, tabMap } from './tabMap';
|
||||
import { UIState } from './uiTypes';
|
||||
import { isEqual, isString } from 'lodash-es';
|
||||
import { tabMap } from './tabMap';
|
||||
|
||||
export const activeTabNameSelector = createSelector(
|
||||
(state: RootState) => state.ui,
|
||||
(ui: UIState) => tabMap[ui.activeTab] as InvokeTabName,
|
||||
{
|
||||
memoizeOptions: {
|
||||
equalityCheck: isEqual,
|
||||
},
|
||||
}
|
||||
(state: RootState) => state,
|
||||
/**
|
||||
* Previously `activeTab` was an integer, but now it's a string.
|
||||
* Default to first tab in case user has integer.
|
||||
*/
|
||||
({ ui }) => (isString(ui.activeTab) ? ui.activeTab : 'txt2img')
|
||||
);
|
||||
|
||||
export const activeTabIndexSelector = createSelector(
|
||||
(state: RootState) => state.ui,
|
||||
(ui: UIState) => ui.activeTab,
|
||||
{
|
||||
memoizeOptions: {
|
||||
equalityCheck: isEqual,
|
||||
},
|
||||
(state: RootState) => state,
|
||||
({ ui, config }) => {
|
||||
const tabs = tabMap.filter((t) => !config.disabledTabs.includes(t));
|
||||
const idx = tabs.indexOf(ui.activeTab);
|
||||
return idx === -1 ? 0 : idx;
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -2,12 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import { initialImageChanged } from 'features/parameters/store/generationSlice';
|
||||
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { setActiveTabReducer } from './extraReducers';
|
||||
import { InvokeTabName } from './tabMap';
|
||||
import { UIState } from './uiTypes';
|
||||
|
||||
export const initialUIState: UIState = {
|
||||
activeTab: 0,
|
||||
activeTab: 'txt2img',
|
||||
shouldShowImageDetails: false,
|
||||
shouldUseCanvasBetaLayout: false,
|
||||
shouldShowExistingModelsInSearch: false,
|
||||
@ -26,7 +25,7 @@ export const uiSlice = createSlice({
|
||||
initialState: initialUIState,
|
||||
reducers: {
|
||||
setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
|
||||
setActiveTabReducer(state, action.payload);
|
||||
state.activeTab = action.payload;
|
||||
},
|
||||
setShouldShowImageDetails: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldShowImageDetails = action.payload;
|
||||
@ -73,7 +72,7 @@ export const uiSlice = createSlice({
|
||||
},
|
||||
extraReducers(builder) {
|
||||
builder.addCase(initialImageChanged, (state) => {
|
||||
setActiveTabReducer(state, 'img2img');
|
||||
state.activeTab = 'img2img';
|
||||
});
|
||||
},
|
||||
});
|
||||
|
@ -1,4 +1,5 @@
|
||||
import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
|
||||
import { InvokeTabName } from './tabMap';
|
||||
|
||||
export type Coordinates = {
|
||||
x: number;
|
||||
@ -13,7 +14,7 @@ export type Dimensions = {
|
||||
export type Rect = Coordinates & Dimensions;
|
||||
|
||||
export interface UIState {
|
||||
activeTab: number;
|
||||
activeTab: InvokeTabName;
|
||||
shouldShowImageDetails: boolean;
|
||||
shouldUseCanvasBetaLayout: boolean;
|
||||
shouldShowExistingModelsInSearch: boolean;
|
||||
|
@ -1,184 +0,0 @@
|
||||
import { createAsyncThunk, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { $queueId } from 'features/queue/store/queueNanoStore';
|
||||
import { isObject } from 'lodash-es';
|
||||
import { $client } from 'services/api/client';
|
||||
import { paths } from 'services/api/schema';
|
||||
import { O } from 'ts-toolbelt';
|
||||
|
||||
type CreateSessionArg = {
|
||||
graph: NonNullable<
|
||||
paths['/api/v1/sessions/']['post']['requestBody']
|
||||
>['content']['application/json'];
|
||||
};
|
||||
|
||||
type CreateSessionResponse = O.Required<
|
||||
NonNullable<
|
||||
paths['/api/v1/sessions/']['post']['requestBody']
|
||||
>['content']['application/json'],
|
||||
'id'
|
||||
>;
|
||||
|
||||
type CreateSessionThunkConfig = {
|
||||
rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
|
||||
};
|
||||
|
||||
/**
|
||||
* `SessionsService.createSession()` thunk
|
||||
*/
|
||||
export const sessionCreated = createAsyncThunk<
|
||||
CreateSessionResponse,
|
||||
CreateSessionArg,
|
||||
CreateSessionThunkConfig
|
||||
>('api/sessionCreated', async (arg, { rejectWithValue }) => {
|
||||
const { graph } = arg;
|
||||
const { POST } = $client.get();
|
||||
const { data, error, response } = await POST('/api/v1/sessions/', {
|
||||
body: graph,
|
||||
params: { query: { queue_id: $queueId.get() } },
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
|
||||
return data;
|
||||
});
|
||||
|
||||
type InvokedSessionArg = {
|
||||
session_id: paths['/api/v1/sessions/{session_id}/invoke']['put']['parameters']['path']['session_id'];
|
||||
};
|
||||
|
||||
type InvokedSessionResponse =
|
||||
paths['/api/v1/sessions/{session_id}/invoke']['put']['responses']['200']['content']['application/json'];
|
||||
|
||||
type InvokedSessionThunkConfig = {
|
||||
rejectValue: {
|
||||
arg: InvokedSessionArg;
|
||||
error: unknown;
|
||||
status: number;
|
||||
};
|
||||
};
|
||||
|
||||
const isErrorWithStatus = (error: unknown): error is { status: number } =>
|
||||
isObject(error) && 'status' in error;
|
||||
|
||||
const isErrorWithDetail = (error: unknown): error is { detail: string } =>
|
||||
isObject(error) && 'detail' in error;
|
||||
|
||||
/**
|
||||
* `SessionsService.invokeSession()` thunk
|
||||
*/
|
||||
export const sessionInvoked = createAsyncThunk<
|
||||
InvokedSessionResponse,
|
||||
InvokedSessionArg,
|
||||
InvokedSessionThunkConfig
|
||||
>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
|
||||
const { session_id } = arg;
|
||||
const { PUT } = $client.get();
|
||||
const { error, response } = await PUT(
|
||||
'/api/v1/sessions/{session_id}/invoke',
|
||||
{
|
||||
params: {
|
||||
query: { queue_id: $queueId.get(), all: true },
|
||||
path: { session_id },
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (error) {
|
||||
if (isErrorWithStatus(error) && error.status === 403) {
|
||||
return rejectWithValue({
|
||||
arg,
|
||||
status: response.status,
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
error: (error as any).body.detail,
|
||||
});
|
||||
}
|
||||
if (isErrorWithDetail(error) && response.status === 403) {
|
||||
return rejectWithValue({
|
||||
arg,
|
||||
status: response.status,
|
||||
error: error.detail,
|
||||
});
|
||||
}
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
type CancelSessionArg =
|
||||
paths['/api/v1/sessions/{session_id}/invoke']['delete']['parameters']['path'];
|
||||
|
||||
type CancelSessionResponse =
|
||||
paths['/api/v1/sessions/{session_id}/invoke']['delete']['responses']['200']['content']['application/json'];
|
||||
|
||||
type CancelSessionThunkConfig = {
|
||||
rejectValue: {
|
||||
arg: CancelSessionArg;
|
||||
error: unknown;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* `SessionsService.cancelSession()` thunk
|
||||
*/
|
||||
export const sessionCanceled = createAsyncThunk<
|
||||
CancelSessionResponse,
|
||||
CancelSessionArg,
|
||||
CancelSessionThunkConfig
|
||||
>('api/sessionCanceled', async (arg, { rejectWithValue }) => {
|
||||
const { session_id } = arg;
|
||||
const { DELETE } = $client.get();
|
||||
const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', {
|
||||
params: {
|
||||
path: { session_id },
|
||||
},
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, error });
|
||||
}
|
||||
|
||||
return data;
|
||||
});
|
||||
|
||||
type ListSessionsArg = {
|
||||
params: paths['/api/v1/sessions/']['get']['parameters'];
|
||||
};
|
||||
|
||||
type ListSessionsResponse =
|
||||
paths['/api/v1/sessions/']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
type ListSessionsThunkConfig = {
|
||||
rejectValue: {
|
||||
arg: ListSessionsArg;
|
||||
error: unknown;
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* `SessionsService.listSessions()` thunk
|
||||
*/
|
||||
export const listedSessions = createAsyncThunk<
|
||||
ListSessionsResponse,
|
||||
ListSessionsArg,
|
||||
ListSessionsThunkConfig
|
||||
>('api/listSessions', async (arg, { rejectWithValue }) => {
|
||||
const { params } = arg;
|
||||
const { GET } = $client.get();
|
||||
const { data, error } = await GET('/api/v1/sessions/', {
|
||||
params,
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, error });
|
||||
}
|
||||
|
||||
return data;
|
||||
});
|
||||
|
||||
export const isAnySessionRejected = isAnyOf(
|
||||
sessionCreated.rejected,
|
||||
sessionInvoked.rejected
|
||||
);
|
@ -1,5 +1,4 @@
|
||||
import { ThemeOverride } from '@chakra-ui/react';
|
||||
|
||||
import { ThemeOverride, ToastProviderProps } from '@chakra-ui/react';
|
||||
import { InvokeAIColors } from './colors/colors';
|
||||
import { accordionTheme } from './components/accordion';
|
||||
import { buttonTheme } from './components/button';
|
||||
@ -149,3 +148,7 @@ export const theme: ThemeOverride = {
|
||||
Tooltip: tooltipTheme,
|
||||
},
|
||||
};
|
||||
|
||||
export const TOAST_OPTIONS: ToastProviderProps = {
|
||||
defaultOptions: { isClosable: true },
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user