Merge branch 'main' into feat/canvas-generation-mode

This commit is contained in:
blessedcoolant 2023-07-24 20:34:25 +12:00 committed by GitHub
commit 7ea477abef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 474 additions and 169 deletions

View File

@ -298,7 +298,7 @@ async def search_for_models(
)->List[pathlib.Path]: )->List[pathlib.Path]:
if not search_path.is_dir(): if not search_path.is_dir():
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory") raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path]) return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
@models_router.get( @models_router.get(
"/ckpt_confs", "/ckpt_confs",

View File

@ -119,8 +119,8 @@ class NoiseInvocation(BaseInvocation):
@validator("seed", pre=True) @validator("seed", pre=True)
def modulo_seed(cls, v): def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range.""" """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % SEED_MAX return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> NoiseOutput: def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise( noise = get_noise(

View File

@ -3,7 +3,13 @@
from typing import Any, Optional from typing import Any, Optional
from invokeai.app.models.image import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo from invokeai.app.services.model_manager_service import (
BaseModelType,
ModelType,
SubModelType,
ModelInfo,
)
class EventServiceBase: class EventServiceBase:
session_event: str = "session_event" session_event: str = "session_event"
@ -38,7 +44,9 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
progress_image=progress_image.dict() if progress_image is not None else None, progress_image=progress_image.dict()
if progress_image is not None
else None,
step=step, step=step,
total_steps=total_steps, total_steps=total_steps,
), ),
@ -67,6 +75,7 @@ class EventServiceBase:
graph_execution_state_id: str, graph_execution_state_id: str,
node: dict, node: dict,
source_node_id: str, source_node_id: str,
error_type: str,
error: str, error: str,
) -> None: ) -> None:
"""Emitted when an invocation has completed""" """Emitted when an invocation has completed"""
@ -76,6 +85,7 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id, graph_execution_state_id=graph_execution_state_id,
node=node, node=node,
source_node_id=source_node_id, source_node_id=source_node_id,
error_type=error_type,
error=error, error=error,
), ),
) )
@ -102,7 +112,7 @@ class EventServiceBase:
), ),
) )
def emit_model_load_started ( def emit_model_load_started(
self, self,
graph_execution_state_id: str, graph_execution_state_id: str,
model_name: str, model_name: str,
@ -145,3 +155,37 @@ class EventServiceBase:
precision=str(model_info.precision), precision=str(model_info.precision),
), ),
) )
def emit_session_retrieval_error(
self,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_session_event(
event_name="session_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
),
)
def emit_invocation_retrieval_error(
self,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_session_event(
event_name="invocation_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
)

View File

@ -600,7 +600,7 @@ class ModelManagerService(ModelManagerServiceBase):
""" """
Return list of all models found in the designated directory. Return list of all models found in the designated directory.
""" """
search = FindModels(directory,self.logger) search = FindModels([directory], self.logger)
return search.list_models() return search.list_models()
def sync_to_config(self): def sync_to_config(self):

View File

@ -39,21 +39,41 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try: try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get() queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e: except Exception as e:
logger.debug("Exception while getting from queue: %s" % e) self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
if not queue_item: # Probably stopping if not queue_item: # Probably stopping
# do not hammer the queue # do not hammer the queue
time.sleep(0.5) time.sleep(0.5)
continue continue
try:
graph_execution_state = ( graph_execution_state = (
self.__invoker.services.graph_execution_manager.get( self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id queue_item.graph_execution_state_id
) )
) )
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
try:
invocation = graph_execution_state.execution_graph.get_node( invocation = graph_execution_state.execution_graph.get_node(
queue_item.invocation_id queue_item.invocation_id
) )
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
# get the source node id to provide to clients (the prepared node id is not as useful) # get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id] source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
@ -114,11 +134,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state graph_execution_state
) )
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event # Send error event
self.__invoker.services.events.emit_invocation_error( self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(), node=invocation.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=error, error=error,
) )
@ -136,11 +158,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try: try:
self.__invoker.invoke(graph_execution_state, invoke_all=True) self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e: except Exception as e:
logger.error("Error while invoking: %s" % e) self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error( self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id, graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(), node=invocation.dict(),
source_node_id=source_node_id, source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=traceback.format_exc() error=traceback.format_exc()
) )
elif is_complete: elif is_complete:

