mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into fix-types-2
This commit is contained in:
commit
d6bf6513ef
@ -119,8 +119,8 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@validator("seed", pre=True)
|
@validator("seed", pre=True)
|
||||||
def modulo_seed(cls, v):
|
def modulo_seed(cls, v):
|
||||||
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
|
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
return v % SEED_MAX
|
return v % (SEED_MAX + 1)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
noise = get_noise(
|
noise = get_noise(
|
||||||
|
@ -14,8 +14,9 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
|
|||||||
return datetime.datetime.fromisoformat(iso_timestamp)
|
return datetime.datetime.fromisoformat(iso_timestamp)
|
||||||
|
|
||||||
|
|
||||||
SEED_MAX = np.iinfo(np.int32).max
|
SEED_MAX = np.iinfo(np.uint32).max
|
||||||
|
|
||||||
|
|
||||||
def get_random_seed():
|
def get_random_seed():
|
||||||
return np.random.randint(0, SEED_MAX)
|
rng = np.random.default_rng(seed=0)
|
||||||
|
return int(rng.integers(0, SEED_MAX))
|
||||||
|
@ -474,7 +474,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _lora_forward_hook(
|
def _lora_forward_hook(
|
||||||
applied_loras: List[Tuple[LoraModel, float]],
|
applied_loras: List[Tuple[LoRAModel, float]],
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
):
|
):
|
||||||
|
|
||||||
@ -519,7 +519,7 @@ class ModelPatcher:
|
|||||||
def apply_lora(
|
def apply_lora(
|
||||||
cls,
|
cls,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
loras: List[Tuple[LoraModel, float]],
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
):
|
):
|
||||||
original_weights = dict()
|
original_weights = dict()
|
||||||
|
@ -10,6 +10,7 @@ from .base import (
|
|||||||
SubModelType,
|
SubModelType,
|
||||||
classproperty,
|
classproperty,
|
||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
# TODO: naming
|
# TODO: naming
|
||||||
from ..lora import LoRAModel as LoRAModelRaw
|
from ..lora import LoRAModel as LoRAModelRaw
|
||||||
|
@ -102,8 +102,7 @@
|
|||||||
"openInNewTab": "Open in New Tab",
|
"openInNewTab": "Open in New Tab",
|
||||||
"dontAskMeAgain": "Don't ask me again",
|
"dontAskMeAgain": "Don't ask me again",
|
||||||
"areYouSure": "Are you sure?",
|
"areYouSure": "Are you sure?",
|
||||||
"imagePrompt": "Image Prompt",
|
"imagePrompt": "Image Prompt"
|
||||||
"clearNodes": "Are you sure you want to clear all nodes?"
|
|
||||||
},
|
},
|
||||||
"gallery": {
|
"gallery": {
|
||||||
"generations": "Generations",
|
"generations": "Generations",
|
||||||
@ -615,6 +614,11 @@
|
|||||||
"initialImageNotSetDesc": "Could not load initial image",
|
"initialImageNotSetDesc": "Could not load initial image",
|
||||||
"nodesSaved": "Nodes Saved",
|
"nodesSaved": "Nodes Saved",
|
||||||
"nodesLoaded": "Nodes Loaded",
|
"nodesLoaded": "Nodes Loaded",
|
||||||
|
"nodesNotValidGraph": "Not a valid InvokeAI Node Graph",
|
||||||
|
"nodesNotValidJSON": "Not a valid JSON",
|
||||||
|
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
|
||||||
|
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
|
||||||
|
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
|
||||||
"nodesLoadedFailed": "Failed To Load Nodes",
|
"nodesLoadedFailed": "Failed To Load Nodes",
|
||||||
"nodesCleared": "Nodes Cleared"
|
"nodesCleared": "Nodes Cleared"
|
||||||
},
|
},
|
||||||
@ -700,9 +704,10 @@
|
|||||||
},
|
},
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"reloadSchema": "Reload Schema",
|
"reloadSchema": "Reload Schema",
|
||||||
"saveNodes": "Save Nodes",
|
"saveGraph": "Save Graph",
|
||||||
"loadNodes": "Load Nodes",
|
"loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
|
||||||
"clearNodes": "Clear Nodes",
|
"clearGraph": "Clear Graph",
|
||||||
|
"clearGraphDesc": "Are you sure you want to clear all nodes?",
|
||||||
"zoomInNodes": "Zoom In",
|
"zoomInNodes": "Zoom In",
|
||||||
"zoomOutNodes": "Zoom Out",
|
"zoomOutNodes": "Zoom Out",
|
||||||
"fitViewportNodes": "Fit View",
|
"fitViewportNodes": "Fit View",
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
export const NUMPY_RAND_MIN = 0;
|
export const NUMPY_RAND_MIN = 0;
|
||||||
export const NUMPY_RAND_MAX = 2147483647;
|
export const NUMPY_RAND_MAX = 4294967295;
|
||||||
|
@ -2,11 +2,11 @@ import { HStack } from '@chakra-ui/react';
|
|||||||
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
|
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
|
||||||
import { memo } from 'react';
|
import { memo } from 'react';
|
||||||
import { Panel } from 'reactflow';
|
import { Panel } from 'reactflow';
|
||||||
import LoadNodesButton from '../ui/LoadNodesButton';
|
import ClearGraphButton from '../ui/ClearGraphButton';
|
||||||
|
import LoadGraphButton from '../ui/LoadGraphButton';
|
||||||
import NodeInvokeButton from '../ui/NodeInvokeButton';
|
import NodeInvokeButton from '../ui/NodeInvokeButton';
|
||||||
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
|
import ReloadSchemaButton from '../ui/ReloadSchemaButton';
|
||||||
import SaveNodesButton from '../ui/SaveNodesButton';
|
import SaveGraphButton from '../ui/SaveGraphButton';
|
||||||
import ClearNodesButton from '../ui/ClearNodesButton';
|
|
||||||
|
|
||||||
const TopCenterPanel = () => {
|
const TopCenterPanel = () => {
|
||||||
return (
|
return (
|
||||||
@ -15,9 +15,9 @@ const TopCenterPanel = () => {
|
|||||||
<NodeInvokeButton />
|
<NodeInvokeButton />
|
||||||
<CancelButton />
|
<CancelButton />
|
||||||
<ReloadSchemaButton />
|
<ReloadSchemaButton />
|
||||||
<SaveNodesButton />
|
<SaveGraphButton />
|
||||||
<LoadNodesButton />
|
<LoadGraphButton />
|
||||||
<ClearNodesButton />
|
<ClearGraphButton />
|
||||||
</HStack>
|
</HStack>
|
||||||
</Panel>
|
</Panel>
|
||||||
);
|
);
|
||||||
|
@ -9,17 +9,17 @@ import {
|
|||||||
Text,
|
Text,
|
||||||
useDisclosure,
|
useDisclosure,
|
||||||
} from '@chakra-ui/react';
|
} from '@chakra-ui/react';
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
import { memo, useRef, useCallback } from 'react';
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import { memo, useCallback, useRef } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaTrash } from 'react-icons/fa';
|
import { FaTrash } from 'react-icons/fa';
|
||||||
|
|
||||||
const ClearNodesButton = () => {
|
const ClearGraphButton = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { isOpen, onOpen, onClose } = useDisclosure();
|
const { isOpen, onOpen, onClose } = useDisclosure();
|
||||||
@ -46,8 +46,8 @@ const ClearNodesButton = () => {
|
|||||||
<>
|
<>
|
||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<FaTrash />}
|
icon={<FaTrash />}
|
||||||
tooltip={t('nodes.clearNodes')}
|
tooltip={t('nodes.clearGraph')}
|
||||||
aria-label={t('nodes.clearNodes')}
|
aria-label={t('nodes.clearGraph')}
|
||||||
onClick={onOpen}
|
onClick={onOpen}
|
||||||
isDisabled={nodes.length === 0}
|
isDisabled={nodes.length === 0}
|
||||||
/>
|
/>
|
||||||
@ -62,11 +62,11 @@ const ClearNodesButton = () => {
|
|||||||
|
|
||||||
<AlertDialogContent>
|
<AlertDialogContent>
|
||||||
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
<AlertDialogHeader fontSize="lg" fontWeight="bold">
|
||||||
{t('nodes.clearNodes')}
|
{t('nodes.clearGraph')}
|
||||||
</AlertDialogHeader>
|
</AlertDialogHeader>
|
||||||
|
|
||||||
<AlertDialogBody>
|
<AlertDialogBody>
|
||||||
<Text>{t('common.clearNodes')}</Text>
|
<Text>{t('nodes.clearGraphDesc')}</Text>
|
||||||
</AlertDialogBody>
|
</AlertDialogBody>
|
||||||
|
|
||||||
<AlertDialogFooter>
|
<AlertDialogFooter>
|
||||||
@ -83,4 +83,4 @@ const ClearNodesButton = () => {
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(ClearNodesButton);
|
export default memo(ClearGraphButton);
|
@ -0,0 +1,161 @@
|
|||||||
|
import { FileButton } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIIconButton from 'common/components/IAIIconButton';
|
||||||
|
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
|
||||||
|
import { addToast } from 'features/system/store/systemSlice';
|
||||||
|
import { makeToast } from 'features/system/util/makeToast';
|
||||||
|
import i18n from 'i18n';
|
||||||
|
import { memo, useCallback, useRef } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { FaUpload } from 'react-icons/fa';
|
||||||
|
import { useReactFlow } from 'reactflow';
|
||||||
|
|
||||||
|
interface JsonFile {
|
||||||
|
[key: string]: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
function sanityCheckInvokeAIGraph(jsonFile: JsonFile): {
|
||||||
|
isValid: boolean;
|
||||||
|
message: string;
|
||||||
|
} {
|
||||||
|
// Check if primary keys exist
|
||||||
|
const keys = ['nodes', 'edges', 'viewport'];
|
||||||
|
for (const key of keys) {
|
||||||
|
if (!(key in jsonFile)) {
|
||||||
|
return {
|
||||||
|
isValid: false,
|
||||||
|
message: i18n.t('toast.nodesNotValidGraph'),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if nodes and edges are arrays
|
||||||
|
if (!Array.isArray(jsonFile.nodes) || !Array.isArray(jsonFile.edges)) {
|
||||||
|
return {
|
||||||
|
isValid: false,
|
||||||
|
message: i18n.t('toast.nodesNotValidGraph'),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if data is present in nodes
|
||||||
|
const nodeKeys = ['data', 'type'];
|
||||||
|
const nodeTypes = ['invocation', 'progress_image'];
|
||||||
|
if (jsonFile.nodes.length > 0) {
|
||||||
|
for (const node of jsonFile.nodes) {
|
||||||
|
for (const nodeKey of nodeKeys) {
|
||||||
|
if (!(nodeKey in node)) {
|
||||||
|
return {
|
||||||
|
isValid: false,
|
||||||
|
message: i18n.t('toast.nodesNotValidGraph'),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (nodeKey === 'type' && !nodeTypes.includes(node[nodeKey])) {
|
||||||
|
return {
|
||||||
|
isValid: false,
|
||||||
|
message: i18n.t('toast.nodesUnrecognizedTypes'),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Edge Object
|
||||||
|
const edgeKeys = ['source', 'sourceHandle', 'target', 'targetHandle'];
|
||||||
|
if (jsonFile.edges.length > 0) {
|
||||||
|
for (const edge of jsonFile.edges) {
|
||||||
|
for (const edgeKey of edgeKeys) {
|
||||||
|
if (!(edgeKey in edge)) {
|
||||||
|
return {
|
||||||
|
isValid: false,
|
||||||
|
message: i18n.t('toast.nodesBrokenConnections'),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
isValid: true,
|
||||||
|
message: i18n.t('toast.nodesLoaded'),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const LoadGraphButton = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { fitView } = useReactFlow();
|
||||||
|
|
||||||
|
const uploadedFileRef = useRef<() => void>(null);
|
||||||
|
|
||||||
|
const restoreJSONToEditor = useCallback(
|
||||||
|
(v: File | null) => {
|
||||||
|
if (!v) return;
|
||||||
|
const reader = new FileReader();
|
||||||
|
reader.onload = async () => {
|
||||||
|
const json = reader.result;
|
||||||
|
|
||||||
|
try {
|
||||||
|
const retrievedNodeTree = await JSON.parse(String(json));
|
||||||
|
const { isValid, message } =
|
||||||
|
sanityCheckInvokeAIGraph(retrievedNodeTree);
|
||||||
|
|
||||||
|
if (isValid) {
|
||||||
|
dispatch(loadFileNodes(retrievedNodeTree.nodes));
|
||||||
|
dispatch(loadFileEdges(retrievedNodeTree.edges));
|
||||||
|
fitView();
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
addToast(makeToast({ title: message, status: 'success' }))
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: message,
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Cleanup
|
||||||
|
reader.abort();
|
||||||
|
} catch (error) {
|
||||||
|
if (error) {
|
||||||
|
dispatch(
|
||||||
|
addToast(
|
||||||
|
makeToast({
|
||||||
|
title: t('toast.nodesNotValidJSON'),
|
||||||
|
status: 'error',
|
||||||
|
})
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
reader.readAsText(v);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
uploadedFileRef.current?.();
|
||||||
|
},
|
||||||
|
[fitView, dispatch, t]
|
||||||
|
);
|
||||||
|
return (
|
||||||
|
<FileButton
|
||||||
|
resetRef={uploadedFileRef}
|
||||||
|
accept="application/json"
|
||||||
|
onChange={restoreJSONToEditor}
|
||||||
|
>
|
||||||
|
{(props) => (
|
||||||
|
<IAIIconButton
|
||||||
|
icon={<FaUpload />}
|
||||||
|
tooltip={t('nodes.loadGraph')}
|
||||||
|
aria-label={t('nodes.loadGraph')}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</FileButton>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(LoadGraphButton);
|
@ -1,79 +0,0 @@
|
|||||||
import { FileButton } from '@mantine/core';
|
|
||||||
import { makeToast } from 'features/system/util/makeToast';
|
|
||||||
import { useAppDispatch } from 'app/store/storeHooks';
|
|
||||||
import IAIIconButton from 'common/components/IAIIconButton';
|
|
||||||
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { addToast } from 'features/system/store/systemSlice';
|
|
||||||
import { memo, useCallback, useRef } from 'react';
|
|
||||||
import { useTranslation } from 'react-i18next';
|
|
||||||
import { FaUpload } from 'react-icons/fa';
|
|
||||||
import { useReactFlow } from 'reactflow';
|
|
||||||
|
|
||||||
const LoadNodesButton = () => {
|
|
||||||
const { t } = useTranslation();
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
const { fitView } = useReactFlow();
|
|
||||||
|
|
||||||
const uploadedFileRef = useRef<() => void>(null);
|
|
||||||
|
|
||||||
const restoreJSONToEditor = useCallback(
|
|
||||||
(v: File | null) => {
|
|
||||||
if (!v) return;
|
|
||||||
const reader = new FileReader();
|
|
||||||
reader.onload = async () => {
|
|
||||||
const json = reader.result;
|
|
||||||
const retrievedNodeTree = await JSON.parse(String(json));
|
|
||||||
|
|
||||||
if (!retrievedNodeTree) {
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({
|
|
||||||
title: t('toast.nodesLoadedFailed'),
|
|
||||||
status: 'error',
|
|
||||||
})
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (retrievedNodeTree) {
|
|
||||||
dispatch(loadFileNodes(retrievedNodeTree.nodes));
|
|
||||||
dispatch(loadFileEdges(retrievedNodeTree.edges));
|
|
||||||
fitView();
|
|
||||||
|
|
||||||
dispatch(
|
|
||||||
addToast(
|
|
||||||
makeToast({ title: t('toast.nodesLoaded'), status: 'success' })
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
reader.abort();
|
|
||||||
};
|
|
||||||
|
|
||||||
reader.readAsText(v);
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
uploadedFileRef.current?.();
|
|
||||||
},
|
|
||||||
[fitView, dispatch, t]
|
|
||||||
);
|
|
||||||
return (
|
|
||||||
<FileButton
|
|
||||||
resetRef={uploadedFileRef}
|
|
||||||
accept="application/json"
|
|
||||||
onChange={restoreJSONToEditor}
|
|
||||||
>
|
|
||||||
{(props) => (
|
|
||||||
<IAIIconButton
|
|
||||||
icon={<FaUpload />}
|
|
||||||
tooltip={t('nodes.loadNodes')}
|
|
||||||
aria-label={t('nodes.loadNodes')}
|
|
||||||
{...props}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</FileButton>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default memo(LoadNodesButton);
|
|
@ -6,7 +6,7 @@ import { memo, useCallback } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { FaSave } from 'react-icons/fa';
|
import { FaSave } from 'react-icons/fa';
|
||||||
|
|
||||||
const SaveNodesButton = () => {
|
const SaveGraphButton = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const editorInstance = useAppSelector(
|
const editorInstance = useAppSelector(
|
||||||
(state: RootState) => state.nodes.editorInstance
|
(state: RootState) => state.nodes.editorInstance
|
||||||
@ -37,12 +37,12 @@ const SaveNodesButton = () => {
|
|||||||
<IAIIconButton
|
<IAIIconButton
|
||||||
icon={<FaSave />}
|
icon={<FaSave />}
|
||||||
fontSize={18}
|
fontSize={18}
|
||||||
tooltip={t('nodes.saveNodes')}
|
tooltip={t('nodes.saveGraph')}
|
||||||
aria-label={t('nodes.saveNodes')}
|
aria-label={t('nodes.saveGraph')}
|
||||||
onClick={saveEditorToJSON}
|
onClick={saveEditorToJSON}
|
||||||
isDisabled={nodes.length === 0}
|
isDisabled={nodes.length === 0}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default memo(SaveNodesButton);
|
export default memo(SaveGraphButton);
|
Loading…
Reference in New Issue
Block a user