resolve conflicts

This commit is contained in:
Lincoln Stein 2023-06-10 10:13:54 -04:00
commit 3d2ff7755e
26 changed files with 535 additions and 217 deletions

View File

@ -14,13 +14,12 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pathlib import Path
from pydantic.schema import schema
# Do this early so that other modules pick up configuration
#This should come early so that modules can log their initialization properly
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger()
logger = InvokeAILogger.getLogger(config=app_config)
import invokeai.frontend.web as web_dir

View File

@ -11,19 +11,26 @@ from typing import Union, get_type_hints
from pydantic import BaseModel, ValidationError
from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService
from invokeai.app.services.metadata import CoreMetadataService
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import (default_text_to_image_graph_id,
create_system_graphs)
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .cli.commands import (BaseCommand, CliContext, ExitCli,
SortedHelpFormatter, add_graph_parsers, add_parsers)
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.default_graphs import (create_system_graphs,
default_text_to_image_graph_id)
from .services.events import EventServiceBase
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
GraphInvocation, LibraryGraph,
@ -32,13 +39,11 @@ from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
from .services.invoker import Invoker
from .services.latent_storage import (DiskLatentsStorage,
ForwardCacheLatentsStorage)
from .services.model_manager_service import ModelManagerService
from .services.processor import DefaultInvocationProcessor
from .services.restoration_services import RestorationServices
from .services.sqlite import SqliteItemStorage
from .services.config import InvokeAIAppConfig
class CliCommand(BaseModel):
command: Union[BaseCommand.get_commands() + BaseInvocation.get_invocations()] = Field(discriminator="type") # type: ignore
@ -47,7 +52,6 @@ class CliCommand(BaseModel):
class InvalidArgs(Exception):
pass
def add_invocation_args(command_parser):
# Add linking capability
command_parser.add_argument(
@ -192,11 +196,6 @@ def invoke_all(context: CliContext):
raise SessionError()
def invoke_cli():
# this gets the basic configuration
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.getLogger()
# get the optional list of invocations to execute on the command line
parser = config.get_parser()
parser.add_argument('commands',nargs='*')

View File

@ -859,11 +859,9 @@ class GraphExecutionState(BaseModel):
if next_node is None:
prepared_id = self._prepare()
# TODO: prepare multiple nodes at once?
# while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation):
# prepared_id = self._prepare()
if prepared_id is not None:
# Prepare as many nodes as we can
while prepared_id is not None:
prepared_id = self._prepare()
next_node = self._get_next_node()
# Get values from edges
@ -1010,14 +1008,30 @@ class GraphExecutionState(BaseModel):
# Get flattened source graph
g = self.graph.nx_graph_flat()
# Find next unprepared node where all source nodes are executed
# Find next node that:
# - was not already prepared
# - is not an iterate node whose inputs have not been executed
# - does not have an unexecuted iterate ancestor
sorted_nodes = nx.topological_sort(g)
next_node_id = next(
(
n
for n in sorted_nodes
# exclude nodes that have already been prepared
if n not in self.source_prepared_mapping
and all((e[0] in self.executed for e in g.in_edges(n)))
# exclude iterate nodes whose inputs have not been executed
and not (
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
)
# exclude nodes who have unexecuted iterate ancestors
and not any(
(
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
and a not in self.executed # ...that is not executed
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
)
)
),
None,
)
@ -1114,9 +1128,22 @@ class GraphExecutionState(BaseModel):
)
def _get_next_node(self) -> Optional[BaseInvocation]:
"""Gets the deepest node that is ready to be executed"""
g = self.execution_graph.nx_graph()
sorted_nodes = nx.topological_sort(g)
next_node = next((n for n in sorted_nodes if n not in self.executed), None)
# Depth-first search with pre-order traversal is a depth-first topological sort
sorted_nodes = nx.dfs_preorder_nodes(g)
next_node = next(
(
n
for n in sorted_nodes
if n not in self.executed # the node must not already be executed...
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
),
None,
)
if next_node is None:
return None

View File

@ -22,7 +22,8 @@ class Invoker:
def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> str | None:
"""Determines the next node to invoke and returns the id of the invoked node, or None if there are no nodes to execute"""
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation
invocation = graph_execution_state.next()

View File

