Merge branch 'invoke-ai:main' into main

This commit is contained in:
Millun Atluri 2023-09-27 10:10:52 +10:00 committed by GitHub
commit f35dfa06bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 462 additions and 734 deletions

View File

@ -332,6 +332,7 @@ class InvokeAiInstance:
Configure the InvokeAI runtime directory Configure the InvokeAI runtime directory
""" """
auto_install = False
# set sys.argv to a consistent state # set sys.argv to a consistent state
new_argv = [sys.argv[0]] new_argv = [sys.argv[0]]
for i in range(1, len(sys.argv)): for i in range(1, len(sys.argv)):
@ -340,13 +341,17 @@ class InvokeAiInstance:
new_argv.append(el) new_argv.append(el)
new_argv.append(sys.argv[i + 1]) new_argv.append(sys.argv[i + 1])
elif el in ["-y", "--yes", "--yes-to-all"]: elif el in ["-y", "--yes", "--yes-to-all"]:
new_argv.append(el) auto_install = True
sys.argv = new_argv sys.argv = new_argv
import messages
import requests # to catch download exceptions import requests # to catch download exceptions
from messages import introduction
introduction() auto_install = auto_install or messages.user_wants_auto_configuration()
if auto_install:
sys.argv.append("--yes")
else:
messages.introduction()
from invokeai.frontend.install.invokeai_configure import invokeai_configure from invokeai.frontend.install.invokeai_configure import invokeai_configure

View File

@ -7,7 +7,7 @@ import os
import platform import platform
from pathlib import Path from pathlib import Path
from prompt_toolkit import prompt from prompt_toolkit import HTML, prompt
from prompt_toolkit.completion import PathCompleter from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.validation import Validator from prompt_toolkit.validation import Validator
from rich import box, print from rich import box, print
@ -65,17 +65,50 @@ def confirm_install(dest: Path) -> bool:
if dest.exists(): if dest.exists():
print(f":exclamation: Directory {dest} already exists :exclamation:") print(f":exclamation: Directory {dest} already exists :exclamation:")
dest_confirmed = Confirm.ask( dest_confirmed = Confirm.ask(
":stop_sign: Are you sure you want to (re)install in this location?", ":stop_sign: (re)install in this location?",
default=False, default=False,
) )
else: else:
print(f"InvokeAI will be installed in {dest}") print(f"InvokeAI will be installed in {dest}")
dest_confirmed = not Confirm.ask("Would you like to pick a different location?", default=False) dest_confirmed = Confirm.ask("Use this location?", default=True)
console.line() console.line()
return dest_confirmed return dest_confirmed
def user_wants_auto_configuration() -> bool:
"""Prompt the user to choose between manual and auto configuration."""
console.rule("InvokeAI Configuration Section")
console.print(
Panel(
Group(
"\n".join(
[
"Libraries are installed and InvokeAI will now set up its root directory and configuration. Choose between:",
"",
" * AUTOMATIC configuration: install reasonable defaults and a minimal set of starter models.",
" * MANUAL configuration: manually inspect and adjust configuration options and pick from a larger set of starter models.",
"",
"Later you can fine tune your configuration by selecting option [6] 'Change InvokeAI startup options' from the invoke.bat/invoke.sh launcher script.",
]
),
),
box=box.MINIMAL,
padding=(1, 1),
)
)
choice = (
prompt(
HTML("Choose <b>&lt;a&gt;</b>utomatic or <b>&lt;m&gt;</b>anual configuration [a/m] (a): "),
validator=Validator.from_callable(
lambda n: n == "" or n.startswith(("a", "A", "m", "M")), error_message="Please select 'a' or 'm'"
),
)
or "a"
)
return choice.lower().startswith("a")
def dest_path(dest=None) -> Path: def dest_path(dest=None) -> Path:
""" """
Prompt the user for the destination path and create the path Prompt the user for the destination path and create the path

View File

@ -241,7 +241,7 @@ class InvokeAIAppConfig(InvokeAISettings):
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other") version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
# CACHE # CACHE
ram : Union[float, Literal["auto"]] = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", ) ram : Union[float, Literal["auto"]] = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number or 'auto')", category="Model Cache", )
vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", ) vram : Union[float, Literal["auto"]] = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number or 'auto')", category="Model Cache", )
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", ) lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", category="Model Cache", )

View File

@ -1,7 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from threading import Lock from threading import Lock
from time import time
from typing import Optional, Union from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
@ -59,7 +58,7 @@ class MemoryInvocationCache(InvocationCacheBase):
# If the cache is full, we need to remove the least used # If the cache is full, we need to remove the least used
number_to_delete = len(self._cache) + 1 - self._max_cache_size number_to_delete = len(self._cache) + 1 - self._max_cache_size
self._delete_oldest_access(number_to_delete) self._delete_oldest_access(number_to_delete)
self._cache[key] = CachedItem(time(), invocation_output, invocation_output.json()) self._cache[key] = CachedItem(invocation_output, invocation_output.json())
def _delete_oldest_access(self, number_to_delete: int) -> None: def _delete_oldest_access(self, number_to_delete: int) -> None:
number_to_delete = min(number_to_delete, len(self._cache)) number_to_delete = min(number_to_delete, len(self._cache))

View File

