resolve conflicts

This commit is contained in:
Lincoln Stein 2023-06-10 10:13:54 -04:00
commit 3d2ff7755e
26 changed files with 535 additions and 217 deletions

View File

@ -14,13 +14,12 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pathlib import Path from pathlib import Path
from pydantic.schema import schema from pydantic.schema import schema
# Do this early so that other modules pick up configuration #This should come early so that modules can log their initialization properly
from .services.config import InvokeAIAppConfig from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
app_config.parse_args() app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger()
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir

View File

@ -11,19 +11,26 @@ from typing import Union, get_type_hints
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from pydantic.fields import Field from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import (default_text_to_image_graph_id,
create_system_graphs)
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .cli.commands import (BaseCommand, CliContext, ExitCli, from .cli.commands import (BaseCommand, CliContext, ExitCli,
SortedHelpFormatter, add_graph_parsers, add_parsers) SortedHelpFormatter, add_graph_parsers, add_parsers)
from .cli.completer import set_autocompleter from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.default_graphs import (create_system_graphs,
default_text_to_image_graph_id)
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.graph import (Edge, EdgeConnection, GraphExecutionState, from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
GraphInvocation, LibraryGraph, GraphInvocation, LibraryGraph,
@ -32,13 +39,11 @@ from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices from .services.invocation_services import InvocationServices
from .services.invoker import Invoker from .services.invoker import Invoker
from .services.latent_storage import (DiskLatentsStorage,
ForwardCacheLatentsStorage)
from .services.model_manager_service import ModelManagerService from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage
from .services.config import InvokeAIAppConfig
class CliCommand(BaseModel): class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -47,7 +52,6 @@ class CliCommand(BaseModel):
class InvalidArgs(Exception): class InvalidArgs(Exception):
pass pass
def add_invocation_args(command_parser): def add_invocation_args(command_parser):
# Add linking capability # Add linking capability
command_parser.add_argument( command_parser.add_argument(
@ -192,11 +196,6 @@ def invoke_all(context: CliContext):
raise SessionError() raise SessionError()
def invoke_cli(): def invoke_cli():
# this gets the basic configuration
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.getLogger()
# get the optional list of invocations to execute on the command line # get the optional list of invocations to execute on the command line
parser = config.get_parser() parser = config.get_parser()
parser.add_argument('commands',nargs='*') parser.add_argument('commands',nargs='*')

View File

@ -859,11 +859,9 @@ class GraphExecutionState(BaseModel):
if next_node is None: if next_node is None:
prepared_id = self._prepare() prepared_id = self._prepare()
# TODO: prepare multiple nodes at once? # Prepare as many nodes as we can
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation): while prepared_id is not None:
# prepared_id = self._prepare() prepared_id = self._prepare()
if prepared_id is not None:
next_node = self._get_next_node() next_node = self._get_next_node()
# Get values from edges # Get values from edges
@ -1010,14 +1008,30 @@ class GraphExecutionState(BaseModel):
# Get flattened source graph # Get flattened source graph
g = self.graph.nx_graph_flat() g = self.graph.nx_graph_flat()
# Find next unprepared node where all source nodes are executed # Find next node that:
# - was not already prepared
# - is not an iterate node whose inputs have not been executed
# - does not have an unexecuted iterate ancestor
sorted_nodes = nx.topological_sort(g) sorted_nodes = nx.topological_sort(g)
next_node_id = next( next_node_id = next(
( (
n n
for n in sorted_nodes for n in sorted_nodes
# exclude nodes that have already been prepared
if n not in self.source_prepared_mapping if n not in self.source_prepared_mapping
and all((e[0] in self.executed for e in g.in_edges(n))) # exclude iterate nodes whose inputs have not been executed
and not (
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
)
# exclude nodes who have unexecuted iterate ancestors
and not any(
(
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
and a not in self.executed # ...that is not executed
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
)
)
), ),
None, None,
) )
@ -1114,9 +1128,22 @@ class GraphExecutionState(BaseModel):
) )
def _get_next_node(self) -> Optional[BaseInvocation]: def _get_next_node(self) -> Optional[BaseInvocation]:
"""Gets the deepest node that is ready to be executed"""
g = self.execution_graph.nx_graph() g = self.execution_graph.nx_graph()
sorted_nodes = nx.topological_sort(g)
next_node = next((n for n in sorted_nodes if n not in self.executed), None) # Depth-first search with pre-order traversal is a depth-first topological sort
sorted_nodes = nx.dfs_preorder_nodes(g)
next_node = next(
(
n
for n in sorted_nodes
if n not in self.executed # the node must not already be executed...
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
),
None,
)
if next_node is None: if next_node is None:
return None return None

View File

@ -22,7 +22,8 @@ class Invoker:
def invoke( def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None: ) -> str | None:
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute""" """Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation # Get the next invocation
invocation = graph_execution_state.next() invocation = graph_execution_state.next()

View File