View File

@ -14,8 +14,9 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(iso_timestamp) return datetime.datetime.fromisoformat(iso_timestamp)
SEED_MAX = np.iinfo(np.int32).max SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed(): def get_random_seed():
return np.random.randint(0, SEED_MAX) rng = np.random.default_rng(seed=0)
return int(rng.integers(0, SEED_MAX))

View File

@ -474,7 +474,7 @@ class ModelPatcher:
@staticmethod @staticmethod
def _lora_forward_hook( def _lora_forward_hook(
applied_loras: List[Tuple[LoraModel, float]], applied_loras: List[Tuple[LoRAModel, float]],
layer_name: str, layer_name: str,
): ):
@ -519,7 +519,7 @@ class ModelPatcher:
def apply_lora( def apply_lora(
cls, cls,
model: torch.nn.Module, model: torch.nn.Module,
loras: List[Tuple[LoraModel, float]], loras: List[Tuple[LoRAModel, float]],
prefix: str, prefix: str,
): ):
original_weights = dict() original_weights = dict()

View File

@ -98,6 +98,6 @@ class FindModels(ModelSearch):
def list_models(self) -> List[Path]: def list_models(self) -> List[Path]:
self.search() self.search()
return self.models_found return list(self.models_found)

View File

@ -10,6 +10,7 @@ from .base import (
SubModelType, SubModelType,
classproperty, classproperty,
InvalidModelException, InvalidModelException,
ModelNotFoundException,
) )
# TODO: naming # TODO: naming
from ..lora import LoRAModel as LoRAModelRaw from ..lora import LoRAModel as LoRAModelRaw

View File