@ -70,7 +70,6 @@ def get_literal_fields(field) -> list[Any]:
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
Model_dir = "models" Model_dir = "models"
Default_config_file = config.model_conf_path Default_config_file = config.model_conf_path
SD_Configs = config.legacy_conf_path SD_Configs = config.legacy_conf_path
@ -458,7 +457,7 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
) )
self.add_widget_intelligent( self.add_widget_intelligent(
npyscreen.TitleFixedText, npyscreen.TitleFixedText,
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model.", name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
begin_entry_at=0, begin_entry_at=0,
editable=False, editable=False,
color="CONTROL", color="CONTROL",
@ -651,8 +650,19 @@ def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Nam
return editApp.new_opts() return editApp.new_opts()
def default_ramcache() -> float:
"""Run a heuristic for the default RAM cache based on installed RAM."""
# Note that on my 64 GB machine, psutil.virtual_memory().total gives 62 GB,
# So we adjust everthing down a bit.
return (
15.0 if MAX_RAM >= 60 else 7.5 if MAX_RAM >= 30 else 4 if MAX_RAM >= 14 else 2.1
) # 2.1 is just large enough for sd 1.5 ;-)
def default_startup_options(init_file: Path) -> Namespace: def default_startup_options(init_file: Path) -> Namespace:
opts = InvokeAIAppConfig.get_config() opts = InvokeAIAppConfig.get_config()
opts.ram = default_ramcache()
return opts return opts

View File

@ -58,6 +58,7 @@
"githubLabel": "Github", "githubLabel": "Github",
"hotkeysLabel": "Hotkeys", "hotkeysLabel": "Hotkeys",
"imagePrompt": "Image Prompt", "imagePrompt": "Image Prompt",
"imageFailedToLoad": "Unable to Load Image",
"img2img": "Image To Image", "img2img": "Image To Image",
"langArabic": "العربية", "langArabic": "العربية",
"langBrPortuguese": "Português do Brasil", "langBrPortuguese": "Português do Brasil",
@ -716,6 +717,7 @@
"cannotConnectInputToInput": "Cannot connect input to input", "cannotConnectInputToInput": "Cannot connect input to input",
"cannotConnectOutputToOutput": "Cannot connect output to output", "cannotConnectOutputToOutput": "Cannot connect output to output",
"cannotConnectToSelf": "Cannot connect to self", "cannotConnectToSelf": "Cannot connect to self",
"cannotDuplicateConnection": "Cannot create duplicate connections",
"clipField": "Clip", "clipField": "Clip",
"clipFieldDescription": "Tokenizer and text_encoder submodels.", "clipFieldDescription": "Tokenizer and text_encoder submodels.",
"collection": "Collection", "collection": "Collection",

View File

@ -5,7 +5,7 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { ReactNode, memo, useEffect, useMemo } from 'react'; import { ReactNode, memo, useEffect, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { theme as invokeAITheme } from 'theme/theme'; import { TOAST_OPTIONS, theme as invokeAITheme } from 'theme/theme';
import '@fontsource-variable/inter'; import '@fontsource-variable/inter';
import { MantineProvider } from '@mantine/core'; import { MantineProvider } from '@mantine/core';
@ -39,7 +39,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
return ( return (
<MantineProvider theme={mantineTheme}> <MantineProvider theme={mantineTheme}>
<ChakraProvider theme={theme} colorModeManager={manager}> <ChakraProvider
theme={theme}
colorModeManager={manager}
toastOptions={TOAST_OPTIONS}
>
{children} {children}
</ChakraProvider> </ChakraProvider>
</MantineProvider> </MantineProvider>

View File

@ -54,21 +54,6 @@ import { addModelSelectedListener } from './listeners/modelSelected';
import { addModelsLoadedListener } from './listeners/modelsLoaded'; import { addModelsLoadedListener } from './listeners/modelsLoaded';
import { addDynamicPromptsListener } from './listeners/promptChanged'; import { addDynamicPromptsListener } from './listeners/promptChanged';
import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema'; import { addReceivedOpenAPISchemaListener } from './listeners/receivedOpenAPISchema';
import {
addSessionCanceledFulfilledListener,
addSessionCanceledPendingListener,
addSessionCanceledRejectedListener,
} from './listeners/sessionCanceled';
import {
addSessionCreatedFulfilledListener,
addSessionCreatedPendingListener,
addSessionCreatedRejectedListener,
} from './listeners/sessionCreated';
import {
addSessionInvokedFulfilledListener,
addSessionInvokedPendingListener,
addSessionInvokedRejectedListener,
} from './listeners/sessionInvoked';
import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected'; import { addSocketConnectedEventListener as addSocketConnectedListener } from './listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected'; import { addSocketDisconnectedEventListener as addSocketDisconnectedListener } from './listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress'; import { addGeneratorProgressEventListener as addGeneratorProgressListener } from './listeners/socketio/socketGeneratorProgress';
@ -86,6 +71,7 @@ import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSa
import { addTabChangedListener } from './listeners/tabChanged'; import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested'; import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded'; import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -136,6 +122,7 @@ addEnqueueRequestedCanvasListener();
addEnqueueRequestedNodes(); addEnqueueRequestedNodes();
addEnqueueRequestedLinear(); addEnqueueRequestedLinear();
addAnyEnqueuedListener(); addAnyEnqueuedListener();
addBatchEnqueuedListener();
// Canvas actions // Canvas actions
addCanvasSavedToGalleryListener(); addCanvasSavedToGalleryListener();
@ -175,21 +162,6 @@ addSessionRetrievalErrorEventListener();
addInvocationRetrievalErrorEventListener(); addInvocationRetrievalErrorEventListener();
addSocketQueueItemStatusChangedEventListener(); addSocketQueueItemStatusChangedEventListener();
// Session Created
addSessionCreatedPendingListener();
addSessionCreatedFulfilledListener();
addSessionCreatedRejectedListener();
// Session Invoked
addSessionInvokedPendingListener();
addSessionInvokedFulfilledListener();
addSessionInvokedRejectedListener();
// Session Canceled
addSessionCanceledPendingListener();
addSessionCanceledFulfilledListener();
addSessionCanceledRejectedListener();
// ControlNet // ControlNet
addControlNetImageProcessedListener(); addControlNetImageProcessedListener();
addControlNetAutoProcessListener(); addControlNetAutoProcessListener();

View File

@ -0,0 +1,96 @@
import { createStandaloneToast } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { t } from 'i18next';
import { get, truncate, upperFirst } from 'lodash-es';
import { queueApi } from 'services/api/endpoints/queue';
import { TOAST_OPTIONS, theme } from 'theme/theme';
import { startAppListening } from '..';
const { toast } = createStandaloneToast({
theme: theme,
defaultOptions: TOAST_OPTIONS.defaultOptions,
});
export const addBatchEnqueuedListener = () => {
// success
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchFulfilled,
effect: async (action) => {
const response = action.payload;
const arg = action.meta.arg.originalArgs;
logger('queue').debug(
{ enqueueResult: parseify(response) },
'Batch enqueued'
);
if (!toast.isActive('batch-queued')) {
toast({
id: 'batch-queued',
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: response.enqueued,
direction: arg.prepend ? t('queue.front') : t('queue.back'),
}),
duration: 1000,
status: 'success',
});
}
},
});
// error
startAppListening({
matcher: queueApi.endpoints.enqueueBatch.matchRejected,
effect: async (action) => {
const response = action.payload;
const arg = action.meta.arg.originalArgs;
if (!response) {
toast({
title: t('queue.batchFailedToQueue'),
status: 'error',
description: 'Unknown Error',
});
logger('queue').error(
{ batchConfig: parseify(arg), error: parseify(response) },
t('queue.batchFailedToQueue')
);
return;
}
const result = zPydanticValidationError.safeParse(response);
if (result.success) {
result.data.data.detail.map((e) => {
toast({
id: 'batch-failed-to-queue',
title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error',
description: truncate(
`Path:
${e.loc.join('.')}`,
{ length: 128 }
),
});
});
} else {
let detail = 'Unknown Error';
if (response.status === 403 && 'body' in response) {
detail = get(response, 'body.detail', 'Unknown Error');
} else if (response.status === 403 && 'error' in response) {
detail = get(response, 'error.detail', 'Unknown Error');
}
toast({
title: t('queue.batchFailedToQueue'),
status: 'error',
description: detail,
});
}
logger('queue').error(
{ batchConfig: parseify(arg), error: parseify(response) },
t('queue.batchFailedToQueue')
);
},
});
};

View File

@ -12,8 +12,6 @@ import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGeneratio
import { canvasGraphBuilt } from 'features/nodes/store/actions'; import { canvasGraphBuilt } from 'features/nodes/store/actions';
import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph';
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
@ -140,8 +138,6 @@ export const addEnqueueRequestedCanvasListener = () => {
const enqueueResult = await req.unwrap(); const enqueueResult = await req.unwrap();
req.reset(); req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it const batchId = enqueueResult.batch.batch_id as string; // we know the is a string, backend provides it
// Prep the canvas staging area if it is not yet initialized // Prep the canvas staging area if it is not yet initialized
@ -158,28 +154,8 @@ export const addEnqueueRequestedCanvasListener = () => {
// Associate the session with the canvas session ID // Associate the session with the canvas session ID
dispatch(canvasBatchIdAdded(batchId)); dispatch(canvasBatchIdAdded(batchId));
dispatch(
addToast({
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: enqueueResult.enqueued,
direction: prepend ? t('queue.front') : t('queue.back'),
}),
status: 'success',
})
);
} catch { } catch {
log.error( // no-op
{ batchConfig: parseify(batchConfig) },
t('queue.batchFailedToQueue')
);
dispatch(
addToast({
title: t('queue.batchFailedToQueue'),
status: 'error',
})
);
} }
}, },
}); });

View File

