Merge branch 'main' into fix-types-2

This commit is contained in:
blessedcoolant 2023-07-24 20:01:48 +12:00 committed by GitHub
commit d6bf6513ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 198 additions and 109 deletions

View File

@ -119,8 +119,8 @@ class NoiseInvocation(BaseInvocation):
@validator("seed", pre=True) @validator("seed", pre=True)
def modulo_seed(cls, v): def modulo_seed(cls, v):
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range.""" """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % SEED_MAX return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> NoiseOutput: def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise( noise = get_noise(

View File

@ -14,8 +14,9 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(iso_timestamp) return datetime.datetime.fromisoformat(iso_timestamp)
SEED_MAX = np.iinfo(np.int32).max SEED_MAX = np.iinfo(np.uint32).max
def get_random_seed(): def get_random_seed():
return np.random.randint(0, SEED_MAX) rng = np.random.default_rng(seed=0)
return int(rng.integers(0, SEED_MAX))

View File

@ -474,7 +474,7 @@ class ModelPatcher:
@staticmethod @staticmethod
def _lora_forward_hook( def _lora_forward_hook(
applied_loras: List[Tuple[LoraModel, float]], applied_loras: List[Tuple[LoRAModel, float]],
layer_name: str, layer_name: str,
): ):
@ -519,7 +519,7 @@ class ModelPatcher:
def apply_lora( def apply_lora(
cls, cls,
model: torch.nn.Module, model: torch.nn.Module,
loras: List[Tuple[LoraModel, float]], loras: List[Tuple[LoRAModel, float]],
prefix: str, prefix: str,
): ):
original_weights = dict() original_weights = dict()

View File

@ -10,6 +10,7 @@ from .base import (
SubModelType, SubModelType,
classproperty, classproperty,
InvalidModelException, InvalidModelException,
ModelNotFoundException,
) )
# TODO: naming # TODO: naming
from ..lora import LoRAModel as LoRAModelRaw from ..lora import LoRAModel as LoRAModelRaw

View File

@ -102,8 +102,7 @@
"openInNewTab": "Open in New Tab", "openInNewTab": "Open in New Tab",
"dontAskMeAgain": "Don't ask me again", "dontAskMeAgain": "Don't ask me again",
"areYouSure": "Are you sure?", "areYouSure": "Are you sure?",
"imagePrompt": "Image Prompt", "imagePrompt": "Image Prompt"
"clearNodes": "Are you sure you want to clear all nodes?"
}, },
"gallery": { "gallery": {
"generations": "Generations", "generations": "Generations",
@ -615,6 +614,11 @@
"initialImageNotSetDesc": "Could not load initial image", "initialImageNotSetDesc": "Could not load initial image",
"nodesSaved": "Nodes Saved", "nodesSaved": "Nodes Saved",
"nodesLoaded": "Nodes Loaded", "nodesLoaded": "Nodes Loaded",
"nodesNotValidGraph": "Not a valid InvokeAI Node Graph",
"nodesNotValidJSON": "Not a valid JSON",
"nodesCorruptedGraph": "Cannot load. Graph seems to be corrupted.",
"nodesUnrecognizedTypes": "Cannot load. Graph has unrecognized types",
"nodesBrokenConnections": "Cannot load. Some connections are broken.",
"nodesLoadedFailed": "Failed To Load Nodes", "nodesLoadedFailed": "Failed To Load Nodes",
"nodesCleared": "Nodes Cleared" "nodesCleared": "Nodes Cleared"
}, },
@ -700,9 +704,10 @@
}, },
"nodes": { "nodes": {
"reloadSchema": "Reload Schema", "reloadSchema": "Reload Schema",
"saveNodes": "Save Nodes", "saveGraph": "Save Graph",
"loadNodes": "Load Nodes", "loadGraph": "Load Graph (saved from Node Editor) (Do not copy-paste metadata)",
"clearNodes": "Clear Nodes", "clearGraph": "Clear Graph",
"clearGraphDesc": "Are you sure you want to clear all nodes?",
"zoomInNodes": "Zoom In", "zoomInNodes": "Zoom In",
"zoomOutNodes": "Zoom Out", "zoomOutNodes": "Zoom Out",
"fitViewportNodes": "Fit View", "fitViewportNodes": "Fit View",

View File

@ -1,2 +1,2 @@
export const NUMPY_RAND_MIN = 0; export const NUMPY_RAND_MIN = 0;
export const NUMPY_RAND_MAX = 2147483647; export const NUMPY_RAND_MAX = 4294967295;

View File

@ -2,11 +2,11 @@ import { HStack } from '@chakra-ui/react';
import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton'; import CancelButton from 'features/parameters/components/ProcessButtons/CancelButton';
import { memo } from 'react'; import { memo } from 'react';
import { Panel } from 'reactflow'; import { Panel } from 'reactflow';
import LoadNodesButton from '../ui/LoadNodesButton'; import ClearGraphButton from '../ui/ClearGraphButton';
import LoadGraphButton from '../ui/LoadGraphButton';
import NodeInvokeButton from '../ui/NodeInvokeButton'; import NodeInvokeButton from '../ui/NodeInvokeButton';
import ReloadSchemaButton from '../ui/ReloadSchemaButton'; import ReloadSchemaButton from '../ui/ReloadSchemaButton';
import SaveNodesButton from '../ui/SaveNodesButton'; import SaveGraphButton from '../ui/SaveGraphButton';
import ClearNodesButton from '../ui/ClearNodesButton';
const TopCenterPanel = () => { const TopCenterPanel = () => {
return ( return (
@ -15,9 +15,9 @@ const TopCenterPanel = () => {
<NodeInvokeButton /> <NodeInvokeButton />
<CancelButton /> <CancelButton />
<ReloadSchemaButton /> <ReloadSchemaButton />
<SaveNodesButton /> <SaveGraphButton />
<LoadNodesButton /> <LoadGraphButton />
<ClearNodesButton /> <ClearGraphButton />
</HStack> </HStack>
</Panel> </Panel>
); );

View File

@ -9,17 +9,17 @@ import {
Text, Text,
useDisclosure, useDisclosure,
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { makeToast } from 'features/system/util/makeToast';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton'; import IAIIconButton from 'common/components/IAIIconButton';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice'; import { addToast } from 'features/system/store/systemSlice';
import { memo, useRef, useCallback } from 'react'; import { makeToast } from 'features/system/util/makeToast';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa'; import { FaTrash } from 'react-icons/fa';
const ClearNodesButton = () => { const ClearGraphButton = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
@ -46,8 +46,8 @@ const ClearNodesButton = () => {
<> <>
<IAIIconButton <IAIIconButton
icon={<FaTrash />} icon={<FaTrash />}
tooltip={t('nodes.clearNodes')} tooltip={t('nodes.clearGraph')}
aria-label={t('nodes.clearNodes')} aria-label={t('nodes.clearGraph')}
onClick={onOpen} onClick={onOpen}
isDisabled={nodes.length === 0} isDisabled={nodes.length === 0}
/> />
@ -62,11 +62,11 @@ const ClearNodesButton = () => {
<AlertDialogContent> <AlertDialogContent>
<AlertDialogHeader fontSize="lg" fontWeight="bold"> <AlertDialogHeader fontSize="lg" fontWeight="bold">
{t('nodes.clearNodes')} {t('nodes.clearGraph')}
</AlertDialogHeader> </AlertDialogHeader>
<AlertDialogBody> <AlertDialogBody>
<Text>{t('common.clearNodes')}</Text> <Text>{t('nodes.clearGraphDesc')}</Text>
</AlertDialogBody> </AlertDialogBody>
<AlertDialogFooter> <AlertDialogFooter>
@ -83,4 +83,4 @@ const ClearNodesButton = () => {
); );
}; };
export default memo(ClearNodesButton); export default memo(ClearGraphButton);

View File

@ -0,0 +1,161 @@
import { FileButton } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import i18n from 'i18n';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUpload } from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
interface JsonFile {
[key: string]: unknown;
}
function sanityCheckInvokeAIGraph(jsonFile: JsonFile): {
isValid: boolean;
message: string;
} {
// Check if primary keys exist
const keys = ['nodes', 'edges', 'viewport'];
for (const key of keys) {
if (!(key in jsonFile)) {
return {
isValid: false,
message: i18n.t('toast.nodesNotValidGraph'),
};
}
}
// Check if nodes and edges are arrays
if (!Array.isArray(jsonFile.nodes) || !Array.isArray(jsonFile.edges)) {
return {
isValid: false,
message: i18n.t('toast.nodesNotValidGraph'),
};
}
// Check if data is present in nodes
const nodeKeys = ['data', 'type'];
const nodeTypes = ['invocation', 'progress_image'];
if (jsonFile.nodes.length > 0) {
for (const node of jsonFile.nodes) {
for (const nodeKey of nodeKeys) {
if (!(nodeKey in node)) {
return {
isValid: false,
message: i18n.t('toast.nodesNotValidGraph'),
};
}
if (nodeKey === 'type' && !nodeTypes.includes(node[nodeKey])) {
return {
isValid: false,
message: i18n.t('toast.nodesUnrecognizedTypes'),
};
}
}
}
}
// Check Edge Object
const edgeKeys = ['source', 'sourceHandle', 'target', 'targetHandle'];
if (jsonFile.edges.length > 0) {
for (const edge of jsonFile.edges) {
for (const edgeKey of edgeKeys) {
if (!(edgeKey in edge)) {
return {
isValid: false,
message: i18n.t('toast.nodesBrokenConnections'),
};
}
}
}
}
return {
isValid: true,
message: i18n.t('toast.nodesLoaded'),
};
}
const LoadGraphButton = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { fitView } = useReactFlow();
const uploadedFileRef = useRef<() => void>(null);
const restoreJSONToEditor = useCallback(
(v: File | null) => {
if (!v) return;
const reader = new FileReader();
reader.onload = async () => {
const json = reader.result;
try {
const retrievedNodeTree = await JSON.parse(String(json));
const { isValid, message } =
sanityCheckInvokeAIGraph(retrievedNodeTree);
if (isValid) {
dispatch(loadFileNodes(retrievedNodeTree.nodes));
dispatch(loadFileEdges(retrievedNodeTree.edges));
fitView();
dispatch(
addToast(makeToast({ title: message, status: 'success' }))
);
} else {
dispatch(
addToast(
makeToast({
title: message,
status: 'error',
})
)
);
}
// Cleanup
reader.abort();
} catch (error) {
if (error) {
dispatch(
addToast(
makeToast({
title: t('toast.nodesNotValidJSON'),
status: 'error',
})
)
);
}
}
};
reader.readAsText(v);
// Cleanup
uploadedFileRef.current?.();
},
[fitView, dispatch, t]
);
return (
<FileButton
resetRef={uploadedFileRef}
accept="application/json"
onChange={restoreJSONToEditor}
>
{(props) => (
<IAIIconButton
icon={<FaUpload />}
tooltip={t('nodes.loadGraph')}
aria-label={t('nodes.loadGraph')}
{...props}
/>
)}
</FileButton>
);
};
export default memo(LoadGraphButton);

View File

@ -1,79 +0,0 @@
import { FileButton } from '@mantine/core';
import { makeToast } from 'features/system/util/makeToast';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIIconButton from 'common/components/IAIIconButton';
import { loadFileEdges, loadFileNodes } from 'features/nodes/store/nodesSlice';
import { addToast } from 'features/system/store/systemSlice';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUpload } from 'react-icons/fa';
import { useReactFlow } from 'reactflow';
const LoadNodesButton = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { fitView } = useReactFlow();
const uploadedFileRef = useRef<() => void>(null);
const restoreJSONToEditor = useCallback(
(v: File | null) => {
if (!v) return;
const reader = new FileReader();
reader.onload = async () => {
const json = reader.result;
const retrievedNodeTree = await JSON.parse(String(json));
if (!retrievedNodeTree) {
dispatch(
addToast(
makeToast({
title: t('toast.nodesLoadedFailed'),
status: 'error',
})
)
);
}
if (retrievedNodeTree) {
dispatch(loadFileNodes(retrievedNodeTree.nodes));
dispatch(loadFileEdges(retrievedNodeTree.edges));
fitView();
dispatch(
addToast(
makeToast({ title: t('toast.nodesLoaded'), status: 'success' })
)
);
}
// Cleanup
reader.abort();
};
reader.readAsText(v);
// Cleanup
uploadedFileRef.current?.();
},
[fitView, dispatch, t]
);
return (
<FileButton
resetRef={uploadedFileRef}
accept="application/json"
onChange={restoreJSONToEditor}
>
{(props) => (
<IAIIconButton
icon={<FaUpload />}
tooltip={t('nodes.loadNodes')}
aria-label={t('nodes.loadNodes')}
{...props}
/>
)}
</FileButton>
);
};
export default memo(LoadNodesButton);

View File

@ -6,7 +6,7 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { FaSave } from 'react-icons/fa'; import { FaSave } from 'react-icons/fa';
const SaveNodesButton = () => { const SaveGraphButton = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const editorInstance = useAppSelector( const editorInstance = useAppSelector(
(state: RootState) => state.nodes.editorInstance (state: RootState) => state.nodes.editorInstance
@ -37,12 +37,12 @@ const SaveNodesButton = () => {
<IAIIconButton <IAIIconButton
icon={<FaSave />} icon={<FaSave />}
fontSize={18} fontSize={18}
tooltip={t('nodes.saveNodes')} tooltip={t('nodes.saveGraph')}
aria-label={t('nodes.saveNodes')} aria-label={t('nodes.saveGraph')}
onClick={saveEditorToJSON} onClick={saveEditorToJSON}
isDisabled={nodes.length === 0} isDisabled={nodes.length === 0}
/> />
); );
}; };
export default memo(SaveNodesButton); export default memo(SaveGraphButton);