@ -40,6 +40,7 @@ import invokeai.configs as configs
from invokeai.app.services.config import ( from invokeai.app.services.config import (
InvokeAIAppConfig, InvokeAIAppConfig,
) )
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import ( from invokeai.frontend.install.widgets import (
CenteredButtonPress, CenteredButtonPress,
@ -80,6 +81,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# or renaming it and then running invokeai-configure again. # or renaming it and then running invokeai-configure again.
""" """
logger=None
# -------------------------------------------- # --------------------------------------------
def postscript(errors: None): def postscript(errors: None):
@ -824,6 +826,7 @@ def main():
if opt.full_precision: if opt.full_precision:
invoke_args.extend(['--precision','float32']) invoke_args.extend(['--precision','float32'])
config.parse_args(invoke_args) config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
errors = set() errors = set()

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team # Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""invokeai.util.logging """
invokeai.util.logging
Logging class for InvokeAI that produces console messages Logging class for InvokeAI that produces console messages
@ -11,6 +12,7 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
(or) (or)
logger = InvokeAILogger.getLogger(__name__) // To use the filename logger = InvokeAILogger.getLogger(__name__) // To use the filename
logger.configure()
logger.critical('this is critical') // Critical Message logger.critical('this is critical') // Critical Message
logger.error('this is an error') // Error Message logger.error('this is an error') // Error Message
@ -28,6 +30,149 @@ Console messages:
Alternate Method (in this case the logger name will be set to InvokeAI): Alternate Method (in this case the logger name will be set to InvokeAI):
import invokeai.backend.util.logging as IAILogger import invokeai.backend.util.logging as IAILogger
IAILogger.debug('this is a debugging message') IAILogger.debug('this is a debugging message')
## Configuration
The default configuration will print to stderr on the console. To add
additional logging handlers, call getLogger with an initialized InvokeAIAppConfig
object:
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.getLogger(config=config)
### Three command-line options control logging:
`--log_handlers <handler1> <handler2> ...`
This option activates one or more log handlers. Options are "console", "file", "syslog" and "http". To specify more than one, separate them by spaces:
```
invokeai-web --log_handlers console syslog=/dev/log file=C:\\Users\\fred\\invokeai.log
```
The format of these options is described below.
### `--log_format {plain|color|legacy|syslog}`
This controls the format of log messages written to the console. Only the "console" log handler is currently affected by this setting.
* "plain" provides formatted messages like this:
```bash
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
```
* "color" produces similar output, but the text will be color coded to indicate the severity of the message.
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
```
### this is a critical error
*** this is an error
** this is a warning
>> this is an informational messages
| this is a debug message
```
* "syslog" produces messages suitable for syslog entries:
```bash
InvokeAI [2691178] <CRITICAL> this is a critical error
InvokeAI [2691178] <ERROR> this is an error
InvokeAI [2691178] <WARNING> this is a warning
InvokeAI [2691178] <INFO> this is an informational messages
InvokeAI [2691178] <DEBUG> this is a debug message
```
(note that the date, time and hostname will be added by the syslog system)
### `--log_level {debug|info|warning|error|critical}`
Providing this command-line option will cause only messages at the specified level or above to be emitted.
## Console logging
When "console" is provided to `--log_handlers`, messages will be written to the command line window in which InvokeAI was launched. By default, the color formatter will be used unless overridden by `--log_format`.
## File logging
When "file" is provided to `--log_handlers`, entries will be written to the file indicated in the path argument. By default, the "plain" format will be used:
```bash
invokeai-web --log_handlers file=/var/log/invokeai.log
```
## Syslog logging
When "syslog" is requested, entries will be sent to the syslog system. There are a variety of ways to control where the log message is sent:
* Send to the local machine using the `/dev/log` socket:
```
invokeai-web --log_handlers syslog=/dev/log
```
* Send to the local machine using a UDP message:
```
invokeai-web --log_handlers syslog=localhost
```
* Send to the local machine using a UDP message on a nonstandard port:
```
invokeai-web --log_handlers syslog=localhost:512
```
* Send to a remote machine named "loghost" on the local LAN using facility LOG_USER and UDP packets:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
```
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM are defaults.
* Send to a remote machine named "loghost" using the facility LOCAL0 and using a TCP socket:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
```
If no arguments are specified (just a bare "syslog"), then the logging system will look for a UNIX socket named `/dev/log`, and if not found try to send a UDP message to `localhost`. The Macintosh OS used to support logging to a socket named `/var/run/syslog`, but this feature has since been disabled.
## Web logging
If you have access to a web server that is configured to log messages when a particular URL is requested, you can log using the "http" method:
```
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
```
The optional [,method=] part can be used to specify whether the URL accepts GET (default) or POST messages.
Currently password authentication and SSL are not supported.
## Using the configuration file
You can set and forget logging options by adding a "Logging" section to `invokeai.yaml`:
```
InvokeAI:
[... other settings...]
Logging:
log_handlers:
- console
- syslog=/dev/log
log_level: info
log_format: color
```
""" """
import logging.handlers import logging.handlers
@ -180,14 +325,17 @@ class InvokeAILogger(object):
loggers = dict() loggers = dict()
@classmethod @classmethod
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger: def getLogger(cls,
config = get_invokeai_config() name: str = 'InvokeAI',
config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger:
if name not in cls.loggers: if name in cls.loggers:
logger = cls.loggers[name]
logger.handlers.clear()
else:
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(config.log_level.upper()) # yes, strings work here logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.getLoggers(config): for ch in cls.getLoggers(config):
logger.addHandler(ch) logger.addHandler(ch)
cls.loggers[name] = logger cls.loggers[name] = logger
return cls.loggers[name] return cls.loggers[name]
@ -199,9 +347,11 @@ class InvokeAILogger(object):
handler_name,*args = handler.split('=',2) handler_name,*args = handler.split('=',2)
args = args[0] if len(args) > 0 else None args = args[0] if len(args) > 0 else None
# console and file are the only handlers that gets a custom formatter # console and file get the fancy formatter.
# syslog gets a simple one
# http gets no custom formatter
formatter = LOG_FORMATTERS[config.log_format]
if handler_name=='console': if handler_name=='console':
formatter = LOG_FORMATTERS[config.log_format]
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setFormatter(formatter()) ch.setFormatter(formatter())
handlers.append(ch) handlers.append(ch)
@ -212,7 +362,7 @@ class InvokeAILogger(object):
elif handler_name=='file': elif handler_name=='file':
ch = cls._parse_file_args(args) ch = cls._parse_file_args(args)
ch.setFormatter(InvokeAISyslogFormatter()) ch.setFormatter(formatter())
handlers.append(ch) handlers.append(ch)
elif handler_name=='http': elif handler_name=='http':

View File

@ -28,7 +28,7 @@ import torch
from npyscreen import widget from npyscreen import widget
from omegaconf import OmegaConf from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.install.model_install_backend import ( from invokeai.backend.install.model_install_backend import (
Dataset_path, Dataset_path,
@ -939,6 +939,7 @@ def main():
if opt.full_precision: if opt.full_precision:
invoke_args.extend(['--precision','float32']) invoke_args.extend(['--precision','float32'])
config.parse_args(invoke_args) config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
if not (config.root_dir / config.conf_path.parent).exists(): if not (config.root_dir / config.conf_path.parent).exists():
logger.info( logger.info(

View File

@ -22,6 +22,7 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import GlobalHotkeys from './GlobalHotkeys'; import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster'; import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal'; import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
const DEFAULT_CONFIG = {}; const DEFAULT_CONFIG = {};
@ -66,10 +67,17 @@ const App = ({
setIsReady(true); setIsReady(true);
} }
if (isApplicationReady) {
// TODO: This is a jank fix for canvas not filling the screen on first load
setTimeout(() => {
dispatch(requestCanvasRescale());
}, 200);
}
return () => { return () => {
setIsReady && setIsReady(false); setIsReady && setIsReady(false);
}; };
}, [isApplicationReady, setIsReady]); }, [dispatch, isApplicationReady, setIsReady]);
return ( return (
<> <>

View File

@ -40,11 +40,11 @@ const ImageDndContext = (props: ImageDndContextProps) => {
); );
const mouseSensor = useSensor(MouseSensor, { const mouseSensor = useSensor(MouseSensor, {
activationConstraint: { delay: 250, tolerance: 5 }, activationConstraint: { delay: 150, tolerance: 5 },
}); });
const touchSensor = useSensor(TouchSensor, { const touchSensor = useSensor(TouchSensor, {
activationConstraint: { delay: 250, tolerance: 5 }, activationConstraint: { delay: 150, tolerance: 5 },
}); });
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos // TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay // Alternatively, fix `rectIntersection` collection detection to work with the drag overlay

View File

@ -2,7 +2,6 @@ import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
import { import {
Box, Box,
Flex, Flex,
FlexProps,
FormControl, FormControl,
FormControlProps, FormControlProps,
FormLabel, FormLabel,
@ -16,6 +15,7 @@ import {
} from '@chakra-ui/react'; } from '@chakra-ui/react';
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom'; import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
import { useSelect } from 'downshift'; import { useSelect } from 'downshift';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
@ -23,15 +23,19 @@ import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
export type ItemTooltips = { [key: string]: string }; export type ItemTooltips = { [key: string]: string };
export type IAICustomSelectOption = {
value: string;
label: string;
tooltip?: string;
};
type IAICustomSelectProps = { type IAICustomSelectProps = {
label?: string; label?: string;
items: string[]; value: string;
itemTooltips?: ItemTooltips; data: IAICustomSelectOption[] | string[];
selectedItem: string; onChange: (v: string) => void;
setSelectedItem: (v: string | null | undefined) => void;
withCheckIcon?: boolean; withCheckIcon?: boolean;
formControlProps?: FormControlProps; formControlProps?: FormControlProps;
buttonProps?: FlexProps;
tooltip?: string; tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>; tooltipProps?: Omit<TooltipProps, 'children'>;
ellipsisPosition?: 'start' | 'end'; ellipsisPosition?: 'start' | 'end';
@ -40,18 +44,33 @@ type IAICustomSelectProps = {
const IAICustomSelect = (props: IAICustomSelectProps) => { const IAICustomSelect = (props: IAICustomSelectProps) => {
const { const {
label, label,
items,
itemTooltips,
setSelectedItem,
selectedItem,
withCheckIcon, withCheckIcon,
formControlProps, formControlProps,
tooltip, tooltip,
buttonProps,
tooltipProps, tooltipProps,
ellipsisPosition = 'end', ellipsisPosition = 'end',
data,
value,
onChange,
} = props; } = props;
const values = useMemo(() => {
return data.map<IAICustomSelectOption>((v) => {
if (isString(v)) {
return { value: v, label: v };
}
return v;
});
}, [data]);
const stringValues = useMemo(() => {
return values.map((v) => v.value);
}, [values]);
const valueData = useMemo(() => {
return values.find((v) => v.value === value);
}, [values, value]);
const { const {
isOpen, isOpen,
getToggleButtonProps, getToggleButtonProps,
@ -60,10 +79,11 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
highlightedIndex, highlightedIndex,
getItemProps, getItemProps,
} = useSelect({ } = useSelect({
items, items: stringValues,
selectedItem, selectedItem: value,
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
setSelectedItem(newSelectedItem), newSelectedItem && onChange(newSelectedItem);
},
}); });
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({ const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
@ -93,8 +113,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
)} )}
<Tooltip label={tooltip} {...tooltipProps}> <Tooltip label={tooltip} {...tooltipProps}>
<Flex <Flex
{...getToggleButtonProps({ ref: refs.setReference })} {...getToggleButtonProps({ ref: refs.reference })}
{...buttonProps}
sx={{ sx={{
alignItems: 'center', alignItems: 'center',
userSelect: 'none', userSelect: 'none',
@ -119,7 +138,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
direction: labelTextDirection, direction: labelTextDirection,
}} }}
> >
{selectedItem} {valueData?.label}
</Text> </Text>
<ChevronUpIcon <ChevronUpIcon
sx={{ sx={{
@ -135,7 +154,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
{isOpen && ( {isOpen && (
<List <List
as={Flex} as={Flex}
ref={refs.setFloating} ref={refs.floating}
sx={{ sx={{
...floatingStyles, ...floatingStyles,
top: 0, top: 0,
@ -155,8 +174,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
}} }}
> >
<OverlayScrollbarsComponent> <OverlayScrollbarsComponent>
{items.map((item, index) => { {values.map((v, index) => {
const isSelected = selectedItem === item; const isSelected = value === v.value;
const isHighlighted = highlightedIndex === index; const isHighlighted = highlightedIndex === index;
const fontWeight = isSelected ? 700 : 500; const fontWeight = isSelected ? 700 : 500;
const bg = isHighlighted const bg = isHighlighted
@ -166,9 +185,9 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
: undefined; : undefined;
return ( return (
<Tooltip <Tooltip
isDisabled={!itemTooltips} isDisabled={!v.tooltip}
key={`${item}${index}`} key={`${v.value}${index}`}
label={itemTooltips?.[item]} label={v.tooltip}
hasArrow hasArrow
placement="right" placement="right"
> >
@ -182,8 +201,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
transitionProperty: 'common', transitionProperty: 'common',
transitionDuration: '0.15s', transitionDuration: '0.15s',
}} }}
key={`${item}${index}`} {...getItemProps({ item: v.value, index })}
{...getItemProps({ item, index })}
> >
{withCheckIcon ? ( {withCheckIcon ? (
<Grid gridTemplateColumns="1.25rem auto"> <Grid gridTemplateColumns="1.25rem auto">
@ -198,7 +216,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
fontWeight, fontWeight,
}} }}
> >
{item} {v.label}
</Text> </Text>
</GridItem> </GridItem>
</Grid> </Grid>
@ -210,7 +228,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
fontWeight, fontWeight,
}} }}
> >
{item} {v.label}
</Text> </Text>
)} )}
</ListItem> </ListItem>

View File

@ -14,7 +14,7 @@ import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { AnimatePresence } from 'framer-motion'; import { AnimatePresence } from 'framer-motion';
import { ReactElement, SyntheticEvent, useCallback } from 'react'; import { ReactElement, SyntheticEvent, useCallback } from 'react';
import { memo, useRef } from 'react'; import { memo, useRef } from 'react';
import { FaImage, FaTimes, FaUpload } from 'react-icons/fa'; import { FaImage, FaTimes, FaUndo, FaUpload } from 'react-icons/fa';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import IAIDropOverlay from './IAIDropOverlay'; import IAIDropOverlay from './IAIDropOverlay';
@ -174,14 +174,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
position: 'absolute', position: 'absolute',
top: 0, top: 0,
right: 0, right: 0,
p: 2,
}} }}
> >
<IAIIconButton <IAIIconButton
size={resetIconSize} size={resetIconSize}
tooltip="Reset Image" tooltip="Reset Image"
aria-label="Reset Image" aria-label="Reset Image"
icon={<FaTimes />} icon={<FaUndo />}
onClick={onReset} onClick={onReset}
/> />
</Box> </Box>

View File

@ -16,7 +16,6 @@ import {
setShouldShowIntermediates, setShouldShowIntermediates,
setShouldSnapToGrid, setShouldSnapToGrid,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
import { ChangeEvent } from 'react'; import { ChangeEvent } from 'react';
@ -159,7 +158,6 @@ const IAICanvasSettingsButtonPopover = () => {
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))} onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
/> />
<ClearCanvasHistoryButtonModal /> <ClearCanvasHistoryButtonModal />
<EmptyTempFolderButtonModal />
</Flex> </Flex>
</IAIPopover> </IAIPopover>
); );

View File

@ -30,7 +30,10 @@ import {
} from './canvasTypes'; } from './canvasTypes';
import { ImageDTO } from 'services/api'; import { ImageDTO } from 'services/api';
import { sessionCanceled } from 'services/thunks/session'; import { sessionCanceled } from 'services/thunks/session';
import { setShouldUseCanvasBetaLayout } from 'features/ui/store/uiSlice'; import {
setActiveTab,
setShouldUseCanvasBetaLayout,
} from 'features/ui/store/uiSlice';
import { imageUrlsReceived } from 'services/thunks/image'; import { imageUrlsReceived } from 'services/thunks/image';
export const initialLayerState: CanvasLayerState = { export const initialLayerState: CanvasLayerState = {
@ -857,6 +860,11 @@ export const canvasSlice = createSlice({
builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => { builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => {
state.doesCanvasNeedScaling = true; state.doesCanvasNeedScaling = true;
}); });
builder.addCase(setActiveTab, (state, action) => {
state.doesCanvasNeedScaling = true;
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } = const { image_name, image_origin, image_url, thumbnail_url } =
action.payload; action.payload;

View File

@ -1,17 +1,26 @@
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import IAICustomSelect from 'common/components/IAICustomSelect'; import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import { import {
CONTROLNET_MODELS, CONTROLNET_MODELS,
ControlNetModel, ControlNetModelName,
} from 'features/controlNet/store/constants'; } from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice'; import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
type ParamControlNetModelProps = { type ParamControlNetModelProps = {
controlNetId: string; controlNetId: string;
model: ControlNetModel; model: ControlNetModelName;
}; };
const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
value: m.type,
label: m.label,
tooltip: m.type,
}));
const ParamControlNetModel = (props: ParamControlNetModelProps) => { const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props; const { controlNetId, model } = props;
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
@ -19,7 +28,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const handleModelChanged = useCallback( const handleModelChanged = useCallback(
(val: string | null | undefined) => { (val: string | null | undefined) => {
// TODO: do not cast // TODO: do not cast
const model = val as ControlNetModel; const model = val as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model })); dispatch(controlNetModelChanged({ controlNetId, model }));
}, },
[controlNetId, dispatch] [controlNetId, dispatch]
@ -29,9 +38,9 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
<IAICustomSelect <IAICustomSelect
tooltip={model} tooltip={model}
tooltipProps={{ placement: 'top', hasArrow: true }} tooltipProps={{ placement: 'top', hasArrow: true }}
items={CONTROLNET_MODELS} data={DATA}
selectedItem={model} value={model}
setSelectedItem={handleModelChanged} onChange={handleModelChanged}
ellipsisPosition="start" ellipsisPosition="start"
withCheckIcon withCheckIcon
/> />

View File

@ -1,4 +1,6 @@
import IAICustomSelect from 'common/components/IAICustomSelect'; import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { import {
ControlNetProcessorNode, ControlNetProcessorNode,
@ -7,15 +9,28 @@ import {
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice'; import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks'; import { useAppDispatch } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from '../../store/constants'; import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { map } from 'lodash-es';
type ParamControlNetProcessorSelectProps = { type ParamControlNetProcessorSelectProps = {
controlNetId: string; controlNetId: string;
processorNode: ControlNetProcessorNode; processorNode: ControlNetProcessorNode;
}; };
const CONTROLNET_PROCESSOR_TYPES = Object.keys( const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
CONTROLNET_PROCESSORS CONTROLNET_PROCESSORS,
) as ControlNetProcessorType[]; (p) => ({
value: p.type,
label: p.label,
tooltip: p.description,
})
).sort((a, b) =>
// sort 'none' to the top
a.value === 'none'
? -1
: b.value === 'none'
? 1
: a.label.localeCompare(b.label)
);
const ParamControlNetProcessorSelect = ( const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps props: ParamControlNetProcessorSelectProps
@ -36,9 +51,9 @@ const ParamControlNetProcessorSelect = (
return ( return (
<IAICustomSelect <IAICustomSelect
label="Processor" label="Processor"
items={CONTROLNET_PROCESSOR_TYPES} value={processorNode.type ?? 'canny_image_processor'}
selectedItem={processorNode.type ?? 'canny_image_processor'} data={CONTROLNET_PROCESSOR_TYPES}
setSelectedItem={handleProcessorTypeChanged} onChange={handleProcessorTypeChanged}
withCheckIcon withCheckIcon
/> />
); );

View File

@ -5,12 +5,12 @@ import {
} from './types'; } from './types';
type ControlNetProcessorsDict = Record< type ControlNetProcessorsDict = Record<
ControlNetProcessorType, string,
{ {
type: ControlNetProcessorType; type: ControlNetProcessorType | 'none';
label: string; label: string;
description: string; description: string;
default: RequiredControlNetProcessorNode; default: RequiredControlNetProcessorNode | { type: 'none' };
} }
>; >;
@ -23,10 +23,10 @@ type ControlNetProcessorsDict = Record<
* *
* TODO: Generate from the OpenAPI schema * TODO: Generate from the OpenAPI schema
*/ */
export const CONTROLNET_PROCESSORS = { export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
none: { none: {
type: 'none', type: 'none',
label: 'None', label: 'none',
description: '', description: '',
default: { default: {
type: 'none', type: 'none',
@ -116,7 +116,7 @@ export const CONTROLNET_PROCESSORS = {
}, },
mlsd_image_processor: { mlsd_image_processor: {
type: 'mlsd_image_processor', type: 'mlsd_image_processor',
label: 'MLSD', label: 'M-LSD',
description: '', description: '',
default: { default: {
id: 'mlsd_image_processor', id: 'mlsd_image_processor',
@ -174,39 +174,84 @@ export const CONTROLNET_PROCESSORS = {
}, },
}; };
export const CONTROLNET_MODELS = [ type ControlNetModel = {
'lllyasviel/control_v11p_sd15_canny', type: string;
'lllyasviel/control_v11p_sd15_inpaint', label: string;
'lllyasviel/control_v11p_sd15_mlsd', description?: string;
'lllyasviel/control_v11f1p_sd15_depth', defaultProcessor?: ControlNetProcessorType;
'lllyasviel/control_v11p_sd15_normalbae',
'lllyasviel/control_v11p_sd15_seg',
'lllyasviel/control_v11p_sd15_lineart',
'lllyasviel/control_v11p_sd15s2_lineart_anime',
'lllyasviel/control_v11p_sd15_scribble',
'lllyasviel/control_v11p_sd15_softedge',
'lllyasviel/control_v11e_sd15_shuffle',
'lllyasviel/control_v11p_sd15_openpose',
'lllyasviel/control_v11f1e_sd15_tile',
'lllyasviel/control_v11e_sd15_ip2p',
'CrucibleAI/ControlNetMediaPipeFace',
];
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
export const CONTROLNET_MODEL_MAP: Record<
ControlNetModel,
ControlNetProcessorType
> = {
'lllyasviel/control_v11p_sd15_canny': 'canny_image_processor',
'lllyasviel/control_v11p_sd15_mlsd': 'mlsd_image_processor',
'lllyasviel/control_v11f1p_sd15_depth': 'midas_depth_image_processor',
'lllyasviel/control_v11p_sd15_normalbae': 'normalbae_image_processor',
'lllyasviel/control_v11p_sd15_lineart': 'lineart_image_processor',
'lllyasviel/control_v11p_sd15s2_lineart_anime':
'lineart_anime_image_processor',
'lllyasviel/control_v11p_sd15_softedge': 'hed_image_processor',
'lllyasviel/control_v11e_sd15_shuffle': 'content_shuffle_image_processor',
'lllyasviel/control_v11p_sd15_openpose': 'openpose_image_processor',
'CrucibleAI/ControlNetMediaPipeFace': 'mediapipe_face_processor',
}; };
export const CONTROLNET_MODELS: Record<string, ControlNetModel> = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segment Anything',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -9,9 +9,8 @@ import {
} from './types'; } from './types';
import { import {
CONTROLNET_MODELS, CONTROLNET_MODELS,
CONTROLNET_MODEL_MAP,
CONTROLNET_PROCESSORS, CONTROLNET_PROCESSORS,
ControlNetModel, ControlNetModelName,
} from './constants'; } from './constants';
import { controlNetImageProcessed } from './actions'; import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image'; import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
@ -21,7 +20,7 @@ import { appSocketInvocationError } from 'services/events/actions';
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = { export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true, isEnabled: true,
model: CONTROLNET_MODELS[0], model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,
@ -36,7 +35,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = { export type ControlNetConfig = {
controlNetId: string; controlNetId: string;
isEnabled: boolean; isEnabled: boolean;
model: ControlNetModel; model: ControlNetModelName;
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
@ -138,14 +137,17 @@ export const controlNetSlice = createSlice({
}, },
controlNetModelChanged: ( controlNetModelChanged: (
state, state,
action: PayloadAction<{ controlNetId: string; model: ControlNetModel }> action: PayloadAction<{
controlNetId: string;
model: ControlNetModelName;
}>
) => { ) => {
const { controlNetId, model } = action.payload; const { controlNetId, model } = action.payload;
state.controlNets[controlNetId].model = model; state.controlNets[controlNetId].model = model;
state.controlNets[controlNetId].processedControlImage = null; state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) { if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODEL_MAP[model]; const processorType = CONTROLNET_MODELS[model].defaultProcessor;
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -225,7 +227,8 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) { if (newShouldAutoConfig) {
// manage the processor for the user // manage the processor for the user
const processorType = const processorType =
CONTROLNET_MODEL_MAP[state.controlNets[controlNetId].model]; CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
if (processorType) { if (processorType) {
state.controlNets[controlNetId].processorType = processorType; state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[ state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[

View File

@ -342,7 +342,10 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
}); });
} }
if (shouldFitToWidthHeight) { if (
shouldFitToWidthHeight &&
(initialImage.width !== width || initialImage.height !== height)
) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS` // The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image // Create a resize node, explicitly setting its image

View File

@ -14,9 +14,11 @@ const selector = createSelector(
(ui, generation) => { (ui, generation) => {
// TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413 // TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
// but we need to wait for the next release before removing this special handling. // but we need to wait for the next release before removing this special handling.
const allSchedulers = ui.schedulers.filter((scheduler) => { const allSchedulers = ui.schedulers
return !['dpmpp_2s'].includes(scheduler); .filter((scheduler) => {
}); return !['dpmpp_2s'].includes(scheduler);
})
.sort((a, b) => a.localeCompare(b));
return { return {
scheduler: generation.scheduler, scheduler: generation.scheduler,
@ -45,9 +47,9 @@ const ParamScheduler = () => {
return ( return (
<IAICustomSelect <IAICustomSelect
label={t('parameters.scheduler')} label={t('parameters.scheduler')}
selectedItem={scheduler} value={scheduler}
setSelectedItem={handleChange} data={allSchedulers}
items={allSchedulers} onChange={handleChange}
withCheckIcon withCheckIcon
/> />
); );

View File

@ -58,6 +58,7 @@ const InitialImagePreview = () => {
onReset={handleReset} onReset={handleReset}
fallback={<IAIImageFallback sx={{ bg: 'none' }} />} fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
postUploadAction={{ type: 'SET_INITIAL_IMAGE' }} postUploadAction={{ type: 'SET_INITIAL_IMAGE' }}
withResetIcon
/> />
</Flex> </Flex>
); );

View File

@ -1,41 +0,0 @@
// import { emptyTempFolder } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import {
clearCanvasHistory,
resetCanvas,
} from 'features/canvas/store/canvasSlice';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
const EmptyTempFolderButtonModal = () => {
const isStaging = useAppSelector(isStagingSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const acceptCallback = () => {
dispatch(emptyTempFolder());
dispatch(resetCanvas());
dispatch(clearCanvasHistory());
};
return (
<IAIAlertDialog
title={t('unifiedCanvas.emptyTempImageFolder')}
acceptCallback={acceptCallback}
acceptButtonText={t('unifiedCanvas.emptyFolder')}
triggerComponent={
<IAIButton leftIcon={<FaTrash />} size="sm" isDisabled={isStaging}>
{t('unifiedCanvas.emptyTempImageFolder')}
</IAIButton>
}
>
<p>{t('unifiedCanvas.emptyTempImagesFolderMessage')}</p>
<br />
<p>{t('unifiedCanvas.emptyTempImagesFolderConfirm')}</p>
</IAIAlertDialog>
);
};
export default EmptyTempFolderButtonModal;

View File

@ -4,34 +4,29 @@ import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { import { selectModelsAll, selectModelsById } from '../store/modelSlice';
selectModelsAll,
selectModelsById,
selectModelsIds,
} from '../store/modelSlice';
import { RootState } from 'app/store/store'; import { RootState } from 'app/store/store';
import { modelSelected } from 'features/parameters/store/generationSlice'; import { modelSelected } from 'features/parameters/store/generationSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors'; import { generationSelector } from 'features/parameters/store/generationSelectors';
import IAICustomSelect, { import IAICustomSelect, {
ItemTooltips, IAICustomSelectOption,
} from 'common/components/IAICustomSelect'; } from 'common/components/IAICustomSelect';
const selector = createSelector( const selector = createSelector(
[(state: RootState) => state, generationSelector], [(state: RootState) => state, generationSelector],
(state, generation) => { (state, generation) => {
const selectedModel = selectModelsById(state, generation.model); const selectedModel = selectModelsById(state, generation.model);
const allModelNames = selectModelsIds(state).map((id) => String(id));
const allModelTooltips = selectModelsAll(state).reduce( const modelData = selectModelsAll(state)
(allModelTooltips, model) => { .map<IAICustomSelectOption>((m) => ({
allModelTooltips[model.name] = model.description ?? ''; value: m.name,
return allModelTooltips; label: m.name,
}, tooltip: m.description,
{} as ItemTooltips }))
); .sort((a, b) => a.label.localeCompare(b.label));
return { return {
allModelNames,
allModelTooltips,
selectedModel, selectedModel,
modelData,
}; };
}, },
{ {
@ -44,8 +39,7 @@ const selector = createSelector(
const ModelSelect = () => { const ModelSelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const { allModelNames, allModelTooltips, selectedModel } = const { selectedModel, modelData } = useAppSelector(selector);
useAppSelector(selector);
const handleChangeModel = useCallback( const handleChangeModel = useCallback(
(v: string | null | undefined) => { (v: string | null | undefined) => {
if (!v) { if (!v) {
@ -60,10 +54,9 @@ const ModelSelect = () => {
<IAICustomSelect <IAICustomSelect
label={t('modelManager.model')} label={t('modelManager.model')}
tooltip={selectedModel?.description} tooltip={selectedModel?.description}
items={allModelNames} data={modelData}
itemTooltips={allModelTooltips} value={selectedModel?.name ?? ''}
selectedItem={selectedModel?.name ?? ''} onChange={handleChangeModel}
setSelectedItem={handleChangeModel}
withCheckIcon={true} withCheckIcon={true}
tooltipProps={{ placement: 'top', hasArrow: true }} tooltipProps={{ placement: 'top', hasArrow: true }}
/> />

View File

@ -14,7 +14,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice'; import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import { InvokeTabName } from 'features/ui/store/tabMap'; import { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice'; import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
import { memo, ReactNode, useCallback, useMemo } from 'react'; import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { MdDeviceHub, MdGridOn } from 'react-icons/md'; import { MdDeviceHub, MdGridOn } from 'react-icons/md';
import { GoTextSize } from 'react-icons/go'; import { GoTextSize } from 'react-icons/go';
@ -47,22 +47,22 @@ export interface InvokeTabInfo {
const tabs: InvokeTabInfo[] = [ const tabs: InvokeTabInfo[] = [
{ {
id: 'txt2img', id: 'txt2img',
icon: <Icon as={GoTextSize} sx={{ boxSize: 6 }} />, icon: <Icon as={GoTextSize} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <TextToImageTab />, content: <TextToImageTab />,
}, },
{ {
id: 'img2img', id: 'img2img',
icon: <Icon as={FaImage} sx={{ boxSize: 6 }} />, icon: <Icon as={FaImage} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <ImageTab />, content: <ImageTab />,
}, },
{ {
id: 'unifiedCanvas', id: 'unifiedCanvas',
icon: <Icon as={MdGridOn} sx={{ boxSize: 6 }} />, icon: <Icon as={MdGridOn} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <UnifiedCanvasTab />, content: <UnifiedCanvasTab />,
}, },
{ {
id: 'nodes', id: 'nodes',
icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6 }} />, icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <NodesTab />, content: <NodesTab />,
}, },
]; ];
@ -119,6 +119,12 @@ const InvokeTabs = () => {
} }
}, [dispatch, activeTabName]); }, [dispatch, activeTabName]);
const handleClickTab = useCallback((e: MouseEvent<HTMLElement>) => {
if (e.target instanceof HTMLElement) {
e.target.blur();
}
}, []);
const tabs = useMemo( const tabs = useMemo(
() => () =>
enabledTabs.map((tab) => ( enabledTabs.map((tab) => (
@ -128,7 +134,7 @@ const InvokeTabs = () => {
label={String(t(`common.${tab.id}` as ResourceKey))} label={String(t(`common.${tab.id}` as ResourceKey))}
placement="end" placement="end"
> >
<Tab> <Tab onClick={handleClickTab}>
<VisuallyHidden> <VisuallyHidden>
{String(t(`common.${tab.id}` as ResourceKey))} {String(t(`common.${tab.id}` as ResourceKey))}
</VisuallyHidden> </VisuallyHidden>
@ -136,7 +142,7 @@ const InvokeTabs = () => {
</Tab> </Tab>
</Tooltip> </Tooltip>
)), )),
[t, enabledTabs] [enabledTabs, t, handleClickTab]
); );
const tabPanels = useMemo( const tabPanels = useMemo(

View File

@ -12,7 +12,6 @@ import {
setShouldShowCanvasDebugInfo, setShouldShowCanvasDebugInfo,
setShouldShowIntermediates, setShouldShowIntermediates,
} from 'features/canvas/store/canvasSlice'; } from 'features/canvas/store/canvasSlice';
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
import { FaWrench } from 'react-icons/fa'; import { FaWrench } from 'react-icons/fa';
@ -105,7 +104,6 @@ const UnifiedCanvasSettings = () => {
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))} onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
/> />
<ClearCanvasHistoryButtonModal /> <ClearCanvasHistoryButtonModal />
<EmptyTempFolderButtonModal />
</Flex> </Flex>
</IAIPopover> </IAIPopover>
); );

View File

@ -55,8 +55,6 @@ const UnifiedCanvasContent = () => {
}); });
useLayoutEffect(() => { useLayoutEffect(() => {
dispatch(requestCanvasRescale());
const resizeCallback = () => { const resizeCallback = () => {
dispatch(requestCanvasRescale()); dispatch(requestCanvasRescale());
}; };

View File

@ -121,3 +121,78 @@ def test_graph_state_collects(mock_services):
assert isinstance(n6[0], CollectInvocation) assert isinstance(n6[0], CollectInvocation)
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts) assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
def test_graph_state_prepares_eagerly(mock_services):
"""Tests that all prepareable nodes are prepared"""
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
# separated, fully-preparable chain of nodes
graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi"))
graph.add_node(PromptTestInvocation(id="prompt_chain_2"))
graph.add_node(PromptTestInvocation(id="prompt_chain_3"))
graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt"))
graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt"))
g = GraphExecutionState(graph=graph)
g.next()
assert "prompt_collection" in g.source_prepared_mapping
assert "prompt_chain_1" in g.source_prepared_mapping
assert "prompt_chain_2" in g.source_prepared_mapping
assert "prompt_chain_3" in g.source_prepared_mapping
assert "iterate" not in g.source_prepared_mapping
assert "prompt_iterated" not in g.source_prepared_mapping
def test_graph_executes_depth_first(mock_services):
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_node(PromptTestInvocation(id="prompt_successor"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
g = GraphExecutionState(graph=graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
# Because ordering is not guaranteed, we cannot compare results directly.
# Instead, we must count the number of results.
def get_completed_count(g, id):
ids = [i for i in g.source_prepared_mapping[id]]
completed_ids = [i for i in g.executed if i in ids]
return len(completed_ids)
# Check at each step that the number of executed nodes matches the expectation for depth-first execution
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 0
n5 = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 1
n6 = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 1
n7 = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 2