@ -1,13 +1,9 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions'; import { enqueueRequested } from 'app/store/actions';
import { parseify } from 'common/util/serialize';
import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig'; import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig';
import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph'; import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph';
import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph'; import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph';
import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph'; import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph';
import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph'; import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -18,7 +14,6 @@ export const addEnqueueRequestedLinear = () => {
(action.payload.tabName === 'txt2img' || (action.payload.tabName === 'txt2img' ||
action.payload.tabName === 'img2img'), action.payload.tabName === 'img2img'),
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const log = logger('queue');
const state = getState(); const state = getState();
const model = state.generation.model; const model = state.generation.model;
const { prepend } = action.payload; const { prepend } = action.payload;
@ -41,38 +36,12 @@ export const addEnqueueRequestedLinear = () => {
const batchConfig = prepareLinearUIBatch(state, graph, prepend); const batchConfig = prepareLinearUIBatch(state, graph, prepend);
try { const req = dispatch(
const req = dispatch( queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { fixedCacheKey: 'enqueueBatch',
fixedCacheKey: 'enqueueBatch', })
}) );
); req.reset();
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
dispatch(
addToast({
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: enqueueResult.enqueued,
direction: prepend ? t('queue.front') : t('queue.back'),
}),
status: 'success',
})
);
} catch {
log.error(
{ batchConfig: parseify(batchConfig) },
t('queue.batchFailedToQueue')
);
dispatch(
addToast({
title: t('queue.batchFailedToQueue'),
status: 'error',
})
);
}
}, },
}); });
}; };

View File

@ -1,9 +1,5 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions'; import { enqueueRequested } from 'app/store/actions';
import { parseify } from 'common/util/serialize';
import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph'; import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue'; import { queueApi } from 'services/api/endpoints/queue';
import { BatchConfig } from 'services/api/types'; import { BatchConfig } from 'services/api/types';
import { startAppListening } from '..'; import { startAppListening } from '..';
@ -13,9 +9,7 @@ export const addEnqueueRequestedNodes = () => {
predicate: (action): action is ReturnType<typeof enqueueRequested> => predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'nodes', enqueueRequested.match(action) && action.payload.tabName === 'nodes',
effect: async (action, { getState, dispatch }) => { effect: async (action, { getState, dispatch }) => {
const log = logger('queue');
const state = getState(); const state = getState();
const { prepend } = action.payload;
const graph = buildNodesGraph(state.nodes); const graph = buildNodesGraph(state.nodes);
const batchConfig: BatchConfig = { const batchConfig: BatchConfig = {
batch: { batch: {
@ -25,38 +19,12 @@ export const addEnqueueRequestedNodes = () => {
prepend: action.payload.prepend, prepend: action.payload.prepend,
}; };
try { const req = dispatch(
const req = dispatch( queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { fixedCacheKey: 'enqueueBatch',
fixedCacheKey: 'enqueueBatch', })
}) );
); req.reset();
const enqueueResult = await req.unwrap();
req.reset();
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
dispatch(
addToast({
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: enqueueResult.enqueued,
direction: prepend ? t('queue.front') : t('queue.back'),
}),
status: 'success',
})
);
} catch {
log.error(
{ batchConfig: parseify(batchConfig) },
'Failed to enqueue batch'
);
dispatch(
addToast({
title: t('queue.batchFailedToQueue'),
status: 'error',
})
);
}
}, },
}); });
}; };

View File

@ -1,44 +0,0 @@
import { logger } from 'app/logging/logger';
import { serializeError } from 'serialize-error';
import { sessionCanceled } from 'services/api/thunks/session';
import { startAppListening } from '..';
export const addSessionCanceledPendingListener = () => {
startAppListening({
actionCreator: sessionCanceled.pending,
effect: () => {
//
},
});
};
export const addSessionCanceledFulfilledListener = () => {
startAppListening({
actionCreator: sessionCanceled.fulfilled,
effect: (action) => {
const log = logger('session');
const { session_id } = action.meta.arg;
log.debug({ session_id }, `Session canceled (${session_id})`);
},
});
};
export const addSessionCanceledRejectedListener = () => {
startAppListening({
actionCreator: sessionCanceled.rejected,
effect: (action) => {
const log = logger('session');
const { session_id } = action.meta.arg;
if (action.payload) {
const { error } = action.payload;
log.error(
{
session_id,
error: serializeError(error),
},
`Problem canceling session`
);
}
},
});
};

View File

@ -1,45 +0,0 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { serializeError } from 'serialize-error';
import { sessionCreated } from 'services/api/thunks/session';
import { startAppListening } from '..';
export const addSessionCreatedPendingListener = () => {
startAppListening({
actionCreator: sessionCreated.pending,
effect: () => {
//
},
});
};
export const addSessionCreatedFulfilledListener = () => {
startAppListening({
actionCreator: sessionCreated.fulfilled,
effect: (action) => {
const log = logger('session');
const session = action.payload;
log.debug(
{ session: parseify(session) },
`Session created (${session.id})`
);
},
});
};
export const addSessionCreatedRejectedListener = () => {
startAppListening({
actionCreator: sessionCreated.rejected,
effect: (action) => {
const log = logger('session');
if (action.payload) {
const { error, status } = action.payload;
const graph = parseify(action.meta.arg);
log.error(
{ graph, status, error: serializeError(error) },
`Problem creating session`
);
}
},
});
};

View File

@ -1,44 +0,0 @@
import { logger } from 'app/logging/logger';
import { serializeError } from 'serialize-error';
import { sessionInvoked } from 'services/api/thunks/session';
import { startAppListening } from '..';
export const addSessionInvokedPendingListener = () => {
startAppListening({
actionCreator: sessionInvoked.pending,
effect: () => {
//
},
});
};
export const addSessionInvokedFulfilledListener = () => {
startAppListening({
actionCreator: sessionInvoked.fulfilled,
effect: (action) => {
const log = logger('session');
const { session_id } = action.meta.arg;
log.debug({ session_id }, `Session invoked (${session_id})`);
},
});
};
export const addSessionInvokedRejectedListener = () => {
startAppListening({
actionCreator: sessionInvoked.rejected,
effect: (action) => {
const log = logger('session');
const { session_id } = action.meta.arg;
if (action.payload) {
const { error } = action.payload;
log.error(
{
session_id,
error: serializeError(error),
},
`Problem invoking session`
);
}
},
});
};

View File

