diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index c8dfd36b4b..fa46762d56 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -14,13 +14,12 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware from pathlib import Path from pydantic.schema import schema -# Do this early so that other modules pick up configuration +#This should come early so that modules can log their initialization properly from .services.config import InvokeAIAppConfig +from ..backend.util.logging import InvokeAILogger app_config = InvokeAIAppConfig.get_config() app_config.parse_args() - -from invokeai.backend.util.logging import InvokeAILogger -logger = InvokeAILogger.getLogger() +logger = InvokeAILogger.getLogger(config=app_config) import invokeai.frontend.web as web_dir diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 0b8609817e..26e058166b 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -11,19 +11,26 @@ from typing import Union, get_type_hints from pydantic import BaseModel, ValidationError from pydantic.fields import Field +# This should come early so that the logger can pick up its configuration options +from .services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger +config = InvokeAIAppConfig.get_config() +config.parse_args() +logger = InvokeAILogger().getLogger(config=config) + from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.images import ImageService from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.urls import LocalUrlService +from .services.default_graphs import (default_text_to_image_graph_id, + create_system_graphs) +from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .cli.commands import (BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers) from .cli.completer import set_autocompleter from .invocations.baseinvocation import BaseInvocation -from .services.default_graphs import (create_system_graphs, - default_text_to_image_graph_id) from .services.events import EventServiceBase from .services.graph import (Edge, EdgeConnection, GraphExecutionState, GraphInvocation, LibraryGraph, @@ -32,13 +39,11 @@ from .services.image_file_storage import DiskImageFileStorage from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_services import InvocationServices from .services.invoker import Invoker -from .services.latent_storage import (DiskLatentsStorage, - ForwardCacheLatentsStorage) from .services.model_manager_service import ModelManagerService from .services.processor import DefaultInvocationProcessor from .services.restoration_services import RestorationServices from .services.sqlite import SqliteItemStorage -from .services.config import InvokeAIAppConfig + class CliCommand(BaseModel): command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore @@ -47,7 +52,6 @@ class CliCommand(BaseModel): class InvalidArgs(Exception): pass - def add_invocation_args(command_parser): # Add linking capability command_parser.add_argument( @@ -192,11 +196,6 @@ def invoke_all(context: CliContext): raise SessionError() def invoke_cli(): - # this gets the basic configuration - config = InvokeAIAppConfig.get_config() - config.parse_args() - logger = InvokeAILogger.getLogger() - # get the optional list of invocations to execute on the command line parser = config.get_parser() parser.add_argument('commands',nargs='*') diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 60e196faa1..94be8225da 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -859,11 +859,9 @@ class GraphExecutionState(BaseModel): if next_node is None: prepared_id = self._prepare() - # TODO: prepare multiple nodes at once? - # while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation): - # prepared_id = self._prepare() - - if prepared_id is not None: + # Prepare as many nodes as we can + while prepared_id is not None: + prepared_id = self._prepare() next_node = self._get_next_node() # Get values from edges @@ -1010,14 +1008,30 @@ class GraphExecutionState(BaseModel): # Get flattened source graph g = self.graph.nx_graph_flat() - # Find next unprepared node where all source nodes are executed + # Find next node that: + # - was not already prepared + # - is not an iterate node whose inputs have not been executed + # - does not have an unexecuted iterate ancestor sorted_nodes = nx.topological_sort(g) next_node_id = next( ( n for n in sorted_nodes + # exclude nodes that have already been prepared if n not in self.source_prepared_mapping - and all((e[0] in self.executed for e in g.in_edges(n))) + # exclude iterate nodes whose inputs have not been executed + and not ( + isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node... + and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs + ) + # exclude nodes who have unexecuted iterate ancestors + and not any( + ( + isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`... + and a not in self.executed # ...that is not executed + for a in nx.ancestors(g, n) # for all ancestors `a` of node `n` + ) + ) ), None, ) @@ -1114,9 +1128,22 @@ class GraphExecutionState(BaseModel): ) def _get_next_node(self) -> Optional[BaseInvocation]: + """Gets the deepest node that is ready to be executed""" g = self.execution_graph.nx_graph() - sorted_nodes = nx.topological_sort(g) - next_node = next((n for n in sorted_nodes if n not in self.executed), None) + + # Depth-first search with pre-order traversal is a depth-first topological sort + sorted_nodes = nx.dfs_preorder_nodes(g) + + next_node = next( + ( + n + for n in sorted_nodes + if n not in self.executed # the node must not already be executed... + and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed + ), + None, + ) + if next_node is None: return None diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index a7c9ae444d..f12ba79c15 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -22,7 +22,8 @@ class Invoker: def invoke( self, graph_execution_state: GraphExecutionState, invoke_all: bool = False ) -> str | None: - """Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute""" + """Determines the next node to invoke and enqueues it, preparing if needed. + Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" # Get the next invocation invocation = graph_execution_state.next() diff --git a/invokeai/backend/install/invokeai_configure.py b/invokeai/backend/install/invokeai_configure.py index aaa744ba09..603760c0c1 100755 --- a/invokeai/backend/install/invokeai_configure.py +++ b/invokeai/backend/install/invokeai_configure.py @@ -40,6 +40,7 @@ import invokeai.configs as configs from invokeai.app.services.config import ( InvokeAIAppConfig, ) +from invokeai.backend.util.logging import InvokeAILogger from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.widgets import ( CenteredButtonPress, @@ -80,6 +81,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file # or renaming it and then running invokeai-configure again. """ +logger=None # -------------------------------------------- def postscript(errors: None): @@ -824,6 +826,7 @@ def main(): if opt.full_precision: invoke_args.extend(['--precision','float32']) config.parse_args(invoke_args) + logger = InvokeAILogger().getLogger(config=config) errors = set() diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 6a0f1f3be6..09ae600633 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team -"""invokeai.util.logging +""" +invokeai.util.logging Logging class for InvokeAI that produces console messages @@ -11,6 +12,7 @@ from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization (or) logger = InvokeAILogger.getLogger(__name__) // To use the filename +logger.configure() logger.critical('this is critical') // Critical Message logger.error('this is an error') // Error Message @@ -28,6 +30,149 @@ Console messages: Alternate Method (in this case the logger name will be set to InvokeAI): import invokeai.backend.util.logging as IAILogger IAILogger.debug('this is a debugging message') + +## Configuration + +The default configuration will print to stderr on the console. To add +additional logging handlers, call getLogger with an initialized InvokeAIAppConfig +object: + + + config = InvokeAIAppConfig.get_config() + config.parse_args() + logger = InvokeAILogger.getLogger(config=config) + +### Three command-line options control logging: + +`--log_handlers ...` + +This option activates one or more log handlers. Options are "console", "file", "syslog" and "http". To specify more than one, separate them by spaces: + +``` +invokeai-web --log_handlers console syslog=/dev/log file=C:\\Users\\fred\\invokeai.log +``` + +The format of these options is described below. + +### `--log_format {plain|color|legacy|syslog}` + +This controls the format of log messages written to the console. Only the "console" log handler is currently affected by this setting. + +* "plain" provides formatted messages like this: + +```bash + +[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message +[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages +[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning +[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error +[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error +``` + +* "color" produces similar output, but the text will be color coded to indicate the severity of the message. + +* "legacy" produces output similar to InvokeAI versions 2.3 and earlier: + +``` +### this is a critical error +*** this is an error +** this is a warning +>> this is an informational messages + | this is a debug message +``` + +* "syslog" produces messages suitable for syslog entries: + +```bash +InvokeAI [2691178] this is a critical error +InvokeAI [2691178] this is an error +InvokeAI [2691178] this is a warning +InvokeAI [2691178] this is an informational messages +InvokeAI [2691178] this is a debug message +``` + +(note that the date, time and hostname will be added by the syslog system) + +### `--log_level {debug|info|warning|error|critical}` + +Providing this command-line option will cause only messages at the specified level or above to be emitted. + +## Console logging + +When "console" is provided to `--log_handlers`, messages will be written to the command line window in which InvokeAI was launched. By default, the color formatter will be used unless overridden by `--log_format`. + +## File logging + +When "file" is provided to `--log_handlers`, entries will be written to the file indicated in the path argument. By default, the "plain" format will be used: + +```bash +invokeai-web --log_handlers file=/var/log/invokeai.log +``` + +## Syslog logging + +When "syslog" is requested, entries will be sent to the syslog system. There are a variety of ways to control where the log message is sent: + +* Send to the local machine using the `/dev/log` socket: + +``` +invokeai-web --log_handlers syslog=/dev/log +``` + +* Send to the local machine using a UDP message: + +``` +invokeai-web --log_handlers syslog=localhost +``` + +* Send to the local machine using a UDP message on a nonstandard port: + +``` +invokeai-web --log_handlers syslog=localhost:512 +``` + +* Send to a remote machine named "loghost" on the local LAN using facility LOG_USER and UDP packets: + +``` +invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM +``` + +This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM are defaults. + +* Send to a remote machine named "loghost" using the facility LOCAL0 and using a TCP socket: + +``` +invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM +``` + +If no arguments are specified (just a bare "syslog"), then the logging system will look for a UNIX socket named `/dev/log`, and if not found try to send a UDP message to `localhost`. The Macintosh OS used to support logging to a socket named `/var/run/syslog`, but this feature has since been disabled. + +## Web logging + +If you have access to a web server that is configured to log messages when a particular URL is requested, you can log using the "http" method: + +``` +invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST +``` + +The optional [,method=] part can be used to specify whether the URL accepts GET (default) or POST messages. + +Currently password authentication and SSL are not supported. + +## Using the configuration file + +You can set and forget logging options by adding a "Logging" section to `invokeai.yaml`: + +``` +InvokeAI: + [... other settings...] + Logging: + log_handlers: + - console + - syslog=/dev/log + log_level: info + log_format: color +``` """ import logging.handlers @@ -180,14 +325,17 @@ class InvokeAILogger(object): loggers = dict() @classmethod - def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger: - config = get_invokeai_config() - - if name not in cls.loggers: + def getLogger(cls, + name: str = 'InvokeAI', + config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger: + if name in cls.loggers: + logger = cls.loggers[name] + logger.handlers.clear() + else: logger = logging.getLogger(name) - logger.setLevel(config.log_level.upper()) # yes, strings work here - for ch in cls.getLoggers(config): - logger.addHandler(ch) + logger.setLevel(config.log_level.upper()) # yes, strings work here + for ch in cls.getLoggers(config): + logger.addHandler(ch) cls.loggers[name] = logger return cls.loggers[name] @@ -199,9 +347,11 @@ class InvokeAILogger(object): handler_name,*args = handler.split('=',2) args = args[0] if len(args) > 0 else None - # console and file are the only handlers that gets a custom formatter + # console and file get the fancy formatter. + # syslog gets a simple one + # http gets no custom formatter + formatter = LOG_FORMATTERS[config.log_format] if handler_name=='console': - formatter = LOG_FORMATTERS[config.log_format] ch = logging.StreamHandler() ch.setFormatter(formatter()) handlers.append(ch) @@ -212,7 +362,7 @@ class InvokeAILogger(object): elif handler_name=='file': ch = cls._parse_file_args(args) - ch.setFormatter(InvokeAISyslogFormatter()) + ch.setFormatter(formatter()) handlers.append(ch) elif handler_name=='http': diff --git a/invokeai/frontend/install/model_install.py b/invokeai/frontend/install/model_install.py index d956299bdf..265c456e3a 100644 --- a/invokeai/frontend/install/model_install.py +++ b/invokeai/frontend/install/model_install.py @@ -28,7 +28,7 @@ import torch from npyscreen import widget from omegaconf import OmegaConf -import invokeai.backend.util.logging as logger +from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.install.model_install_backend import ( Dataset_path, @@ -939,6 +939,7 @@ def main(): if opt.full_precision: invoke_args.extend(['--precision','float32']) config.parse_args(invoke_args) + logger = InvokeAILogger().getLogger(config=config) if not (config.root_dir / config.conf_path.parent).exists(): logger.info( diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index bb2f140716..ddc6dace27 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -22,6 +22,7 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants'; import GlobalHotkeys from './GlobalHotkeys'; import Toaster from './Toaster'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; +import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; const DEFAULT_CONFIG = {}; @@ -66,10 +67,17 @@ const App = ({ setIsReady(true); } + if (isApplicationReady) { + // TODO: This is a jank fix for canvas not filling the screen on first load + setTimeout(() => { + dispatch(requestCanvasRescale()); + }, 200); + } + return () => { setIsReady && setIsReady(false); }; - }, [isApplicationReady, setIsReady]); + }, [dispatch, isApplicationReady, setIsReady]); return ( <> diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx index c6349be9a7..104073c023 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx @@ -40,11 +40,11 @@ const ImageDndContext = (props: ImageDndContextProps) => { ); const mouseSensor = useSensor(MouseSensor, { - activationConstraint: { delay: 250, tolerance: 5 }, + activationConstraint: { delay: 150, tolerance: 5 }, }); const touchSensor = useSensor(TouchSensor, { - activationConstraint: { delay: 250, tolerance: 5 }, + activationConstraint: { delay: 150, tolerance: 5 }, }); // TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos // Alternatively, fix `rectIntersection` collection detection to work with the drag overlay diff --git a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx index 9accceb846..c764018829 100644 --- a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx @@ -2,7 +2,6 @@ import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons'; import { Box, Flex, - FlexProps, FormControl, FormControlProps, FormLabel, @@ -16,6 +15,7 @@ import { } from '@chakra-ui/react'; import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom'; import { useSelect } from 'downshift'; +import { isString } from 'lodash-es'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { memo, useMemo } from 'react'; @@ -23,15 +23,19 @@ import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles'; export type ItemTooltips = { [key: string]: string }; +export type IAICustomSelectOption = { + value: string; + label: string; + tooltip?: string; +}; + type IAICustomSelectProps = { label?: string; - items: string[]; - itemTooltips?: ItemTooltips; - selectedItem: string; - setSelectedItem: (v: string | null | undefined) => void; + value: string; + data: IAICustomSelectOption[] | string[]; + onChange: (v: string) => void; withCheckIcon?: boolean; formControlProps?: FormControlProps; - buttonProps?: FlexProps; tooltip?: string; tooltipProps?: Omit; ellipsisPosition?: 'start' | 'end'; @@ -40,18 +44,33 @@ type IAICustomSelectProps = { const IAICustomSelect = (props: IAICustomSelectProps) => { const { label, - items, - itemTooltips, - setSelectedItem, - selectedItem, withCheckIcon, formControlProps, tooltip, - buttonProps, tooltipProps, ellipsisPosition = 'end', + data, + value, + onChange, } = props; + const values = useMemo(() => { + return data.map((v) => { + if (isString(v)) { + return { value: v, label: v }; + } + return v; + }); + }, [data]); + + const stringValues = useMemo(() => { + return values.map((v) => v.value); + }, [values]); + + const valueData = useMemo(() => { + return values.find((v) => v.value === value); + }, [values, value]); + const { isOpen, getToggleButtonProps, @@ -60,10 +79,11 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { highlightedIndex, getItemProps, } = useSelect({ - items, - selectedItem, - onSelectedItemChange: ({ selectedItem: newSelectedItem }) => - setSelectedItem(newSelectedItem), + items: stringValues, + selectedItem: value, + onSelectedItemChange: ({ selectedItem: newSelectedItem }) => { + newSelectedItem && onChange(newSelectedItem); + }, }); const { refs, floatingStyles } = useFloating({ @@ -93,8 +113,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { )} { direction: labelTextDirection, }} > - {selectedItem} + {valueData?.label} { {isOpen && ( { }} > - {items.map((item, index) => { - const isSelected = selectedItem === item; + {values.map((v, index) => { + const isSelected = value === v.value; const isHighlighted = highlightedIndex === index; const fontWeight = isSelected ? 700 : 500; const bg = isHighlighted @@ -166,9 +185,9 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { : undefined; return ( @@ -182,8 +201,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { transitionProperty: 'common', transitionDuration: '0.15s', }} - key={`${item}${index}`} - {...getItemProps({ item, index })} + {...getItemProps({ item: v.value, index })} > {withCheckIcon ? ( @@ -198,7 +216,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { fontWeight, }} > - {item} + {v.label} @@ -210,7 +228,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => { fontWeight, }} > - {item} + {v.label} )} diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx index 9c6af169a8..669a68c88a 100644 --- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx @@ -14,7 +14,7 @@ import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import { AnimatePresence } from 'framer-motion'; import { ReactElement, SyntheticEvent, useCallback } from 'react'; import { memo, useRef } from 'react'; -import { FaImage, FaTimes, FaUpload } from 'react-icons/fa'; +import { FaImage, FaTimes, FaUndo, FaUpload } from 'react-icons/fa'; import { ImageDTO } from 'services/api'; import { v4 as uuidv4 } from 'uuid'; import IAIDropOverlay from './IAIDropOverlay'; @@ -174,14 +174,13 @@ const IAIDndImage = (props: IAIDndImageProps) => { position: 'absolute', top: 0, right: 0, - p: 2, }} > } + icon={} onClick={onReset} /> diff --git a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasSettingsButtonPopover.tsx b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasSettingsButtonPopover.tsx index 638332809c..ae03df8409 100644 --- a/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasSettingsButtonPopover.tsx +++ b/invokeai/frontend/web/src/features/canvas/components/IAICanvasToolbar/IAICanvasSettingsButtonPopover.tsx @@ -16,7 +16,6 @@ import { setShouldShowIntermediates, setShouldSnapToGrid, } from 'features/canvas/store/canvasSlice'; -import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal'; import { isEqual } from 'lodash-es'; import { ChangeEvent } from 'react'; @@ -159,7 +158,6 @@ const IAICanvasSettingsButtonPopover = () => { onChange={(e) => dispatch(setShouldAntialias(e.target.checked))} /> - ); diff --git a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts index 4742de0483..dc86783642 100644 --- a/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/canvas/store/canvasSlice.ts @@ -30,7 +30,10 @@ import { } from './canvasTypes'; import { ImageDTO } from 'services/api'; import { sessionCanceled } from 'services/thunks/session'; -import { setShouldUseCanvasBetaLayout } from 'features/ui/store/uiSlice'; +import { + setActiveTab, + setShouldUseCanvasBetaLayout, +} from 'features/ui/store/uiSlice'; import { imageUrlsReceived } from 'services/thunks/image'; export const initialLayerState: CanvasLayerState = { @@ -857,6 +860,11 @@ export const canvasSlice = createSlice({ builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => { state.doesCanvasNeedScaling = true; }); + + builder.addCase(setActiveTab, (state, action) => { + state.doesCanvasNeedScaling = true; + }); + builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { const { image_name, image_origin, image_url, thumbnail_url } = action.payload; diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx index 187d296a4f..222e8d657b 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetModel.tsx @@ -1,17 +1,26 @@ import { useAppDispatch } from 'app/store/storeHooks'; -import IAICustomSelect from 'common/components/IAICustomSelect'; +import IAICustomSelect, { + IAICustomSelectOption, +} from 'common/components/IAICustomSelect'; import { CONTROLNET_MODELS, - ControlNetModel, + ControlNetModelName, } from 'features/controlNet/store/constants'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; +import { map } from 'lodash-es'; import { memo, useCallback } from 'react'; type ParamControlNetModelProps = { controlNetId: string; - model: ControlNetModel; + model: ControlNetModelName; }; +const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({ + value: m.type, + label: m.label, + tooltip: m.type, +})); + const ParamControlNetModel = (props: ParamControlNetModelProps) => { const { controlNetId, model } = props; const dispatch = useAppDispatch(); @@ -19,7 +28,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => { const handleModelChanged = useCallback( (val: string | null | undefined) => { // TODO: do not cast - const model = val as ControlNetModel; + const model = val as ControlNetModelName; dispatch(controlNetModelChanged({ controlNetId, model })); }, [controlNetId, dispatch] @@ -29,9 +38,9 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => { diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx index 019b5ef849..19f05bc53d 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetProcessorSelect.tsx @@ -1,4 +1,6 @@ -import IAICustomSelect from 'common/components/IAICustomSelect'; +import IAICustomSelect, { + IAICustomSelectOption, +} from 'common/components/IAICustomSelect'; import { memo, useCallback } from 'react'; import { ControlNetProcessorNode, @@ -7,15 +9,28 @@ import { import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; import { useAppDispatch } from 'app/store/storeHooks'; import { CONTROLNET_PROCESSORS } from '../../store/constants'; +import { map } from 'lodash-es'; type ParamControlNetProcessorSelectProps = { controlNetId: string; processorNode: ControlNetProcessorNode; }; -const CONTROLNET_PROCESSOR_TYPES = Object.keys( - CONTROLNET_PROCESSORS -) as ControlNetProcessorType[]; +const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map( + CONTROLNET_PROCESSORS, + (p) => ({ + value: p.type, + label: p.label, + tooltip: p.description, + }) +).sort((a, b) => + // sort 'none' to the top + a.value === 'none' + ? -1 + : b.value === 'none' + ? 1 + : a.label.localeCompare(b.label) +); const ParamControlNetProcessorSelect = ( props: ParamControlNetProcessorSelectProps @@ -36,9 +51,9 @@ const ParamControlNetProcessorSelect = ( return ( ); diff --git a/invokeai/frontend/web/src/features/controlNet/store/constants.ts b/invokeai/frontend/web/src/features/controlNet/store/constants.ts index c8689badf5..cca3dbf644 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/constants.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/constants.ts @@ -5,12 +5,12 @@ import { } from './types'; type ControlNetProcessorsDict = Record< - ControlNetProcessorType, + string, { - type: ControlNetProcessorType; + type: ControlNetProcessorType | 'none'; label: string; description: string; - default: RequiredControlNetProcessorNode; + default: RequiredControlNetProcessorNode | { type: 'none' }; } >; @@ -23,10 +23,10 @@ type ControlNetProcessorsDict = Record< * * TODO: Generate from the OpenAPI schema */ -export const CONTROLNET_PROCESSORS = { +export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = { none: { type: 'none', - label: 'None', + label: 'none', description: '', default: { type: 'none', @@ -116,7 +116,7 @@ export const CONTROLNET_PROCESSORS = { }, mlsd_image_processor: { type: 'mlsd_image_processor', - label: 'MLSD', + label: 'M-LSD', description: '', default: { id: 'mlsd_image_processor', @@ -174,39 +174,84 @@ export const CONTROLNET_PROCESSORS = { }, }; -export const CONTROLNET_MODELS = [ - 'lllyasviel/control_v11p_sd15_canny', - 'lllyasviel/control_v11p_sd15_inpaint', - 'lllyasviel/control_v11p_sd15_mlsd', - 'lllyasviel/control_v11f1p_sd15_depth', - 'lllyasviel/control_v11p_sd15_normalbae', - 'lllyasviel/control_v11p_sd15_seg', - 'lllyasviel/control_v11p_sd15_lineart', - 'lllyasviel/control_v11p_sd15s2_lineart_anime', - 'lllyasviel/control_v11p_sd15_scribble', - 'lllyasviel/control_v11p_sd15_softedge', - 'lllyasviel/control_v11e_sd15_shuffle', - 'lllyasviel/control_v11p_sd15_openpose', - 'lllyasviel/control_v11f1e_sd15_tile', - 'lllyasviel/control_v11e_sd15_ip2p', - 'CrucibleAI/ControlNetMediaPipeFace', -]; - -export type ControlNetModel = (typeof CONTROLNET_MODELS)[number]; - -export const CONTROLNET_MODEL_MAP: Record< - ControlNetModel, - ControlNetProcessorType -> = { - 'lllyasviel/control_v11p_sd15_canny': 'canny_image_processor', - 'lllyasviel/control_v11p_sd15_mlsd': 'mlsd_image_processor', - 'lllyasviel/control_v11f1p_sd15_depth': 'midas_depth_image_processor', - 'lllyasviel/control_v11p_sd15_normalbae': 'normalbae_image_processor', - 'lllyasviel/control_v11p_sd15_lineart': 'lineart_image_processor', - 'lllyasviel/control_v11p_sd15s2_lineart_anime': - 'lineart_anime_image_processor', - 'lllyasviel/control_v11p_sd15_softedge': 'hed_image_processor', - 'lllyasviel/control_v11e_sd15_shuffle': 'content_shuffle_image_processor', - 'lllyasviel/control_v11p_sd15_openpose': 'openpose_image_processor', - 'CrucibleAI/ControlNetMediaPipeFace': 'mediapipe_face_processor', +type ControlNetModel = { + type: string; + label: string; + description?: string; + defaultProcessor?: ControlNetProcessorType; }; + +export const CONTROLNET_MODELS: Record = { + 'lllyasviel/control_v11p_sd15_canny': { + type: 'lllyasviel/control_v11p_sd15_canny', + label: 'Canny', + defaultProcessor: 'canny_image_processor', + }, + 'lllyasviel/control_v11p_sd15_inpaint': { + type: 'lllyasviel/control_v11p_sd15_inpaint', + label: 'Inpaint', + }, + 'lllyasviel/control_v11p_sd15_mlsd': { + type: 'lllyasviel/control_v11p_sd15_mlsd', + label: 'M-LSD', + defaultProcessor: 'mlsd_image_processor', + }, + 'lllyasviel/control_v11f1p_sd15_depth': { + type: 'lllyasviel/control_v11f1p_sd15_depth', + label: 'Depth', + defaultProcessor: 'midas_depth_image_processor', + }, + 'lllyasviel/control_v11p_sd15_normalbae': { + type: 'lllyasviel/control_v11p_sd15_normalbae', + label: 'Normal Map (BAE)', + defaultProcessor: 'normalbae_image_processor', + }, + 'lllyasviel/control_v11p_sd15_seg': { + type: 'lllyasviel/control_v11p_sd15_seg', + label: 'Segment Anything', + }, + 'lllyasviel/control_v11p_sd15_lineart': { + type: 'lllyasviel/control_v11p_sd15_lineart', + label: 'Lineart', + defaultProcessor: 'lineart_image_processor', + }, + 'lllyasviel/control_v11p_sd15s2_lineart_anime': { + type: 'lllyasviel/control_v11p_sd15s2_lineart_anime', + label: 'Lineart Anime', + defaultProcessor: 'lineart_anime_image_processor', + }, + 'lllyasviel/control_v11p_sd15_scribble': { + type: 'lllyasviel/control_v11p_sd15_scribble', + label: 'Scribble', + }, + 'lllyasviel/control_v11p_sd15_softedge': { + type: 'lllyasviel/control_v11p_sd15_softedge', + label: 'Soft Edge', + defaultProcessor: 'hed_image_processor', + }, + 'lllyasviel/control_v11e_sd15_shuffle': { + type: 'lllyasviel/control_v11e_sd15_shuffle', + label: 'Content Shuffle', + defaultProcessor: 'content_shuffle_image_processor', + }, + 'lllyasviel/control_v11p_sd15_openpose': { + type: 'lllyasviel/control_v11p_sd15_openpose', + label: 'Openpose', + defaultProcessor: 'openpose_image_processor', + }, + 'lllyasviel/control_v11f1e_sd15_tile': { + type: 'lllyasviel/control_v11f1e_sd15_tile', + label: 'Tile (experimental)', + }, + 'lllyasviel/control_v11e_sd15_ip2p': { + type: 'lllyasviel/control_v11e_sd15_ip2p', + label: 'Pix2Pix (experimental)', + }, + 'CrucibleAI/ControlNetMediaPipeFace': { + type: 'CrucibleAI/ControlNetMediaPipeFace', + label: 'Mediapipe Face', + defaultProcessor: 'mediapipe_face_processor', + }, +}; + +export type ControlNetModelName = keyof typeof CONTROLNET_MODELS; diff --git a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts index 2558a38ab2..d71ff4da68 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts @@ -9,9 +9,8 @@ import { } from './types'; import { CONTROLNET_MODELS, - CONTROLNET_MODEL_MAP, CONTROLNET_PROCESSORS, - ControlNetModel, + ControlNetModelName, } from './constants'; import { controlNetImageProcessed } from './actions'; import { imageDeleted, imageUrlsReceived } from 'services/thunks/image'; @@ -21,7 +20,7 @@ import { appSocketInvocationError } from 'services/events/actions'; export const initialControlNet: Omit = { isEnabled: true, - model: CONTROLNET_MODELS[0], + model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type, weight: 1, beginStepPct: 0, endStepPct: 1, @@ -36,7 +35,7 @@ export const initialControlNet: Omit = { export type ControlNetConfig = { controlNetId: string; isEnabled: boolean; - model: ControlNetModel; + model: ControlNetModelName; weight: number; beginStepPct: number; endStepPct: number; @@ -138,14 +137,17 @@ export const controlNetSlice = createSlice({ }, controlNetModelChanged: ( state, - action: PayloadAction<{ controlNetId: string; model: ControlNetModel }> + action: PayloadAction<{ + controlNetId: string; + model: ControlNetModelName; + }> ) => { const { controlNetId, model } = action.payload; state.controlNets[controlNetId].model = model; state.controlNets[controlNetId].processedControlImage = null; if (state.controlNets[controlNetId].shouldAutoConfig) { - const processorType = CONTROLNET_MODEL_MAP[model]; + const processorType = CONTROLNET_MODELS[model].defaultProcessor; if (processorType) { state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ @@ -225,7 +227,8 @@ export const controlNetSlice = createSlice({ if (newShouldAutoConfig) { // manage the processor for the user const processorType = - CONTROLNET_MODEL_MAP[state.controlNets[controlNetId].model]; + CONTROLNET_MODELS[state.controlNets[controlNetId].model] + .defaultProcessor; if (processorType) { state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts index a1dc5d48ab..f8fea7e4d7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildImageToImageGraph.ts @@ -342,7 +342,10 @@ export const buildImageToImageGraph = (state: RootState): Graph => { }); } - if (shouldFitToWidthHeight) { + if ( + shouldFitToWidthHeight && + (initialImage.width !== width || initialImage.height !== height) + ) { // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` // Create a resize node, explicitly setting its image diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index f4413c4cf6..2aa762b477 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -14,9 +14,11 @@ const selector = createSelector( (ui, generation) => { // TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413 // but we need to wait for the next release before removing this special handling. - const allSchedulers = ui.schedulers.filter((scheduler) => { - return !['dpmpp_2s'].includes(scheduler); - }); + const allSchedulers = ui.schedulers + .filter((scheduler) => { + return !['dpmpp_2s'].includes(scheduler); + }) + .sort((a, b) => a.localeCompare(b)); return { scheduler: generation.scheduler, @@ -45,9 +47,9 @@ const ParamScheduler = () => { return ( ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx index 2ddd3fbb9e..fa415074e6 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ImageToImage/InitialImagePreview.tsx @@ -58,6 +58,7 @@ const InitialImagePreview = () => { onReset={handleReset} fallback={} postUploadAction={{ type: 'SET_INITIAL_IMAGE' }} + withResetIcon /> ); diff --git a/invokeai/frontend/web/src/features/system/components/ClearTempFolderButtonModal.tsx b/invokeai/frontend/web/src/features/system/components/ClearTempFolderButtonModal.tsx deleted file mode 100644 index a220c93b3f..0000000000 --- a/invokeai/frontend/web/src/features/system/components/ClearTempFolderButtonModal.tsx +++ /dev/null @@ -1,41 +0,0 @@ -// import { emptyTempFolder } from 'app/socketio/actions'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAIAlertDialog from 'common/components/IAIAlertDialog'; -import IAIButton from 'common/components/IAIButton'; -import { isStagingSelector } from 'features/canvas/store/canvasSelectors'; -import { - clearCanvasHistory, - resetCanvas, -} from 'features/canvas/store/canvasSlice'; -import { useTranslation } from 'react-i18next'; -import { FaTrash } from 'react-icons/fa'; - -const EmptyTempFolderButtonModal = () => { - const isStaging = useAppSelector(isStagingSelector); - const dispatch = useAppDispatch(); - const { t } = useTranslation(); - - const acceptCallback = () => { - dispatch(emptyTempFolder()); - dispatch(resetCanvas()); - dispatch(clearCanvasHistory()); - }; - - return ( - } size="sm" isDisabled={isStaging}> - {t('unifiedCanvas.emptyTempImageFolder')} - - } - > -