@ -40,6 +40,7 @@ import invokeai.configs as configs
from invokeai.app.services.config import (
InvokeAIAppConfig,
)
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
from invokeai.frontend.install.widgets import (
CenteredButtonPress,
@ -80,6 +81,7 @@ INIT_FILE_PREAMBLE = """# InvokeAI initialization file
# or renaming it and then running invokeai-configure again.
"""
logger=None
# --------------------------------------------
def postscript(errors: None):
@ -824,6 +826,7 @@ def main():
if opt.full_precision:
invoke_args.extend(['--precision','float32'])
config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
errors = set()

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""invokeai.util.logging
"""
invokeai.util.logging
Logging class for InvokeAI that produces console messages
@ -11,6 +12,7 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.getLogger(name='InvokeAI') // Initialization
(or)
logger = InvokeAILogger.getLogger(__name__) // To use the filename
logger.configure()
logger.critical('this is critical') // Critical Message
logger.error('this is an error') // Error Message
@ -28,6 +30,149 @@ Console messages:
Alternate Method (in this case the logger name will be set to InvokeAI):
import invokeai.backend.util.logging as IAILogger
IAILogger.debug('this is a debugging message')
## Configuration
The default configuration will print to stderr on the console. To add
additional logging handlers, call getLogger with an initialized InvokeAIAppConfig
object:
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger.getLogger(config=config)
### Three command-line options control logging:
`--log_handlers <handler1> <handler2> ...`
This option activates one or more log handlers. Options are "console", "file", "syslog" and "http". To specify more than one, separate them by spaces:
```
invokeai-web --log_handlers console syslog=/dev/log file=C:\\Users\\fred\\invokeai.log
```
The format of these options is described below.
### `--log_format {plain|color|legacy|syslog}`
This controls the format of log messages written to the console. Only the "console" log handler is currently affected by this setting.
* "plain" provides formatted messages like this:
```bash
[2023-05-24 23:18:2[2023-05-24 23:18:50,352]::[InvokeAI]::DEBUG --> this is a debug message
[2023-05-24 23:18:50,352]::[InvokeAI]::INFO --> this is an informational messages
[2023-05-24 23:18:50,352]::[InvokeAI]::WARNING --> this is a warning
[2023-05-24 23:18:50,352]::[InvokeAI]::ERROR --> this is an error
[2023-05-24 23:18:50,352]::[InvokeAI]::CRITICAL --> this is a critical error
```
* "color" produces similar output, but the text will be color coded to indicate the severity of the message.
* "legacy" produces output similar to InvokeAI versions 2.3 and earlier:
```
### this is a critical error
*** this is an error
** this is a warning
>> this is an informational messages
| this is a debug message
```
* "syslog" produces messages suitable for syslog entries:
```bash
InvokeAI [2691178] <CRITICAL> this is a critical error
InvokeAI [2691178] <ERROR> this is an error
InvokeAI [2691178] <WARNING> this is a warning
InvokeAI [2691178] <INFO> this is an informational messages
InvokeAI [2691178] <DEBUG> this is a debug message
```
(note that the date, time and hostname will be added by the syslog system)
### `--log_level {debug|info|warning|error|critical}`
Providing this command-line option will cause only messages at the specified level or above to be emitted.
## Console logging
When "console" is provided to `--log_handlers`, messages will be written to the command line window in which InvokeAI was launched. By default, the color formatter will be used unless overridden by `--log_format`.
## File logging
When "file" is provided to `--log_handlers`, entries will be written to the file indicated in the path argument. By default, the "plain" format will be used:
```bash
invokeai-web --log_handlers file=/var/log/invokeai.log
```
## Syslog logging
When "syslog" is requested, entries will be sent to the syslog system. There are a variety of ways to control where the log message is sent:
* Send to the local machine using the `/dev/log` socket:
```
invokeai-web --log_handlers syslog=/dev/log
```
* Send to the local machine using a UDP message:
```
invokeai-web --log_handlers syslog=localhost
```
* Send to the local machine using a UDP message on a nonstandard port:
```
invokeai-web --log_handlers syslog=localhost:512
```
* Send to a remote machine named "loghost" on the local LAN using facility LOG_USER and UDP packets:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_USER,socktype=SOCK_DGRAM
```
This can be abbreviated `syslog=loghost`, as LOG_USER and SOCK_DGRAM are defaults.
* Send to a remote machine named "loghost" using the facility LOCAL0 and using a TCP socket:
```
invokeai-web --log_handlers syslog=loghost,facility=LOG_LOCAL0,socktype=SOCK_STREAM
```
If no arguments are specified (just a bare "syslog"), then the logging system will look for a UNIX socket named `/dev/log`, and if not found try to send a UDP message to `localhost`. The Macintosh OS used to support logging to a socket named `/var/run/syslog`, but this feature has since been disabled.
## Web logging
If you have access to a web server that is configured to log messages when a particular URL is requested, you can log using the "http" method:
```
invokeai-web --log_handlers http=http://my.server/path/to/logger,method=POST
```
The optional [,method=] part can be used to specify whether the URL accepts GET (default) or POST messages.
Currently password authentication and SSL are not supported.
## Using the configuration file
You can set and forget logging options by adding a "Logging" section to `invokeai.yaml`:
```
InvokeAI:
[... other settings...]
Logging:
log_handlers:
- console
- syslog=/dev/log
log_level: info
log_format: color
```
"""
import logging.handlers
@ -180,14 +325,17 @@ class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(cls, name: str = 'InvokeAI') -> logging.Logger:
config = get_invokeai_config()
if name not in cls.loggers:
def getLogger(cls,
name: str = 'InvokeAI',
config: InvokeAIAppConfig=InvokeAIAppConfig.get_config())->logging.Logger:
if name in cls.loggers:
logger = cls.loggers[name]
logger.handlers.clear()
else:
logger = logging.getLogger(name)
logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.getLoggers(config):
logger.addHandler(ch)
logger.setLevel(config.log_level.upper()) # yes, strings work here
for ch in cls.getLoggers(config):
logger.addHandler(ch)
cls.loggers[name] = logger
return cls.loggers[name]
@ -199,9 +347,11 @@ class InvokeAILogger(object):
handler_name,*args = handler.split('=',2)
args = args[0] if len(args) > 0 else None
# console and file are the only handlers that gets a custom formatter
# console and file get the fancy formatter.
# syslog gets a simple one
# http gets no custom formatter
formatter = LOG_FORMATTERS[config.log_format]
if handler_name=='console':
formatter = LOG_FORMATTERS[config.log_format]
ch = logging.StreamHandler()
ch.setFormatter(formatter())
handlers.append(ch)
@ -212,7 +362,7 @@ class InvokeAILogger(object):
elif handler_name=='file':
ch = cls._parse_file_args(args)
ch.setFormatter(InvokeAISyslogFormatter())
ch.setFormatter(formatter())
handlers.append(ch)
elif handler_name=='http':

View File

@ -28,7 +28,7 @@ import torch
from npyscreen import widget
from omegaconf import OmegaConf
import invokeai.backend.util.logging as logger
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.install.model_install_backend import (
Dataset_path,
@ -939,6 +939,7 @@ def main():
if opt.full_precision:
invoke_args.extend(['--precision','float32'])
config.parse_args(invoke_args)
logger = InvokeAILogger().getLogger(config=config)
if not (config.root_dir / config.conf_path.parent).exists():
logger.info(

View File

@ -22,6 +22,7 @@ import { APP_HEIGHT, APP_WIDTH } from 'theme/util/constants';
import GlobalHotkeys from './GlobalHotkeys';
import Toaster from './Toaster';
import DeleteImageModal from 'features/gallery/components/DeleteImageModal';
import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale';
const DEFAULT_CONFIG = {};
@ -66,10 +67,17 @@ const App = ({
setIsReady(true);
}
if (isApplicationReady) {
// TODO: This is a jank fix for canvas not filling the screen on first load
setTimeout(() => {
dispatch(requestCanvasRescale());
}, 200);
}
return () => {
setIsReady && setIsReady(false);
};
}, [isApplicationReady, setIsReady]);
}, [dispatch, isApplicationReady, setIsReady]);
return (
<>

View File

@ -40,11 +40,11 @@ const ImageDndContext = (props: ImageDndContextProps) => {
);
const mouseSensor = useSensor(MouseSensor, {
activationConstraint: { delay: 250, tolerance: 5 },
activationConstraint: { delay: 150, tolerance: 5 },
});
const touchSensor = useSensor(TouchSensor, {
activationConstraint: { delay: 250, tolerance: 5 },
activationConstraint: { delay: 150, tolerance: 5 },
});
// TODO: Use KeyboardSensor - needs composition of multiple collisionDetection algos
// Alternatively, fix `rectIntersection` collection detection to work with the drag overlay

View File

@ -2,7 +2,6 @@ import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons';
import {
Box,
Flex,
FlexProps,
FormControl,
FormControlProps,
FormLabel,
@ -16,6 +15,7 @@ import {
} from '@chakra-ui/react';
import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom';
import { useSelect } from 'downshift';
import { isString } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useMemo } from 'react';
@ -23,15 +23,19 @@ import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles';
export type ItemTooltips = { [key: string]: string };
export type IAICustomSelectOption = {
value: string;
label: string;
tooltip?: string;
};
type IAICustomSelectProps = {
label?: string;
items: string[];
itemTooltips?: ItemTooltips;
selectedItem: string;
setSelectedItem: (v: string | null | undefined) => void;
value: string;
data: IAICustomSelectOption[] | string[];
onChange: (v: string) => void;
withCheckIcon?: boolean;
formControlProps?: FormControlProps;
buttonProps?: FlexProps;
tooltip?: string;
tooltipProps?: Omit<TooltipProps, 'children'>;
ellipsisPosition?: 'start' | 'end';
@ -40,18 +44,33 @@ type IAICustomSelectProps = {
const IAICustomSelect = (props: IAICustomSelectProps) => {
const {
label,
items,
itemTooltips,
setSelectedItem,
selectedItem,
withCheckIcon,
formControlProps,
tooltip,
buttonProps,
tooltipProps,
ellipsisPosition = 'end',
data,
value,
onChange,
} = props;
const values = useMemo(() => {
return data.map<IAICustomSelectOption>((v) => {
if (isString(v)) {
return { value: v, label: v };
}
return v;
});
}, [data]);
const stringValues = useMemo(() => {
return values.map((v) => v.value);
}, [values]);
const valueData = useMemo(() => {
return values.find((v) => v.value === value);
}, [values, value]);
const {
isOpen,
getToggleButtonProps,
@ -60,10 +79,11 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
highlightedIndex,
getItemProps,
} = useSelect({
items,
selectedItem,
onSelectedItemChange: ({ selectedItem: newSelectedItem }) =>
setSelectedItem(newSelectedItem),
items: stringValues,
selectedItem: value,
onSelectedItemChange: ({ selectedItem: newSelectedItem }) => {
newSelectedItem && onChange(newSelectedItem);
},
});
const { refs, floatingStyles } = useFloating<HTMLButtonElement>({
@ -93,8 +113,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
)}
<Tooltip label={tooltip} {...tooltipProps}>
<Flex
{...getToggleButtonProps({ ref: refs.setReference })}
{...buttonProps}
{...getToggleButtonProps({ ref: refs.reference })}
sx={{
alignItems: 'center',
userSelect: 'none',
@ -119,7 +138,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
direction: labelTextDirection,
}}
>
{selectedItem}
{valueData?.label}
</Text>
<ChevronUpIcon
sx={{
@ -135,7 +154,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
{isOpen && (
<List
as={Flex}
ref={refs.setFloating}
ref={refs.floating}
sx={{
...floatingStyles,
top: 0,
@ -155,8 +174,8 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
}}
>
<OverlayScrollbarsComponent>
{items.map((item, index) => {
const isSelected = selectedItem === item;
{values.map((v, index) => {
const isSelected = value === v.value;
const isHighlighted = highlightedIndex === index;
const fontWeight = isSelected ? 700 : 500;
const bg = isHighlighted
@ -166,9 +185,9 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
: undefined;
return (
<Tooltip
isDisabled={!itemTooltips}
key={`${item}${index}`}
label={itemTooltips?.[item]}
isDisabled={!v.tooltip}
key={`${v.value}${index}`}
label={v.tooltip}
hasArrow
placement="right"
>
@ -182,8 +201,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
transitionProperty: 'common',
transitionDuration: '0.15s',
}}
key={`${item}${index}`}
{...getItemProps({ item, index })}
{...getItemProps({ item: v.value, index })}
>
{withCheckIcon ? (
<Grid gridTemplateColumns="1.25rem auto">
@ -198,7 +216,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
fontWeight,
}}
>
{item}
{v.label}
</Text>
</GridItem>
</Grid>
@ -210,7 +228,7 @@ const IAICustomSelect = (props: IAICustomSelectProps) => {
fontWeight,
}}
>
{item}
{v.label}
</Text>
)}
</ListItem>

View File

@ -14,7 +14,7 @@ import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { AnimatePresence } from 'framer-motion';
import { ReactElement, SyntheticEvent, useCallback } from 'react';
import { memo, useRef } from 'react';
import { FaImage, FaTimes, FaUpload } from 'react-icons/fa';
import { FaImage, FaTimes, FaUndo, FaUpload } from 'react-icons/fa';
import { ImageDTO } from 'services/api';
import { v4 as uuidv4 } from 'uuid';
import IAIDropOverlay from './IAIDropOverlay';
@ -174,14 +174,13 @@ const IAIDndImage = (props: IAIDndImageProps) => {
position: 'absolute',
top: 0,
right: 0,
p: 2,
}}
>
<IAIIconButton
size={resetIconSize}
tooltip="Reset Image"
aria-label="Reset Image"
icon={<FaTimes />}
icon={<FaUndo />}
onClick={onReset}
/>
</Box>

View File

@ -16,7 +16,6 @@ import {
setShouldShowIntermediates,
setShouldSnapToGrid,
} from 'features/canvas/store/canvasSlice';
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
import { isEqual } from 'lodash-es';
import { ChangeEvent } from 'react';
@ -159,7 +158,6 @@ const IAICanvasSettingsButtonPopover = () => {
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
/>
<ClearCanvasHistoryButtonModal />
<EmptyTempFolderButtonModal />
</Flex>
</IAIPopover>
);

View File

@ -30,7 +30,10 @@ import {
} from './canvasTypes';
import { ImageDTO } from 'services/api';
import { sessionCanceled } from 'services/thunks/session';
import { setShouldUseCanvasBetaLayout } from 'features/ui/store/uiSlice';
import {
setActiveTab,
setShouldUseCanvasBetaLayout,
} from 'features/ui/store/uiSlice';
import { imageUrlsReceived } from 'services/thunks/image';
export const initialLayerState: CanvasLayerState = {
@ -857,6 +860,11 @@ export const canvasSlice = createSlice({
builder.addCase(setShouldUseCanvasBetaLayout, (state, action) => {
state.doesCanvasNeedScaling = true;
});
builder.addCase(setActiveTab, (state, action) => {
state.doesCanvasNeedScaling = true;
});
builder.addCase(imageUrlsReceived.fulfilled, (state, action) => {
const { image_name, image_origin, image_url, thumbnail_url } =
action.payload;

View File

@ -1,17 +1,26 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAICustomSelect from 'common/components/IAICustomSelect';
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import {
CONTROLNET_MODELS,
ControlNetModel,
ControlNetModelName,
} from 'features/controlNet/store/constants';
import { controlNetModelChanged } from 'features/controlNet/store/controlNetSlice';
import { map } from 'lodash-es';
import { memo, useCallback } from 'react';
type ParamControlNetModelProps = {
controlNetId: string;
model: ControlNetModel;
model: ControlNetModelName;
};
const DATA: IAICustomSelectOption[] = map(CONTROLNET_MODELS, (m) => ({
value: m.type,
label: m.label,
tooltip: m.type,
}));
const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const { controlNetId, model } = props;
const dispatch = useAppDispatch();
@ -19,7 +28,7 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
const handleModelChanged = useCallback(
(val: string | null | undefined) => {
// TODO: do not cast
const model = val as ControlNetModel;
const model = val as ControlNetModelName;
dispatch(controlNetModelChanged({ controlNetId, model }));
},
[controlNetId, dispatch]
@ -29,9 +38,9 @@ const ParamControlNetModel = (props: ParamControlNetModelProps) => {
<IAICustomSelect
tooltip={model}
tooltipProps={{ placement: 'top', hasArrow: true }}
items={CONTROLNET_MODELS}
selectedItem={model}
setSelectedItem={handleModelChanged}
data={DATA}
value={model}
onChange={handleModelChanged}
ellipsisPosition="start"
withCheckIcon
/>

View File

@ -1,4 +1,6 @@
import IAICustomSelect from 'common/components/IAICustomSelect';
import IAICustomSelect, {
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
import { memo, useCallback } from 'react';
import {
ControlNetProcessorNode,
@ -7,15 +9,28 @@ import {
import { controlNetProcessorTypeChanged } from '../../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { CONTROLNET_PROCESSORS } from '../../store/constants';
import { map } from 'lodash-es';
type ParamControlNetProcessorSelectProps = {
controlNetId: string;
processorNode: ControlNetProcessorNode;
};
const CONTROLNET_PROCESSOR_TYPES = Object.keys(
CONTROLNET_PROCESSORS
) as ControlNetProcessorType[];
const CONTROLNET_PROCESSOR_TYPES: IAICustomSelectOption[] = map(
CONTROLNET_PROCESSORS,
(p) => ({
value: p.type,
label: p.label,
tooltip: p.description,
})
).sort((a, b) =>
// sort 'none' to the top
a.value === 'none'
? -1
: b.value === 'none'
? 1
: a.label.localeCompare(b.label)
);
const ParamControlNetProcessorSelect = (
props: ParamControlNetProcessorSelectProps
@ -36,9 +51,9 @@ const ParamControlNetProcessorSelect = (
return (
<IAICustomSelect
label="Processor"
items={CONTROLNET_PROCESSOR_TYPES}
selectedItem={processorNode.type ?? 'canny_image_processor'}
setSelectedItem={handleProcessorTypeChanged}
value={processorNode.type ?? 'canny_image_processor'}
data={CONTROLNET_PROCESSOR_TYPES}
onChange={handleProcessorTypeChanged}
withCheckIcon
/>
);

View File

@ -5,12 +5,12 @@ import {
} from './types';
type ControlNetProcessorsDict = Record<
ControlNetProcessorType,
string,
{
type: ControlNetProcessorType;
type: ControlNetProcessorType | 'none';
label: string;
description: string;
default: RequiredControlNetProcessorNode;
default: RequiredControlNetProcessorNode | { type: 'none' };
}
>;
@ -23,10 +23,10 @@ type ControlNetProcessorsDict = Record<
*
* TODO: Generate from the OpenAPI schema
*/
export const CONTROLNET_PROCESSORS = {
export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
none: {
type: 'none',
label: 'None',
label: 'none',
description: '',
default: {
type: 'none',
@ -116,7 +116,7 @@ export const CONTROLNET_PROCESSORS = {
},
mlsd_image_processor: {
type: 'mlsd_image_processor',
label: 'MLSD',
label: 'M-LSD',
description: '',
default: {
id: 'mlsd_image_processor',
@ -174,39 +174,84 @@ export const CONTROLNET_PROCESSORS = {
},
};
export const CONTROLNET_MODELS = [
'lllyasviel/control_v11p_sd15_canny',
'lllyasviel/control_v11p_sd15_inpaint',
'lllyasviel/control_v11p_sd15_mlsd',
'lllyasviel/control_v11f1p_sd15_depth',
'lllyasviel/control_v11p_sd15_normalbae',
'lllyasviel/control_v11p_sd15_seg',
'lllyasviel/control_v11p_sd15_lineart',
'lllyasviel/control_v11p_sd15s2_lineart_anime',
'lllyasviel/control_v11p_sd15_scribble',
'lllyasviel/control_v11p_sd15_softedge',
'lllyasviel/control_v11e_sd15_shuffle',
'lllyasviel/control_v11p_sd15_openpose',
'lllyasviel/control_v11f1e_sd15_tile',
'lllyasviel/control_v11e_sd15_ip2p',
'CrucibleAI/ControlNetMediaPipeFace',
];
export type ControlNetModel = (typeof CONTROLNET_MODELS)[number];
export const CONTROLNET_MODEL_MAP: Record<
ControlNetModel,
ControlNetProcessorType
> = {
'lllyasviel/control_v11p_sd15_canny': 'canny_image_processor',
'lllyasviel/control_v11p_sd15_mlsd': 'mlsd_image_processor',
'lllyasviel/control_v11f1p_sd15_depth': 'midas_depth_image_processor',
'lllyasviel/control_v11p_sd15_normalbae': 'normalbae_image_processor',
'lllyasviel/control_v11p_sd15_lineart': 'lineart_image_processor',
'lllyasviel/control_v11p_sd15s2_lineart_anime':
'lineart_anime_image_processor',
'lllyasviel/control_v11p_sd15_softedge': 'hed_image_processor',
'lllyasviel/control_v11e_sd15_shuffle': 'content_shuffle_image_processor',
'lllyasviel/control_v11p_sd15_openpose': 'openpose_image_processor',
'CrucibleAI/ControlNetMediaPipeFace': 'mediapipe_face_processor',
type ControlNetModel = {
type: string;
label: string;
description?: string;
defaultProcessor?: ControlNetProcessorType;
};
export const CONTROLNET_MODELS: Record<string, ControlNetModel> = {
'lllyasviel/control_v11p_sd15_canny': {
type: 'lllyasviel/control_v11p_sd15_canny',
label: 'Canny',
defaultProcessor: 'canny_image_processor',
},
'lllyasviel/control_v11p_sd15_inpaint': {
type: 'lllyasviel/control_v11p_sd15_inpaint',
label: 'Inpaint',
},
'lllyasviel/control_v11p_sd15_mlsd': {
type: 'lllyasviel/control_v11p_sd15_mlsd',
label: 'M-LSD',
defaultProcessor: 'mlsd_image_processor',
},
'lllyasviel/control_v11f1p_sd15_depth': {
type: 'lllyasviel/control_v11f1p_sd15_depth',
label: 'Depth',
defaultProcessor: 'midas_depth_image_processor',
},
'lllyasviel/control_v11p_sd15_normalbae': {
type: 'lllyasviel/control_v11p_sd15_normalbae',
label: 'Normal Map (BAE)',
defaultProcessor: 'normalbae_image_processor',
},
'lllyasviel/control_v11p_sd15_seg': {
type: 'lllyasviel/control_v11p_sd15_seg',
label: 'Segment Anything',
},
'lllyasviel/control_v11p_sd15_lineart': {
type: 'lllyasviel/control_v11p_sd15_lineart',
label: 'Lineart',
defaultProcessor: 'lineart_image_processor',
},
'lllyasviel/control_v11p_sd15s2_lineart_anime': {
type: 'lllyasviel/control_v11p_sd15s2_lineart_anime',
label: 'Lineart Anime',
defaultProcessor: 'lineart_anime_image_processor',
},
'lllyasviel/control_v11p_sd15_scribble': {
type: 'lllyasviel/control_v11p_sd15_scribble',
label: 'Scribble',
},
'lllyasviel/control_v11p_sd15_softedge': {
type: 'lllyasviel/control_v11p_sd15_softedge',
label: 'Soft Edge',
defaultProcessor: 'hed_image_processor',
},
'lllyasviel/control_v11e_sd15_shuffle': {
type: 'lllyasviel/control_v11e_sd15_shuffle',
label: 'Content Shuffle',
defaultProcessor: 'content_shuffle_image_processor',
},
'lllyasviel/control_v11p_sd15_openpose': {
type: 'lllyasviel/control_v11p_sd15_openpose',
label: 'Openpose',
defaultProcessor: 'openpose_image_processor',
},
'lllyasviel/control_v11f1e_sd15_tile': {
type: 'lllyasviel/control_v11f1e_sd15_tile',
label: 'Tile (experimental)',
},
'lllyasviel/control_v11e_sd15_ip2p': {
type: 'lllyasviel/control_v11e_sd15_ip2p',
label: 'Pix2Pix (experimental)',
},
'CrucibleAI/ControlNetMediaPipeFace': {
type: 'CrucibleAI/ControlNetMediaPipeFace',
label: 'Mediapipe Face',
defaultProcessor: 'mediapipe_face_processor',
},
};
export type ControlNetModelName = keyof typeof CONTROLNET_MODELS;

View File

@ -9,9 +9,8 @@ import {
} from './types';
import {
CONTROLNET_MODELS,
CONTROLNET_MODEL_MAP,
CONTROLNET_PROCESSORS,
ControlNetModel,
ControlNetModelName,
} from './constants';
import { controlNetImageProcessed } from './actions';
import { imageDeleted, imageUrlsReceived } from 'services/thunks/image';
@ -21,7 +20,7 @@ import { appSocketInvocationError } from 'services/events/actions';
export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
isEnabled: true,
model: CONTROLNET_MODELS[0],
model: CONTROLNET_MODELS['lllyasviel/control_v11p_sd15_canny'].type,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
@ -36,7 +35,7 @@ export const initialControlNet: Omit<ControlNetConfig, 'controlNetId'> = {
export type ControlNetConfig = {
controlNetId: string;
isEnabled: boolean;
model: ControlNetModel;
model: ControlNetModelName;
weight: number;
beginStepPct: number;
endStepPct: number;
@ -138,14 +137,17 @@ export const controlNetSlice = createSlice({
},
controlNetModelChanged: (
state,
action: PayloadAction<{ controlNetId: string; model: ControlNetModel }>
action: PayloadAction<{
controlNetId: string;
model: ControlNetModelName;
}>
) => {
const { controlNetId, model } = action.payload;
state.controlNets[controlNetId].model = model;
state.controlNets[controlNetId].processedControlImage = null;
if (state.controlNets[controlNetId].shouldAutoConfig) {
const processorType = CONTROLNET_MODEL_MAP[model];
const processorType = CONTROLNET_MODELS[model].defaultProcessor;
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[
@ -225,7 +227,8 @@ export const controlNetSlice = createSlice({
if (newShouldAutoConfig) {
// manage the processor for the user
const processorType =
CONTROLNET_MODEL_MAP[state.controlNets[controlNetId].model];
CONTROLNET_MODELS[state.controlNets[controlNetId].model]
.defaultProcessor;
if (processorType) {
state.controlNets[controlNetId].processorType = processorType;
state.controlNets[controlNetId].processorNode = CONTROLNET_PROCESSORS[

View File

@ -342,7 +342,10 @@ export const buildImageToImageGraph = (state: RootState): Graph => {
});
}
if (shouldFitToWidthHeight) {
if (
shouldFitToWidthHeight &&
(initialImage.width !== width || initialImage.height !== height)
) {
// The init image needs to be resized to the specified width and height before being passed to `IMAGE_TO_LATENTS`
// Create a resize node, explicitly setting its image

View File

@ -14,9 +14,11 @@ const selector = createSelector(
(ui, generation) => {
// TODO: DPMSolverSinglestepScheduler is fixed in https://github.com/huggingface/diffusers/pull/3413
// but we need to wait for the next release before removing this special handling.
const allSchedulers = ui.schedulers.filter((scheduler) => {
return !['dpmpp_2s'].includes(scheduler);
});
const allSchedulers = ui.schedulers
.filter((scheduler) => {
return !['dpmpp_2s'].includes(scheduler);
})
.sort((a, b) => a.localeCompare(b));
return {
scheduler: generation.scheduler,
@ -45,9 +47,9 @@ const ParamScheduler = () => {
return (
<IAICustomSelect
label={t('parameters.scheduler')}
selectedItem={scheduler}
setSelectedItem={handleChange}
items={allSchedulers}
value={scheduler}
data={allSchedulers}
onChange={handleChange}
withCheckIcon
/>
);

View File

@ -58,6 +58,7 @@ const InitialImagePreview = () => {
onReset={handleReset}
fallback={<IAIImageFallback sx={{ bg: 'none' }} />}
postUploadAction={{ type: 'SET_INITIAL_IMAGE' }}
withResetIcon
/>
</Flex>
);

View File

@ -1,41 +0,0 @@
// import { emptyTempFolder } from 'app/socketio/actions';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIAlertDialog from 'common/components/IAIAlertDialog';
import IAIButton from 'common/components/IAIButton';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import {
clearCanvasHistory,
resetCanvas,
} from 'features/canvas/store/canvasSlice';
import { useTranslation } from 'react-i18next';
import { FaTrash } from 'react-icons/fa';
const EmptyTempFolderButtonModal = () => {
const isStaging = useAppSelector(isStagingSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const acceptCallback = () => {
dispatch(emptyTempFolder());
dispatch(resetCanvas());
dispatch(clearCanvasHistory());
};
return (
<IAIAlertDialog
title={t('unifiedCanvas.emptyTempImageFolder')}
acceptCallback={acceptCallback}
acceptButtonText={t('unifiedCanvas.emptyFolder')}
triggerComponent={
<IAIButton leftIcon={<FaTrash />} size="sm" isDisabled={isStaging}>
{t('unifiedCanvas.emptyTempImageFolder')}
</IAIButton>
}
>
<p>{t('unifiedCanvas.emptyTempImagesFolderMessage')}</p>
<br />
<p>{t('unifiedCanvas.emptyTempImagesFolderConfirm')}</p>
</IAIAlertDialog>
);
};
export default EmptyTempFolderButtonModal;

View File

@ -4,34 +4,29 @@ import { isEqual } from 'lodash-es';
import { useTranslation } from 'react-i18next';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectModelsAll,
selectModelsById,
selectModelsIds,
} from '../store/modelSlice';
import { selectModelsAll, selectModelsById } from '../store/modelSlice';
import { RootState } from 'app/store/store';
import { modelSelected } from 'features/parameters/store/generationSlice';
import { generationSelector } from 'features/parameters/store/generationSelectors';
import IAICustomSelect, {
ItemTooltips,
IAICustomSelectOption,
} from 'common/components/IAICustomSelect';
const selector = createSelector(
[(state: RootState) => state, generationSelector],
(state, generation) => {
const selectedModel = selectModelsById(state, generation.model);
const allModelNames = selectModelsIds(state).map((id) => String(id));
const allModelTooltips = selectModelsAll(state).reduce(
(allModelTooltips, model) => {
allModelTooltips[model.name] = model.description ?? '';
return allModelTooltips;
},
{} as ItemTooltips
);
const modelData = selectModelsAll(state)
.map<IAICustomSelectOption>((m) => ({
value: m.name,
label: m.name,
tooltip: m.description,
}))
.sort((a, b) => a.label.localeCompare(b.label));
return {
allModelNames,
allModelTooltips,
selectedModel,
modelData,
};
},
{
@ -44,8 +39,7 @@ const selector = createSelector(
const ModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { allModelNames, allModelTooltips, selectedModel } =
useAppSelector(selector);
const { selectedModel, modelData } = useAppSelector(selector);
const handleChangeModel = useCallback(
(v: string | null | undefined) => {
if (!v) {
@ -60,10 +54,9 @@ const ModelSelect = () => {
<IAICustomSelect
label={t('modelManager.model')}
tooltip={selectedModel?.description}
items={allModelNames}
itemTooltips={allModelTooltips}
selectedItem={selectedModel?.name ?? ''}
setSelectedItem={handleChangeModel}
data={modelData}
value={selectedModel?.name ?? ''}
onChange={handleChangeModel}
withCheckIcon={true}
tooltipProps={{ placement: 'top', hasArrow: true }}
/>

View File

@ -14,7 +14,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setIsLightboxOpen } from 'features/lightbox/store/lightboxSlice';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { setActiveTab, togglePanels } from 'features/ui/store/uiSlice';
import { memo, ReactNode, useCallback, useMemo } from 'react';
import { memo, MouseEvent, ReactNode, useCallback, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { MdDeviceHub, MdGridOn } from 'react-icons/md';
import { GoTextSize } from 'react-icons/go';
@ -47,22 +47,22 @@ export interface InvokeTabInfo {
const tabs: InvokeTabInfo[] = [
{
id: 'txt2img',
icon: <Icon as={GoTextSize} sx={{ boxSize: 6 }} />,
icon: <Icon as={GoTextSize} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <TextToImageTab />,
},
{
id: 'img2img',
icon: <Icon as={FaImage} sx={{ boxSize: 6 }} />,
icon: <Icon as={FaImage} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <ImageTab />,
},
{
id: 'unifiedCanvas',
icon: <Icon as={MdGridOn} sx={{ boxSize: 6 }} />,
icon: <Icon as={MdGridOn} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <UnifiedCanvasTab />,
},
{
id: 'nodes',
icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6 }} />,
icon: <Icon as={MdDeviceHub} sx={{ boxSize: 6, pointerEvents: 'none' }} />,
content: <NodesTab />,
},
];
@ -119,6 +119,12 @@ const InvokeTabs = () => {
}
}, [dispatch, activeTabName]);
const handleClickTab = useCallback((e: MouseEvent<HTMLElement>) => {
if (e.target instanceof HTMLElement) {
e.target.blur();
}
}, []);
const tabs = useMemo(
() =>
enabledTabs.map((tab) => (
@ -128,7 +134,7 @@ const InvokeTabs = () => {
label={String(t(`common.${tab.id}` as ResourceKey))}
placement="end"
>
<Tab>
<Tab onClick={handleClickTab}>
<VisuallyHidden>
{String(t(`common.${tab.id}` as ResourceKey))}
</VisuallyHidden>
@ -136,7 +142,7 @@ const InvokeTabs = () => {
</Tab>
</Tooltip>
)),
[t, enabledTabs]
[enabledTabs, t, handleClickTab]
);
const tabPanels = useMemo(

View File

@ -12,7 +12,6 @@ import {
setShouldShowCanvasDebugInfo,
setShouldShowIntermediates,
} from 'features/canvas/store/canvasSlice';
import EmptyTempFolderButtonModal from 'features/system/components/ClearTempFolderButtonModal';
import { FaWrench } from 'react-icons/fa';
@ -105,7 +104,6 @@ const UnifiedCanvasSettings = () => {
onChange={(e) => dispatch(setShouldAntialias(e.target.checked))}
/>
<ClearCanvasHistoryButtonModal />
<EmptyTempFolderButtonModal />
</Flex>
</IAIPopover>
);

View File

@ -55,8 +55,6 @@ const UnifiedCanvasContent = () => {
});
useLayoutEffect(() => {
dispatch(requestCanvasRescale());
const resizeCallback = () => {
dispatch(requestCanvasRescale());
};

View File

@ -121,3 +121,78 @@ def test_graph_state_collects(mock_services):
assert isinstance(n6[0], CollectInvocation)
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
def test_graph_state_prepares_eagerly(mock_services):
"""Tests that all prepareable nodes are prepared"""
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
# separated, fully-preparable chain of nodes
graph.add_node(PromptTestInvocation(id="prompt_chain_1", prompt="Dinosaur sushi"))
graph.add_node(PromptTestInvocation(id="prompt_chain_2"))
graph.add_node(PromptTestInvocation(id="prompt_chain_3"))
graph.add_edge(create_edge("prompt_chain_1", "prompt", "prompt_chain_2", "prompt"))
graph.add_edge(create_edge("prompt_chain_2", "prompt", "prompt_chain_3", "prompt"))
g = GraphExecutionState(graph=graph)
g.next()
assert "prompt_collection" in g.source_prepared_mapping
assert "prompt_chain_1" in g.source_prepared_mapping
assert "prompt_chain_2" in g.source_prepared_mapping
assert "prompt_chain_3" in g.source_prepared_mapping
assert "iterate" not in g.source_prepared_mapping
assert "prompt_iterated" not in g.source_prepared_mapping
def test_graph_executes_depth_first(mock_services):
"""Tests that the graph executes depth-first, executing a branch as far as possible before moving to the next branch"""
graph = Graph()
test_prompts = ["Banana sushi", "Cat sushi"]
graph.add_node(PromptCollectionTestInvocation(id="prompt_collection", collection=list(test_prompts)))
graph.add_node(IterateInvocation(id="iterate"))
graph.add_node(PromptTestInvocation(id="prompt_iterated"))
graph.add_node(PromptTestInvocation(id="prompt_successor"))
graph.add_edge(create_edge("prompt_collection", "collection", "iterate", "collection"))
graph.add_edge(create_edge("iterate", "item", "prompt_iterated", "prompt"))
graph.add_edge(create_edge("prompt_iterated", "prompt", "prompt_successor", "prompt"))
g = GraphExecutionState(graph=graph)
n1 = invoke_next(g, mock_services)
n2 = invoke_next(g, mock_services)
n3 = invoke_next(g, mock_services)
n4 = invoke_next(g, mock_services)
# Because ordering is not guaranteed, we cannot compare results directly.
# Instead, we must count the number of results.
def get_completed_count(g, id):
ids = [i for i in g.source_prepared_mapping[id]]
completed_ids = [i for i in g.executed if i in ids]
return len(completed_ids)
# Check at each step that the number of executed nodes matches the expectation for depth-first execution
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 0
n5 = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 1
assert get_completed_count(g, "prompt_successor") == 1
n6 = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 1
n7 = invoke_next(g, mock_services)
assert get_completed_count(g, "prompt_iterated") == 2
assert get_completed_count(g, "prompt_successor") == 2