@ -102,8 +102,7 @@
"openInNewTab": "Open in New Tab", "openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again", "dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?", "areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt", "imagePrompt": "Image Prompt"
"clearNodes": "Are you sure you want to clear all nodes?"
}, },
"gallery": { "gallery": {
"generations": "Generations", "generations": "Generations",
@ -615,6 +614,11 @@
"initialImageNotSetDesc": "Could not load initial image", "initialImageNotSetDesc": "Could not load initial image",
"nodesSaved": "Nodes Saved", "nodesSaved": "Nodes Saved",
"nodesLoaded": "Nodes Loaded", "nodesLoaded": "Nodes Loaded",
"nodesNotValidGraph": "Not a valid InvokeAI Node Graph",
"nodesNotValidJSON": "Not a valid JSON",
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
"nodesLoadedFailed": "Failed To Load Nodes", "nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared" "nodesCleared": "Nodes Cleared"
}, },
@ -700,9 +704,10 @@
}, },
"nodes": { "nodes": {
"reloadSchema": "Reload Schema", "reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes", "saveGraph": "Save Graph",
"loadNodes": "Load Nodes", "loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
"clearNodes": "Clear Nodes", "clearGraph": "Clear Graph",
"clearGraphDesc": "Are you sure you want to clear all nodes?",
"zoomInNodes": "Zoom In", "zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out", "zoomOutNodes": "Zoom Out",
"fitViewportNodes": "Fit View", "fitViewportNodes": "Fit View",

View File

@ -1,2 +1,2 @@
export const NUMPY_RAND_MIN = 0; export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 2147483647; export const NUMPY_RAND_MAX = 4294967295;

View File

@ -75,6 +75,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage'; import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes'; import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage'; import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
export const listenerMiddleware = createListenerMiddleware(); export const listenerMiddleware = createListenerMiddleware();
@ -153,6 +155,8 @@ addSocketDisconnectedListener();
addSocketSubscribedListener(); addSocketSubscribedListener();
addSocketUnsubscribedListener(); addSocketUnsubscribedListener();
addModelLoadEventListener(); addModelLoadEventListener();
addSessionRetrievalErrorEventListener();
addInvocationRetrievalErrorEventListener();
// Session Created // Session Created
addSessionCreatedPendingListener(); addSessionCreatedPendingListener();

View File

@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => {
effect: (action) => { effect: (action) => {
const log = logger('session'); const log = logger('session');
if (action.payload) { if (action.payload) {
const { error } = action.payload; const { error, status } = action.payload;
const graph = parseify(action.meta.arg); const graph = parseify(action.meta.arg);
const stringifiedError = JSON.stringify(error);
log.error( log.error(
{ graph, error: serializeError(error) }, { graph, status, error: serializeError(error) },
`Problem creating session: ${stringifiedError}` `Problem creating session`
); );
} }
}, },

View File

@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => {
const { session_id } = action.meta.arg; const { session_id } = action.meta.arg;
if (action.payload) { if (action.payload) {
const { error } = action.payload; const { error } = action.payload;
const stringifiedError = JSON.stringify(error);
log.error( log.error(
{ {
session_id, session_id,
error: serializeError(error), error: serializeError(error),
}, },
`Problem invoking session: ${stringifiedError}` `Problem invoking session`
); );
} }
}, },

View File

@ -0,0 +1,20 @@
import { logger } from 'app/logging/logger';
import {
appSocketInvocationRetrievalError,
socketInvocationRetrievalError,
} from 'services/events/actions';
import { startAppListening } from '../..';
export const addInvocationRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketInvocationRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
log.error(
action.payload,
`Invocation retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketInvocationRetrievalError(action.payload));
},
});
};

View File

@ -0,0 +1,20 @@
import { logger } from 'app/logging/logger';
import {
appSocketSessionRetrievalError,
socketSessionRetrievalError,
} from 'services/events/actions';
import { startAppListening } from '../..';
export const addSessionRetrievalErrorEventListener = () => {
startAppListening({
actionCreator: socketSessionRetrievalError,
effect: (action, { dispatch }) => {
const log = logger('socketio');
log.error(
action.payload,
`Session retrieval error (${action.payload.data.graph_execution_state_id})`
);
dispatch(appSocketSessionRetrievalError(action.payload));
},
});
};

View File

@ -2,11 +2,11 @@ import { HStack } from '@chakra-ui/react';
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton'; import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
import { memo } from 'react'; import { memo } from 'react';
import { Panel } from 'reactflow'; import { Panel } from 'reactflow';
import LoadNodesButton from '../ui/LoadNodesButton'; import ClearGraphButton from '../ui/ClearGraphButton';
import LoadGraphButton from '../ui/LoadGraphButton';
import NodeInvokeButton from '../ui/NodeInvokeButton'; import NodeInvokeButton from '../ui/NodeInvokeButton';
import ReloadSchemaButton from '../ui/ReloadSchemaButton'; import ReloadSchemaButton from '../ui/ReloadSchemaButton';
import SaveNodesButton from '../ui/SaveNodesButton'; import SaveGraphButton from '../ui/SaveGraphButton';
import ClearNodesButton from '../ui/ClearNodesButton';
const TopCenterPanel = () => { const TopCenterPanel = () => {
return ( return (
@ -15,9 +15,9 @@ const TopCenterPanel = () => {
<NodeInvokeButton /> <NodeInvokeButton />
<CancelButton /> <CancelButton />
<ReloadSchemaButton /> <ReloadSchemaButton />
<SaveNodesButton /> <SaveGraphButton />
<LoadNodesButton /> <LoadGraphButton />
<ClearNodesButton /> <ClearGraphButton />
</HStack> </HStack>
</Panel> </Panel>
); );

View File

@ -9,17 +9,17 @@ import {
Text, Text,
useDisclosure, useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { makeToast } from 'features/system/util/makeToast';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { memo, useRef, useCallback } from 'react'; import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa'; import { FaTrash } from 'react-icons/fa';
const ClearNodesButton = () => { const ClearGraphButton = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
@ -46,8 +46,8 @@ const ClearNodesButton = () => {
<> <>
<IAIIconButton <IAIIconButton
icon={<FaTrash />} icon={<FaTrash />}
tooltip={t('nodes.clearNodes')} tooltip={t('nodes.clearGraph')}
aria-label={t('nodes.clearNodes')} aria-label={t('nodes.clearGraph')}
onClick={onOpen} onClick={onOpen}
isDisabled={nodes.length === 0} isDisabled={nodes.length === 0}
/> />
@ -62,11 +62,11 @@ const ClearNodesButton = () => {
<AlertDialogContent> <AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold"> <AlertDialogHeader fontSize="lg" fontWeight="bold">
{t('nodes.clearNodes')} {t('nodes.clearGraph')}
</AlertDialogHeader> </AlertDialogHeader>
<AlertDialogBody> <AlertDialogBody>
<Text>{t('common.clearNodes')}</Text> <Text>{t('nodes.clearGraphDesc')}</Text>
</AlertDialogBody> </AlertDialogBody>
<AlertDialogFooter> <AlertDialogFooter>
@ -83,4 +83,4 @@ const ClearNodesButton = () => {
); );
}; };
export default memo(ClearNodesButton); export default memo(ClearGraphButton);

View File

@ -0,0 +1,161 @@
import { FileButton } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import i18n from 'i18n';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUpload } from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
interface JsonFile {
[key: string]: unknown;
}
function sanityCheckInvokeAIGraph(jsonFile: JsonFile): {
isValid: boolean;
message: string;
} {
// Check if primary keys exist
const keys = ['nodes', 'edges', 'viewport'];
for (const key of keys) {
if (!(key in jsonFile)) {
return {
isValid: false,
message: i18n.t('toast.nodesNotValidGraph'),
};
}
}
// Check if nodes and edges are arrays
if (!Array.isArray(jsonFile.nodes) || !Array.isArray(jsonFile.edges)) {
return {
isValid: false,
message: i18n.t('toast.nodesNotValidGraph'),
};
}
// Check if data is present in nodes
const nodeKeys = ['data', 'type'];
const nodeTypes = ['invocation', 'progress_image'];
if (jsonFile.nodes.length > 0) {
for (const node of jsonFile.nodes) {
for (const nodeKey of nodeKeys) {
if (!(nodeKey in node)) {
return {
isValid: false,
message: i18n.t('toast.nodesNotValidGraph'),
};
}
if (nodeKey === 'type' && !nodeTypes.includes(node[nodeKey])) {
return {
isValid: false,
message: i18n.t('toast.nodesUnrecognizedTypes'),
};
}
}
}
}
// Check Edge Object
const edgeKeys = ['source', 'sourceHandle', 'target', 'targetHandle'];
if (jsonFile.edges.length > 0) {
for (const edge of jsonFile.edges) {
for (const edgeKey of edgeKeys) {
if (!(edgeKey in edge)) {
return {
isValid: false,
message: i18n.t('toast.nodesBrokenConnections'),
};
}
}
}
}
return {
isValid: true,
message: i18n.t('toast.nodesLoaded'),
};
}
const LoadGraphButton = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { fitView } = useReactFlow();
const uploadedFileRef = useRef<() => void>(null);
const restoreJSONToEditor = useCallback(
(v: File | null) => {
if (!v) return;
const reader = new FileReader();
reader.onload = async () => {
const json = reader.result;
try {
const retrievedNodeTree = await JSON.parse(String(json));
const { isValid, message } =
sanityCheckInvokeAIGraph(retrievedNodeTree);
if (isValid) {
dispatch(loadFileNodes(retrievedNodeTree.nodes));
dispatch(loadFileEdges(retrievedNodeTree.edges));
fitView();
dispatch(
addToast(makeToast({ title: message, status: 'success' }))
);
} else {
dispatch(
addToast(
makeToast({
title: message,
status: 'error',
})
)
);
}
// Cleanup
reader.abort();
} catch (error) {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.nodesNotValidJSON'),
status: 'error',
})
)
);
}
}
};
reader.readAsText(v);
// Cleanup
uploadedFileRef.current?.();
},
[fitView, dispatch, t]
);
return (
<FileButton
resetRef={uploadedFileRef}
accept="application/json"
onChange={restoreJSONToEditor}
>
{(props) => (
<IAIIconButton
icon={<FaUpload />}
tooltip={t('nodes.loadGraph')}
aria-label={t('nodes.loadGraph')}
{...props}
/>
)}
</FileButton>
);
};
export default memo(LoadGraphButton);

View File

@ -1,79 +0,0 @@
import { FileButton } from '@mantine/core';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUpload } from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
const LoadNodesButton = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { fitView } = useReactFlow();
const uploadedFileRef = useRef<() => void>(null);
const restoreJSONToEditor = useCallback(
(v: File | null) => {
if (!v) return;
const reader = new FileReader();
reader.onload = async () => {
const json = reader.result;
const retrievedNodeTree = await JSON.parse(String(json));
if (!retrievedNodeTree) {
dispatch(
addToast(
makeToast({
title: t('toast.nodesLoadedFailed'),
status: 'error',
})
)
);
}
if (retrievedNodeTree) {
dispatch(loadFileNodes(retrievedNodeTree.nodes));
dispatch(loadFileEdges(retrievedNodeTree.edges));
fitView();
dispatch(
addToast(
makeToast({ title: t('toast.nodesLoaded'), status: 'success' })
)
);
}
// Cleanup
reader.abort();
};
reader.readAsText(v);
// Cleanup
uploadedFileRef.current?.();
},
[fitView, dispatch, t]
);
return (
<FileButton
resetRef={uploadedFileRef}
accept="application/json"
onChange={restoreJSONToEditor}
>
{(props) => (
<IAIIconButton
icon={<FaUpload />}
tooltip={t('nodes.loadNodes')}
aria-label={t('nodes.loadNodes')}
{...props}
/>
)}
</FileButton>
);
};
export default memo(LoadNodesButton);

View File

@ -6,7 +6,7 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaSave } from 'react-icons/fa'; import { FaSave } from 'react-icons/fa';
const SaveNodesButton = () => { const SaveGraphButton = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const editorInstance = useAppSelector( const editorInstance = useAppSelector(
(state: RootState) => state.nodes.editorInstance (state: RootState) => state.nodes.editorInstance
@ -37,12 +37,12 @@ const SaveNodesButton = () => {
<IAIIconButton <IAIIconButton
icon={<FaSave />} icon={<FaSave />}
fontSize={18} fontSize={18}
tooltip={t('nodes.saveNodes')} tooltip={t('nodes.saveGraph')}
aria-label={t('nodes.saveNodes')} aria-label={t('nodes.saveGraph')}
onClick={saveEditorToJSON} onClick={saveEditorToJSON}
isDisabled={nodes.length === 0} isDisabled={nodes.length === 0}
/> />
); );
}; };
export default memo(SaveNodesButton); export default memo(SaveGraphButton);

View File

@ -1,5 +1,5 @@
import { UseToastOptions } from '@chakra-ui/react'; import { UseToastOptions } from '@chakra-ui/react';
import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
import { InvokeLogLevel } from 'app/logging/logger'; import { InvokeLogLevel } from 'app/logging/logger';
import { userInvoked } from 'app/store/actions'; import { userInvoked } from 'app/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
@ -16,13 +16,16 @@ import {
appSocketGraphExecutionStateComplete, appSocketGraphExecutionStateComplete,
appSocketInvocationComplete, appSocketInvocationComplete,
appSocketInvocationError, appSocketInvocationError,
appSocketInvocationRetrievalError,
appSocketInvocationStarted, appSocketInvocationStarted,
appSocketSessionRetrievalError,
appSocketSubscribed, appSocketSubscribed,
appSocketUnsubscribed, appSocketUnsubscribed,
} from 'services/events/actions'; } from 'services/events/actions';
import { ProgressImage } from 'services/events/types'; import { ProgressImage } from 'services/events/types';
import { makeToast } from '../util/makeToast'; import { makeToast } from '../util/makeToast';
import { LANGUAGES } from './constants'; import { LANGUAGES } from './constants';
import { startCase } from 'lodash-es';
export type CancelStrategy = 'immediate' | 'scheduled'; export type CancelStrategy = 'immediate' | 'scheduled';
@ -288,25 +291,6 @@ export const systemSlice = createSlice({
} }
}); });
/**
* Invocation Error
*/
builder.addCase(appSocketInvocationError, (state) => {
state.isProcessing = false;
state.isCancelable = true;
// state.currentIteration = 0;
// state.totalIterations = 0;
state.currentStatusHasSteps = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusError';
state.progressImage = null;
state.toastQueue.push(
makeToast({ title: t('toast.serverError'), status: 'error' })
);
});
/** /**
* Graph Execution State Complete * Graph Execution State Complete
*/ */
@ -362,7 +346,7 @@ export const systemSlice = createSlice({
* Session Invoked - REJECTED * Session Invoked - REJECTED
* Session Created - REJECTED * Session Created - REJECTED
*/ */
builder.addMatcher(isAnySessionRejected, (state) => { builder.addMatcher(isAnySessionRejected, (state, action) => {
state.isProcessing = false; state.isProcessing = false;
state.isCancelable = false; state.isCancelable = false;
state.isCancelScheduled = false; state.isCancelScheduled = false;
@ -372,7 +356,35 @@ export const systemSlice = createSlice({
state.progressImage = null; state.progressImage = null;
state.toastQueue.push( state.toastQueue.push(
makeToast({ title: t('toast.serverError'), status: 'error' }) makeToast({
title: t('toast.serverError'),
status: 'error',
description:
action.payload?.status === 422 ? 'Validation Error' : undefined,
})
);
});
/**
* Any server error
*/
builder.addMatcher(isAnyServerError, (state, action) => {
state.isProcessing = false;
state.isCancelable = true;
// state.currentIteration = 0;
// state.totalIterations = 0;
state.currentStatusHasSteps = false;
state.currentStep = 0;
state.totalSteps = 0;
state.statusTranslationKey = 'common.statusError';
state.progressImage = null;
state.toastQueue.push(
makeToast({
title: t('toast.serverError'),
status: 'error',
description: startCase(action.payload.data.error_type),
})
); );
}); });
}, },
@ -400,3 +412,9 @@ export const {
} = systemSlice.actions; } = systemSlice.actions;
export default systemSlice.reducer; export default systemSlice.reducer;
const isAnyServerError = isAnyOf(
appSocketInvocationError,
appSocketSessionRetrievalError,
appSocketInvocationRetrievalError
);

View File

@ -18,7 +18,7 @@ type CreateSessionResponse = O.Required<
>; >;
type CreateSessionThunkConfig = { type CreateSessionThunkConfig = {
rejectValue: { arg: CreateSessionArg; error: unknown }; rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
}; };
/** /**
@ -36,7 +36,7 @@ export const sessionCreated = createAsyncThunk<
}); });
if (error) { if (error) {
return rejectWithValue({ arg, error }); return rejectWithValue({ arg, status: response.status, error });
} }
return data; return data;
@ -53,6 +53,7 @@ type InvokedSessionThunkConfig = {
rejectValue: { rejectValue: {
arg: InvokedSessionArg; arg: InvokedSessionArg;
error: unknown; error: unknown;
status: number;
}; };
}; };
@ -78,9 +79,13 @@ export const sessionInvoked = createAsyncThunk<
if (error) { if (error) {
if (isErrorWithStatus(error) && error.status === 403) { if (isErrorWithStatus(error) && error.status === 403) {
return rejectWithValue({ arg, error: (error as any).body.detail }); return rejectWithValue({
arg,
status: response.status,
error: (error as any).body.detail,
});
} }
return rejectWithValue({ arg, error }); return rejectWithValue({ arg, status: response.status, error });
} }
}); });

View File

@ -4,9 +4,11 @@ import {
GraphExecutionStateCompleteEvent, GraphExecutionStateCompleteEvent,
InvocationCompleteEvent, InvocationCompleteEvent,
InvocationErrorEvent, InvocationErrorEvent,
InvocationRetrievalErrorEvent,
InvocationStartedEvent, InvocationStartedEvent,
ModelLoadCompletedEvent, ModelLoadCompletedEvent,
ModelLoadStartedEvent, ModelLoadStartedEvent,
SessionRetrievalErrorEvent,
} from 'services/events/types'; } from 'services/events/types';
// Create actions for each socket // Create actions for each socket
@ -181,3 +183,35 @@ export const socketModelLoadCompleted = createAction<{
export const appSocketModelLoadCompleted = createAction<{ export const appSocketModelLoadCompleted = createAction<{
data: ModelLoadCompletedEvent; data: ModelLoadCompletedEvent;
}>('socket/appSocketModelLoadCompleted'); }>('socket/appSocketModelLoadCompleted');
/**
* Socket.IO Session Retrieval Error
*
* Do not use. Only for use in middleware.
*/
export const socketSessionRetrievalError = createAction<{
data: SessionRetrievalErrorEvent;
}>('socket/socketSessionRetrievalError');
/**
* App-level Session Retrieval Error
*/
export const appSocketSessionRetrievalError = createAction<{
data: SessionRetrievalErrorEvent;
}>('socket/appSocketSessionRetrievalError');
/**
* Socket.IO Invocation Retrieval Error
*
* Do not use. Only for use in middleware.
*/
export const socketInvocationRetrievalError = createAction<{
data: InvocationRetrievalErrorEvent;
}>('socket/socketInvocationRetrievalError');
/**
* App-level Invocation Retrieval Error
*/
export const appSocketInvocationRetrievalError = createAction<{
data: InvocationRetrievalErrorEvent;
}>('socket/appSocketInvocationRetrievalError');

View File

@ -87,6 +87,7 @@ export type InvocationErrorEvent = {
graph_execution_state_id: string; graph_execution_state_id: string;
node: BaseNode; node: BaseNode;
source_node_id: string; source_node_id: string;
error_type: string;
error: string; error: string;
}; };
@ -110,6 +111,29 @@ export type GraphExecutionStateCompleteEvent = {
graph_execution_state_id: string; graph_execution_state_id: string;
}; };
/**
* A `session_retrieval_error` socket.io event.
*
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
*/
export type SessionRetrievalErrorEvent = {
graph_execution_state_id: string;
error_type: string;
error: string;
};
/**
* A `invocation_retrieval_error` socket.io event.
*
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
*/
export type InvocationRetrievalErrorEvent = {
graph_execution_state_id: string;
node_id: string;
error_type: string;
error: string;
};
export type ClientEmitSubscribe = { export type ClientEmitSubscribe = {
session: string; session: string;
}; };
@ -128,6 +152,8 @@ export type ServerToClientEvents = {
) => void; ) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void; model_load_started: (payload: ModelLoadStartedEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void; model_load_completed: (payload: ModelLoadCompletedEvent) => void;
session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void;
invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
}; };
export type ClientToServerEvents = { export type ClientToServerEvents = {

View File

@ -11,9 +11,11 @@ import {
socketGraphExecutionStateComplete, socketGraphExecutionStateComplete,
socketInvocationComplete, socketInvocationComplete,
socketInvocationError, socketInvocationError,
socketInvocationRetrievalError,
socketInvocationStarted, socketInvocationStarted,
socketModelLoadCompleted, socketModelLoadCompleted,
socketModelLoadStarted, socketModelLoadStarted,
socketSessionRetrievalError,
socketSubscribed, socketSubscribed,
} from '../actions'; } from '../actions';
import { ClientToServerEvents, ServerToClientEvents } from '../types'; import { ClientToServerEvents, ServerToClientEvents } from '../types';
@ -138,4 +140,26 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
}) })
); );
}); });
/**
* Session retrieval error
*/
socket.on('session_retrieval_error', (data) => {
dispatch(
socketSessionRetrievalError({
data,
})
);
});
/**
* Invocation retrieval error
*/
socket.on('invocation_retrieval_error', (data) => {
dispatch(
socketInvocationRetrievalError({
data,
})
);
});
}; };