{t('unifiedCanvas.emptyTempImagesFolderMessage')}

-
-

{t('unifiedCanvas.emptyTempImagesFolderConfirm')}

-
- ); -}; -export default EmptyTempFolderButtonModal; diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index be4be8ceaa..1eb8e4cb4c 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -4,34 +4,29 @@ import { isEqual } from 'lodash-es'; import { useTranslation } from 'react-i18next'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { - selectModelsAll, - selectModelsById, - selectModelsIds, -} from '../store/modelSlice'; +import { selectModelsAll, selectModelsById } from '../store/modelSlice'; import { RootState } from 'app/store/store'; import { modelSelected } from 'features/parameters/store/generationSlice'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import IAICustomSelect, { - ItemTooltips, + IAICustomSelectOption, } from 'common/components/IAICustomSelect'; const selector = createSelector( [(state: RootState) => state, generationSelector], (state, generation) => { const selectedModel = selectModelsById(state, generation.model); - const allModelNames = selectModelsIds(state).map((id) => String(id)); - const allModelTooltips = selectModelsAll(state).reduce( - (allModelTooltips, model) => { - allModelTooltips[model.name] = model.description ?? ''; - return allModelTooltips; - }, - {} as ItemTooltips - ); + + const modelData = selectModelsAll(state) + .map((m) => ({ + value: m.name, + label: m.name, + tooltip: m.description, + })) + .sort((a, b) => a.label.localeCompare(b.label)); return { - allModelNames, - allModelTooltips, selectedModel, + modelData, }; }, { @@ -44,8 +39,7 @@ const selector = createSelector( const ModelSelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { allModelNames, allModelTooltips, selectedModel } = - useAppSelector(selector); + const { selectedModel, modelData } = useAppSelector(selector); const handleChangeModel = useCallback( (v: string | null | undefined) => { if (!v) { @@ -60,10 +54,9 @@ const ModelSelect = () => { diff --git a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx index c164b87515..b566967b7c 100644 --- a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx +++ b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx @@ -14,7 +14,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice'; import { InvokeTabName } from 'features/ui/store/tabMap'; import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice'; -import { memo, ReactNode, useCallback, useMemo } from 'react'; +import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { MdDeviceHub, MdGridOn } from 'react-icons/md'; import { GoTextSize } from 'react-icons/go'; @@ -47,22 +47,22 @@ export interface InvokeTabInfo { const tabs: InvokeTabInfo[] = [ { id: 'txt2img', - icon: , + icon: , content: , }, { id: 'img2img', - icon: , + icon: , content: , }, { id: 'unifiedCanvas', - icon: , + icon: , content: , }, { id: 'nodes', - icon: , + icon: , content: , }, ]; @@ -119,6 +119,12 @@ const InvokeTabs = () => { } }, [dispatch, activeTabName]); + const handleClickTab = useCallback((e: MouseEvent) => { + if (e.target instanceof HTMLElement) { + e.target.blur(); + } + }, []); + const tabs = useMemo( () => enabledTabs.map((tab) => ( @@ -128,7 +134,7 @@ const InvokeTabs = () => { label={String(t(`common.${tab.id}` as ResourceKey))} placement="end" > - + {String(t(`common.${tab.id}` as ResourceKey))} @@ -136,7 +142,7 @@ const InvokeTabs = () => {
)), - [t, enabledTabs] + [enabledTabs, t, handleClickTab] ); const tabPanels = useMemo( diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasBeta/UnifiedCanvasToolSettings/UnifiedCanvasSettings.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasBeta/UnifiedCanvasToolSettings/UnifiedCanvasSettings.tsx index a173211258..a179a95c3f 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasBeta/UnifiedCanvasToolSettings/UnifiedCanvasSettings.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasBeta/UnifiedCanvasToolSettings/UnifiedCanvasSettings.tsx @@ -12,7 +12,6 @@ import { setShouldShowCanvasDebugInfo, setShouldShowIntermediates, } from 'features/canvas/store/canvasSlice'; -import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal'; import { FaWrench } from 'react-icons/fa'; @@ -105,7 +104,6 @@ const UnifiedCanvasSettings = () => { onChange={(e) => dispatch(setShouldAntialias(e.target.checked))} /> - ); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx index 1fadd0ada5..898f7db839 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasContent.tsx @@ -55,8 +55,6 @@ const UnifiedCanvasContent = () => { }); useLayoutEffect(() => { - dispatch(requestCanvasRescale()); - const resizeCallback = () => { dispatch(requestCanvasRescale()); }; diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 9f433aa330..5363cc480b 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -121,3 +121,78 @@ def test_graph_state_collects(mock_services): assert isinstance(n6[0], CollectInvocation) assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) + + +def test_graph_state_prepares_eagerly(mock_services): + """Tests that all prepareable nodes are prepared""" + graph = Graph() + + test_prompts = ["Banana sushi", "Cat sushi"] + graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) + graph.add_node(IterateInvocation(id="iterate")) + graph.add_node(PromptTestInvocation(id="prompt_iterated")) + graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) + graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) + + # separated, fully-preparable chain of nodes + graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi")) + graph.add_node(PromptTestInvocation(id="prompt_chain_2")) + graph.add_node(PromptTestInvocation(id="prompt_chain_3")) + graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt")) + graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt")) + + g = GraphExecutionState(graph=graph) + g.next() + + assert "prompt_collection" in g.source_prepared_mapping + assert "prompt_chain_1" in g.source_prepared_mapping + assert "prompt_chain_2" in g.source_prepared_mapping + assert "prompt_chain_3" in g.source_prepared_mapping + assert "iterate" not in g.source_prepared_mapping + assert "prompt_iterated" not in g.source_prepared_mapping + + +def test_graph_executes_depth_first(mock_services): + """Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch""" + graph = Graph() + + test_prompts = ["Banana sushi", "Cat sushi"] + graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts))) + graph.add_node(IterateInvocation(id="iterate")) + graph.add_node(PromptTestInvocation(id="prompt_iterated")) + graph.add_node(PromptTestInvocation(id="prompt_successor")) + graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection")) + graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt")) + graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt")) + + g = GraphExecutionState(graph=graph) + n1 = invoke_next(g, mock_services) + n2 = invoke_next(g, mock_services) + n3 = invoke_next(g, mock_services) + n4 = invoke_next(g, mock_services) + + # Because ordering is not guaranteed, we cannot compare results directly. + # Instead, we must count the number of results. + def get_completed_count(g, id): + ids = [i for i in g.source_prepared_mapping[id]] + completed_ids = [i for i in g.executed if i in ids] + return len(completed_ids) + + # Check at each step that the number of executed nodes matches the expectation for depth-first execution + assert get_completed_count(g, "prompt_iterated") == 1 + assert get_completed_count(g, "prompt_successor") == 0 + + n5 = invoke_next(g, mock_services) + + assert get_completed_count(g, "prompt_iterated") == 1 + assert get_completed_count(g, "prompt_successor") == 1 + + n6 = invoke_next(g, mock_services) + + assert get_completed_count(g, "prompt_iterated") == 2 + assert get_completed_count(g, "prompt_successor") == 1 + + n7 = invoke_next(g, mock_services) + + assert get_completed_count(g, "prompt_iterated") == 2 + assert get_completed_count(g, "prompt_successor") == 2