Merge branch 'sdxl-refiner-gradient-mask' of https://github.com/blessedcoolant/invokeai into sdxl-refiner-gradient-mask

This commit is contained in:
Kent Keirsey
2024-04-08 20:48:15 -04:00
36 changed files with 419 additions and 84 deletions

View File

@ -40,6 +40,25 @@ Follow the same steps to scan and import the missing models.
- Check the `ram` setting in `invokeai.yaml`. This setting tells Invoke how much of your system RAM can be used to cache models. Having this too high or too low can slow things down. That said, it's generally safest to not set this at all and instead let Invoke manage it.
- Check the `vram` setting in `invokeai.yaml`. This setting tells Invoke how much of your GPU VRAM can be used to cache models. Counter-intuitively, if this setting is too high, Invoke will need to do a lot of shuffling of models as it juggles the VRAM cache and the currently-loaded model. The default value of 0.25 is generally works well for GPUs without 16GB or more VRAM. Even on a 24GB card, the default works well.
- Check that your generations are happening on your GPU (if you have one). InvokeAI will log what is being used for generation upon startup. If your GPU isn't used, re-install to ensure the correct versions of torch get installed.
- If you are on Windows, you may have exceeded your GPU's VRAM capacity and are using slower [shared GPU memory](#shared-gpu-memory-windows). There's a guide to opt out of this behaviour in the linked FAQ entry.
## Shared GPU Memory (Windows)
!!! tip "Nvidia GPUs with driver 536.40"
This only applies to current Nvidia cards with driver 536.40 or later, released in June 2023.
When the GPU doesn't have enough VRAM for a task, Windows is able to allocate some of its CPU RAM to the GPU. This is much slower than VRAM, but it does allow the system to generate when it otherwise might no have enough VRAM.
When shared GPU memory is used, generation slows down dramatically - but at least it doesn't crash.
If you'd like to opt out of this behavior and instead get an error when you exceed your GPU's VRAM, follow [this guide from Nvidia](https://nvidia.custhelp.com/app/answers/detail/a_id/5490).
Here's how to get the python path required in the linked guide:
- Run `invoke.bat`.
- Select option 2 for developer console.
- At least one python path will be printed. Copy the path that includes your invoke installation directory (typically the first).
## Installer cannot find python (Windows)

View File

@ -67,7 +67,7 @@ class IPAdapterInvocation(BaseInvocation):
)
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
default="auto",
default="ViT-H",
ui_order=2,
)
weight: Union[float, List[float]] = InputField(

View File

@ -245,6 +245,18 @@ class ImagesInterface(InvocationContextInterface):
"""
return self._services.images.get_dto(image_name)
def get_path(self, image_name: str, thumbnail: bool = False) -> Path:
"""Gets the internal path to an image or thumbnail.
Args:
image_name: The name of the image to get the path of.
thumbnail: Get the path of the thumbnail instead of the full image
Returns:
The local path of the image or thumbnail.
"""
return self._services.images.get_path(image_name, thumbnail)
class TensorsInterface(InvocationContextInterface):
def save(self, tensor: Tensor) -> str:

View File

@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
self._torch_dtype = torch_dtype(choose_torch_device())
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""

View File

@ -117,7 +117,7 @@ class ModelCacheBase(ABC, Generic[T]):
@property
@abstractmethod
def stats(self) -> CacheStats:
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
pass

View File

@ -269,9 +269,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
if torch.device(source_device).type == torch.device(target_device).type:
return
# may raise an exception here if insufficient GPU VRAM
self._check_free_vram(target_device, cache_entry.size)
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
@ -329,11 +326,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def make_room(self, model_size: int) -> None:
def make_room(self, size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
bytes_needed = size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self.cache_size()
@ -388,7 +385,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
@ -420,24 +417,3 @@ class ModelCache(ModelCacheBase[AnyModel]):
mps.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _free_vram(self, device: torch.device) -> int:
vram_device = ( # mem_get_info() needs an indexed device
device if device.index is not None else torch.device(str(device), index=0)
)
free_mem, _ = torch.cuda.mem_get_info(vram_device)
for _, cache_entry in self._cached_models.items():
if cache_entry.loaded and not cache_entry.locked:
free_mem += cache_entry.size
return free_mem
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
if target_device.type != "cuda":
return
free_mem = self._free_vram(target_device)
if needed_size > free_mem:
needed_gb = round(needed_size / GIG, 2)
free_gb = round(free_mem / GIG, 2)
raise torch.cuda.OutOfMemoryError(
f"Insufficient VRAM to load model, requested {needed_gb}GB but only had {free_gb}GB free"
)

View File

@ -6,8 +6,7 @@ from typing import Literal, Optional, Union
import torch
from torch import autocast
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.config.config_default import PRECISION, get_config
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
@ -33,35 +32,34 @@ def get_torch_device_name() -> str:
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
# We are in transition here from using a single global AppConfig to allowing multiple
# configurations. It is strongly recommended to pass the app_config to this function.
def choose_precision(
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
) -> Literal["float32", "float16", "bfloat16"]:
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
"""Return an appropriate precision for the given torch device."""
app_config = app_config or get_config()
app_config = get_config()
if device.type == "cuda":
device_name = torch.cuda.get_device_name(device)
if not ("GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name):
if app_config.precision == "float32":
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
# These GPUs have limited support for float16
return "float32"
elif app_config.precision == "bfloat16":
return "bfloat16"
elif app_config.precision == "auto" or app_config.precision == "autocast":
# Default to float16 for CUDA devices
return "float16"
else:
return "float16"
# Use the user-defined precision
return app_config.precision
elif device.type == "mps":
if app_config.precision == "auto" or app_config.precision == "autocast":
# Default to float16 for MPS devices
return "float16"
else:
# Use the user-defined precision
return app_config.precision
# CPU / safe fallback
return "float32"
# We are in transition here from using a single global AppConfig to allowing multiple
# configurations. It is strongly recommended to pass the app_config to this function.
def torch_dtype(
device: Optional[torch.device] = None,
app_config: Optional[InvokeAIAppConfig] = None,
) -> torch.dtype:
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
device = device or choose_torch_device()
precision = choose_precision(device, app_config)
precision = choose_precision(device)
if precision == "float16":
return torch.float16
if precision == "bfloat16":
@ -71,7 +69,7 @@ def torch_dtype(
return torch.float32
def choose_autocast(precision):
def choose_autocast(precision: PRECISION):
"""Returns an autocast context or nullcontext for the given precision string"""
# float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float'

View File

@ -52,6 +52,7 @@
},
"dependencies": {
"@chakra-ui/react-use-size": "^2.1.0",
"@dagrejs/dagre": "^1.1.1",
"@dagrejs/graphlib": "^2.2.1",
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/sortable": "^8.0.0",

View File

@ -11,6 +11,9 @@ dependencies:
'@chakra-ui/react-use-size':
specifier: ^2.1.0
version: 2.1.0(react@18.2.0)
'@dagrejs/dagre':
specifier: ^1.1.1
version: 1.1.1
'@dagrejs/graphlib':
specifier: ^2.2.1
version: 2.2.1
@ -3092,6 +3095,12 @@ packages:
dev: true
optional: true
/@dagrejs/dagre@1.1.1:
resolution: {integrity: sha512-AQfT6pffEuPE32weFzhS/u3UpX+bRXUARIXL7UqLaxz497cN8pjuBlX6axO4IIECE2gBV8eLFQkGCtKX5sDaUA==}
dependencies:
'@dagrejs/graphlib': 2.2.1
dev: false
/@dagrejs/graphlib@2.2.1:
resolution: {integrity: sha512-xJsN1v6OAxXk6jmNdM+OS/bBE8nDCwM0yDNprXR18ZNatL6to9ggod9+l2XtiLhXfLm0NkE7+Er/cpdlM+SkUA==}
engines: {node: '>17.0.0'}

View File

@ -291,7 +291,6 @@
"canvasMerged": "تم دمج الخط",
"sentToImageToImage": "تم إرسال إلى صورة إلى صورة",
"sentToUnifiedCanvas": "تم إرسال إلى لوحة موحدة",
"parametersSet": "تم تعيين المعلمات",
"parametersNotSet": "لم يتم تعيين المعلمات",
"metadataLoadFailed": "فشل تحميل البيانات الوصفية"
},

View File

@ -480,7 +480,6 @@
"canvasMerged": "Leinwand zusammengeführt",
"sentToImageToImage": "Gesendet an Bild zu Bild",
"sentToUnifiedCanvas": "Gesendet an Leinwand",
"parametersSet": "Parameter festlegen",
"parametersNotSet": "Parameter nicht festgelegt",
"metadataLoadFailed": "Metadaten konnten nicht geladen werden",
"setCanvasInitialImage": "Ausgangsbild setzen",

View File

@ -849,6 +849,7 @@
"version": "Version",
"versionUnknown": " Version Unknown",
"workflow": "Workflow",
"graph": "Graph",
"workflowAuthor": "Author",
"workflowContact": "Contact",
"workflowDescription": "Short Description",
@ -1041,10 +1042,10 @@
"metadataLoadFailed": "Failed to load metadata",
"modelAddedSimple": "Model Added to Queue",
"modelImportCanceled": "Model Import Canceled",
"parameters": "Parameters",
"parameterNotSet": "{{parameter}} not set",
"parameterSet": "{{parameter}} set",
"parametersNotSet": "Parameters Not Set",
"parametersSet": "Parameters Set",
"problemCopyingCanvas": "Problem Copying Canvas",
"problemCopyingCanvasDesc": "Unable to export base layer",
"problemCopyingImage": "Unable to Copy Image",
@ -1482,7 +1483,11 @@
"workflowName": "Workflow Name",
"newWorkflowCreated": "New Workflow Created",
"workflowCleared": "Workflow Cleared",
"workflowEditorMenu": "Workflow Editor Menu"
"workflowEditorMenu": "Workflow Editor Menu",
"loadFromGraph": "Load Workflow from Graph",
"convertGraph": "Convert Graph",
"loadWorkflow": "$t(common.load) Workflow",
"autoLayout": "Auto Layout"
},
"app": {
"storeNotInitialized": "Store is not initialized"

View File

@ -363,7 +363,6 @@
"canvasMerged": "Lienzo consolidado",
"sentToImageToImage": "Enviar hacia Imagen a Imagen",
"sentToUnifiedCanvas": "Enviar hacia Lienzo Consolidado",
"parametersSet": "Parámetros establecidos",
"parametersNotSet": "Parámetros no establecidos",
"metadataLoadFailed": "Error al cargar metadatos",
"serverError": "Error en el servidor",

View File

@ -298,7 +298,6 @@
"canvasMerged": "Canvas fusionné",
"sentToImageToImage": "Envoyé à Image à Image",
"sentToUnifiedCanvas": "Envoyé à Canvas unifié",
"parametersSet": "Paramètres définis",
"parametersNotSet": "Paramètres non définis",
"metadataLoadFailed": "Échec du chargement des métadonnées"
},

View File

@ -306,7 +306,6 @@
"canvasMerged": "קנבס מוזג",
"sentToImageToImage": "נשלח לתמונה לתמונה",
"sentToUnifiedCanvas": "נשלח אל קנבס מאוחד",
"parametersSet": "הגדרת פרמטרים",
"parametersNotSet": "פרמטרים לא הוגדרו",
"metadataLoadFailed": "טעינת מטא-נתונים נכשלה"
},

View File

@ -569,7 +569,6 @@
"canvasMerged": "Tela unita",
"sentToImageToImage": "Inviato a Immagine a Immagine",
"sentToUnifiedCanvas": "Inviato a Tela Unificata",
"parametersSet": "Parametri impostati",
"parametersNotSet": "Parametri non impostati",
"metadataLoadFailed": "Impossibile caricare i metadati",
"serverError": "Errore del Server",

View File

@ -420,7 +420,6 @@
"canvasMerged": "Canvas samengevoegd",
"sentToImageToImage": "Gestuurd naar Afbeelding naar afbeelding",
"sentToUnifiedCanvas": "Gestuurd naar Centraal canvas",
"parametersSet": "Parameters ingesteld",
"parametersNotSet": "Parameters niet ingesteld",
"metadataLoadFailed": "Fout bij laden metagegevens",
"serverError": "Serverfout",

View File

@ -267,7 +267,6 @@
"canvasMerged": "Scalono widoczne warstwy",
"sentToImageToImage": "Wysłano do Obraz na obraz",
"sentToUnifiedCanvas": "Wysłano do trybu uniwersalnego",
"parametersSet": "Ustawiono parametry",
"parametersNotSet": "Nie ustawiono parametrów",
"metadataLoadFailed": "Błąd wczytywania metadanych"
},

View File

@ -310,7 +310,6 @@
"canvasMerged": "Tela Fundida",
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
"parametersSet": "Parâmetros Definidos",
"parametersNotSet": "Parâmetros Não Definidos",
"metadataLoadFailed": "Falha ao tentar carregar metadados"
},

View File

@ -307,7 +307,6 @@
"canvasMerged": "Tela Fundida",
"sentToImageToImage": "Mandar Para Imagem Para Imagem",
"sentToUnifiedCanvas": "Enviada para a Tela Unificada",
"parametersSet": "Parâmetros Definidos",
"parametersNotSet": "Parâmetros Não Definidos",
"metadataLoadFailed": "Falha ao tentar carregar metadados"
},

View File

@ -575,7 +575,6 @@
"canvasMerged": "Холст объединен",
"sentToImageToImage": "Отправить в img2img",
"sentToUnifiedCanvas": "Отправлено на Единый холст",
"parametersSet": "Параметры заданы",
"parametersNotSet": "Параметры не заданы",
"metadataLoadFailed": "Не удалось загрузить метаданные",
"serverError": "Ошибка сервера",

View File

@ -315,7 +315,6 @@
"canvasMerged": "Полотно об'єднане",
"sentToImageToImage": "Надіслати до img2img",
"sentToUnifiedCanvas": "Надіслати на полотно",
"parametersSet": "Параметри задані",
"parametersNotSet": "Параметри не задані",
"metadataLoadFailed": "Не вдалося завантажити метадані",
"serverError": "Помилка сервера",

View File

@ -487,7 +487,6 @@
"canvasMerged": "画布已合并",
"sentToImageToImage": "已发送到图生图",
"sentToUnifiedCanvas": "已发送到统一画布",
"parametersSet": "参数已设定",
"parametersNotSet": "参数未设定",
"metadataLoadFailed": "加载元数据失败",
"uploadFailedInvalidUploadDesc": "必须是单张的 PNG 或 JPEG 图片",

View File

@ -33,6 +33,7 @@ const ImageMetadataActions = (props: Props) => {
<MetadataItem metadata={metadata} handlers={handlers.scheduler} />
<MetadataItem metadata={metadata} handlers={handlers.cfgScale} />
<MetadataItem metadata={metadata} handlers={handlers.cfgRescaleMultiplier} />
<MetadataItem metadata={metadata} handlers={handlers.initialImage} />
<MetadataItem metadata={metadata} handlers={handlers.strength} />
<MetadataItem metadata={metadata} handlers={handlers.hrfEnabled} />
<MetadataItem metadata={metadata} handlers={handlers.hrfMethod} />

View File

@ -189,6 +189,12 @@ export const handlers = {
recaller: recallers.cfgScale,
}),
height: buildHandlers({ getLabel: () => t('metadata.height'), parser: parsers.height, recaller: recallers.height }),
initialImage: buildHandlers({
getLabel: () => t('metadata.initImage'),
parser: parsers.initialImage,
recaller: recallers.initialImage,
renderValue: async (imageDTO) => imageDTO.image_name,
}),
negativePrompt: buildHandlers({
getLabel: () => t('metadata.negativePrompt'),
parser: parsers.negativePrompt,
@ -405,6 +411,6 @@ export const parseAndRecallAllMetadata = async (metadata: unknown, skip: (keyof
})
);
if (results.some((result) => result.status === 'fulfilled')) {
parameterSetToast(t('toast.parametersSet'));
parameterSetToast(t('toast.parameters'));
}
};

View File

@ -1,3 +1,4 @@
import { getStore } from 'app/store/nanostores/store';
import {
initialControlNet,
initialIPAdapter,
@ -57,6 +58,8 @@ import {
isParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import { get, isArray, isString } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import {
isControlNetModelConfig,
isIPAdapterModelConfig,
@ -135,6 +138,14 @@ const parseCFGRescaleMultiplier: MetadataParseFunc<ParameterCFGRescaleMultiplier
const parseScheduler: MetadataParseFunc<ParameterScheduler> = (metadata) =>
getProperty(metadata, 'scheduler', isParameterScheduler);
const parseInitialImage: MetadataParseFunc<ImageDTO> = async (metadata) => {
const imageName = await getProperty(metadata, 'init_image', isString);
const imageDTORequest = getStore().dispatch(imagesApi.endpoints.getImageDTO.initiate(imageName));
const imageDTO = await imageDTORequest.unwrap();
imageDTORequest.unsubscribe();
return imageDTO;
};
const parseWidth: MetadataParseFunc<ParameterWidth> = (metadata) => getProperty(metadata, 'width', isParameterWidth);
const parseHeight: MetadataParseFunc<ParameterHeight> = (metadata) =>
@ -402,6 +413,7 @@ export const parsers = {
cfgScale: parseCFGScale,
cfgRescaleMultiplier: parseCFGRescaleMultiplier,
scheduler: parseScheduler,
initialImage: parseInitialImage,
width: parseWidth,
height: parseHeight,
steps: parseSteps,

View File

@ -17,6 +17,7 @@ import type {
import { modelSelected } from 'features/parameters/store/actions';
import {
heightRecalled,
initialImageChanged,
setCfgRescaleMultiplier,
setCfgScale,
setImg2imgStrength,
@ -61,6 +62,7 @@ import {
setRefinerStart,
setRefinerSteps,
} from 'features/sdxl/store/sdxlSlice';
import type { ImageDTO } from 'services/api/types';
const recallPositivePrompt: MetadataRecallFunc<ParameterPositivePrompt> = (positivePrompt) => {
getStore().dispatch(setPositivePrompt(positivePrompt));
@ -94,6 +96,10 @@ const recallScheduler: MetadataRecallFunc<ParameterScheduler> = (scheduler) => {
getStore().dispatch(setScheduler(scheduler));
};
const recallInitialImage: MetadataRecallFunc<ImageDTO> = async (imageDTO) => {
getStore().dispatch(initialImageChanged(imageDTO));
};
const recallWidth: MetadataRecallFunc<ParameterWidth> = (width) => {
getStore().dispatch(widthRecalled(width));
};
@ -235,6 +241,7 @@ export const recallers = {
cfgScale: recallCFGScale,
cfgRescaleMultiplier: recallCFGRescaleMultiplier,
scheduler: recallScheduler,
initialImage: recallInitialImage,
width: recallWidth,
height: recallHeight,
steps: recallSteps,

View File

@ -3,6 +3,7 @@ import 'reactflow/dist/style.css';
import { Flex } from '@invoke-ai/ui-library';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import TopPanel from 'features/nodes/components/flow/panels/TopPanel/TopPanel';
import { LoadWorkflowFromGraphModal } from 'features/workflowLibrary/components/LoadWorkflowFromGraphModal/LoadWorkflowFromGraphModal';
import { SaveWorkflowAsDialog } from 'features/workflowLibrary/components/SaveWorkflowAsDialog/SaveWorkflowAsDialog';
import type { AnimationProps } from 'framer-motion';
import { AnimatePresence, motion } from 'framer-motion';
@ -61,6 +62,7 @@ const NodeEditor = () => {
<BottomLeftPanel />
<MinimapPanel />
<SaveWorkflowAsDialog />
<LoadWorkflowFromGraphModal />
</motion.div>
)}
</AnimatePresence>

View File

@ -37,34 +37,50 @@ const NumberFieldInputComponent = (
);
const min = useMemo(() => {
let min = -NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.minimum)) {
return fieldTemplate.minimum;
min = fieldTemplate.minimum;
}
if (!isNil(fieldTemplate.exclusiveMinimum)) {
return fieldTemplate.exclusiveMinimum + 0.01;
min = fieldTemplate.exclusiveMinimum + 0.01;
}
return;
return min;
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
const max = useMemo(() => {
let max = NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.maximum)) {
return fieldTemplate.maximum;
max = fieldTemplate.maximum;
}
if (!isNil(fieldTemplate.exclusiveMaximum)) {
return fieldTemplate.exclusiveMaximum - 0.01;
max = fieldTemplate.exclusiveMaximum - 0.01;
}
return;
return max;
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
const step = useMemo(() => {
if (isNil(fieldTemplate.multipleOf)) {
return isIntegerField ? 1 : 0.1;
}
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
const fineStep = useMemo(() => {
if (isNil(fieldTemplate.multipleOf)) {
return isIntegerField ? 1 : 0.01;
}
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
return (
<CompositeNumberInput
defaultValue={fieldTemplate.default}
onChange={handleValueChanged}
value={field.value}
min={min ?? -NUMPY_RAND_MAX}
max={max ?? NUMPY_RAND_MAX}
step={isIntegerField ? 1 : 0.1}
fineStep={isIntegerField ? 1 : 0.01}
min={min}
max={max}
step={step}
fineStep={fineStep}
className="nodrag"
/>
);

View File

@ -0,0 +1,148 @@
import * as dagre from '@dagrejs/dagre';
import { logger } from 'app/logging/logger';
import { getStore } from 'app/store/nanostores/store';
import { NODE_WIDTH } from 'features/nodes/types/constants';
import type { FieldInputInstance } from 'features/nodes/types/field';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import type { NonNullableGraph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
/**
* Converts a graph to a workflow. This is a best-effort conversion and may not be perfect.
* For example, if a graph references an unknown node type, that node will be skipped.
* @param graph The graph to convert to a workflow
* @param autoLayout Whether to auto-layout the nodes using `dagre`. If false, nodes will be simply stacked on top of one another with an offset.
* @returns The workflow.
*/
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => {
const invocationTemplates = getStore().getState().nodes.templates;
if (!invocationTemplates) {
throw new Error(t('app.storeNotInitialized'));
}
// Initialize the workflow
const workflow: WorkflowV3 = {
name: '',
author: '',
contact: '',
description: '',
meta: {
category: 'user',
version: '3.0.0',
},
notes: '',
tags: '',
version: '',
exposedFields: [],
edges: [],
nodes: [],
};
// Convert nodes
forEach(graph.nodes, (node) => {
const template = invocationTemplates[node.type];
// Skip missing node templates - this is a best-effort
if (!template) {
logger('nodes').warn(`Node type ${node.type} not found in invocationTemplates`);
return;
}
// Build field input instances for each attr
const inputs: Record<string, FieldInputInstance> = {};
forEach(node, (value, key) => {
// Ignore the non-input keys - I think this is all of them?
if (key === 'id' || key === 'type' || key === 'is_intermediate' || key === 'use_cache') {
return;
}
const inputTemplate = template.inputs[key];
// Skip missing input templates
if (!inputTemplate) {
logger('nodes').warn(`Input ${key} not found in template for node type ${node.type}`);
return;
}
// This _should_ be all we need to do!
const inputInstance = buildFieldInputInstance(node.id, inputTemplate);
inputInstance.value = value;
inputs[key] = inputInstance;
});
workflow.nodes.push({
id: node.id,
type: 'invocation',
position: { x: 0, y: 0 }, // we'll do layout later, just need something here
data: {
id: node.id,
type: node.type,
version: template.version,
label: '',
notes: '',
isOpen: true,
isIntermediate: node.is_intermediate ?? false,
useCache: node.use_cache ?? true,
inputs,
},
});
});
forEach(graph.edges, (edge) => {
workflow.edges.push({
id: uuidv4(), // we don't have edge IDs in the graph
type: 'default',
source: edge.source.node_id,
sourceHandle: edge.source.field,
target: edge.destination.node_id,
targetHandle: edge.destination.field,
});
});
if (autoLayout) {
// Best-effort auto layout via dagre - not perfect but better than nothing
const dagreGraph = new dagre.graphlib.Graph();
// `rankdir` and `align` could be tweaked, but it's gonna be janky no matter what we choose
dagreGraph.setGraph({ rankdir: 'TB', align: 'UL' });
dagreGraph.setDefaultEdgeLabel(() => ({}));
// We don't know the dimensions of the nodes until we load the graph into `reactflow` - use a reasonable value
forEach(graph.nodes, (node) => {
const width = NODE_WIDTH;
const height = NODE_WIDTH * 1.5;
dagreGraph.setNode(node.id, { width, height });
});
graph.edges.forEach((edge) => {
dagreGraph.setEdge(edge.source.node_id, edge.destination.node_id);
});
// This does the magic
dagre.layout(dagreGraph);
// Update the workflow now that we've got the positions
workflow.nodes.forEach((node) => {
const nodeWithPosition = dagreGraph.node(node.id);
node.position = {
x: nodeWithPosition.x - nodeWithPosition.width / 2,
y: nodeWithPosition.y - nodeWithPosition.height / 2,
};
});
} else {
// Stack nodes with a 50px,50px offset from the previous ndoe
let x = 0;
let y = 0;
workflow.nodes.forEach((node) => {
node.position = { x, y };
x = x + 50;
y = y + 50;
});
}
return workflow;
};

View File

@ -0,0 +1,111 @@
import {
Button,
Checkbox,
Flex,
FormControl,
FormLabel,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalHeader,
ModalOverlay,
Spacer,
Textarea,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { workflowLoadRequested } from 'features/nodes/store/actions';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
import { atom } from 'nanostores';
import type { ChangeEvent } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
const $isOpen = atom<boolean>(false);
export const useLoadWorkflowFromGraphModal = () => {
const isOpen = useStore($isOpen);
const onOpen = useCallback(() => {
$isOpen.set(true);
}, []);
const onClose = useCallback(() => {
$isOpen.set(false);
}, []);
return { isOpen, onOpen, onClose };
};
export const LoadWorkflowFromGraphModal = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { isOpen, onClose } = useLoadWorkflowFromGraphModal();
const [graphRaw, setGraphRaw] = useState<string>('');
const [workflowRaw, setWorkflowRaw] = useState<string>('');
const [shouldAutoLayout, setShouldAutoLayout] = useState(true);
const onChangeGraphRaw = useCallback((e: ChangeEvent<HTMLTextAreaElement>) => {
setGraphRaw(e.target.value);
}, []);
const onChangeWorkflowRaw = useCallback((e: ChangeEvent<HTMLTextAreaElement>) => {
setWorkflowRaw(e.target.value);
}, []);
const onChangeShouldAutoLayout = useCallback((e: ChangeEvent<HTMLInputElement>) => {
setShouldAutoLayout(e.target.checked);
}, []);
const parse = useCallback(() => {
const graph = JSON.parse(graphRaw);
const workflow = graphToWorkflow(graph, shouldAutoLayout);
setWorkflowRaw(JSON.stringify(workflow, null, 2));
}, [graphRaw, shouldAutoLayout]);
const loadWorkflow = useCallback(() => {
const workflow = JSON.parse(workflowRaw);
dispatch(workflowLoadRequested({ workflow, asCopy: true }));
onClose();
}, [dispatch, onClose, workflowRaw]);
return (
<Modal isOpen={isOpen} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent w="80vw" h="80vh" maxW="unset" maxH="unset">
<ModalHeader>{t('workflows.loadFromGraph')}</ModalHeader>
<ModalCloseButton />
<ModalBody as={Flex} flexDir="column" gap={4} w="full" h="full" pb={4}>
<Flex gap={4}>
<Button onClick={parse} size="sm" flexShrink={0}>
{t('workflows.convertGraph')}
</Button>
<FormControl>
<FormLabel>{t('workflows.autoLayout')}</FormLabel>
<Checkbox isChecked={shouldAutoLayout} onChange={onChangeShouldAutoLayout} />
</FormControl>
<Spacer />
<Button onClick={loadWorkflow} size="sm" flexShrink={0}>
{t('workflows.loadWorkflow')}
</Button>
</Flex>
<FormControl orientation="vertical" h="50%">
<FormLabel>{t('nodes.graph')}</FormLabel>
<Textarea
h="full"
value={graphRaw}
fontFamily="monospace"
whiteSpace="pre-wrap"
overflowWrap="normal"
onChange={onChangeGraphRaw}
/>
</FormControl>
<FormControl orientation="vertical" h="50%">
<FormLabel>{t('nodes.workflow')}</FormLabel>
<Textarea
h="full"
value={workflowRaw}
fontFamily="monospace"
whiteSpace="pre-wrap"
overflowWrap="normal"
onChange={onChangeWorkflowRaw}
/>
</FormControl>
</ModalBody>
</ModalContent>
</Modal>
);
};

View File

@ -0,0 +1,18 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useLoadWorkflowFromGraphModal } from 'features/workflowLibrary/components/LoadWorkflowFromGraphModal/LoadWorkflowFromGraphModal';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFlaskBold } from 'react-icons/pi';
const LoadWorkflowFromGraphMenuItem = () => {
const { t } = useTranslation();
const { onOpen } = useLoadWorkflowFromGraphModal();
return (
<MenuItem as="button" icon={<PiFlaskBold />} onClick={onOpen}>
{t('workflows.loadFromGraph')}
</MenuItem>
);
};
export default memo(LoadWorkflowFromGraphMenuItem);

View File

@ -6,8 +6,10 @@ import {
MenuList,
useDisclosure,
useGlobalMenuClose,
useShiftModifier,
} from '@invoke-ai/ui-library';
import DownloadWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/DownloadWorkflowMenuItem';
import LoadWorkflowFromGraphMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem';
import { NewWorkflowMenuItem } from 'features/workflowLibrary/components/WorkflowLibraryMenu/NewWorkflowMenuItem';
import SaveWorkflowAsMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowAsMenuItem';
import SaveWorkflowMenuItem from 'features/workflowLibrary/components/WorkflowLibraryMenu/SaveWorkflowMenuItem';
@ -20,6 +22,7 @@ import { PiDotsThreeOutlineFill } from 'react-icons/pi';
const WorkflowLibraryMenu = () => {
const { t } = useTranslation();
const { isOpen, onOpen, onClose } = useDisclosure();
const shift = useShiftModifier();
useGlobalMenuClose(onClose);
return (
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>
@ -38,6 +41,8 @@ const WorkflowLibraryMenu = () => {
<DownloadWorkflowMenuItem />
<MenuDivider />
<SettingsMenuItem />
{shift && <MenuDivider />}
{shift && <LoadWorkflowFromGraphMenuItem />}
</MenuList>
</Menu>
);

File diff suppressed because one or more lines are too long

View File

@ -27,6 +27,7 @@ from invokeai.app.invocations.fields import (
OutputField,
UIComponent,
UIType,
WithBoard,
WithMetadata,
WithWorkflow,
)
@ -105,6 +106,7 @@ __all__ = [
"OutputField",
"UIComponent",
"UIType",
"WithBoard",
"WithMetadata",
"WithWorkflow",
# invokeai.app.invocations.latent

View File

@ -1 +1 @@
__version__ = "4.0.3"
__version__ = "4.0.4"