mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
resolve conflicts
This commit is contained in:
commit
3d2ff7755e
@ -14,13 +14,12 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pathlib import Path
|
||||
from pydantic.schema import schema
|
||||
|
||||
# Do this early so that other modules pick up configuration
|
||||
#This should come early so that modules can log their initialization properly
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
logger = InvokeAILogger.getLogger()
|
||||
logger = InvokeAILogger.getLogger(config=app_config)
|
||||
|
||||
import invokeai.frontend.web as web_dir
|
||||
|
||||
|
@ -11,19 +11,26 @@ from typing import Union, get_type_hints
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic.fields import Field
|
||||
|
||||
# This should come early so that the logger can pick up its configuration options
|
||||
from .services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService
|
||||
from invokeai.app.services.metadata import CoreMetadataService
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from .services.default_graphs import (default_text_to_image_graph_id,
|
||||
create_system_graphs)
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
|
||||
from .cli.commands import (BaseCommand, CliContext, ExitCli,
|
||||
SortedHelpFormatter, add_graph_parsers, add_parsers)
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.default_graphs import (create_system_graphs,
|
||||
default_text_to_image_graph_id)
|
||||
from .services.events import EventServiceBase
|
||||
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
|
||||
GraphInvocation, LibraryGraph,
|
||||
@ -32,13 +39,11 @@ from .services.image_file_storage import DiskImageFileStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
from .services.invocation_services import InvocationServices
|
||||
from .services.invoker import Invoker
|
||||
from .services.latent_storage import (DiskLatentsStorage,
|
||||
ForwardCacheLatentsStorage)
|
||||
from .services.model_manager_service import ModelManagerService
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.restoration_services import RestorationServices
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
|
||||
class CliCommand(BaseModel):
|
||||
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
|
||||
@ -47,7 +52,6 @@ class CliCommand(BaseModel):
|
||||
class InvalidArgs(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def add_invocation_args(command_parser):
|
||||
# Add linking capability
|
||||
command_parser.add_argument(
|
||||
@ -192,11 +196,6 @@ def invoke_all(context: CliContext):
|
||||
raise SessionError()
|
||||
|
||||
def invoke_cli():
|
||||
# this gets the basic configuration
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
# get the optional list of invocations to execute on the command line
|
||||
parser = config.get_parser()
|
||||
parser.add_argument('commands',nargs='*')
|
||||
|
@ -859,11 +859,9 @@ class GraphExecutionState(BaseModel):
|
||||
if next_node is None:
|
||||
prepared_id = self._prepare()
|
||||
|
||||
# TODO: prepare multiple nodes at once?
|
||||
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
|
||||
# prepared_id = self._prepare()
|
||||
|
||||
if prepared_id is not None:
|
||||
# Prepare as many nodes as we can
|
||||
while prepared_id is not None:
|
||||
prepared_id = self._prepare()
|
||||
next_node = self._get_next_node()
|
||||
|
||||
# Get values from edges
|
||||
@ -1010,14 +1008,30 @@ class GraphExecutionState(BaseModel):
|
||||
# Get flattened source graph
|
||||
g = self.graph.nx_graph_flat()
|
||||
|
||||
# Find next unprepared node where all source nodes are executed
|
||||
# Find next node that:
|
||||
# - was not already prepared
|
||||
# - is not an iterate node whose inputs have not been executed
|
||||
# - does not have an unexecuted iterate ancestor
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node_id = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
# exclude nodes that have already been prepared
|
||||
if n not in self.source_prepared_mapping
|
||||
and all((e[0] in self.executed for e in g.in_edges(n)))
|
||||
# exclude iterate nodes whose inputs have not been executed
|
||||
and not (
|
||||
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
|
||||
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
|
||||
)
|
||||
# exclude nodes who have unexecuted iterate ancestors
|
||||
and not any(
|
||||
(
|
||||
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
|
||||
and a not in self.executed # ...that is not executed
|
||||
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
|
||||
)
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@ -1114,9 +1128,22 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
def _get_next_node(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the deepest node that is ready to be executed"""
|
||||
g = self.execution_graph.nx_graph()
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
|
||||
|
||||
# Depth-first search with pre-order traversal is a depth-first topological sort
|
||||
sorted_nodes = nx.dfs_preorder_nodes(g)
|
||||
|
||||
next_node = next(
|
||||
(
|
||||
n
|
||||
for n in sorted_nodes
|
||||
if n not in self.executed # the node must not already be executed...
|
||||
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if next_node is None:
|
||||
return None
|
||||
|
||||
|
@ -22,7 +22,8 @@ class Invoker:
|
||||
def invoke(
|
||||
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> str | None:
|
||||
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
|
@ -40,6 +40,7 @@ import invokeai.configs as configs
|
||||
from invokeai.app.services.config import (
|
||||
InvokeAIAppConfig,
|
||||
)
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.widgets import (
|
||||
CenteredButtonPress,
|
||||
@ -80,6 +81,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
|
||||
# or renaming it and then running invokeai-configure again.
|
||||
"""
|
||||
|
||||
logger=None
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
@ -824,6 +826,7 @@ def main():
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(['--precision','float32'])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
errors = set()
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
||||
|
||||
"""invokeai.util.logging
|
||||
"""
|
||||
invokeai.util.logging
|
||||
|
||||
Logging class for InvokeAI that produces console messages
|
||||
|
||||
@ -11,6 +12,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
|
||||
(or)
|
||||
logger = InvokeAILogger.getLogger(__name__) // To use the filename
|
||||
logger.configure()
|
||||
|
||||
logger.critical('this is critical') // Critical Message
|
||||
logger.error('this is an error') // Error Message
|
||||
@ -28,6 +30,149 @@ Console messages:
|
||||
Alternate Method (in this case the logger name will be set to InvokeAI):
|
||||
import invokeai.backend.util.logging as IAILogger
|
||||
IAILogger.debug('this is a debugging message')
|
||||
|
||||
## Configuration
|
||||
|
||||
The default configuration will print to stderr on the console. To add
|
||||
additional logging handlers, call getLogger with an initialized InvokeAIAppConfig
|
||||
object:
|
||||
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
config.parse_args()
|
||||
logger = InvokeAILogger.getLogger(config=config)
|
||||
|
||||
### Three command-line options control logging:
|
||||
|
||||
`--log_handlers <handler1> <handler2> ...`
|
||||
|
||||
This option activates one or more log handlers. Options are "console", "file", "syslog" and "http". To specify more than one, separate them by spaces:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers console syslog=/dev/log file=C:\\Users\\fred\\invokeai.log
|
||||
```
|
||||
|
||||
The format of these options is described below.
|
||||
|
||||
### `--log_format {plain|color|legacy|syslog}`
|
||||
|
||||
This controls the format of log messages written to the console. Only the "console" log handler is currently affected by this setting.
|
||||
|
||||
* "plain" provides formatted messages like this:
|
||||
|
||||
```bash
|
||||
|
||||
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
|
||||
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
|
||||
```
|
||||
|
||||
* "color" produces similar output, but the text will be color coded to indicate the severity of the message.
|
||||
|
||||
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
|
||||
|
||||
```
|
||||
### this is a critical error
|
||||
*** this is an error
|
||||
** this is a warning
|
||||
>> this is an informational messages
|
||||
| this is a debug message
|
||||
```
|
||||
|
||||
* "syslog" produces messages suitable for syslog entries:
|
||||
|
||||
```bash
|
||||
InvokeAI [2691178] <CRITICAL> this is a critical error
|
||||
InvokeAI [2691178] <ERROR> this is an error
|
||||
InvokeAI [2691178] <WARNING> this is a warning
|
||||
InvokeAI [2691178] <INFO> this is an informational messages
|
||||
InvokeAI [2691178] <DEBUG> this is a debug message
|
||||
```
|
||||
|
||||
(note that the date, time and hostname will be added by the syslog system)
|
||||
|
||||
### `--log_level {debug|info|warning|error|critical}`
|
||||
|
||||
Providing this command-line option will cause only messages at the specified level or above to be emitted.
|
||||
|
||||
## Console logging
|
||||
|
||||
When "console" is provided to `--log_handlers`, messages will be written to the command line window in which InvokeAI was launched. By default, the color formatter will be used unless overridden by `--log_format`.
|
||||
|
||||
## File logging
|
||||
|
||||
When "file" is provided to `--log_handlers`, entries will be written to the file indicated in the path argument. By default, the "plain" format will be used:
|
||||
|
||||
```bash
|
||||
invokeai-web --log_handlers file=/var/log/invokeai.log
|
||||
```
|
||||
|
||||
## Syslog logging
|
||||
|
||||
When "syslog" is requested, entries will be sent to the syslog system. There are a variety of ways to control where the log message is sent:
|
||||
|
||||
* Send to the local machine using the `/dev/log` socket:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=/dev/log
|
||||
```
|
||||
|
||||
* Send to the local machine using a UDP message:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=localhost
|
||||
```
|
||||
|
||||
* Send to the local machine using a UDP message on a nonstandard port:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=localhost:512
|
||||
```
|
||||
|
||||
* Send to a remote machine named "loghost" on the local LAN using facility LOG_USER and UDP packets:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
|
||||
```
|
||||
|
||||
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM are defaults.
|
||||
|
||||
* Send to a remote machine named "loghost" using the facility LOCAL0 and using a TCP socket:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
|
||||
```
|
||||
|
||||
If no arguments are specified (just a bare "syslog"), then the logging system will look for a UNIX socket named `/dev/log`, and if not found try to send a UDP message to `localhost`. The Macintosh OS used to support logging to a socket named `/var/run/syslog`, but this feature has since been disabled.
|
||||
|
||||
## Web logging
|
||||
|
||||
If you have access to a web server that is configured to log messages when a particular URL is requested, you can log using the "http" method:
|
||||
|
||||
```
|
||||
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
|
||||
```
|
||||
|
||||
The optional [,method=] part can be used to specify whether the URL accepts GET (default) or POST messages.
|
||||
|
||||
Currently password authentication and SSL are not supported.
|
||||
|
||||
## Using the configuration file
|
||||
|
||||
You can set and forget logging options by adding a "Logging" section to `invokeai.yaml`:
|
||||
|
||||
```
|
||||
InvokeAI:
|
||||
[... other settings...]
|
||||
Logging:
|
||||
log_handlers:
|
||||
- console
|
||||
- syslog=/dev/log
|
||||
log_level: info
|
||||
log_format: color
|
||||
```
|
||||
"""
|
||||
|
||||
import logging.handlers
|
||||
@ -180,14 +325,17 @@ class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
|
||||
@classmethod
|
||||
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
|
||||
config = get_invokeai_config()
|
||||
|
||||
if name not in cls.loggers:
|
||||
def getLogger(cls,
|
||||
name: str = 'InvokeAI',
|
||||
config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger:
|
||||
if name in cls.loggers:
|
||||
logger = cls.loggers[name]
|
||||
logger.handlers.clear()
|
||||
else:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||
for ch in cls.getLoggers(config):
|
||||
logger.addHandler(ch)
|
||||
logger.setLevel(config.log_level.upper()) # yes, strings work here
|
||||
for ch in cls.getLoggers(config):
|
||||
logger.addHandler(ch)
|
||||
cls.loggers[name] = logger
|
||||
return cls.loggers[name]
|
||||
|
||||
@ -199,9 +347,11 @@ class InvokeAILogger(object):
|
||||
handler_name,*args = handler.split('=',2)
|
||||
args = args[0] if len(args) > 0 else None
|
||||
|
||||
# console and file are the only handlers that gets a custom formatter
|
||||
# console and file get the fancy formatter.
|
||||
# syslog gets a simple one
|
||||
# http gets no custom formatter
|
||||
formatter = LOG_FORMATTERS[config.log_format]
|
||||
if handler_name=='console':
|
||||
formatter = LOG_FORMATTERS[config.log_format]
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(formatter())
|
||||
handlers.append(ch)
|
||||
@ -212,7 +362,7 @@ class InvokeAILogger(object):
|
||||
|
||||
elif handler_name=='file':
|
||||
ch = cls._parse_file_args(args)
|
||||
ch.setFormatter(InvokeAISyslogFormatter())
|
||||
ch.setFormatter(formatter())
|
||||
handlers.append(ch)
|
||||
|
||||
elif handler_name=='http':
|
||||
|
@ -28,7 +28,7 @@ import torch
|
||||
from npyscreen import widget
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from invokeai.backend.install.model_install_backend import (
|
||||
Dataset_path,
|
||||
@ -939,6 +939,7 @@ def main():
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(['--precision','float32'])
|
||||
config.parse_args(invoke_args)
|
||||
logger = InvokeAILogger().getLogger(config=config)
|
||||
|
||||
if not (config.root_dir / config.conf_path.parent).exists():
|
||||
logger.info(
|
||||
|
@ -22,6 +22,7 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
|
||||
import GlobalHotkeys from './GlobalHotkeys';
|
||||
import Toaster from './Toaster';
|
||||
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
|
||||
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
|
||||
|
||||
const DEFAULT_CONFIG = {};
|
||||
|
||||
@ -66,10 +67,17 @@ const App = ({
|
||||
setIsReady(true);
|
||||
}
|
||||
|
||||
if (isApplicationReady) {
|
||||
// TODO: This is a jank fix for canvas not filling the screen on first load
|
||||
setTimeout(() => {
|
||||
dispatch(requestCanvasRescale());
|
||||
}, 200);
|
||||
}
|
||||
|
||||
return () => {
|
||||
setIsReady && setIsReady(false);
|
||||
};
|
||||
}, [isApplicationReady, setIsReady]);
|
||||
}, [dispatch, isApplicationReady, setIsReady]);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
@ -40,11 +40,11 @@ const ImageDndContext = (props: ImageDndContextProps) => {
|
||||
);
|
||||
|
||||
const mouseSensor = useSensor(MouseSensor, {
|
||||
activationConstraint: { delay: 250, tolerance: 5 },
|
||||
activationConstraint: { delay: 150, tolerance: 5 },
|
||||
});
|
||||
|
||||
const touchSensor = useSensor(TouchSensor, {
|
||||
activationConstraint: { delay: 250, tolerance: 5 },
|
||||
activationConstraint: { delay: 150, tolerance: 5 },
|
||||
});
|
||||
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
|
||||
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay
|
||||
|
@ -2,7 +2,6 @@ import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
|
||||
import {
|
||||
Box,
|
||||
Flex,
|
||||
FlexProps,
|
||||
FormControl,
|
||||
FormControlProps,
|
||||
FormLabel,
|
||||
@ -16,6 +15,7 @@ import {
|
||||
} from '@chakra-ui/react';
|
||||
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
|
||||
import { useSelect } from 'downshift';
|
||||
import { isString } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
|
||||
import { memo, useMemo } from 'react';
|
||||
@ -23,15 +23,19 @@ import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
|
||||
|
||||
export type ItemTooltips = { [key: string]: string };
|
||||
|
||||
export type IAICustomSelectOption = {
|
||||
value: string;
|
||||
label: string;
|
||||
tooltip?: string;
|
||||
};
|
||||
|
||||
type IAICustomSelectProps = {
|
||||
label?: string;
|
||||
items: string[];
|
||||
itemTooltips?: ItemTooltips;
|
||||
selectedItem: string;
|
||||
setSelectedItem: (v: string | null | undefined) => void;
|
||||
value: string;
|
||||
data: IAICustomSelectOption[] | string[];
|
||||
onChange: (v: string) => void;
|
||||
withCheckIcon?: boolean;
|
||||
formControlProps?: FormControlProps;
|
||||
buttonProps?: FlexProps;
|
||||
tooltip?: string;
|
||||
tooltipProps?: Omit<TooltipProps, 'children'>;
|
||||
ellipsisPosition?: 'start' | 'end';
|
||||
@ -40,18 +44,33 @@ type IAICustomSelectProps = {
|
||||
const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
const {
|
||||
label,
|
||||
items,
|
||||
itemTooltips,
|
||||
setSelectedItem,
|
||||
selectedItem,
|
||||
withCheckIcon,
|
||||
formControlProps,
|
||||
tooltip,
|
||||
buttonProps,
|
||||
tooltipProps,
|
||||
ellipsisPosition = 'end',
|
||||
data,
|
||||
value,
|
||||
onChange,
|
||||
} = props;
|
||||
|
||||
const values = useMemo(() => {
|
||||
return data.map<IAICustomSelectOption>((v) => {
|
||||
if (isString(v)) {
|
||||
return { value: v, label: v };
|
||||
}
|
||||
return v;
|
||||
});
|
||||
}, [data]);
|
||||
|
||||
const stringValues = useMemo(() => {
|
||||
return values.map((v) => v.value);
|
||||
}, [values]);
|
||||
|
||||
const valueData = useMemo(() => {
|
||||
return values.find((v) => v.value === value);
|
||||
}, [values, value]);
|
||||
|
||||
const {
|
||||
isOpen,
|
||||
getToggleButtonProps,
|
||||
@ -60,10 +79,11 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
highlightedIndex,
|
||||
getItemProps,
|
||||
} = useSelect({
|
||||
items,
|
||||
selectedItem,
|
||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) =>
|
||||
setSelectedItem(newSelectedItem),
|
||||
items: stringValues,
|
||||
selectedItem: value,
|
||||
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
|
||||
newSelectedItem && onChange(newSelectedItem);
|
||||
},
|
||||
});
|
||||
|
||||
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
|
||||
@ -93,8 +113,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
)}
|
||||
<Tooltip label={tooltip} {...tooltipProps}>
|
||||
<Flex
|
||||
{...getToggleButtonProps({ ref: refs.setReference })}
|
||||
{...buttonProps}
|
||||
{...getToggleButtonProps({ ref: refs.reference })}
|
||||
sx={{
|
||||
alignItems: 'center',
|
||||
userSelect: 'none',
|
||||
@ -119,7 +138,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
direction: labelTextDirection,
|
||||
}}
|
||||
>
|
||||
{selectedItem}
|
||||
{valueData?.label}
|
||||
</Text>
|
||||
<ChevronUpIcon
|
||||
sx={{
|
||||
@ -135,7 +154,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
{isOpen && (
|
||||
<List
|
||||
as={Flex}
|
||||
ref={refs.setFloating}
|
||||
ref={refs.floating}
|
||||
sx={{
|
||||
...floatingStyles,
|
||||
top: 0,
|
||||
@ -155,8 +174,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
}}
|
||||
>
|
||||
<OverlayScrollbarsComponent>
|
||||
{items.map((item, index) => {
|
||||
const isSelected = selectedItem === item;
|
||||
{values.map((v, index) => {
|
||||
const isSelected = value === v.value;
|
||||
const isHighlighted = highlightedIndex === index;
|
||||
const fontWeight = isSelected ? 700 : 500;
|
||||
const bg = isHighlighted
|
||||
@ -166,9 +185,9 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
: undefined;
|
||||
return (
|
||||
<Tooltip
|
||||
isDisabled={!itemTooltips}
|
||||
key={`${item}${index}`}
|
||||
label={itemTooltips?.[item]}
|
||||
isDisabled={!v.tooltip}
|
||||
key={`${v.value}${index}`}
|
||||
label={v.tooltip}
|
||||
hasArrow
|
||||
placement="right"
|
||||
>
|
||||
@ -182,8 +201,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.15s',
|
||||
}}
|
||||
key={`${item}${index}`}
|
||||
{...getItemProps({ item, index })}
|
||||
{...getItemProps({ item: v.value, index })}
|
||||
>
|
||||
{withCheckIcon ? (
|
||||
<Grid gridTemplateColumns="1.25rem auto">
|
||||
@ -198,7 +216,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
fontWeight,
|
||||
}}
|
||||
>
|
||||
{item}
|
||||
{v.label}
|
||||
</Text>
|
||||
</GridItem>
|
||||
</Grid>
|
||||
@ -210,7 +228,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
|
||||
fontWeight,
|
||||
}}
|
||||
>
|
||||
{item}
|
||||
{v.label}
|
||||
</Text>
|
||||
)}
|
||||
</ListItem>
|
||||
|
@ -14,7 +14,7 @@ import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||
import { AnimatePresence } from 'framer-motion';
|
||||
import { ReactElement, SyntheticEvent, useCallback } from 'react';
|
||||
import { memo, useRef } from 'react';
|
||||
import { FaImage, FaTimes, FaUpload } from 'react-icons/fa';
|
||||
import { FaImage, FaTimes, FaUndo, FaUpload } from 'react-icons/fa';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import IAIDropOverlay from './IAIDropOverlay';
|
||||
@ -174,14 +174,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
right: 0,
|
||||
p: 2,
|
||||
}}
|
||||
>
|
||||
<IAIIconButton
|
||||
size={resetIconSize}
|
||||
tooltip="Reset Image"
|
||||
aria-label="Reset Image"
|
||||
icon={<FaTimes />}
|
||||
icon={<FaUndo />}
|
||||
onClick={onReset}
|
||||
/>
|
||||
</Box>
|
||||
|
@ -16,7 +16,6 @@ import {
|
||||
setShouldShowIntermediates,
|
||||
setShouldSnapToGrid,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
|
||||
import { isEqual } from 'lodash-es';
|
||||
|
||||
import { ChangeEvent } from 'react';
|
||||
@ -159,7 +158,6 @@ const IAICanvasSettingsButtonPopover = () => {
|
||||
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
||||
/>
|
||||
<ClearCanvasHistoryButtonModal />
|
||||
<EmptyTempFolderButtonModal />
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
);
|
||||
|
@ -30,7 +30,10 @@ import {
|
||||
} from './canvasTypes';
|
||||
import { ImageDTO } from 'services/api';
|
||||
import { sessionCanceled } from 'services/thunks/session';
|
||||
import { setShouldUseCanvasBetaLayout } from 'features/ui/store/uiSlice';
|
||||
import {
|
||||
setActiveTab,
|
||||
setShouldUseCanvasBetaLayout,
|
||||
} from 'features/ui/store/uiSlice';
|
||||
import { imageUrlsReceived } from 'services/thunks/image';
|
||||
|
||||
export const initialLayerState: CanvasLayerState = {
|
||||
@ -857,6 +860,11 @@ export const canvasSlice = createSlice({
|
||||
builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => {
|
||||
state.doesCanvasNeedScaling = true;
|
||||
});
|
||||
|
||||
builder.addCase(setActiveTab, (state, action) => {
|
||||
state.doesCanvasNeedScaling = true;
|
||||
});
|
||||
|
||||
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
|
||||
const { image_name, image_origin, image_url, thumbnail_url } =
|
||||
action.payload;
|
||||
|
@ -1,17 +1,26 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAICustomSelect from 'common/components/IAICustomSelect';
|
||||
import IAICustomSelect, {
|
||||
IAICustomSelectOption,
|
||||
} from 'common/components/IAICustomSelect';
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
ControlNetModel,
|
||||
ControlNetModelName,
|
||||
} from 'features/controlNet/store/constants';
|
||||
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
|
||||
import { map } from 'lodash-es';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
type ParamControlNetModelProps = {
|
||||
controlNetId: string;
|
||||
model: ControlNetModel;
|
||||
model: ControlNetModelName;
|
||||
};
|
||||
|
||||
const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
|
||||
value: m.type,
|
||||
label: m.label,
|
||||
tooltip: m.type,
|
||||
}));
|
||||
|
||||
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const { controlNetId, model } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
@ -19,7 +28,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
const handleModelChanged = useCallback(
|
||||
(val: string | null | undefined) => {
|
||||
// TODO: do not cast
|
||||
const model = val as ControlNetModel;
|
||||
const model = val as ControlNetModelName;
|
||||
dispatch(controlNetModelChanged({ controlNetId, model }));
|
||||
},
|
||||
[controlNetId, dispatch]
|
||||
@ -29,9 +38,9 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
|
||||
<IAICustomSelect
|
||||
tooltip={model}
|
||||
tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
items={CONTROLNET_MODELS}
|
||||
selectedItem={model}
|
||||
setSelectedItem={handleModelChanged}
|
||||
data={DATA}
|
||||
value={model}
|
||||
onChange={handleModelChanged}
|
||||
ellipsisPosition="start"
|
||||
withCheckIcon
|
||||
/>
|
||||
|
@ -1,4 +1,6 @@
|
||||
import IAICustomSelect from 'common/components/IAICustomSelect';
|
||||
import IAICustomSelect, {
|
||||
IAICustomSelectOption,
|
||||
} from 'common/components/IAICustomSelect';
|
||||
import { memo, useCallback } from 'react';
|
||||
import {
|
||||
ControlNetProcessorNode,
|
||||
@ -7,15 +9,28 @@ import {
|
||||
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { CONTROLNET_PROCESSORS } from '../../store/constants';
|
||||
import { map } from 'lodash-es';
|
||||
|
||||
type ParamControlNetProcessorSelectProps = {
|
||||
controlNetId: string;
|
||||
processorNode: ControlNetProcessorNode;
|
||||
};
|
||||
|
||||
const CONTROLNET_PROCESSOR_TYPES = Object.keys(
|
||||
CONTROLNET_PROCESSORS
|
||||
) as ControlNetProcessorType[];
|
||||
const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
|
||||
CONTROLNET_PROCESSORS,
|
||||
(p) => ({
|
||||
value: p.type,
|
||||
label: p.label,
|
||||
tooltip: p.description,
|
||||
})
|
||||
).sort((a, b) =>
|
||||
// sort 'none' to the top
|
||||
a.value === 'none'
|
||||
? -1
|
||||
: b.value === 'none'
|
||||
? 1
|
||||
: a.label.localeCompare(b.label)
|
||||
);
|
||||
|
||||
const ParamControlNetProcessorSelect = (
|
||||
props: ParamControlNetProcessorSelectProps
|
||||
@ -36,9 +51,9 @@ const ParamControlNetProcessorSelect = (
|
||||
return (
|
||||
<IAICustomSelect
|
||||
label="Processor"
|
||||
items={CONTROLNET_PROCESSOR_TYPES}
|
||||
selectedItem={processorNode.type ?? 'canny_image_processor'}
|
||||
setSelectedItem={handleProcessorTypeChanged}
|
||||
value={processorNode.type ?? 'canny_image_processor'}
|
||||
data={CONTROLNET_PROCESSOR_TYPES}
|
||||
onChange={handleProcessorTypeChanged}
|
||||
withCheckIcon
|
||||
/>
|
||||
);
|
||||
|
@ -5,12 +5,12 @@ import {
|
||||
} from './types';
|
||||
|
||||
type ControlNetProcessorsDict = Record<
|
||||
ControlNetProcessorType,
|
||||
string,
|
||||
{
|
||||
type: ControlNetProcessorType;
|
||||
type: ControlNetProcessorType | 'none';
|
||||
label: string;
|
||||
description: string;
|
||||
default: RequiredControlNetProcessorNode;
|
||||
default: RequiredControlNetProcessorNode | { type: 'none' };
|
||||
}
|
||||
>;
|
||||
|
||||
@ -23,10 +23,10 @@ type ControlNetProcessorsDict = Record<
|
||||
*
|
||||
* TODO: Generate from the OpenAPI schema
|
||||
*/
|
||||
export const CONTROLNET_PROCESSORS = {
|
||||
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
none: {
|
||||
type: 'none',
|
||||
label: 'None',
|
||||
label: 'none',
|
||||
description: '',
|
||||
default: {
|
||||
type: 'none',
|
||||
@ -116,7 +116,7 @@ export const CONTROLNET_PROCESSORS = {
|
||||
},
|
||||
mlsd_image_processor: {
|
||||
type: 'mlsd_image_processor',
|
||||
label: 'MLSD',
|
||||
label: 'M-LSD',
|
||||
description: '',
|
||||
default: {
|
||||
id: 'mlsd_image_processor',
|
||||
@ -174,39 +174,84 @@ export const CONTROLNET_PROCESSORS = {
|
||||
},
|
||||
};
|
||||
|
||||
export const CONTROLNET_MODELS = [
|
||||
'lllyasviel/control_v11p_sd15_canny',
|
||||
'lllyasviel/control_v11p_sd15_inpaint',
|
||||
'lllyasviel/control_v11p_sd15_mlsd',
|
||||
'lllyasviel/control_v11f1p_sd15_depth',
|
||||
'lllyasviel/control_v11p_sd15_normalbae',
|
||||
'lllyasviel/control_v11p_sd15_seg',
|
||||
'lllyasviel/control_v11p_sd15_lineart',
|
||||
'lllyasviel/control_v11p_sd15s2_lineart_anime',
|
||||
'lllyasviel/control_v11p_sd15_scribble',
|
||||
'lllyasviel/control_v11p_sd15_softedge',
|
||||
'lllyasviel/control_v11e_sd15_shuffle',
|
||||
'lllyasviel/control_v11p_sd15_openpose',
|
||||
'lllyasviel/control_v11f1e_sd15_tile',
|
||||
'lllyasviel/control_v11e_sd15_ip2p',
|
||||
'CrucibleAI/ControlNetMediaPipeFace',
|
||||
];
|
||||
|
||||
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
|
||||
|
||||
export const CONTROLNET_MODEL_MAP: Record<
|
||||
ControlNetModel,
|
||||
ControlNetProcessorType
|
||||
> = {
|
||||
'lllyasviel/control_v11p_sd15_canny': 'canny_image_processor',
|
||||
'lllyasviel/control_v11p_sd15_mlsd': 'mlsd_image_processor',
|
||||
'lllyasviel/control_v11f1p_sd15_depth': 'midas_depth_image_processor',
|
||||
'lllyasviel/control_v11p_sd15_normalbae': 'normalbae_image_processor',
|
||||
'lllyasviel/control_v11p_sd15_lineart': 'lineart_image_processor',
|
||||
'lllyasviel/control_v11p_sd15s2_lineart_anime':
|
||||
'lineart_anime_image_processor',
|
||||
'lllyasviel/control_v11p_sd15_softedge': 'hed_image_processor',
|
||||
'lllyasviel/control_v11e_sd15_shuffle': 'content_shuffle_image_processor',
|
||||
'lllyasviel/control_v11p_sd15_openpose': 'openpose_image_processor',
|
||||
'CrucibleAI/ControlNetMediaPipeFace': 'mediapipe_face_processor',
|
||||
type ControlNetModel = {
|
||||
type: string;
|
||||
label: string;
|
||||
description?: string;
|
||||
defaultProcessor?: ControlNetProcessorType;
|
||||
};
|
||||
|
||||
export const CONTROLNET_MODELS: Record<string, ControlNetModel> = {
|
||||
'lllyasviel/control_v11p_sd15_canny': {
|
||||
type: 'lllyasviel/control_v11p_sd15_canny',
|
||||
label: 'Canny',
|
||||
defaultProcessor: 'canny_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_inpaint': {
|
||||
type: 'lllyasviel/control_v11p_sd15_inpaint',
|
||||
label: 'Inpaint',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_mlsd': {
|
||||
type: 'lllyasviel/control_v11p_sd15_mlsd',
|
||||
label: 'M-LSD',
|
||||
defaultProcessor: 'mlsd_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11f1p_sd15_depth': {
|
||||
type: 'lllyasviel/control_v11f1p_sd15_depth',
|
||||
label: 'Depth',
|
||||
defaultProcessor: 'midas_depth_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_normalbae': {
|
||||
type: 'lllyasviel/control_v11p_sd15_normalbae',
|
||||
label: 'Normal Map (BAE)',
|
||||
defaultProcessor: 'normalbae_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_seg': {
|
||||
type: 'lllyasviel/control_v11p_sd15_seg',
|
||||
label: 'Segment Anything',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_lineart': {
|
||||
type: 'lllyasviel/control_v11p_sd15_lineart',
|
||||
label: 'Lineart',
|
||||
defaultProcessor: 'lineart_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
|
||||
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
|
||||
label: 'Lineart Anime',
|
||||
defaultProcessor: 'lineart_anime_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_scribble': {
|
||||
type: 'lllyasviel/control_v11p_sd15_scribble',
|
||||
label: 'Scribble',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_softedge': {
|
||||
type: 'lllyasviel/control_v11p_sd15_softedge',
|
||||
label: 'Soft Edge',
|
||||
defaultProcessor: 'hed_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11e_sd15_shuffle': {
|
||||
type: 'lllyasviel/control_v11e_sd15_shuffle',
|
||||
label: 'Content Shuffle',
|
||||
defaultProcessor: 'content_shuffle_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11p_sd15_openpose': {
|
||||
type: 'lllyasviel/control_v11p_sd15_openpose',
|
||||
label: 'Openpose',
|
||||
defaultProcessor: 'openpose_image_processor',
|
||||
},
|
||||
'lllyasviel/control_v11f1e_sd15_tile': {
|
||||
type: 'lllyasviel/control_v11f1e_sd15_tile',
|
||||
label: 'Tile (experimental)',
|
||||
},
|
||||
'lllyasviel/control_v11e_sd15_ip2p': {
|
||||
type: 'lllyasviel/control_v11e_sd15_ip2p',
|
||||
label: 'Pix2Pix (experimental)',
|
||||
},
|
||||
'CrucibleAI/ControlNetMediaPipeFace': {
|
||||
type: 'CrucibleAI/ControlNetMediaPipeFace',
|
||||
label: 'Mediapipe Face',
|
||||
defaultProcessor: 'mediapipe_face_processor',
|
||||
},
|
||||
};
|
||||
|
||||
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;
|
||||
|
@ -9,9 +9,8 @@ import {
|
||||
} from './types';
|
||||
import {
|
||||
CONTROLNET_MODELS,
|
||||
CONTROLNET_MODEL_MAP,
|
||||
CONTROLNET_PROCESSORS,
|
||||
ControlNetModel,
|
||||
ControlNetModelName,
|
||||
} from './constants';
|
||||
import { controlNetImageProcessed } from './actions';
|
||||
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
|
||||
@ -21,7 +20,7 @@ import { appSocketInvocationError } from 'services/events/actions';
|
||||
|
||||
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
isEnabled: true,
|
||||
model: CONTROLNET_MODELS[0],
|
||||
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
@ -36,7 +35,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
|
||||
export type ControlNetConfig = {
|
||||
controlNetId: string;
|
||||
isEnabled: boolean;
|
||||
model: ControlNetModel;
|
||||
model: ControlNetModelName;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
@ -138,14 +137,17 @@ export const controlNetSlice = createSlice({
|
||||
},
|
||||
controlNetModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ controlNetId: string; model: ControlNetModel }>
|
||||
action: PayloadAction<{
|
||||
controlNetId: string;
|
||||
model: ControlNetModelName;
|
||||
}>
|
||||
) => {
|
||||
const { controlNetId, model } = action.payload;
|
||||
state.controlNets[controlNetId].model = model;
|
||||
state.controlNets[controlNetId].processedControlImage = null;
|
||||
|
||||
if (state.controlNets[controlNetId].shouldAutoConfig) {
|
||||
const processorType = CONTROLNET_MODEL_MAP[model];
|
||||
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
|
||||
if (processorType) {
|
||||
state.controlNets[controlNetId].processorType = processorType;
|
||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||
@ -225,7 +227,8 @@ export const controlNetSlice = createSlice({
|
||||
if (newShouldAutoConfig) {
|
||||
// manage the processor for the user
|
||||
const processorType =
|
||||
CONTROLNET_MODEL_MAP[state.controlNets[controlNetId].model];
|
||||
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
|
||||
.defaultProcessor;
|
||||
if (processorType) {
|
||||
state.controlNets[controlNetId].processorType = processorType;
|
||||
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
|
||||
|
@ -342,7 +342,10 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
|
||||
});
|
||||
}
|
||||
|
||||
if (shouldFitToWidthHeight) {
|
||||
if (
|
||||
shouldFitToWidthHeight &&
|
||||
(initialImage.width !== width || initialImage.height !== height)
|
||||
) {
|
||||
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
|
||||
|
||||
// Create a resize node, explicitly setting its image
|
||||
|
@ -14,9 +14,11 @@ const selector = createSelector(
|
||||
(ui, generation) => {
|
||||
// TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
|
||||
// but we need to wait for the next release before removing this special handling.
|
||||
const allSchedulers = ui.schedulers.filter((scheduler) => {
|
||||
return !['dpmpp_2s'].includes(scheduler);
|
||||
});
|
||||
const allSchedulers = ui.schedulers
|
||||
.filter((scheduler) => {
|
||||
return !['dpmpp_2s'].includes(scheduler);
|
||||
})
|
||||
.sort((a, b) => a.localeCompare(b));
|
||||
|
||||
return {
|
||||
scheduler: generation.scheduler,
|
||||
@ -45,9 +47,9 @@ const ParamScheduler = () => {
|
||||
return (
|
||||
<IAICustomSelect
|
||||
label={t('parameters.scheduler')}
|
||||
selectedItem={scheduler}
|
||||
setSelectedItem={handleChange}
|
||||
items={allSchedulers}
|
||||
value={scheduler}
|
||||
data={allSchedulers}
|
||||
onChange={handleChange}
|
||||
withCheckIcon
|
||||
/>
|
||||
);
|
||||
|
@ -58,6 +58,7 @@ const InitialImagePreview = () => {
|
||||
onReset={handleReset}
|
||||
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
|
||||
postUploadAction={{ type: 'SET_INITIAL_IMAGE' }}
|
||||
withResetIcon
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -1,41 +0,0 @@
|
||||
// import { emptyTempFolder } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIAlertDialog from 'common/components/IAIAlertDialog';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import {
|
||||
clearCanvasHistory,
|
||||
resetCanvas,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTrash } from 'react-icons/fa';
|
||||
|
||||
const EmptyTempFolderButtonModal = () => {
|
||||
const isStaging = useAppSelector(isStagingSelector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const acceptCallback = () => {
|
||||
dispatch(emptyTempFolder());
|
||||
dispatch(resetCanvas());
|
||||
dispatch(clearCanvasHistory());
|
||||
};
|
||||
|
||||
return (
|
||||
<IAIAlertDialog
|
||||
title={t('unifiedCanvas.emptyTempImageFolder')}
|
||||
acceptCallback={acceptCallback}
|
||||
acceptButtonText={t('unifiedCanvas.emptyFolder')}
|
||||
triggerComponent={
|
||||
<IAIButton leftIcon={<FaTrash />} size="sm" isDisabled={isStaging}>
|
||||
{t('unifiedCanvas.emptyTempImageFolder')}
|
||||
</IAIButton>
|
||||
}
|
||||
>
|
||||
<p>{t('unifiedCanvas.emptyTempImagesFolderMessage')}</p>
|
||||
<br />
|
||||
<p>{t('unifiedCanvas.emptyTempImagesFolderConfirm')}</p>
|
||||
</IAIAlertDialog>
|
||||
);
|
||||
};
|
||||
export default EmptyTempFolderButtonModal;
|
@ -4,34 +4,29 @@ import { isEqual } from 'lodash-es';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectModelsAll,
|
||||
selectModelsById,
|
||||
selectModelsIds,
|
||||
} from '../store/modelSlice';
|
||||
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
|
||||
import { RootState } from 'app/store/store';
|
||||
import { modelSelected } from 'features/parameters/store/generationSlice';
|
||||
import { generationSelector } from 'features/parameters/store/generationSelectors';
|
||||
import IAICustomSelect, {
|
||||
ItemTooltips,
|
||||
IAICustomSelectOption,
|
||||
} from 'common/components/IAICustomSelect';
|
||||
|
||||
const selector = createSelector(
|
||||
[(state: RootState) => state, generationSelector],
|
||||
(state, generation) => {
|
||||
const selectedModel = selectModelsById(state, generation.model);
|
||||
const allModelNames = selectModelsIds(state).map((id) => String(id));
|
||||
const allModelTooltips = selectModelsAll(state).reduce(
|
||||
(allModelTooltips, model) => {
|
||||
allModelTooltips[model.name] = model.description ?? '';
|
||||
return allModelTooltips;
|
||||
},
|
||||
{} as ItemTooltips
|
||||
);
|
||||
|
||||
const modelData = selectModelsAll(state)
|
||||
.map<IAICustomSelectOption>((m) => ({
|
||||
value: m.name,
|
||||
label: m.name,
|
||||
tooltip: m.description,
|
||||
}))
|
||||
.sort((a, b) => a.label.localeCompare(b.label));
|
||||
return {
|
||||
allModelNames,
|
||||
allModelTooltips,
|
||||
selectedModel,
|
||||
modelData,
|
||||
};
|
||||
},
|
||||
{
|
||||
@ -44,8 +39,7 @@ const selector = createSelector(
|
||||
const ModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const { allModelNames, allModelTooltips, selectedModel } =
|
||||
useAppSelector(selector);
|
||||
const { selectedModel, modelData } = useAppSelector(selector);
|
||||
const handleChangeModel = useCallback(
|
||||
(v: string | null | undefined) => {
|
||||
if (!v) {
|
||||
@ -60,10 +54,9 @@ const ModelSelect = () => {
|
||||
<IAICustomSelect
|
||||
label={t('modelManager.model')}
|
||||
tooltip={selectedModel?.description}
|
||||
items={allModelNames}
|
||||
itemTooltips={allModelTooltips}
|
||||
selectedItem={selectedModel?.name ?? ''}
|
||||
setSelectedItem={handleChangeModel}
|
||||
data={modelData}
|
||||
value={selectedModel?.name ?? ''}
|
||||
onChange={handleChangeModel}
|
||||
withCheckIcon={true}
|
||||
tooltipProps={{ placement: 'top', hasArrow: true }}
|
||||
/>
|
||||
|
@ -14,7 +14,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
|
||||
import { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
|
||||
import { memo, ReactNode, useCallback, useMemo } from 'react';
|
||||
import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { MdDeviceHub, MdGridOn } from 'react-icons/md';
|
||||
import { GoTextSize } from 'react-icons/go';
|
||||
@ -47,22 +47,22 @@ export interface InvokeTabInfo {
|
||||
const tabs: InvokeTabInfo[] = [
|
||||
{
|
||||
id: 'txt2img',
|
||||
icon: <Icon as={GoTextSize} sx={{ boxSize: 6 }} />,
|
||||
icon: <Icon as={GoTextSize} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||
content: <TextToImageTab />,
|
||||
},
|
||||
{
|
||||
id: 'img2img',
|
||||
icon: <Icon as={FaImage} sx={{ boxSize: 6 }} />,
|
||||
icon: <Icon as={FaImage} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||
content: <ImageTab />,
|
||||
},
|
||||
{
|
||||
id: 'unifiedCanvas',
|
||||
icon: <Icon as={MdGridOn} sx={{ boxSize: 6 }} />,
|
||||
icon: <Icon as={MdGridOn} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||
content: <UnifiedCanvasTab />,
|
||||
},
|
||||
{
|
||||
id: 'nodes',
|
||||
icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6 }} />,
|
||||
icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
|
||||
content: <NodesTab />,
|
||||
},
|
||||
];
|
||||
@ -119,6 +119,12 @@ const InvokeTabs = () => {
|
||||
}
|
||||
}, [dispatch, activeTabName]);
|
||||
|
||||
const handleClickTab = useCallback((e: MouseEvent<HTMLElement>) => {
|
||||
if (e.target instanceof HTMLElement) {
|
||||
e.target.blur();
|
||||
}
|
||||
}, []);
|
||||
|
||||
const tabs = useMemo(
|
||||
() =>
|
||||
enabledTabs.map((tab) => (
|
||||
@ -128,7 +134,7 @@ const InvokeTabs = () => {
|
||||
label={String(t(`common.${tab.id}` as ResourceKey))}
|
||||
placement="end"
|
||||
>
|
||||
<Tab>
|
||||
<Tab onClick={handleClickTab}>
|
||||
<VisuallyHidden>
|
||||
{String(t(`common.${tab.id}` as ResourceKey))}
|
||||
</VisuallyHidden>
|
||||
@ -136,7 +142,7 @@ const InvokeTabs = () => {
|
||||
</Tab>
|
||||
</Tooltip>
|
||||
)),
|
||||
[t, enabledTabs]
|
||||
[enabledTabs, t, handleClickTab]
|
||||
);
|
||||
|
||||
const tabPanels = useMemo(
|
||||
|
@ -12,7 +12,6 @@ import {
|
||||
setShouldShowCanvasDebugInfo,
|
||||
setShouldShowIntermediates,
|
||||
} from 'features/canvas/store/canvasSlice';
|
||||
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
|
||||
|
||||
import { FaWrench } from 'react-icons/fa';
|
||||
|
||||
@ -105,7 +104,6 @@ const UnifiedCanvasSettings = () => {
|
||||
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
|
||||
/>
|
||||
<ClearCanvasHistoryButtonModal />
|
||||
<EmptyTempFolderButtonModal />
|
||||
</Flex>
|
||||
</IAIPopover>
|
||||
);
|
||||
|
@ -55,8 +55,6 @@ const UnifiedCanvasContent = () => {
|
||||
});
|
||||
|
||||
useLayoutEffect(() => {
|
||||
dispatch(requestCanvasRescale());
|
||||
|
||||
const resizeCallback = () => {
|
||||
dispatch(requestCanvasRescale());
|
||||
};
|
||||
|
@ -121,3 +121,78 @@ def test_graph_state_collects(mock_services):
|
||||
assert isinstance(n6[0], CollectInvocation)
|
||||
|
||||
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
|
||||
|
||||
|
||||
def test_graph_state_prepares_eagerly(mock_services):
|
||||
"""Tests that all prepareable nodes are prepared"""
|
||||
graph = Graph()
|
||||
|
||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
||||
graph.add_node(IterateInvocation(id="iterate"))
|
||||
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
||||
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
||||
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
||||
|
||||
# separated, fully-preparable chain of nodes
|
||||
graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi"))
|
||||
graph.add_node(PromptTestInvocation(id="prompt_chain_2"))
|
||||
graph.add_node(PromptTestInvocation(id="prompt_chain_3"))
|
||||
graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt"))
|
||||
graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
g.next()
|
||||
|
||||
assert "prompt_collection" in g.source_prepared_mapping
|
||||
assert "prompt_chain_1" in g.source_prepared_mapping
|
||||
assert "prompt_chain_2" in g.source_prepared_mapping
|
||||
assert "prompt_chain_3" in g.source_prepared_mapping
|
||||
assert "iterate" not in g.source_prepared_mapping
|
||||
assert "prompt_iterated" not in g.source_prepared_mapping
|
||||
|
||||
|
||||
def test_graph_executes_depth_first(mock_services):
|
||||
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
|
||||
graph = Graph()
|
||||
|
||||
test_prompts = ["Banana sushi", "Cat sushi"]
|
||||
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
|
||||
graph.add_node(IterateInvocation(id="iterate"))
|
||||
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
|
||||
graph.add_node(PromptTestInvocation(id="prompt_successor"))
|
||||
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
|
||||
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
|
||||
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
|
||||
|
||||
g = GraphExecutionState(graph=graph)
|
||||
n1 = invoke_next(g, mock_services)
|
||||
n2 = invoke_next(g, mock_services)
|
||||
n3 = invoke_next(g, mock_services)
|
||||
n4 = invoke_next(g, mock_services)
|
||||
|
||||
# Because ordering is not guaranteed, we cannot compare results directly.
|
||||
# Instead, we must count the number of results.
|
||||
def get_completed_count(g, id):
|
||||
ids = [i for i in g.source_prepared_mapping[id]]
|
||||
completed_ids = [i for i in g.executed if i in ids]
|
||||
return len(completed_ids)
|
||||
|
||||
# Check at each step that the number of executed nodes matches the expectation for depth-first execution
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 0
|
||||
|
||||
n5 = invoke_next(g, mock_services)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 1
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
n6 = invoke_next(g, mock_services)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 1
|
||||
|
||||
n7 = invoke_next(g, mock_services)
|
||||
|
||||
assert get_completed_count(g, "prompt_iterated") == 2
|
||||
assert get_completed_count(g, "prompt_successor") == 2
|
||||
|
Loading…
Reference in New Issue
Block a user