@ -35,6 +35,7 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
queueApi.util.invalidateTags([ queueApi.util.invalidateTags([
'CurrentSessionQueueItem', 'CurrentSessionQueueItem',
'NextSessionQueueItem', 'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id }, { type: 'SessionQueueItem', id: item_id },
{ type: 'SessionQueueItemDTO', id: item_id }, { type: 'SessionQueueItemDTO', id: item_id },
{ type: 'BatchStatus', id: queue_batch_id }, { type: 'BatchStatus', id: queue_batch_id },

View File

@ -1,54 +0,0 @@
import { logger } from 'app/logging/logger';
import { AppThunkDispatch } from 'app/store/store';
import { parseify } from 'common/util/serialize';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import { BatchConfig } from 'services/api/types';
export const enqueueBatch = async (
batchConfig: BatchConfig,
dispatch: AppThunkDispatch
) => {
const log = logger('session');
const { prepend } = batchConfig;
try {
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'resumeProcessor',
})
);
log.debug({ enqueueResult: parseify(enqueueResult) }, 'Batch enqueued');
dispatch(
addToast({
title: t('queue.batchQueued'),
description: t('queue.batchQueuedDesc', {
item_count: enqueueResult.enqueued,
direction: prepend ? t('queue.front') : t('queue.back'),
}),
status: 'success',
})
);
} catch {
log.error(
{ batchConfig: parseify(batchConfig) },
t('queue.batchFailedToQueue')
);
dispatch(
addToast({
title: t('queue.batchFailedToQueue'),
status: 'error',
})
);
}
};

View File

@ -1,18 +1,9 @@
import { chakra, ChakraProps } from '@chakra-ui/react'; import { Box, ChakraProps } from '@chakra-ui/react';
import { memo } from 'react'; import { memo } from 'react';
import { RgbaColorPicker } from 'react-colorful'; import { RgbaColorPicker } from 'react-colorful';
import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types'; import { ColorPickerBaseProps, RgbaColor } from 'react-colorful/dist/types';
type IAIColorPickerProps = Omit<ColorPickerBaseProps<RgbaColor>, 'color'> & type IAIColorPickerProps = ColorPickerBaseProps<RgbaColor>;
ChakraProps & {
pickerColor: RgbaColor;
styleClass?: string;
};
const ChakraRgbaColorPicker = chakra(RgbaColorPicker, {
baseStyle: { paddingInline: 4 },
shouldForwardProp: (prop) => !['pickerColor'].includes(prop),
});
const colorPickerStyles: NonNullable<ChakraProps['sx']> = { const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
width: 6, width: 6,
@ -20,19 +11,17 @@ const colorPickerStyles: NonNullable<ChakraProps['sx']> = {
borderColor: 'base.100', borderColor: 'base.100',
}; };
const IAIColorPicker = (props: IAIColorPickerProps) => { const sx = {
const { styleClass = '', ...rest } = props; '.react-colorful__hue-pointer': colorPickerStyles,
'.react-colorful__saturation-pointer': colorPickerStyles,
'.react-colorful__alpha-pointer': colorPickerStyles,
};
const IAIColorPicker = (props: IAIColorPickerProps) => {
return ( return (
<ChakraRgbaColorPicker <Box sx={sx}>
sx={{ <RgbaColorPicker {...props} />
'.react-colorful__hue-pointer': colorPickerStyles, </Box>
'.react-colorful__saturation-pointer': colorPickerStyles,
'.react-colorful__alpha-pointer': colorPickerStyles,
}}
className={styleClass}
{...rest}
/>
); );
}; };

View File

@ -1,26 +1,27 @@
import { skipToken } from '@reduxjs/toolkit/dist/query'; import { skipToken } from '@reduxjs/toolkit/dist/query';
import { Image, Rect } from 'react-konva'; import { memo } from 'react';
import { Image } from 'react-konva';
import { $authToken } from 'services/api/client';
import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import useImage from 'use-image'; import useImage from 'use-image';
import { CanvasImage } from '../store/canvasTypes'; import { CanvasImage } from '../store/canvasTypes';
import { $authToken } from 'services/api/client'; import IAICanvasImageErrorFallback from './IAICanvasImageErrorFallback';
import { memo } from 'react';
type IAICanvasImageProps = { type IAICanvasImageProps = {
canvasImage: CanvasImage; canvasImage: CanvasImage;
}; };
const IAICanvasImage = (props: IAICanvasImageProps) => { const IAICanvasImage = (props: IAICanvasImageProps) => {
const { width, height, x, y, imageName } = props.canvasImage; const { x, y, imageName } = props.canvasImage;
const { currentData: imageDTO, isError } = useGetImageDTOQuery( const { currentData: imageDTO, isError } = useGetImageDTOQuery(
imageName ?? skipToken imageName ?? skipToken
); );
const [image] = useImage( const [image, status] = useImage(
imageDTO?.image_url ?? '', imageDTO?.image_url ?? '',
$authToken.get() ? 'use-credentials' : 'anonymous' $authToken.get() ? 'use-credentials' : 'anonymous'
); );
if (isError) { if (isError || status === 'failed') {
return <Rect x={x} y={y} width={width} height={height} fill="red" />; return <IAICanvasImageErrorFallback canvasImage={props.canvasImage} />;
} }
return <Image x={x} y={y} image={image} listening={false} />; return <Image x={x} y={y} image={image} listening={false} />;

View File

@ -0,0 +1,44 @@
import { useColorModeValue, useToken } from '@chakra-ui/react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { Group, Rect, Text } from 'react-konva';
import { CanvasImage } from '../store/canvasTypes';
type IAICanvasImageErrorFallbackProps = {
canvasImage: CanvasImage;
};
const IAICanvasImageErrorFallback = ({
canvasImage,
}: IAICanvasImageErrorFallbackProps) => {
const [errorColorLight, errorColorDark, fontColorLight, fontColorDark] =
useToken('colors', ['gray.400', 'gray.500', 'base.700', 'base.900']);
const errorColor = useColorModeValue(errorColorLight, errorColorDark);
const fontColor = useColorModeValue(fontColorLight, fontColorDark);
const { t } = useTranslation();
return (
<Group>
<Rect
x={canvasImage.x}
y={canvasImage.y}
width={canvasImage.width}
height={canvasImage.height}
fill={errorColor}
/>
<Text
x={canvasImage.x}
y={canvasImage.y}
width={canvasImage.width}
height={canvasImage.height}
align="center"
verticalAlign="middle"
fontFamily='"Inter Variable", sans-serif'
fontSize={canvasImage.width / 16}
fontStyle="600"
text={t('common.imageFailedToLoad')}
fill={fontColor}
/>
</Group>
);
};
export default memo(IAICanvasImageErrorFallback);

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex } from '@chakra-ui/react'; import { Box, ButtonGroup, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton'; import IAIButton from 'common/components/IAIButton';
@ -135,11 +135,12 @@ const IAICanvasMaskOptions = () => {
dispatch(setShouldPreserveMaskedArea(e.target.checked)) dispatch(setShouldPreserveMaskedArea(e.target.checked))
} }
/> />
<IAIColorPicker <Box sx={{ paddingTop: 2, paddingBottom: 2 }}>
sx={{ paddingTop: 2, paddingBottom: 2 }} <IAIColorPicker
pickerColor={maskColor} color={maskColor}
onChange={(newColor) => dispatch(setMaskColor(newColor))} onChange={(newColor) => dispatch(setMaskColor(newColor))}
/> />
</Box>
<IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}> <IAIButton size="sm" leftIcon={<FaSave />} onClick={handleSaveMask}>
Save Mask Save Mask
</IAIButton> </IAIButton>

View File

@ -1,4 +1,4 @@
import { ButtonGroup, Flex } from '@chakra-ui/react'; import { ButtonGroup, Flex, Box } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@ -237,15 +237,18 @@ const IAICanvasToolChooserOptions = () => {
sliderNumberInputProps={{ max: 500 }} sliderNumberInputProps={{ max: 500 }}
/> />
</Flex> </Flex>
<IAIColorPicker <Box
sx={{ sx={{
width: '100%', width: '100%',
paddingTop: 2, paddingTop: 2,
paddingBottom: 2, paddingBottom: 2,
}} }}
pickerColor={brushColor} >
onChange={(newColor) => dispatch(setBrushColor(newColor))} <IAIColorPicker
/> color={brushColor}
onChange={(newColor) => dispatch(setBrushColor(newColor))}
/>
</Box>
</Flex> </Flex>
</IAIPopover> </IAIPopover>
</ButtonGroup> </ButtonGroup>

View File

@ -8,7 +8,6 @@ import { setAspectRatio } from 'features/parameters/store/generationSlice';
import { IRect, Vector2d } from 'konva/lib/types'; import { IRect, Vector2d } from 'konva/lib/types';
import { clamp, cloneDeep } from 'lodash-es'; import { clamp, cloneDeep } from 'lodash-es';
import { RgbaColor } from 'react-colorful'; import { RgbaColor } from 'react-colorful';
import { sessionCanceled } from 'services/api/thunks/session';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import calculateCoordinates from '../util/calculateCoordinates'; import calculateCoordinates from '../util/calculateCoordinates';
import calculateScale from '../util/calculateScale'; import calculateScale from '../util/calculateScale';
@ -786,11 +785,6 @@ export const canvasSlice = createSlice({
}, },
}, },
extraReducers: (builder) => { extraReducers: (builder) => {
builder.addCase(sessionCanceled.pending, (state) => {
if (!state.layerState.stagingArea.images.length) {
state.layerState.stagingArea = initialLayerState.stagingArea;
}
});
builder.addCase(setAspectRatio, (state, action) => { builder.addCase(setAspectRatio, (state, action) => {
const ratio = action.payload; const ratio = action.payload;
if (ratio) { if (ratio) {

View File

@ -6,7 +6,6 @@ import {
import { cloneDeep, forEach } from 'lodash-es'; import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images'; import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema'; import { components } from 'services/api/schema';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { ImageDTO } from 'services/api/types'; import { ImageDTO } from 'services/api/types';
import { appSocketInvocationError } from 'services/events/actions'; import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions'; import { controlNetImageProcessed } from './actions';
@ -418,10 +417,6 @@ export const controlNetSlice = createSlice({
state.pendingControlImages = []; state.pendingControlImages = [];
}); });
builder.addMatcher(isAnySessionRejected, (state) => {
state.pendingControlImages = [];
});
builder.addMatcher( builder.addMatcher(
imagesApi.endpoints.deleteImage.matchFulfilled, imagesApi.endpoints.deleteImage.matchFulfilled,
(state, action) => { (state, action) => {

View File

@ -12,6 +12,7 @@ import {
OnConnect, OnConnect,
OnConnectEnd, OnConnectEnd,
OnConnectStart, OnConnectStart,
OnEdgeUpdateFunc,
OnEdgesChange, OnEdgesChange,
OnEdgesDelete, OnEdgesDelete,
OnInit, OnInit,
@ -21,6 +22,7 @@ import {
OnSelectionChangeFunc, OnSelectionChangeFunc,
ProOptions, ProOptions,
ReactFlow, ReactFlow,
ReactFlowProps,
XYPosition, XYPosition,
} from 'reactflow'; } from 'reactflow';
import { useIsValidConnection } from '../../hooks/useIsValidConnection'; import { useIsValidConnection } from '../../hooks/useIsValidConnection';
@ -28,6 +30,8 @@ import {
connectionEnded, connectionEnded,
connectionMade, connectionMade,
connectionStarted, connectionStarted,
edgeAdded,
edgeDeleted,
edgesChanged, edgesChanged,
edgesDeleted, edgesDeleted,
nodesChanged, nodesChanged,
@ -167,6 +171,63 @@ export const Flow = () => {
} }
}, []); }, []);
// #region Updatable Edges
/**
* Adapted from https://reactflow.dev/docs/examples/edges/updatable-edge/
* and https://reactflow.dev/docs/examples/edges/delete-edge-on-drop/
*
* - Edges can be dragged from one handle to another.
* - If the user drags the edge away from the node and drops it, delete the edge.
* - Do not delete the edge if the cursor didn't move (resolves annoying behaviour
* where the edge is deleted if you click it accidentally).
*/
// We have a ref for cursor position, but it is the *projected* cursor position.
// Easiest to just keep track of the last mouse event for this particular feature
const edgeUpdateMouseEvent = useRef<MouseEvent>();
const onEdgeUpdateStart: NonNullable<ReactFlowProps['onEdgeUpdateStart']> =
useCallback(
(e, edge, _handleType) => {
// update mouse event
edgeUpdateMouseEvent.current = e;
// always delete the edge when starting an updated
dispatch(edgeDeleted(edge.id));
},
[dispatch]
);
const onEdgeUpdate: OnEdgeUpdateFunc = useCallback(
(_oldEdge, newConnection) => {
// instead of updating the edge (we deleted it earlier), we instead create
// a new one.
dispatch(connectionMade(newConnection));
},
[dispatch]
);
const onEdgeUpdateEnd: NonNullable<ReactFlowProps['onEdgeUpdateEnd']> =
useCallback(
(e, edge, _handleType) => {
// Handle the case where user begins a drag but didn't move the cursor -
// bc we deleted the edge, we need to add it back
if (
// ignore touch events
!('touches' in e) &&
edgeUpdateMouseEvent.current?.clientX === e.clientX &&
edgeUpdateMouseEvent.current?.clientY === e.clientY
) {
dispatch(edgeAdded(edge));
}
// reset mouse event
edgeUpdateMouseEvent.current = undefined;
},
[dispatch]
);
// #endregion
useHotkeys(['Ctrl+c', 'Meta+c'], (e) => { useHotkeys(['Ctrl+c', 'Meta+c'], (e) => {
e.preventDefault(); e.preventDefault();
dispatch(selectionCopied()); dispatch(selectionCopied());
@ -196,6 +257,9 @@ export const Flow = () => {
onNodesChange={onNodesChange} onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange} onEdgesChange={onEdgesChange}
onEdgesDelete={onEdgesDelete} onEdgesDelete={onEdgesDelete}
onEdgeUpdate={onEdgeUpdate}
onEdgeUpdateStart={onEdgeUpdateStart}
onEdgeUpdateEnd={onEdgeUpdateEnd}
onNodesDelete={onNodesDelete} onNodesDelete={onNodesDelete}
onConnectStart={onConnectStart} onConnectStart={onConnectStart}
onConnect={onConnect} onConnect={onConnect}

View File

@ -53,13 +53,12 @@ export const useIsValidConnection = () => {
} }
if ( if (
edges edges.find((edge) => {
.filter((edge) => { edge.target === target &&
return edge.target === target && edge.targetHandle === targetHandle; edge.targetHandle === targetHandle &&
}) edge.source === source &&
.find((edge) => { edge.sourceHandle === sourceHandle;
edge.source === source && edge.sourceHandle === sourceHandle; })
})
) { ) {
// We already have a connection from this source to this target // We already have a connection from this source to this target
return false; return false;

View File

@ -15,6 +15,7 @@ import {
NodeChange, NodeChange,
OnConnectStartParams, OnConnectStartParams,
SelectionMode, SelectionMode,
updateEdge,
Viewport, Viewport,
XYPosition, XYPosition,
} from 'reactflow'; } from 'reactflow';
@ -182,6 +183,16 @@ const nodesSlice = createSlice({
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => { edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges); state.edges = applyEdgeChanges(action.payload, state.edges);
}, },
edgeAdded: (state, action: PayloadAction<Edge>) => {
state.edges = addEdge(action.payload, state.edges);
},
edgeUpdated: (
state,
action: PayloadAction<{ oldEdge: Edge; newConnection: Connection }>
) => {
const { oldEdge, newConnection } = action.payload;
state.edges = updateEdge(oldEdge, newConnection, state.edges);
},
connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => { connectionStarted: (state, action: PayloadAction<OnConnectStartParams>) => {
state.connectionStartParams = action.payload; state.connectionStartParams = action.payload;
const { nodeId, handleId, handleType } = action.payload; const { nodeId, handleId, handleType } = action.payload;
@ -366,6 +377,7 @@ const nodesSlice = createSlice({
target: edge.target, target: edge.target,
type: 'collapsed', type: 'collapsed',
data: { count: 1 }, data: { count: 1 },
updatable: false,
}); });
} }
} }
@ -388,6 +400,7 @@ const nodesSlice = createSlice({
target: edge.target, target: edge.target,
type: 'collapsed', type: 'collapsed',
data: { count: 1 }, data: { count: 1 },
updatable: false,
}); });
} }
} }
@ -400,6 +413,9 @@ const nodesSlice = createSlice({
} }
} }
}, },
edgeDeleted: (state, action: PayloadAction<string>) => {
state.edges = state.edges.filter((e) => e.id !== action.payload);
},
edgesDeleted: (state, action: PayloadAction<Edge[]>) => { edgesDeleted: (state, action: PayloadAction<Edge[]>) => {
const edges = action.payload; const edges = action.payload;
const collapsedEdges = edges.filter((e) => e.type === 'collapsed'); const collapsedEdges = edges.filter((e) => e.type === 'collapsed');
@ -890,69 +906,72 @@ const nodesSlice = createSlice({
}); });
export const { export const {
nodesChanged, addNodePopoverClosed,
edgesChanged, addNodePopoverOpened,
nodeAdded, addNodePopoverToggled,
nodesDeleted, connectionEnded,
connectionMade, connectionMade,
connectionStarted, connectionStarted,
connectionEnded, edgeDeleted,
shouldShowFieldTypeLegendChanged, edgesChanged,
shouldShowMinimapPanelChanged, edgesDeleted,
nodeTemplatesBuilt, edgeUpdated,
nodeEditorReset,
imageCollectionFieldValueChanged,
fieldStringValueChanged,
fieldNumberValueChanged,
fieldBoardValueChanged, fieldBoardValueChanged,
fieldBooleanValueChanged, fieldBooleanValueChanged,
fieldImageValueChanged,
fieldColorValueChanged, fieldColorValueChanged,
fieldMainModelValueChanged,
fieldVaeModelValueChanged,
fieldLoRAModelValueChanged,
fieldEnumModelValueChanged,
fieldControlNetModelValueChanged, fieldControlNetModelValueChanged,
fieldEnumModelValueChanged,
fieldImageValueChanged,
fieldIPAdapterModelValueChanged, fieldIPAdapterModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldRefinerModelValueChanged, fieldRefinerModelValueChanged,
fieldSchedulerValueChanged, fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldVaeModelValueChanged,
imageCollectionFieldValueChanged,
mouseOverFieldChanged,
mouseOverNodeChanged,
nodeAdded,
nodeEditorReset,
nodeEmbedWorkflowChanged,
nodeExclusivelySelected,
nodeIsIntermediateChanged,
nodeIsOpenChanged, nodeIsOpenChanged,
nodeLabelChanged, nodeLabelChanged,
nodeNotesChanged, nodeNotesChanged,
edgesDeleted,
shouldValidateGraphChanged,
shouldAnimateEdgesChanged,
nodeOpacityChanged, nodeOpacityChanged,
shouldSnapToGridChanged, nodesChanged,
shouldColorEdgesChanged, nodesDeleted,
selectedNodesChanged, nodeTemplatesBuilt,
selectedEdgesChanged, nodeUseCacheChanged,
workflowNameChanged,
workflowDescriptionChanged,
workflowTagsChanged,
workflowAuthorChanged,
workflowNotesChanged,
workflowVersionChanged,
workflowContactChanged,
workflowLoaded,
notesNodeValueChanged, notesNodeValueChanged,
selectedAll,
selectedEdgesChanged,
selectedNodesChanged,
selectionCopied,
selectionModeChanged,
selectionPasted,
shouldAnimateEdgesChanged,
shouldColorEdgesChanged,
shouldShowFieldTypeLegendChanged,
shouldShowMinimapPanelChanged,
shouldSnapToGridChanged,
shouldValidateGraphChanged,
viewportChanged,
workflowAuthorChanged,
workflowContactChanged,
workflowDescriptionChanged,
workflowExposedFieldAdded, workflowExposedFieldAdded,
workflowExposedFieldRemoved, workflowExposedFieldRemoved,
fieldLabelChanged, workflowLoaded,
viewportChanged, workflowNameChanged,
mouseOverFieldChanged, workflowNotesChanged,
selectionCopied, workflowTagsChanged,
selectionPasted, workflowVersionChanged,
selectedAll, edgeAdded,
addNodePopoverOpened,
addNodePopoverClosed,
addNodePopoverToggled,
selectionModeChanged,
nodeEmbedWorkflowChanged,
nodeIsIntermediateChanged,
mouseOverNodeChanged,
nodeExclusivelySelected,
nodeUseCacheChanged,
} = nodesSlice.actions; } = nodesSlice.actions;
export default nodesSlice.reducer; export default nodesSlice.reducer;

