mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/canvas-generation-mode
This commit is contained in:
commit
7ea477abef
@ -298,7 +298,7 @@ async def search_for_models(
|
||||
)->List[pathlib.Path]:
|
||||
if not search_path.is_dir():
|
||||
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
|
||||
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
|
||||
|
||||
@models_router.get(
|
||||
"/ckpt_confs",
|
||||
|
@ -119,8 +119,8 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
@validator("seed", pre=True)
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
||||
return v % SEED_MAX
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
noise = get_noise(
|
||||
|
@ -3,7 +3,13 @@
|
||||
from typing import Any, Optional
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
|
||||
from invokeai.app.services.model_manager_service import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelInfo,
|
||||
)
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
@ -38,7 +44,9 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=progress_image.dict() if progress_image is not None else None,
|
||||
progress_image=progress_image.dict()
|
||||
if progress_image is not None
|
||||
else None,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
),
|
||||
@ -67,6 +75,7 @@ class EventServiceBase:
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
@ -76,6 +85,7 @@ class EventServiceBase:
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
@ -102,7 +112,7 @@ class EventServiceBase:
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_load_started (
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
@ -145,3 +155,37 @@ class EventServiceBase:
|
||||
precision=str(model_info.precision),
|
||||
),
|
||||
)
|
||||
|
||||
def emit_session_retrieval_error(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when session retrieval fails"""
|
||||
self.__emit_session_event(
|
||||
event_name="session_retrieval_error",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_retrieval_error(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when invocation retrieval fails"""
|
||||
self.__emit_session_event(
|
||||
event_name="invocation_retrieval_error",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node_id=node_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
@ -600,7 +600,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
"""
|
||||
Return list of all models found in the designated directory.
|
||||
"""
|
||||
search = FindModels(directory,self.logger)
|
||||
search = FindModels([directory], self.logger)
|
||||
return search.list_models()
|
||||
|
||||
def sync_to_config(self):
|
||||
|
@ -39,21 +39,41 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
logger.debug("Exception while getting from queue: %s" % e)
|
||||
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
|
||||
|
||||
if not queue_item: # Probably stopping
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
try:
|
||||
graph_execution_state = (
|
||||
self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||
self.__invoker.services.events.emit_session_retrieval_error(
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
invocation = graph_execution_state.execution_graph.get_node(
|
||||
queue_item.invocation_id
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
node_id=queue_item.invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
continue
|
||||
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
@ -114,11 +134,13 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
graph_execution_state
|
||||
)
|
||||
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
|
||||
@ -136,11 +158,12 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
except Exception as e:
|
||||
logger.error("Error while invoking: %s" % e)
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc()
|
||||
)
|
||||
elif is_complete:
|
||||
|
@ -14,8 +14,9 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
||||
return datetime.datetime.fromisoformat(iso_timestamp)
|
||||
|
||||
|
||||
SEED_MAX = np.iinfo(np.int32).max
|
||||
SEED_MAX = np.iinfo(np.uint32).max
|
||||
|
||||
|
||||
def get_random_seed():
|
||||
return np.random.randint(0, SEED_MAX)
|
||||
rng = np.random.default_rng(seed=0)
|
||||
return int(rng.integers(0, SEED_MAX))
|
||||
|
@ -474,7 +474,7 @@ class ModelPatcher:
|
||||
|
||||
@staticmethod
|
||||
def _lora_forward_hook(
|
||||
applied_loras: List[Tuple[LoraModel, float]],
|
||||
applied_loras: List[Tuple[LoRAModel, float]],
|
||||
layer_name: str,
|
||||
):
|
||||
|
||||
@ -519,7 +519,7 @@ class ModelPatcher:
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: torch.nn.Module,
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
original_weights = dict()
|
||||
|
@ -98,6 +98,6 @@ class FindModels(ModelSearch):
|
||||
|
||||
def list_models(self) -> List[Path]:
|
||||
self.search()
|
||||
return self.models_found
|
||||
return list(self.models_found)
|
||||
|
||||
|
||||
|
@ -10,6 +10,7 @@ from .base import (
|
||||
SubModelType,
|
||||
classproperty,
|
||||
InvalidModelException,
|
||||
ModelNotFoundException,
|
||||
)
|
||||
# TODO: naming
|
||||
from ..lora import LoRAModel as LoRAModelRaw
|
||||
|
@ -102,8 +102,7 @@
|
||||
"openInNewTab": "Open in New Tab",
|
||||
"dontAskMeAgain": "Don't ask me again",
|
||||
"areYouSure": "Are you sure?",
|
||||
"imagePrompt": "Image Prompt",
|
||||
"clearNodes": "Are you sure you want to clear all nodes?"
|
||||
"imagePrompt": "Image Prompt"
|
||||
},
|
||||
"gallery": {
|
||||
"generations": "Generations",
|
||||
@ -615,6 +614,11 @@
|
||||
"initialImageNotSetDesc": "Could not load initial image",
|
||||
"nodesSaved": "Nodes Saved",
|
||||
"nodesLoaded": "Nodes Loaded",
|
||||
"nodesNotValidGraph": "Not a valid InvokeAI Node Graph",
|
||||
"nodesNotValidJSON": "Not a valid JSON",
|
||||
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
|
||||
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
|
||||
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
|
||||
"nodesLoadedFailed": "Failed To Load Nodes",
|
||||
"nodesCleared": "Nodes Cleared"
|
||||
},
|
||||
@ -700,9 +704,10 @@
|
||||
},
|
||||
"nodes": {
|
||||
"reloadSchema": "Reload Schema",
|
||||
"saveNodes": "Save Nodes",
|
||||
"loadNodes": "Load Nodes",
|
||||
"clearNodes": "Clear Nodes",
|
||||
"saveGraph": "Save Graph",
|
||||
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
|
||||
"clearGraph": "Clear Graph",
|
||||
"clearGraphDesc": "Are you sure you want to clear all nodes?",
|
||||
"zoomInNodes": "Zoom In",
|
||||
"zoomOutNodes": "Zoom Out",
|
||||
"fitViewportNodes": "Fit View",
|
||||
|
@ -1,2 +1,2 @@
|
||||
export const NUMPY_RAND_MIN = 0;
|
||||
export const NUMPY_RAND_MAX = 2147483647;
|
||||
export const NUMPY_RAND_MAX = 4294967295;
|
||||
|
@ -75,6 +75,8 @@ import { addUserInvokedCanvasListener } from './listeners/userInvokedCanvas';
|
||||
import { addUserInvokedImageToImageListener } from './listeners/userInvokedImageToImage';
|
||||
import { addUserInvokedNodesListener } from './listeners/userInvokedNodes';
|
||||
import { addUserInvokedTextToImageListener } from './listeners/userInvokedTextToImage';
|
||||
import { addSessionRetrievalErrorEventListener } from './listeners/socketio/socketSessionRetrievalError';
|
||||
import { addInvocationRetrievalErrorEventListener } from './listeners/socketio/socketInvocationRetrievalError';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@ -153,6 +155,8 @@ addSocketDisconnectedListener();
|
||||
addSocketSubscribedListener();
|
||||
addSocketUnsubscribedListener();
|
||||
addModelLoadEventListener();
|
||||
addSessionRetrievalErrorEventListener();
|
||||
addInvocationRetrievalErrorEventListener();
|
||||
|
||||
// Session Created
|
||||
addSessionCreatedPendingListener();
|
||||
|
@ -33,12 +33,11 @@ export const addSessionCreatedRejectedListener = () => {
|
||||
effect: (action) => {
|
||||
const log = logger('session');
|
||||
if (action.payload) {
|
||||
const { error } = action.payload;
|
||||
const { error, status } = action.payload;
|
||||
const graph = parseify(action.meta.arg);
|
||||
const stringifiedError = JSON.stringify(error);
|
||||
log.error(
|
||||
{ graph, error: serializeError(error) },
|
||||
`Problem creating session: ${stringifiedError}`
|
||||
{ graph, status, error: serializeError(error) },
|
||||
`Problem creating session`
|
||||
);
|
||||
}
|
||||
},
|
||||
|
@ -31,13 +31,12 @@ export const addSessionInvokedRejectedListener = () => {
|
||||
const { session_id } = action.meta.arg;
|
||||
if (action.payload) {
|
||||
const { error } = action.payload;
|
||||
const stringifiedError = JSON.stringify(error);
|
||||
log.error(
|
||||
{
|
||||
session_id,
|
||||
error: serializeError(error),
|
||||
},
|
||||
`Problem invoking session: ${stringifiedError}`
|
||||
`Problem invoking session`
|
||||
);
|
||||
}
|
||||
},
|
||||
|
@ -0,0 +1,20 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import {
|
||||
appSocketInvocationRetrievalError,
|
||||
socketInvocationRetrievalError,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
export const addInvocationRetrievalErrorEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationRetrievalError,
|
||||
effect: (action, { dispatch }) => {
|
||||
const log = logger('socketio');
|
||||
log.error(
|
||||
action.payload,
|
||||
`Invocation retrieval error (${action.payload.data.graph_execution_state_id})`
|
||||
);
|
||||
dispatch(appSocketInvocationRetrievalError(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
@ -0,0 +1,20 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import {
|
||||
appSocketSessionRetrievalError,
|
||||
socketSessionRetrievalError,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
|
||||
export const addSessionRetrievalErrorEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketSessionRetrievalError,
|
||||
effect: (action, { dispatch }) => {
|
||||
const log = logger('socketio');
|
||||
log.error(
|
||||
action.payload,
|
||||
`Session retrieval error (${action.payload.data.graph_execution_state_id})`
|
||||
);
|
||||
dispatch(appSocketSessionRetrievalError(action.payload));
|
||||
},
|
||||
});
|
||||
};
|
@ -2,11 +2,11 @@ import { HStack } from '@chakra-ui/react';
|
||||
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
|
||||
import { memo } from 'react';
|
||||
import { Panel } from 'reactflow';
|
||||
import LoadNodesButton from '../ui/LoadNodesButton';
|
||||
import ClearGraphButton from '../ui/ClearGraphButton';
|
||||
import LoadGraphButton from '../ui/LoadGraphButton';
|
||||
import NodeInvokeButton from '../ui/NodeInvokeButton';
|
||||
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
|
||||
import SaveNodesButton from '../ui/SaveNodesButton';
|
||||
import ClearNodesButton from '../ui/ClearNodesButton';
|
||||
import SaveGraphButton from '../ui/SaveGraphButton';
|
||||
|
||||
const TopCenterPanel = () => {
|
||||
return (
|
||||
@ -15,9 +15,9 @@ const TopCenterPanel = () => {
|
||||
<NodeInvokeButton />
|
||||
<CancelButton />
|
||||
<ReloadSchemaButton />
|
||||
<SaveNodesButton />
|
||||
<LoadNodesButton />
|
||||
<ClearNodesButton />
|
||||
<SaveGraphButton />
|
||||
<LoadGraphButton />
|
||||
<ClearGraphButton />
|
||||
</HStack>
|
||||
</Panel>
|
||||
);
|
||||
|
@ -9,17 +9,17 @@ import {
|
||||
Text,
|
||||
useDisclosure,
|
||||
} from '@chakra-ui/react';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { memo, useRef, useCallback } from 'react';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
|
||||
const ClearNodesButton = () => {
|
||||
const ClearGraphButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||
@ -46,8 +46,8 @@ const ClearNodesButton = () => {
|
||||
<>
|
||||
<IAIIconButton
|
||||
icon={<FaTrash />}
|
||||
tooltip={t('nodes.clearNodes')}
|
||||
aria-label={t('nodes.clearNodes')}
|
||||
tooltip={t('nodes.clearGraph')}
|
||||
aria-label={t('nodes.clearGraph')}
|
||||
onClick={onOpen}
|
||||
isDisabled={nodes.length === 0}
|
||||
/>
|
||||
@ -62,11 +62,11 @@ const ClearNodesButton = () => {
|
||||
|
||||
<AlertDialogContent>
|
||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||
{t('nodes.clearNodes')}
|
||||
{t('nodes.clearGraph')}
|
||||
</AlertDialogHeader>
|
||||
|
||||
<AlertDialogBody>
|
||||
<Text>{t('common.clearNodes')}</Text>
|
||||
<Text>{t('nodes.clearGraphDesc')}</Text>
|
||||
</AlertDialogBody>
|
||||
|
||||
<AlertDialogFooter>
|
||||
@ -83,4 +83,4 @@ const ClearNodesButton = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ClearNodesButton);
|
||||
export default memo(ClearGraphButton);
|
@ -0,0 +1,161 @@
|
||||
import { FileButton } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import i18n from 'i18n';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaUpload } from 'react-icons/fa';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
|
||||
interface JsonFile {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
function sanityCheckInvokeAIGraph(jsonFile: JsonFile): {
|
||||
isValid: boolean;
|
||||
message: string;
|
||||
} {
|
||||
// Check if primary keys exist
|
||||
const keys = ['nodes', 'edges', 'viewport'];
|
||||
for (const key of keys) {
|
||||
if (!(key in jsonFile)) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesNotValidGraph'),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Check if nodes and edges are arrays
|
||||
if (!Array.isArray(jsonFile.nodes) || !Array.isArray(jsonFile.edges)) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesNotValidGraph'),
|
||||
};
|
||||
}
|
||||
|
||||
// Check if data is present in nodes
|
||||
const nodeKeys = ['data', 'type'];
|
||||
const nodeTypes = ['invocation', 'progress_image'];
|
||||
if (jsonFile.nodes.length > 0) {
|
||||
for (const node of jsonFile.nodes) {
|
||||
for (const nodeKey of nodeKeys) {
|
||||
if (!(nodeKey in node)) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesNotValidGraph'),
|
||||
};
|
||||
}
|
||||
if (nodeKey === 'type' && !nodeTypes.includes(node[nodeKey])) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesUnrecognizedTypes'),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check Edge Object
|
||||
const edgeKeys = ['source', 'sourceHandle', 'target', 'targetHandle'];
|
||||
if (jsonFile.edges.length > 0) {
|
||||
for (const edge of jsonFile.edges) {
|
||||
for (const edgeKey of edgeKeys) {
|
||||
if (!(edgeKey in edge)) {
|
||||
return {
|
||||
isValid: false,
|
||||
message: i18n.t('toast.nodesBrokenConnections'),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: true,
|
||||
message: i18n.t('toast.nodesLoaded'),
|
||||
};
|
||||
}
|
||||
|
||||
const LoadGraphButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { fitView } = useReactFlow();
|
||||
|
||||
const uploadedFileRef = useRef<() => void>(null);
|
||||
|
||||
const restoreJSONToEditor = useCallback(
|
||||
(v: File | null) => {
|
||||
if (!v) return;
|
||||
const reader = new FileReader();
|
||||
reader.onload = async () => {
|
||||
const json = reader.result;
|
||||
|
||||
try {
|
||||
const retrievedNodeTree = await JSON.parse(String(json));
|
||||
const { isValid, message } =
|
||||
sanityCheckInvokeAIGraph(retrievedNodeTree);
|
||||
|
||||
if (isValid) {
|
||||
dispatch(loadFileNodes(retrievedNodeTree.nodes));
|
||||
dispatch(loadFileEdges(retrievedNodeTree.edges));
|
||||
fitView();
|
||||
|
||||
dispatch(
|
||||
addToast(makeToast({ title: message, status: 'success' }))
|
||||
);
|
||||
} else {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: message,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
// Cleanup
|
||||
reader.abort();
|
||||
} catch (error) {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.nodesNotValidJSON'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
reader.readAsText(v);
|
||||
|
||||
// Cleanup
|
||||
uploadedFileRef.current?.();
|
||||
},
|
||||
[fitView, dispatch, t]
|
||||
);
|
||||
return (
|
||||
<FileButton
|
||||
resetRef={uploadedFileRef}
|
||||
accept="application/json"
|
||||
onChange={restoreJSONToEditor}
|
||||
>
|
||||
{(props) => (
|
||||
<IAIIconButton
|
||||
icon={<FaUpload />}
|
||||
tooltip={t('nodes.loadGraph')}
|
||||
aria-label={t('nodes.loadGraph')}
|
||||
{...props}
|
||||
/>
|
||||
)}
|
||||
</FileButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoadGraphButton);
|
@ -1,79 +0,0 @@
|
||||
import { FileButton } from '@mantine/core';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaUpload } from 'react-icons/fa';
|
||||
import { useReactFlow } from 'reactflow';
|
||||
|
||||
const LoadNodesButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { fitView } = useReactFlow();
|
||||
|
||||
const uploadedFileRef = useRef<() => void>(null);
|
||||
|
||||
const restoreJSONToEditor = useCallback(
|
||||
(v: File | null) => {
|
||||
if (!v) return;
|
||||
const reader = new FileReader();
|
||||
reader.onload = async () => {
|
||||
const json = reader.result;
|
||||
const retrievedNodeTree = await JSON.parse(String(json));
|
||||
|
||||
if (!retrievedNodeTree) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.nodesLoadedFailed'),
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (retrievedNodeTree) {
|
||||
dispatch(loadFileNodes(retrievedNodeTree.nodes));
|
||||
dispatch(loadFileEdges(retrievedNodeTree.edges));
|
||||
fitView();
|
||||
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({ title: t('toast.nodesLoaded'), status: 'success' })
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
reader.abort();
|
||||
};
|
||||
|
||||
reader.readAsText(v);
|
||||
|
||||
// Cleanup
|
||||
uploadedFileRef.current?.();
|
||||
},
|
||||
[fitView, dispatch, t]
|
||||
);
|
||||
return (
|
||||
<FileButton
|
||||
resetRef={uploadedFileRef}
|
||||
accept="application/json"
|
||||
onChange={restoreJSONToEditor}
|
||||
>
|
||||
{(props) => (
|
||||
<IAIIconButton
|
||||
icon={<FaUpload />}
|
||||
tooltip={t('nodes.loadNodes')}
|
||||
aria-label={t('nodes.loadNodes')}
|
||||
{...props}
|
||||
/>
|
||||
)}
|
||||
</FileButton>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(LoadNodesButton);
|
@ -6,7 +6,7 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaSave } from 'react-icons/fa';
|
||||
|
||||
const SaveNodesButton = () => {
|
||||
const SaveGraphButton = () => {
|
||||
const { t } = useTranslation();
|
||||
const editorInstance = useAppSelector(
|
||||
(state: RootState) => state.nodes.editorInstance
|
||||
@ -37,12 +37,12 @@ const SaveNodesButton = () => {
|
||||
<IAIIconButton
|
||||
icon={<FaSave />}
|
||||
fontSize={18}
|
||||
tooltip={t('nodes.saveNodes')}
|
||||
aria-label={t('nodes.saveNodes')}
|
||||
tooltip={t('nodes.saveGraph')}
|
||||
aria-label={t('nodes.saveGraph')}
|
||||
onClick={saveEditorToJSON}
|
||||
isDisabled={nodes.length === 0}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SaveNodesButton);
|
||||
export default memo(SaveGraphButton);
|
@ -1,5 +1,5 @@
|
||||
import { UseToastOptions } from '@chakra-ui/react';
|
||||
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
|
||||
import { PayloadAction, createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { InvokeLogLevel } from 'app/logging/logger';
|
||||
import { userInvoked } from 'app/store/actions';
|
||||
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
|
||||
@ -16,13 +16,16 @@ import {
|
||||
appSocketGraphExecutionStateComplete,
|
||||
appSocketInvocationComplete,
|
||||
appSocketInvocationError,
|
||||
appSocketInvocationRetrievalError,
|
||||
appSocketInvocationStarted,
|
||||
appSocketSessionRetrievalError,
|
||||
appSocketSubscribed,
|
||||
appSocketUnsubscribed,
|
||||
} from 'services/events/actions';
|
||||
import { ProgressImage } from 'services/events/types';
|
||||
import { makeToast } from '../util/makeToast';
|
||||
import { LANGUAGES } from './constants';
|
||||
import { startCase } from 'lodash-es';
|
||||
|
||||
export type CancelStrategy = 'immediate' | 'scheduled';
|
||||
|
||||
@ -288,25 +291,6 @@ export const systemSlice = createSlice({
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation Error
|
||||
*/
|
||||
builder.addCase(appSocketInvocationError, (state) => {
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = true;
|
||||
// state.currentIteration = 0;
|
||||
// state.totalIterations = 0;
|
||||
state.currentStatusHasSteps = false;
|
||||
state.currentStep = 0;
|
||||
state.totalSteps = 0;
|
||||
state.statusTranslationKey = 'common.statusError';
|
||||
state.progressImage = null;
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({ title: t('toast.serverError'), status: 'error' })
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Graph Execution State Complete
|
||||
*/
|
||||
@ -362,7 +346,7 @@ export const systemSlice = createSlice({
|
||||
* Session Invoked - REJECTED
|
||||
* Session Created - REJECTED
|
||||
*/
|
||||
builder.addMatcher(isAnySessionRejected, (state) => {
|
||||
builder.addMatcher(isAnySessionRejected, (state, action) => {
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = false;
|
||||
state.isCancelScheduled = false;
|
||||
@ -372,7 +356,35 @@ export const systemSlice = createSlice({
|
||||
state.progressImage = null;
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({ title: t('toast.serverError'), status: 'error' })
|
||||
makeToast({
|
||||
title: t('toast.serverError'),
|
||||
status: 'error',
|
||||
description:
|
||||
action.payload?.status === 422 ? 'Validation Error' : undefined,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Any server error
|
||||
*/
|
||||
builder.addMatcher(isAnyServerError, (state, action) => {
|
||||
state.isProcessing = false;
|
||||
state.isCancelable = true;
|
||||
// state.currentIteration = 0;
|
||||
// state.totalIterations = 0;
|
||||
state.currentStatusHasSteps = false;
|
||||
state.currentStep = 0;
|
||||
state.totalSteps = 0;
|
||||
state.statusTranslationKey = 'common.statusError';
|
||||
state.progressImage = null;
|
||||
|
||||
state.toastQueue.push(
|
||||
makeToast({
|
||||
title: t('toast.serverError'),
|
||||
status: 'error',
|
||||
description: startCase(action.payload.data.error_type),
|
||||
})
|
||||
);
|
||||
});
|
||||
},
|
||||
@ -400,3 +412,9 @@ export const {
|
||||
} = systemSlice.actions;
|
||||
|
||||
export default systemSlice.reducer;
|
||||
|
||||
const isAnyServerError = isAnyOf(
|
||||
appSocketInvocationError,
|
||||
appSocketSessionRetrievalError,
|
||||
appSocketInvocationRetrievalError
|
||||
);
|
||||
|
@ -18,7 +18,7 @@ type CreateSessionResponse = O.Required<
|
||||
>;
|
||||
|
||||
type CreateSessionThunkConfig = {
|
||||
rejectValue: { arg: CreateSessionArg; error: unknown };
|
||||
rejectValue: { arg: CreateSessionArg; status: number; error: unknown };
|
||||
};
|
||||
|
||||
/**
|
||||
@ -36,7 +36,7 @@ export const sessionCreated = createAsyncThunk<
|
||||
});
|
||||
|
||||
if (error) {
|
||||
return rejectWithValue({ arg, error });
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
|
||||
return data;
|
||||
@ -53,6 +53,7 @@ type InvokedSessionThunkConfig = {
|
||||
rejectValue: {
|
||||
arg: InvokedSessionArg;
|
||||
error: unknown;
|
||||
status: number;
|
||||
};
|
||||
};
|
||||
|
||||
@ -78,9 +79,13 @@ export const sessionInvoked = createAsyncThunk<
|
||||
|
||||
if (error) {
|
||||
if (isErrorWithStatus(error) && error.status === 403) {
|
||||
return rejectWithValue({ arg, error: (error as any).body.detail });
|
||||
return rejectWithValue({
|
||||
arg,
|
||||
status: response.status,
|
||||
error: (error as any).body.detail,
|
||||
});
|
||||
}
|
||||
return rejectWithValue({ arg, error });
|
||||
return rejectWithValue({ arg, status: response.status, error });
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -4,9 +4,11 @@ import {
|
||||
GraphExecutionStateCompleteEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationRetrievalErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelLoadCompletedEvent,
|
||||
ModelLoadStartedEvent,
|
||||
SessionRetrievalErrorEvent,
|
||||
} from 'services/events/types';
|
||||
|
||||
// Create actions for each socket
|
||||
@ -181,3 +183,35 @@ export const socketModelLoadCompleted = createAction<{
|
||||
export const appSocketModelLoadCompleted = createAction<{
|
||||
data: ModelLoadCompletedEvent;
|
||||
}>('socket/appSocketModelLoadCompleted');
|
||||
|
||||
/**
|
||||
* Socket.IO Session Retrieval Error
|
||||
*
|
||||
* Do not use. Only for use in middleware.
|
||||
*/
|
||||
export const socketSessionRetrievalError = createAction<{
|
||||
data: SessionRetrievalErrorEvent;
|
||||
}>('socket/socketSessionRetrievalError');
|
||||
|
||||
/**
|
||||
* App-level Session Retrieval Error
|
||||
*/
|
||||
export const appSocketSessionRetrievalError = createAction<{
|
||||
data: SessionRetrievalErrorEvent;
|
||||
}>('socket/appSocketSessionRetrievalError');
|
||||
|
||||
/**
|
||||
* Socket.IO Invocation Retrieval Error
|
||||
*
|
||||
* Do not use. Only for use in middleware.
|
||||
*/
|
||||
export const socketInvocationRetrievalError = createAction<{
|
||||
data: InvocationRetrievalErrorEvent;
|
||||
}>('socket/socketInvocationRetrievalError');
|
||||
|
||||
/**
|
||||
* App-level Invocation Retrieval Error
|
||||
*/
|
||||
export const appSocketInvocationRetrievalError = createAction<{
|
||||
data: InvocationRetrievalErrorEvent;
|
||||
}>('socket/appSocketInvocationRetrievalError');
|
||||
|
@ -87,6 +87,7 @@ export type InvocationErrorEvent = {
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
@ -110,6 +111,29 @@ export type GraphExecutionStateCompleteEvent = {
|
||||
graph_execution_state_id: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `session_retrieval_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type SessionRetrievalErrorEvent = {
|
||||
graph_execution_state_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `invocation_retrieval_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type InvocationRetrievalErrorEvent = {
|
||||
graph_execution_state_id: string;
|
||||
node_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
export type ClientEmitSubscribe = {
|
||||
session: string;
|
||||
};
|
||||
@ -128,6 +152,8 @@ export type ServerToClientEvents = {
|
||||
) => void;
|
||||
model_load_started: (payload: ModelLoadStartedEvent) => void;
|
||||
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
|
||||
session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void;
|
||||
invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
|
||||
};
|
||||
|
||||
export type ClientToServerEvents = {
|
||||
|
@ -11,9 +11,11 @@ import {
|
||||
socketGraphExecutionStateComplete,
|
||||
socketInvocationComplete,
|
||||
socketInvocationError,
|
||||
socketInvocationRetrievalError,
|
||||
socketInvocationStarted,
|
||||
socketModelLoadCompleted,
|
||||
socketModelLoadStarted,
|
||||
socketSessionRetrievalError,
|
||||
socketSubscribed,
|
||||
} from '../actions';
|
||||
import { ClientToServerEvents, ServerToClientEvents } from '../types';
|
||||
@ -138,4 +140,26 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Session retrieval error
|
||||
*/
|
||||
socket.on('session_retrieval_error', (data) => {
|
||||
dispatch(
|
||||
socketSessionRetrievalError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation retrieval error
|
||||
*/
|
||||
socket.on('invocation_retrieval_error', (data) => {
|
||||
dispatch(
|
||||
socketInvocationRetrievalError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user