View File

@ -55,9 +55,29 @@ export const makeConnectionErrorSelector = (
return i18n.t('nodes.cannotConnectInputToInput'); return i18n.t('nodes.cannotConnectInputToInput');
} }
// we have to figure out which is the target and which is the source
const target = handleType === 'target' ? nodeId : connectionNodeId;
const targetHandle =
handleType === 'target' ? fieldName : connectionFieldName;
const source = handleType === 'source' ? nodeId : connectionNodeId;
const sourceHandle =
handleType === 'source' ? fieldName : connectionFieldName;
if ( if (
edges.find((edge) => { edges.find((edge) => {
return edge.target === nodeId && edge.targetHandle === fieldName; edge.target === target &&
edge.targetHandle === targetHandle &&
edge.source === source &&
edge.sourceHandle === sourceHandle;
})
) {
// We already have a connection from this source to this target
return i18n.t('nodes.cannotDuplicateConnection');
}
if (
edges.find((edge) => {
return edge.target === target && edge.targetHandle === targetHandle;
}) && }) &&
// except CollectionItem inputs can have multiples // except CollectionItem inputs can have multiples
targetType !== 'CollectionItem' targetType !== 'CollectionItem'

View File

@ -1,9 +1,7 @@
import { ButtonGroup } from '@chakra-ui/react'; import { ButtonGroup } from '@chakra-ui/react';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo'; import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import ClearInvocationCacheButton from './ClearInvocationCacheButton'; import ClearInvocationCacheButton from './ClearInvocationCacheButton';
import ToggleInvocationCacheButton from './ToggleInvocationCacheButton'; import ToggleInvocationCacheButton from './ToggleInvocationCacheButton';
import StatusStatGroup from './common/StatusStatGroup'; import StatusStatGroup from './common/StatusStatGroup';
@ -11,16 +9,7 @@ import StatusStatItem from './common/StatusStatItem';
const InvocationCacheStatus = () => { const InvocationCacheStatus = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const isConnected = useAppSelector((state) => state.system.isConnected); const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined);
const { data: queueStatus } = useGetQueueStatusQuery(undefined);
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined, {
pollingInterval:
isConnected &&
queueStatus?.processor.is_started &&
queueStatus?.queue.pending > 0
? 5000
: 0,
});
return ( return (
<StatusStatGroup> <StatusStatGroup>

View File

@ -1,9 +1,8 @@
import { UseToastOptions } from '@chakra-ui/react'; import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit'; import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { t } from 'i18next'; import { t } from 'i18next';
import { get, startCase, truncate, upperFirst } from 'lodash-es'; import { startCase } from 'lodash-es';
import { LogLevelName } from 'roarr'; import { LogLevelName } from 'roarr';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { import {
appSocketConnected, appSocketConnected,
appSocketDisconnected, appSocketDisconnected,
@ -20,8 +19,7 @@ import {
} from 'services/events/actions'; } from 'services/events/actions';
import { calculateStepPercentage } from '../util/calculateStepPercentage'; import { calculateStepPercentage } from '../util/calculateStepPercentage';
import { makeToast } from '../util/makeToast'; import { makeToast } from '../util/makeToast';
import { SystemState, LANGUAGES } from './types'; import { LANGUAGES, SystemState } from './types';
import { zPydanticValidationError } from './zodSchemas';
export const initialSystemState: SystemState = { export const initialSystemState: SystemState = {
isInitialized: false, isInitialized: false,
@ -175,50 +173,6 @@ export const systemSlice = createSlice({
// *** Matchers - must be after all cases *** // *** Matchers - must be after all cases ***
/**
* Session Invoked - REJECTED
* Session Created - REJECTED
*/
builder.addMatcher(isAnySessionRejected, (state, action) => {
let errorDescription = undefined;
const duration = 5000;
if (action.payload?.status === 422) {
const result = zPydanticValidationError.safeParse(action.payload);
if (result.success) {
result.data.error.detail.map((e) => {
state.toastQueue.push(
makeToast({
title: truncate(upperFirst(e.msg), { length: 128 }),
status: 'error',
description: truncate(
`Path:
${e.loc.join('.')}`,
{ length: 128 }
),
duration,
})
);
});
return;
}
} else if (action.payload?.error) {
errorDescription = action.payload?.error;
}
state.toastQueue.push(
makeToast({
title: t('toast.serverError'),
status: 'error',
description: truncate(
get(errorDescription, 'detail', 'Unknown Error'),
{ length: 128 }
),
duration,
})
);
});
/** /**
* Any server error * Any server error
*/ */

View File

@ -2,7 +2,7 @@ import { z } from 'zod';
export const zPydanticValidationError = z.object({ export const zPydanticValidationError = z.object({
status: z.literal(422), status: z.literal(422),
error: z.object({ data: z.object({
detail: z.array( detail: z.array(
z.object({ z.object({
loc: z.array(z.string()), loc: z.array(z.string()),

View File

@ -14,7 +14,7 @@ import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent'; import ImageGalleryContent from 'features/gallery/components/ImageGalleryContent';
import NodeEditorPanelGroup from 'features/nodes/components/sidePanel/NodeEditorPanelGroup'; import NodeEditorPanelGroup from 'features/nodes/components/sidePanel/NodeEditorPanelGroup';
import { InvokeTabName, tabMap } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab } from 'features/ui/store/uiSlice'; import { setActiveTab } from 'features/ui/store/uiSlice';
import { ResourceKey } from 'i18next'; import { ResourceKey } from 'i18next';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -110,7 +110,7 @@ export const NO_GALLERY_TABS: InvokeTabName[] = ['modelManager', 'queue'];
export const NO_SIDE_PANEL_TABS: InvokeTabName[] = ['modelManager', 'queue']; export const NO_SIDE_PANEL_TABS: InvokeTabName[] = ['modelManager', 'queue'];
const InvokeTabs = () => { const InvokeTabs = () => {
const activeTab = useAppSelector(activeTabIndexSelector); const activeTabIndex = useAppSelector(activeTabIndexSelector);
const activeTabName = useAppSelector(activeTabNameSelector); const activeTabName = useAppSelector(activeTabNameSelector);
const enabledTabs = useAppSelector(enabledTabsSelector); const enabledTabs = useAppSelector(enabledTabsSelector);
const { t } = useTranslation(); const { t } = useTranslation();
@ -150,13 +150,13 @@ const InvokeTabs = () => {
const handleTabChange = useCallback( const handleTabChange = useCallback(
(index: number) => { (index: number) => {
const activeTabName = tabMap[index]; const tab = enabledTabs[index];
if (!activeTabName) { if (!tab) {
return; return;
} }
dispatch(setActiveTab(activeTabName)); dispatch(setActiveTab(tab.id));
}, },
[dispatch] [dispatch, enabledTabs]
); );
const { const {
@ -216,8 +216,8 @@ const InvokeTabs = () => {
return ( return (
<Tabs <Tabs
variant="appTabs" variant="appTabs"
defaultIndex={activeTab} defaultIndex={activeTabIndex}
index={activeTab} index={activeTabIndex}
onChange={handleTabChange} onChange={handleTabChange}
sx={{ sx={{
flexGrow: 1, flexGrow: 1,

View File

@ -95,26 +95,32 @@ export default function UnifiedCanvasColorPicker() {
> >
<Flex minWidth={60} direction="column" gap={4} width="100%"> <Flex minWidth={60} direction="column" gap={4} width="100%">
{layer === 'base' && ( {layer === 'base' && (
<IAIColorPicker <Box
sx={{ sx={{
width: '100%', width: '100%',
paddingTop: 2, paddingTop: 2,
paddingBottom: 2, paddingBottom: 2,
}} }}
pickerColor={brushColor} >
onChange={(newColor) => dispatch(setBrushColor(newColor))} <IAIColorPicker
/> color={brushColor}
onChange={(newColor) => dispatch(setBrushColor(newColor))}
/>
</Box>
)} )}
{layer === 'mask' && ( {layer === 'mask' && (
<IAIColorPicker <Box
sx={{ sx={{
width: '100%', width: '100%',
paddingTop: 2, paddingTop: 2,
paddingBottom: 2, paddingBottom: 2,
}} }}
pickerColor={maskColor} >
onChange={(newColor) => dispatch(setMaskColor(newColor))} <IAIColorPicker
/> color={maskColor}
onChange={(newColor) => dispatch(setMaskColor(newColor))}
/>
</Box>
)} )}
</Flex> </Flex>
</IAIPopover> </IAIPopover>

View File

@ -1,13 +0,0 @@
import { InvokeTabName, tabMap } from './tabMap';
import { UIState } from './uiTypes';
export const setActiveTabReducer = (
state: UIState,
newActiveTab: number | InvokeTabName
) => {
if (typeof newActiveTab === 'number') {
state.activeTab = newActiveTab;
} else {
state.activeTab = tabMap.indexOf(newActiveTab);
}
};

View File

@ -1,27 +1,23 @@
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { isEqual } from 'lodash-es'; import { isEqual, isString } from 'lodash-es';
import { tabMap } from './tabMap';
import { InvokeTabName, tabMap } from './tabMap';
import { UIState } from './uiTypes';
export const activeTabNameSelector = createSelector( export const activeTabNameSelector = createSelector(
(state: RootState) => state.ui, (state: RootState) => state,
(ui: UIState) => tabMap[ui.activeTab] as InvokeTabName, /**
{ * Previously `activeTab` was an integer, but now it's a string.
memoizeOptions: { * Default to first tab in case user has integer.
equalityCheck: isEqual, */
}, ({ ui }) => (isString(ui.activeTab) ? ui.activeTab : 'txt2img')
}
); );
export const activeTabIndexSelector = createSelector( export const activeTabIndexSelector = createSelector(
(state: RootState) => state.ui, (state: RootState) => state,
(ui: UIState) => ui.activeTab, ({ ui, config }) => {
{ const tabs = tabMap.filter((t) => !config.disabledTabs.includes(t));
memoizeOptions: { const idx = tabs.indexOf(ui.activeTab);
equalityCheck: isEqual, return idx === -1 ? 0 : idx;
},
} }
); );

View File

@ -2,12 +2,11 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit';
import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { setActiveTabReducer } from './extraReducers';
import { InvokeTabName } from './tabMap'; import { InvokeTabName } from './tabMap';
import { UIState } from './uiTypes'; import { UIState } from './uiTypes';
export const initialUIState: UIState = { export const initialUIState: UIState = {
activeTab: 0, activeTab: 'txt2img',
shouldShowImageDetails: false, shouldShowImageDetails: false,
shouldUseCanvasBetaLayout: false, shouldUseCanvasBetaLayout: false,
shouldShowExistingModelsInSearch: false, shouldShowExistingModelsInSearch: false,
@ -26,7 +25,7 @@ export const uiSlice = createSlice({
initialState: initialUIState, initialState: initialUIState,
reducers: { reducers: {
setActiveTab: (state, action: PayloadAction<InvokeTabName>) => { setActiveTab: (state, action: PayloadAction<InvokeTabName>) => {
setActiveTabReducer(state, action.payload); state.activeTab = action.payload;
}, },
setShouldShowImageDetails: (state, action: PayloadAction<boolean>) => { setShouldShowImageDetails: (state, action: PayloadAction<boolean>) => {
state.shouldShowImageDetails = action.payload; state.shouldShowImageDetails = action.payload;
@ -73,7 +72,7 @@ export const uiSlice = createSlice({
}, },
extraReducers(builder) { extraReducers(builder) {
builder.addCase(initialImageChanged, (state) => { builder.addCase(initialImageChanged, (state) => {
setActiveTabReducer(state, 'img2img'); state.activeTab = 'img2img';
}); });
}, },
}); });

View File

@ -1,4 +1,5 @@
import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; import { SchedulerParam } from 'features/parameters/types/parameterSchemas';
import { InvokeTabName } from './tabMap';
export type Coordinates = { export type Coordinates = {
x: number; x: number;
@ -13,7 +14,7 @@ export type Dimensions = {
export type Rect = Coordinates & Dimensions; export type Rect = Coordinates & Dimensions;
export interface UIState { export interface UIState {
activeTab: number; activeTab: InvokeTabName;
shouldShowImageDetails: boolean; shouldShowImageDetails: boolean;
shouldUseCanvasBetaLayout: boolean; shouldUseCanvasBetaLayout: boolean;
shouldShowExistingModelsInSearch: boolean; shouldShowExistingModelsInSearch: boolean;

View File

@ -1,184 +0,0 @@
import { createAsyncThunk, isAnyOf } from '@reduxjs/toolkit';
import { $queueId } from 'features/queue/store/queueNanoStore';
import { isObject } from 'lodash-es';
import { $client } from 'services/api/client';
import { paths } from 'services/api/schema';
import { O } from 'ts-toolbelt';
type CreateSessionArg = {
graph: NonNullable<
paths['/api/v1/sessions/']['post']['requestBody']
>['content']['application/json'];
};
type CreateSessionResponse = O.Required<
NonNullable<
paths['/api/v1/sessions/']['post']['requestBody']
>['content']['application/json'],
'id'
>;
type CreateSessionThunkConfig = {
rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
};
/**
* `SessionsService.createSession()` thunk
*/
export const sessionCreated = createAsyncThunk<
CreateSessionResponse,
CreateSessionArg,
CreateSessionThunkConfig
>('api/sessionCreated', async (arg, { rejectWithValue }) => {
const { graph } = arg;
const { POST } = $client.get();
const { data, error, response } = await POST('/api/v1/sessions/', {
body: graph,
params: { query: { queue_id: $queueId.get() } },
});
if (error) {
return rejectWithValue({ arg, status: response.status, error });
}
return data;
});
type InvokedSessionArg = {
session_id: paths['/api/v1/sessions/{session_id}/invoke']['put']['parameters']['path']['session_id'];
};
type InvokedSessionResponse =
paths['/api/v1/sessions/{session_id}/invoke']['put']['responses']['200']['content']['application/json'];
type InvokedSessionThunkConfig = {
rejectValue: {
arg: InvokedSessionArg;
error: unknown;
status: number;
};
};
const isErrorWithStatus = (error: unknown): error is { status: number } =>
isObject(error) && 'status' in error;
const isErrorWithDetail = (error: unknown): error is { detail: string } =>
isObject(error) && 'detail' in error;
/**
* `SessionsService.invokeSession()` thunk
*/
export const sessionInvoked = createAsyncThunk<
InvokedSessionResponse,
InvokedSessionArg,
InvokedSessionThunkConfig
>('api/sessionInvoked', async (arg, { rejectWithValue }) => {
const { session_id } = arg;
const { PUT } = $client.get();
const { error, response } = await PUT(
'/api/v1/sessions/{session_id}/invoke',
{
params: {
query: { queue_id: $queueId.get(), all: true },
path: { session_id },
},
}
);
if (error) {
if (isErrorWithStatus(error) && error.status === 403) {
return rejectWithValue({
arg,
status: response.status,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error: (error as any).body.detail,
});
}
if (isErrorWithDetail(error) && response.status === 403) {
return rejectWithValue({
arg,
status: response.status,
error: error.detail,
});
}
if (error) {
return rejectWithValue({ arg, status: response.status, error });
}
}
});
type CancelSessionArg =
paths['/api/v1/sessions/{session_id}/invoke']['delete']['parameters']['path'];
type CancelSessionResponse =
paths['/api/v1/sessions/{session_id}/invoke']['delete']['responses']['200']['content']['application/json'];
type CancelSessionThunkConfig = {
rejectValue: {
arg: CancelSessionArg;
error: unknown;
};
};
/**
* `SessionsService.cancelSession()` thunk
*/
export const sessionCanceled = createAsyncThunk<
CancelSessionResponse,
CancelSessionArg,
CancelSessionThunkConfig
>('api/sessionCanceled', async (arg, { rejectWithValue }) => {
const { session_id } = arg;
const { DELETE } = $client.get();
const { data, error } = await DELETE('/api/v1/sessions/{session_id}/invoke', {
params: {
path: { session_id },
},
});
if (error) {
return rejectWithValue({ arg, error });
}
return data;
});
type ListSessionsArg = {
params: paths['/api/v1/sessions/']['get']['parameters'];
};
type ListSessionsResponse =
paths['/api/v1/sessions/']['get']['responses']['200']['content']['application/json'];
type ListSessionsThunkConfig = {
rejectValue: {
arg: ListSessionsArg;
error: unknown;
};
};
/**
* `SessionsService.listSessions()` thunk
*/
export const listedSessions = createAsyncThunk<
ListSessionsResponse,
ListSessionsArg,
ListSessionsThunkConfig
>('api/listSessions', async (arg, { rejectWithValue }) => {
const { params } = arg;
const { GET } = $client.get();
const { data, error } = await GET('/api/v1/sessions/', {
params,
});
if (error) {
return rejectWithValue({ arg, error });
}
return data;
});
export const isAnySessionRejected = isAnyOf(
sessionCreated.rejected,
sessionInvoked.rejected
);

View File

@ -1,5 +1,4 @@
import { ThemeOverride } from '@chakra-ui/react'; import { ThemeOverride, ToastProviderProps } from '@chakra-ui/react';
import { InvokeAIColors } from './colors/colors'; import { InvokeAIColors } from './colors/colors';
import { accordionTheme } from './components/accordion'; import { accordionTheme } from './components/accordion';
import { buttonTheme } from './components/button'; import { buttonTheme } from './components/button';
@ -149,3 +148,7 @@ export const theme: ThemeOverride = {
Tooltip: tooltipTheme, Tooltip: tooltipTheme,
}, },
}; };
export const TOAST_OPTIONS: ToastProviderProps = {
defaultOptions: { isClosable: true },
};