feat(ui, app): use layer as control (wip)

This commit is contained in:
psychedelicious 2024-08-14 18:10:51 +10:00
parent 3b36eb0223
commit 636d9a7209
65 changed files with 1734 additions and 292 deletions

View File

@ -10,7 +10,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
QueueItemOrigin,
SessionQueueItem,
SessionQueueStatus,
)
@ -89,7 +88,7 @@ class QueueItemEventBase(QueueEventBase):
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
class InvocationEventBase(QueueItemEventBase):
@ -284,7 +283,7 @@ class BatchEnqueuedEvent(QueueEventBase):
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":

View File

@ -86,7 +86,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this batch.")
origin: str | None = Field(default=None, description="The origin of this batch.")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field(
@ -205,7 +205,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this queue item. ")
origin: str | None = Field(default=None, description="The origin of this queue item. ")
session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
)
@ -305,7 +305,7 @@ class SessionQueueStatus(BaseModel):
class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
origin: QueueItemOrigin | None = Field(..., description="The origin of the batch")
origin: str | None = Field(..., description="The origin of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
@ -451,7 +451,7 @@ class SessionQueueValueToInsert(NamedTuple):
field_values: Optional[str] # field_values json
priority: int # priority
workflow: Optional[str] # workflow json
origin: QueueItemOrigin | None
origin: str | None
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]

View File

@ -10,6 +10,7 @@ import { setEventListeners } from 'services/events/setEventListeners';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
import { io } from 'socket.io-client';
import { assert } from 'tsafe';
// Inject socket options and url into window for debugging
declare global {
@ -18,6 +19,14 @@ declare global {
}
}
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
export const $socket = atom<AppSocket | null>(null);
export const getSocket = () => {
const socket = $socket.get();
assert(socket !== null, 'Socket is not initialized');
return socket;
};
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
const $isSocketInitialized = atom<boolean>(false);
@ -61,7 +70,8 @@ export const useSocketIO = () => {
return;
}
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(socketUrl, socketOptions);
const socket: AppSocket = io(socketUrl, socketOptions);
$socket.set(socket);
setEventListeners({ dispatch, socket });
socket.connect();

View File

@ -12,7 +12,7 @@ import {
caRecalled,
} from 'features/controlLayers/store/canvasV2Slice';
import { selectCA } from 'features/controlLayers/store/controlAdaptersReducers';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
@ -95,7 +95,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
}
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image.image, config as never);
const processorNode = IMAGE_FILTERS[config.type].buildNode(image.image, config as never);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {

View File

@ -1,4 +1,5 @@
import type { createStore } from 'app/store/store';
import { useStore } from '@nanostores/react';
import type { AppStore } from 'app/store/store';
import { atom } from 'nanostores';
// Inject socket options and url into window for debugging
@ -22,7 +23,7 @@ class ReduxStoreNotInitialized extends Error {
}
}
export const $store = atom<Readonly<ReturnType<typeof createStore>> | undefined>();
export const $store = atom<Readonly<AppStore | undefined>>();
export const getStore = () => {
const store = $store.get();
@ -31,3 +32,11 @@ export const getStore = () => {
}
return store;
};
export const useAppStore = () => {
const store = useStore($store);
if (!store) {
throw new ReduxStoreNotInitialized();
}
return store;
};

View File

@ -180,7 +180,8 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
},
});
export type RootState = ReturnType<ReturnType<typeof createStore>['getState']>;
export type AppStore = ReturnType<typeof createStore>;
export type RootState = ReturnType<AppStore['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch'];

View File

@ -1,8 +1,8 @@
import type { AppThunkDispatch, RootState } from 'app/store/store';
import type { AppStore, AppThunkDispatch, RootState } from 'app/store/store';
import type { TypedUseSelectorHook } from 'react-redux';
import { useDispatch, useSelector, useStore } from 'react-redux';
import {useDispatch, useSelector, useStore } from 'react-redux';
// Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
export const useAppStore = () => useStore<RootState>();
export const useAppStore = () => useStore<AppStore>();

View File

@ -1,4 +1,4 @@
import type { ProcessorTypeV2 } from 'features/controlLayers/store/types';
import type { FilterType } from 'features/controlLayers/store/types';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { O } from 'ts-toolbelt';
@ -83,7 +83,7 @@ export type AppConfig = {
sd: {
defaultModel?: string;
disabledControlNetModels: string[];
disabledControlNetProcessors: ProcessorTypeV2[];
disabledControlNetProcessors: FilterType[];
// Core parameters
iterations: NumericalParameterConfig;
width: NumericalParameterConfig; // initial value comes from model

View File

@ -9,12 +9,12 @@ import { MediapipeFaceProcessor } from 'features/controlLayers/components/Contro
import { MidasDepthProcessor } from 'features/controlLayers/components/ControlAdapter/processors/MidasDepthProcessor';
import { MlsdImageProcessor } from 'features/controlLayers/components/ControlAdapter/processors/MlsdImageProcessor';
import { PidiProcessor } from 'features/controlLayers/components/ControlAdapter/processors/PidiProcessor';
import type { ProcessorConfig } from 'features/controlLayers/store/types';
import type { FilterConfig } from 'features/controlLayers/store/types';
import { memo } from 'react';
type Props = {
config: ProcessorConfig | null;
onChange: (config: ProcessorConfig | null) => void;
config: FilterConfig | null;
onChange: (config: FilterConfig | null) => void;
};
export const ControlAdapterProcessorConfig = memo(({ config, onChange }: Props) => {

View File

@ -3,8 +3,8 @@ import { Combobox, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/u
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type {ProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA, isProcessorTypeV2 } from 'features/controlLayers/store/types';
import type {FilterConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS, isFilterType } from 'features/controlLayers/store/types';
import { configSelector } from 'features/system/store/configSelectors';
import { includes, map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
@ -13,8 +13,8 @@ import { PiXBold } from 'react-icons/pi';
import { assert } from 'tsafe';
type Props = {
config: ProcessorConfig | null;
onChange: (config: ProcessorConfig | null) => void;
config: FilterConfig | null;
onChange: (config: FilterConfig | null) => void;
};
const selectDisabledProcessors = createMemoizedSelector(
@ -26,7 +26,7 @@ export const ControlAdapterProcessorTypeSelect = memo(({ config, onChange }: Pro
const { t } = useTranslation();
const disabledProcessors = useAppSelector(selectDisabledProcessors);
const options = useMemo(() => {
return map(CA_PROCESSOR_DATA, ({ labelTKey }, type) => ({ value: type, label: t(labelTKey) })).filter(
return map(IMAGE_FILTERS, ({ labelTKey }, type) => ({ value: type, label: t(labelTKey) })).filter(
(o) => !includes(disabledProcessors, o.value)
);
}, [disabledProcessors, t]);
@ -36,8 +36,8 @@ export const ControlAdapterProcessorTypeSelect = memo(({ config, onChange }: Pro
if (!v) {
onChange(null);
} else {
assert(isProcessorTypeV2(v.value));
onChange(CA_PROCESSOR_DATA[v.value].buildDefaults());
assert(isFilterType(v.value));
onChange(IMAGE_FILTERS[v.value].buildDefaults());
}
},
[onChange]

View File

@ -19,7 +19,7 @@ import {
caWeightChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import { selectCAOrThrow } from 'features/controlLayers/store/controlAdaptersReducers';
import type { ControlModeV2, ProcessorConfig } from 'features/controlLayers/store/types';
import type { ControlModeV2, FilterConfig } from 'features/controlLayers/store/types';
import type { CAImageDropData } from 'features/dnd/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@ -62,7 +62,7 @@ export const ControlAdapterSettings = memo(() => {
);
const onChangeProcessorConfig = useCallback(
(processorConfig: ProcessorConfig | null) => {
(processorConfig: FilterConfig | null) => {
dispatch(caProcessorConfigChanged({ id, processorConfig }));
},
[dispatch, id]

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { CannyProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<CannyProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['canny_image_processor'].buildDefaults();
const DEFAULTS = IMAGE_FILTERS['canny_image_processor'].buildDefaults();
export const CannyProcessor = ({ onChange, config }: Props) => {
const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { ColorMapProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<ColorMapProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['color_map_image_processor'].buildDefaults();
const DEFAULTS = IMAGE_FILTERS['color_map_image_processor'].buildDefaults();
export const ColorMapProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { ContentShuffleProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<ContentShuffleProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['content_shuffle_image_processor'].buildDefaults();
const DEFAULTS = IMAGE_FILTERS['content_shuffle_image_processor'].buildDefaults();
export const ContentShuffleProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();

View File

@ -1,7 +1,7 @@
import { Flex, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { DWOpenposeProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<DWOpenposeProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['dw_openpose_image_processor'].buildDefaults();
const DEFAULTS = IMAGE_FILTERS['dw_openpose_image_processor'].buildDefaults();
export const DWOpenposeProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { MediapipeFaceProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<MediapipeFaceProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['mediapipe_face_processor'].buildDefaults();
const DEFAULTS = IMAGE_FILTERS['mediapipe_face_processor'].buildDefaults();
export const MediapipeFaceProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { MidasDepthProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<MidasDepthProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['midas_depth_image_processor'].buildDefaults();
const DEFAULTS = IMAGE_FILTERS['midas_depth_image_processor'].buildDefaults();
export const MidasDepthProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { MlsdProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<MlsdProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['mlsd_image_processor'].buildDefaults();
const DEFAULTS = IMAGE_FILTERS['mlsd_image_processor'].buildDefaults();
export const MlsdImageProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();

View File

@ -1,6 +1,6 @@
import type { ProcessorConfig } from 'features/controlLayers/store/types';
import type { FilterConfig } from 'features/controlLayers/store/types';
export type ProcessorComponentProps<T extends ProcessorConfig> = {
export type ProcessorComponentProps<T extends FilterConfig> = {
onChange: (config: T) => void;
config: T;
};

View File

@ -1,19 +1,37 @@
/* eslint-disable i18next/no-literal-string */
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { AddLayerButton } from 'features/controlLayers/components/AddLayerButton';
import { CanvasEntityList } from 'features/controlLayers/components/CanvasEntityList';
import { DeleteAllLayersButton } from 'features/controlLayers/components/DeleteAllLayersButton';
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { $filteringEntity } from 'features/controlLayers/store/canvasV2Slice';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { memo } from 'react';
import { Panel, PanelGroup } from 'react-resizable-panels';
export const ControlLayersPanelContent = memo(() => {
const filteringEntity = useStore($filteringEntity);
return (
<Flex flexDir="column" gap={2} w="full" h="full">
<Flex justifyContent="space-around">
<AddLayerButton />
<DeleteAllLayersButton />
</Flex>
<CanvasEntityList />
</Flex>
<PanelGroup direction="vertical">
<Panel id="canvas-entity-list-panel" order={0}>
<Flex flexDir="column" gap={2} w="full" h="full">
<Flex justifyContent="space-around">
<AddLayerButton />
<DeleteAllLayersButton />
</Flex>
<CanvasEntityList />
</Flex>
</Panel>
{Boolean(filteringEntity) && (
<>
<ResizeHandle orientation="horizontal" />
<Panel id="filter-panel" order={1}>
<Filter />
</Panel>
</>
)}
</PanelGroup>
);
});

View File

@ -12,11 +12,25 @@ import { ResetCanvasButton } from 'features/controlLayers/components/ResetCanvas
import { ToolChooser } from 'features/controlLayers/components/ToolChooser';
import { UndoRedoButtonGroup } from 'features/controlLayers/components/UndoRedoButtonGroup';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { nanoid } from 'features/controlLayers/konva/util';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { ViewerToggleMenu } from 'features/gallery/components/ImageViewer/ViewerToggleMenu';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
const filter = () => {
const entity = $canvasManager.get()?.stateApi.getSelectedEntity();
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.previewFilter({
type: 'canny_image_processor',
id: nanoid(),
low_threshold: 50,
high_threshold: 50,
});
};
export const ControlLayersToolbar = memo(() => {
const tool = useAppSelector((s) => s.canvasV2.tool.selected);
const canvasManager = useStore($canvasManager);
@ -47,6 +61,7 @@ export const ControlLayersToolbar = memo(() => {
<Flex gap={2} marginInlineEnd="auto" alignItems="center">
<ToggleProgressButton />
<ToolChooser />
<Button onClick={filter}>Filter</Button>
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center" alignItems="center">

View File

@ -0,0 +1,76 @@
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { FilterSettings } from 'features/controlLayers/components/Filters/FilterSettings';
import { FilterTypeSelect } from 'features/controlLayers/components/Filters/FilterTypeSelect';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { $filteringEntity } from 'features/controlLayers/store/canvasV2Slice';
import { memo, useCallback } from 'react';
export const Filter = memo(() => {
const filteringEntity = useStore($filteringEntity);
const preview = useCallback(() => {
if (!filteringEntity) {
return;
}
const canvasManager = $canvasManager.get();
if (!canvasManager) {
return;
}
const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.previewFilter();
}, [filteringEntity]);
const apply = useCallback(() => {
if (!filteringEntity) {
return;
}
const canvasManager = $canvasManager.get();
if (!canvasManager) {
return;
}
const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.applyFilter();
}, [filteringEntity]);
const cancel = useCallback(() => {
if (!filteringEntity) {
return;
}
const canvasManager = $canvasManager.get();
if (!canvasManager) {
return;
}
const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.cancelFilter();
}, [filteringEntity]);
return (
<Flex flexDir="column" gap={3} w="full" h="full">
<FilterTypeSelect />
<ButtonGroup isAttached={false}>
<Button onClick={preview} isDisabled={!filteringEntity}>
Preview
</Button>
<Button onClick={apply} isDisabled={!filteringEntity}>
Apply
</Button>
<Button onClick={cancel} isDisabled={!filteringEntity}>
Cancel
</Button>
</ButtonGroup>
<FilterSettings />
</Flex>
);
});
Filter.displayName = 'Filter';

View File

@ -0,0 +1,67 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { CannyProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<CannyProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['canny_image_processor'].buildDefaults();
export const FilterCanny = ({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleLowThresholdChanged = useCallback(
(v: number) => {
onChange({ ...config, low_threshold: v });
},
[onChange, config]
);
const handleHighThresholdChanged = useCallback(
(v: number) => {
onChange({ ...config, high_threshold: v });
},
[onChange, config]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.lowThreshold')}</FormLabel>
<CompositeSlider
value={config.low_threshold}
onChange={handleLowThresholdChanged}
defaultValue={DEFAULTS.low_threshold}
min={0}
max={255}
/>
<CompositeNumberInput
value={config.low_threshold}
onChange={handleLowThresholdChanged}
defaultValue={DEFAULTS.low_threshold}
min={0}
max={255}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.highThreshold')}</FormLabel>
<CompositeSlider
value={config.high_threshold}
onChange={handleHighThresholdChanged}
defaultValue={DEFAULTS.high_threshold}
min={0}
max={255}
/>
<CompositeNumberInput
value={config.high_threshold}
onChange={handleHighThresholdChanged}
defaultValue={DEFAULTS.high_threshold}
min={0}
max={255}
/>
</FormControl>
</>
);
};
FilterCanny.displayName = 'FilterCanny';

View File

@ -0,0 +1,47 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ColorMapProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<ColorMapProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['color_map_image_processor'].buildDefaults();
export const FilterColorMap = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleColorMapTileSizeChanged = useCallback(
(v: number) => {
onChange({ ...config, color_map_tile_size: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.colorMapTileSize')}</FormLabel>
<CompositeSlider
value={config.color_map_tile_size}
defaultValue={DEFAULTS.color_map_tile_size}
onChange={handleColorMapTileSizeChanged}
min={1}
max={256}
step={1}
marks
/>
<CompositeNumberInput
value={config.color_map_tile_size}
defaultValue={DEFAULTS.color_map_tile_size}
onChange={handleColorMapTileSizeChanged}
min={1}
max={4096}
step={1}
/>
</FormControl>
</>
);
});
FilterColorMap.displayName = 'FilterColorMap';

View File

@ -0,0 +1,78 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ContentShuffleProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<ContentShuffleProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['content_shuffle_image_processor'].buildDefaults();
export const FilterContentShuffle = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleWChanged = useCallback(
(v: number) => {
onChange({ ...config, w: v });
},
[config, onChange]
);
const handleHChanged = useCallback(
(v: number) => {
onChange({ ...config, h: v });
},
[config, onChange]
);
const handleFChanged = useCallback(
(v: number) => {
onChange({ ...config, f: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.w')}</FormLabel>
<CompositeSlider
value={config.w}
defaultValue={DEFAULTS.w}
onChange={handleWChanged}
min={0}
max={4096}
marks
/>
<CompositeNumberInput value={config.w} defaultValue={DEFAULTS.w} onChange={handleWChanged} min={0} max={4096} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.h')}</FormLabel>
<CompositeSlider
value={config.h}
defaultValue={DEFAULTS.h}
onChange={handleHChanged}
min={0}
max={4096}
marks
/>
<CompositeNumberInput value={config.h} defaultValue={DEFAULTS.h} onChange={handleHChanged} min={0} max={4096} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.f')}</FormLabel>
<CompositeSlider
value={config.f}
defaultValue={DEFAULTS.f}
onChange={handleFChanged}
min={0}
max={4096}
marks
/>
<CompositeNumberInput value={config.f} defaultValue={DEFAULTS.f} onChange={handleFChanged} min={0} max={4096} />
</FormControl>
</>
);
});
FilterContentShuffle.displayName = 'FilterContentShuffle';

View File

@ -0,0 +1,61 @@
import { Flex, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { DWOpenposeProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<DWOpenposeProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['dw_openpose_image_processor'].buildDefaults();
export const FilterDWOpenpose = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleDrawBodyChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, draw_body: e.target.checked });
},
[config, onChange]
);
const handleDrawFaceChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, draw_face: e.target.checked });
},
[config, onChange]
);
const handleDrawHandsChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, draw_hands: e.target.checked });
},
[config, onChange]
);
return (
<>
<Flex sx={{ flexDir: 'row', gap: 6 }}>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlnet.body')}</FormLabel>
<Switch defaultChecked={DEFAULTS.draw_body} isChecked={config.draw_body} onChange={handleDrawBodyChanged} />
</FormControl>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlnet.face')}</FormLabel>
<Switch defaultChecked={DEFAULTS.draw_face} isChecked={config.draw_face} onChange={handleDrawFaceChanged} />
</FormControl>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlnet.hands')}</FormLabel>
<Switch
defaultChecked={DEFAULTS.draw_hands}
isChecked={config.draw_hands}
onChange={handleDrawHandsChanged}
/>
</FormControl>
</Flex>
</>
);
});
FilterDWOpenpose.displayName = 'FilterDWOpenpose';

View File

@ -0,0 +1,46 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { DepthAnythingModelSize, DepthAnythingProcessorConfig } from 'features/controlLayers/store/types';
import { isDepthAnythingModelSize } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<DepthAnythingProcessorConfig>;
export const FilterDepthAnything = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleModelSizeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isDepthAnythingModelSize(v?.value)) {
return;
}
onChange({ ...config, model_size: v.value });
},
[config, onChange]
);
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
() => [
{ label: t('controlnet.depthAnythingSmallV2'), value: 'small_v2' },
{ label: t('controlnet.small'), value: 'small' },
{ label: t('controlnet.base'), value: 'base' },
{ label: t('controlnet.large'), value: 'large' },
],
[t]
);
const value = useMemo(() => options.filter((o) => o.value === config.model_size)[0], [options, config.model_size]);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.modelSize')}</FormLabel>
<Combobox value={value} options={options} onChange={handleModelSizeChange} isSearchable={false} />
</FormControl>
</>
);
});
FilterDepthAnything.displayName = 'FilterDepthAnything';

View File

@ -0,0 +1,31 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { HedProcessorConfig } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<HedProcessorConfig>;
export const FilterHed = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleScribbleChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, scribble: e.target.checked });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.scribble')}</FormLabel>
<Switch isChecked={config.scribble} onChange={handleScribbleChanged} />
</FormControl>
</>
);
});
FilterHed.displayName = 'FilterHed';

View File

@ -0,0 +1,31 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { LineartProcessorConfig } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<LineartProcessorConfig>;
export const FilterLineart = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleCoarseChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, coarse: e.target.checked });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.coarse')}</FormLabel>
<Switch isChecked={config.coarse} onChange={handleCoarseChanged} />
</FormControl>
</>
);
});
FilterLineart.displayName = 'FilterLineart';

View File

@ -0,0 +1,73 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { MediapipeFaceProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<MediapipeFaceProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['mediapipe_face_processor'].buildDefaults();
export const FilterMediapipeFace = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleMaxFacesChanged = useCallback(
(v: number) => {
onChange({ ...config, max_faces: v });
},
[config, onChange]
);
const handleMinConfidenceChanged = useCallback(
(v: number) => {
onChange({ ...config, min_confidence: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.maxFaces')}</FormLabel>
<CompositeSlider
value={config.max_faces}
onChange={handleMaxFacesChanged}
defaultValue={DEFAULTS.max_faces}
min={1}
max={20}
marks
/>
<CompositeNumberInput
value={config.max_faces}
onChange={handleMaxFacesChanged}
defaultValue={DEFAULTS.max_faces}
min={1}
max={20}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.minConfidence')}</FormLabel>
<CompositeSlider
value={config.min_confidence}
onChange={handleMinConfidenceChanged}
defaultValue={DEFAULTS.min_confidence}
min={0}
max={1}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.min_confidence}
onChange={handleMinConfidenceChanged}
defaultValue={DEFAULTS.min_confidence}
min={0}
max={1}
step={0.01}
/>
</FormControl>
</>
);
});
FilterMediapipeFace.displayName = 'FilterMediapipeFace';

View File

@ -0,0 +1,75 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { MidasDepthProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<MidasDepthProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['midas_depth_image_processor'].buildDefaults();
export const FilterMidasDepth = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleAMultChanged = useCallback(
(v: number) => {
onChange({ ...config, a_mult: v });
},
[config, onChange]
);
const handleBgThChanged = useCallback(
(v: number) => {
onChange({ ...config, bg_th: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.amult')}</FormLabel>
<CompositeSlider
value={config.a_mult}
onChange={handleAMultChanged}
defaultValue={DEFAULTS.a_mult}
min={0}
max={20}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.a_mult}
onChange={handleAMultChanged}
defaultValue={DEFAULTS.a_mult}
min={0}
max={20}
step={0.01}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.bgth')}</FormLabel>
<CompositeSlider
value={config.bg_th}
onChange={handleBgThChanged}
defaultValue={DEFAULTS.bg_th}
min={0}
max={20}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.bg_th}
onChange={handleBgThChanged}
defaultValue={DEFAULTS.bg_th}
min={0}
max={20}
step={0.01}
/>
</FormControl>
</>
);
});
FilterMidasDepth.displayName = 'FilterMidasDepth';

View File

@ -0,0 +1,75 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { MlsdProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<MlsdProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['mlsd_image_processor'].buildDefaults();
export const FilterMlsdImage = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleThrDChanged = useCallback(
(v: number) => {
onChange({ ...config, thr_d: v });
},
[config, onChange]
);
const handleThrVChanged = useCallback(
(v: number) => {
onChange({ ...config, thr_v: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.w')} </FormLabel>
<CompositeSlider
value={config.thr_d}
onChange={handleThrDChanged}
defaultValue={DEFAULTS.thr_d}
min={0}
max={1}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.thr_d}
onChange={handleThrDChanged}
defaultValue={DEFAULTS.thr_d}
min={0}
max={1}
step={0.01}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.h')} </FormLabel>
<CompositeSlider
value={config.thr_v}
onChange={handleThrVChanged}
defaultValue={DEFAULTS.thr_v}
min={0}
max={1}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.thr_v}
onChange={handleThrVChanged}
defaultValue={DEFAULTS.thr_v}
min={0}
max={1}
step={0.01}
/>
</FormControl>
</>
);
});
FilterMlsdImage.displayName = 'FilterMlsdImage';

View File

@ -0,0 +1,42 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { PidiProcessorConfig } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<PidiProcessorConfig>;
export const FilterPidi = ({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleScribbleChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, scribble: e.target.checked });
},
[config, onChange]
);
const handleSafeChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, safe: e.target.checked });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.scribble')}</FormLabel>
<Switch isChecked={config.scribble} onChange={handleScribbleChanged} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.safe')}</FormLabel>
<Switch isChecked={config.safe} onChange={handleSafeChanged} />
</FormControl>
</>
);
};
FilterPidi.displayName = 'FilterPidi';

View File

@ -0,0 +1,77 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { FilterCanny } from 'features/controlLayers/components/Filters/FilterCanny';
import { FilterColorMap } from 'features/controlLayers/components/Filters/FilterColorMap';
import { FilterContentShuffle } from 'features/controlLayers/components/Filters/FilterContentShuffle';
import { FilterDepthAnything } from 'features/controlLayers/components/Filters/FilterDepthAnything';
import { FilterDWOpenpose } from 'features/controlLayers/components/Filters/FilterDWOpenpose';
import { FilterHed } from 'features/controlLayers/components/Filters/FilterHed';
import { FilterLineart } from 'features/controlLayers/components/Filters/FilterLineart';
import { FilterMediapipeFace } from 'features/controlLayers/components/Filters/FilterMediapipeFace';
import { FilterMidasDepth } from 'features/controlLayers/components/Filters/FilterMidasDepth';
import { FilterMlsdImage } from 'features/controlLayers/components/Filters/FilterMlsdImage';
import { FilterPidi } from 'features/controlLayers/components/Filters/FilterPidi';
import { filterConfigChanged } from 'features/controlLayers/store/canvasV2Slice';
import { type FilterConfig, IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const FilterSettings = memo(() => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const config = useAppSelector((s) => s.canvasV2.filter.config);
const updateFilter = useCallback(
(config: FilterConfig) => {
dispatch(filterConfigChanged({ config }));
},
[dispatch]
);
if (config.type === 'canny_image_processor') {
return <FilterCanny config={config} onChange={updateFilter} />;
}
if (config.type === 'color_map_image_processor') {
return <FilterColorMap config={config} onChange={updateFilter} />;
}
if (config.type === 'content_shuffle_image_processor') {
return <FilterContentShuffle config={config} onChange={updateFilter} />;
}
if (config.type === 'depth_anything_image_processor') {
return <FilterDepthAnything config={config} onChange={updateFilter} />;
}
if (config.type === 'dw_openpose_image_processor') {
return <FilterDWOpenpose config={config} onChange={updateFilter} />;
}
if (config.type === 'hed_image_processor') {
return <FilterHed config={config} onChange={updateFilter} />;
}
if (config.type === 'lineart_image_processor') {
return <FilterLineart config={config} onChange={updateFilter} />;
}
if (config.type === 'mediapipe_face_processor') {
return <FilterMediapipeFace config={config} onChange={updateFilter} />;
}
if (config.type === 'midas_depth_image_processor') {
return <FilterMidasDepth config={config} onChange={updateFilter} />;
}
if (config.type === 'mlsd_image_processor') {
return <FilterMlsdImage config={config} onChange={updateFilter} />;
}
if (config.type === 'pidi_image_processor') {
return <FilterPidi config={config} onChange={updateFilter} />;
}
return <IAINoContentFallback label={`${t(IMAGE_FILTERS[config.type].labelTKey)} has no settings`} icon={null} />;
});
FilterSettings.displayName = 'Filter';

View File

@ -0,0 +1,54 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { filterSelected } from 'features/controlLayers/store/canvasV2Slice';
import { IMAGE_FILTERS, isFilterType } from 'features/controlLayers/store/types';
import { configSelector } from 'features/system/store/configSelectors';
import { includes, map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { assert } from 'tsafe';
const selectDisabledProcessors = createMemoizedSelector(
configSelector,
(config) => config.sd.disabledControlNetProcessors
);
export const FilterTypeSelect = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const filterType = useAppSelector((s) => s.canvasV2.filter.config.type);
const disabledProcessors = useAppSelector(selectDisabledProcessors);
const options = useMemo(() => {
return map(IMAGE_FILTERS, ({ labelTKey }, type) => ({ value: type, label: t(labelTKey) })).filter(
(o) => !includes(disabledProcessors, o.value)
);
}, [disabledProcessors, t]);
const _onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v) {
return;
}
assert(isFilterType(v.value));
dispatch(filterSelected({ type: v.value }));
},
[dispatch]
);
const value = useMemo(() => options.find((o) => o.value === filterType) ?? null, [options, filterType]);
return (
<Flex gap={2}>
<FormControl>
<InformationalPopover feature="controlNetProcessor">
<FormLabel m={0}>{t('controlLayers.filter')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={options} onChange={_onChange} isSearchable={false} isClearable={false} />
</FormControl>
</Flex>
);
});
FilterTypeSelect.displayName = 'FilterTypeSelect';

View File

@ -0,0 +1,17 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useFilter } from 'features/controlLayers/components/Filters/Filter';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
const FilterWrapper = (props: PropsWithChildren) => {
const isPreviewDisabled = useAppSelector((s) => s.canvasV2.selectedEntityIdentifier?.type !== 'layer');
const filter = useFilter();
return (
<Flex flexDir="column" gap={3} w="full" h="full">
{props.children}
</Flex>
);
};
export default memo(FilterWrapper);

View File

@ -0,0 +1,6 @@
import type { FilterConfig } from 'features/controlLayers/store/types';
export type FilterComponentProps<T extends FilterConfig> = {
onChange: (config: T) => void;
config: T;
};

View File

@ -1,4 +1,4 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library';
import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
@ -18,12 +18,11 @@ type Props = {
export const Layer = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'layer' }), [id]);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: false });
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={onToggle}>
<CanvasEntityHeader>
<CanvasEntityEnabledToggle />
<CanvasEntityTitle />
<Spacer />
@ -31,7 +30,7 @@ export const Layer = memo(({ id }: Props) => {
<LayerActionsMenu />
<CanvasEntityDeleteButton />
</CanvasEntityHeader>
{isOpen && <LayerSettings />}
<LayerSettings />
</CanvasEntityContainer>
</EntityIdentifierContext.Provider>
);

View File

@ -0,0 +1,66 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { ControlAdapterControlModeSelect } from 'features/controlLayers/components/ControlAdapter/ControlAdapterControlModeSelect';
import { ControlAdapterModel } from 'features/controlLayers/components/ControlAdapter/ControlAdapterModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import {
layerControlAdapterBeginEndStepPctChanged,
layerControlAdapterControlModeChanged,
layerControlAdapterModelChanged,
layerControlAdapterWeightChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import type { ControlModeV2, ControlNetConfig, T2IAdapterConfig } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
type Props = {
controlAdapter: ControlNetConfig | T2IAdapterConfig;
};
export const LayerControlAdapter = memo(({ controlAdapter }: Props) => {
const dispatch = useAppDispatch();
const { id } = useEntityIdentifierContext();
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
dispatch(layerControlAdapterBeginEndStepPctChanged({ id, beginEndStepPct }));
},
[dispatch, id]
);
const onChangeControlMode = useCallback(
(controlMode: ControlModeV2) => {
dispatch(layerControlAdapterControlModeChanged({ id, controlMode }));
},
[dispatch, id]
);
const onChangeWeight = useCallback(
(weight: number) => {
dispatch(layerControlAdapterWeightChanged({ id, weight }));
},
[dispatch, id]
);
const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
dispatch(layerControlAdapterModelChanged({ id, modelConfig }));
},
[dispatch, id]
);
return (
<Flex flexDir="column" gap={3} position="relative" w="full">
<ControlAdapterModel modelKey={controlAdapter.model?.key ?? null} onChange={onChangeModel} />
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
{controlAdapter.type === 'controlnet' && (
<ControlAdapterControlModeSelect controlMode={controlAdapter.controlMode} onChange={onChangeControlMode} />
)}
</Flex>
);
});
LayerControlAdapter.displayName = 'LayerControlAdapter';

View File

@ -1,10 +1,22 @@
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings';
import { LayerControlAdapter } from 'features/controlLayers/components/Layer/LayerControlAdapter';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useLayerControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { memo } from 'react';
export const LayerSettings = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
return <CanvasEntitySettings>PLACEHOLDER</CanvasEntitySettings>;
const controlAdapter = useLayerControlAdapter(entityIdentifier);
if (!controlAdapter) {
return null;
}
return (
<CanvasEntitySettings>
<LayerControlAdapter controlAdapter={controlAdapter} />
</CanvasEntitySettings>
);
});
LayerSettings.displayName = 'LayerSettings';

View File

@ -1,6 +1,8 @@
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $socket } from 'app/hooks/useSocketIO';
import { logger } from 'app/logging/logger';
import { useAppStore } from 'app/store/storeHooks';
import { useAppStore } from 'app/store/nanostores/store';
import { HeadsUpDisplay } from 'features/controlLayers/components/HeadsUpDisplay';
import { $canvasManager, CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import Konva from 'konva';
@ -17,6 +19,7 @@ Konva.showWarnings = false;
const useStageRenderer = (stage: Konva.Stage, container: HTMLDivElement | null, asPreview: boolean) => {
const store = useAppStore();
const socket = useStore($socket);
const dpr = useDevicePixelRatio({ round: false });
useLayoutEffect(() => {
@ -27,12 +30,17 @@ const useStageRenderer = (stage: Konva.Stage, container: HTMLDivElement | null,
return () => {};
}
const manager = new CanvasManager(stage, container, store);
if (!socket) {
log.debug('Socket not connected, skipping initialization');
return () => {};
}
const manager = new CanvasManager(stage, container, store, socket);
$canvasManager.set(manager);
console.log(manager);
const cleanup = manager.initialize();
return cleanup;
}, [asPreview, container, stage, store]);
}, [asPreview, container, socket, stage, store]);
useLayoutEffect(() => {
Konva.pixelRatio = dpr;

View File

@ -2,7 +2,8 @@ import { Button, IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { memo, useCallback, useEffect, useState } from 'react';
import { $transformingEntity } from 'features/controlLayers/store/canvasV2Slice';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { PiResizeBold } from 'react-icons/pi';
@ -10,20 +11,11 @@ import { PiResizeBold } from 'react-icons/pi';
export const TransformToolButton = memo(() => {
const { t } = useTranslation();
const canvasManager = useStore($canvasManager);
const [isTransforming, setIsTransforming] = useState(false);
const transformingEntity = useStore($transformingEntity);
const isDisabled = useAppSelector(
(s) => s.canvasV2.selectedEntityIdentifier === null || s.canvasV2.session.isStaging
);
useEffect(() => {
if (!canvasManager) {
return;
}
return canvasManager.stateApi.$transformingEntity.listen((newValue) => {
setIsTransforming(Boolean(newValue));
});
}, [canvasManager]);
const onTransform = useCallback(() => {
if (!canvasManager) {
return;
@ -47,7 +39,7 @@ export const TransformToolButton = memo(() => {
useHotkeys(['ctrl+t', 'meta+t'], onTransform, { enabled: !isDisabled }, [isDisabled, onTransform]);
if (isTransforming) {
if (transformingEntity) {
return (
<>
<Button onClick={onApplyTransformation}>{t('common.apply')}</Button>

View File

@ -1,8 +1,12 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { MenuDivider, MenuItem } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useLayerUseAsControl } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import {
$filteringEntity,
entityArrangedBackwardOne,
entityArrangedForwardOne,
entityArrangedToBack,
@ -20,7 +24,11 @@ import {
PiArrowLineDownBold,
PiArrowLineUpBold,
PiArrowUpBold,
PiCheckBold,
PiQuestionMarkBold,
PiStarHalfBold,
PiTrashSimpleBold,
PiXBold,
} from 'react-icons/pi';
const getIndexAndCount = (
@ -52,18 +60,15 @@ const getIndexAndCount = (
export const CanvasEntityActionMenuItems = memo(() => {
const { t } = useTranslation();
const canvasManager = useStore($canvasManager);
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext();
const useAsControl = useLayerUseAsControl(entityIdentifier);
const selectValidActions = useMemo(
() =>
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
const { index, count } = getIndexAndCount(canvasV2, entityIdentifier);
return {
isArrangeable:
entityIdentifier.type === 'layer' ||
entityIdentifier.type === 'control_adapter' ||
entityIdentifier.type === 'regional_guidance',
isDeleteable: entityIdentifier.type !== 'inpaint_mask',
canMoveForwardOne: index < count - 1,
canMoveBackwardOne: index > 0,
canMoveToFront: index < count - 1,
@ -75,6 +80,18 @@ export const CanvasEntityActionMenuItems = memo(() => {
const validActions = useAppSelector(selectValidActions);
const isArrangeable = useMemo(
() => entityIdentifier.type === 'layer' || entityIdentifier.type === 'regional_guidance',
[entityIdentifier.type]
);
const isDeleteable = useMemo(
() => entityIdentifier.type === 'layer' || entityIdentifier.type === 'regional_guidance',
[entityIdentifier.type]
);
const isFilterable = useMemo(() => entityIdentifier.type === 'layer', [entityIdentifier.type]);
const isUseAsControlable = useMemo(() => entityIdentifier.type === 'layer', [entityIdentifier.type]);
const deleteEntity = useCallback(() => {
dispatch(entityDeleted({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
@ -93,10 +110,23 @@ export const CanvasEntityActionMenuItems = memo(() => {
const moveToBack = useCallback(() => {
dispatch(entityArrangedToBack({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
const filter = useCallback(() => {
$filteringEntity.set(entityIdentifier);
}, [entityIdentifier]);
const debug = useCallback(() => {
if (!canvasManager) {
return;
}
const entity = canvasManager.stateApi.getEntity(entityIdentifier);
if (!entity) {
return;
}
console.debug(entity);
}, [canvasManager, entityIdentifier]);
return (
<>
{validActions.isArrangeable && (
{isArrangeable && (
<>
<MenuItem onClick={moveToFront} isDisabled={!validActions.canMoveToFront} icon={<PiArrowLineUpBold />}>
{t('controlLayers.moveToFront')}
@ -112,14 +142,29 @@ export const CanvasEntityActionMenuItems = memo(() => {
</MenuItem>
</>
)}
{isFilterable && (
<MenuItem onClick={filter} icon={<PiStarHalfBold />}>
{t('common.filter')}
</MenuItem>
)}
{isUseAsControlable && (
<MenuItem onClick={useAsControl.toggle} icon={useAsControl.hasControlAdapter ? <PiXBold /> : <PiCheckBold />}>
{useAsControl.hasControlAdapter ? t('common.removeControl') : t('common.useAsControl')}
</MenuItem>
)}
<MenuDivider />
<MenuItem onClick={resetEntity} icon={<PiArrowCounterClockwiseBold />}>
{t('accessibility.reset')}
</MenuItem>
{validActions.isDeleteable && (
{isDeleteable && (
<MenuItem onClick={deleteEntity} icon={<PiTrashSimpleBold />} color="error.300">
{t('common.delete')}
</MenuItem>
)}
<MenuDivider />
<MenuItem onClick={debug} icon={<PiQuestionMarkBold />} color="warn.300">
{t('common.debug')}
</MenuItem>
</>
);
});

View File

@ -11,6 +11,7 @@ import { PiCheckBold } from 'react-icons/pi';
export const CanvasEntityEnabledToggle = memo(() => {
const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext();
const isEnabled = useEntityIsEnabled(entityIdentifier);
const dispatch = useAppDispatch();
const onClick = useCallback(() => {

View File

@ -2,11 +2,11 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { caAdded, ipaAdded, rgIPAdapterAdded } from 'features/controlLayers/store/canvasV2Slice';
import {
CA_PROCESSOR_DATA,
IMAGE_FILTERS,
initialControlNetV2,
initialIPAdapterV2,
initialT2IAdapterV2,
isProcessorTypeV2,
isFilterType,
} from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback, useMemo } from 'react';
@ -30,8 +30,8 @@ export const useAddCALayer = () => {
}
const defaultPreprocessor = model.default_settings?.preprocessor;
const processorConfig = isProcessorTypeV2(defaultPreprocessor)
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults(baseModel)
const processorConfig = isFilterType(defaultPreprocessor)
? IMAGE_FILTERS[defaultPreprocessor].buildDefaults(baseModel)
: null;
const initialConfig = deepClone(model.type === 'controlnet' ? initialControlNetV2 : initialT2IAdapterV2);

View File

@ -0,0 +1,8 @@
import { useStore } from '@nanostores/react';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
export const useEntityAdapter = (entityIdentifier: CanvasEntityIdentifier) => {
const canvasManager = useStore($canvasManager);
console.log(canvasManager);
};

View File

@ -0,0 +1,57 @@
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { layerUsedAsControlChanged, selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { selectLayer } from 'features/controlLayers/store/layersReducers';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { initialControlNetV2, initialT2IAdapterV2 } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback, useMemo } from 'react';
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export const useLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier) => {
const selectControlAdapter = useMemo(
() =>
createMemoizedAppSelector(selectCanvasV2Slice, (canvasV2) => {
const layer = selectLayer(canvasV2, entityIdentifier.id);
if (!layer) {
return null;
}
return layer.controlAdapter;
}),
[entityIdentifier]
);
const controlAdapter = useAppSelector(selectControlAdapter);
return controlAdapter;
};
export const useLayerUseAsControl = (entityIdentifier: CanvasEntityIdentifier) => {
const dispatch = useAppDispatch();
const [modelConfigs] = useControlNetAndT2IAdapterModels();
const baseModel = useAppSelector((s) => s.canvasV2.params.model?.base);
const controlAdapter = useLayerControlAdapter(entityIdentifier);
const model: ControlNetModelConfig | T2IAdapterModelConfig | null = useMemo(() => {
// prefer to use a model that matches the base model
const compatibleModels = modelConfigs.filter((m) => (baseModel ? m.base === baseModel : true));
return compatibleModels[0] ?? modelConfigs[0] ?? null;
}, [baseModel, modelConfigs]);
const toggle = useCallback(() => {
if (controlAdapter) {
dispatch(layerUsedAsControlChanged({ id: entityIdentifier.id, controlAdapter: null }));
return;
}
const newControlAdapter = deepClone(model?.type === 't2i_adapter' ? initialT2IAdapterV2 : initialControlNetV2);
if (model) {
newControlAdapter.model = zModelIdentifierField.parse(model);
}
dispatch(layerUsedAsControlChanged({ id: entityIdentifier.id, controlAdapter: newControlAdapter }));
}, [controlAdapter, dispatch, entityIdentifier.id, model]);
return { hasControlAdapter: Boolean(controlAdapter), toggle };
};

View File

@ -0,0 +1,132 @@
import type { JSONObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import type { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasImageState } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS, imageDTOToImageObject } from 'features/controlLayers/store/types';
import type { Logger } from 'roarr';
import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import type { InvocationCompleteEvent } from 'services/events/types';
import { assert } from 'tsafe';
const TYPE = 'entity_filter_preview';
export class CanvasFilter {
readonly type = TYPE;
id: string;
path: string[];
parent: CanvasLayerAdapter;
manager: CanvasManager;
log: Logger;
imageState: CanvasImageState | null = null;
constructor(parent: CanvasLayerAdapter) {
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = parent.manager;
this.path = this.parent.path.concat(this.id);
this.log = this.manager.buildLogger(this.getLoggingContext);
this.log.trace('Creating filter');
}
previewFilter = async () => {
const { config } = this.manager.stateApi.getFilterState();
this.log.trace({ config }, 'Previewing filter');
const dispatch = this.manager.stateApi._store.dispatch;
const imageDTO = await this.parent.renderer.rasterize();
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const filterNode = IMAGE_FILTERS[config.type].buildNode(imageDTO, config as never);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[filterNode.id]: {
...filterNode,
// Control images are always intermediate - do not save to gallery
// is_intermediate: true,
is_intermediate: false, // false for testing
},
},
edges: [],
},
origin: this.id,
runs: 1,
},
};
// Listen for the filter processing completion event
const listener = async (event: InvocationCompleteEvent) => {
if (event.origin !== this.id || event.invocation_source_id !== filterNode.id) {
return;
}
this.log.trace({ event: parseify(event) }, 'Handling filter processing completion');
const { result } = event;
assert(result.type === 'image_output', `Processor did not return an image output, got: ${result}`);
const imageDTO = await getImageDTO(result.image.image_name);
assert(imageDTO, "Failed to fetch processor output's image DTO");
this.imageState = imageDTOToImageObject(imageDTO);
this.parent.renderer.clearBuffer();
await this.parent.renderer.setBuffer(this.imageState);
this.parent.renderer.hideObjects([this.imageState.id]);
this.manager.socket.off('invocation_complete', listener);
};
this.manager.socket.on('invocation_complete', listener);
this.log.trace({ enqueueBatchArg: parseify(enqueueBatchArg) }, 'Enqueuing filter batch');
dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
};
applyFilter = () => {
this.log.trace('Applying filter');
if (!this.imageState) {
this.log.warn('No image state to apply filter to');
return;
}
this.parent.renderer.commitBuffer();
const rect = this.parent.transformer.getRelativeRect();
this.manager.stateApi.rasterizeEntity({
entityIdentifier: this.parent.getEntityIdentifier(),
imageObject: this.imageState,
position: { x: Math.round(rect.x), y: Math.round(rect.y) },
});
this.parent.renderer.showObjects();
this.manager.stateApi.$filteringEntity.set(null);
this.imageState = null;
};
cancelFilter = () => {
this.log.trace('Cancelling filter');
this.parent.renderer.clearBuffer();
this.parent.renderer.showObjects();
this.manager.stateApi.$filteringEntity.set(null);
this.imageState = null;
};
destroy = () => {
this.log.trace('Destroying filter');
};
repr = () => {
return {
id: this.id,
type: this.type,
};
};
getLoggingContext = (): JSONObject => {
return { ...this.parent.getLoggingContext(), path: this.path.join('.') };
};
}

View File

@ -1,5 +1,6 @@
import type { JSONObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { CanvasFilter } from 'features/controlLayers/konva/CanvasFilter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasObjectRenderer } from 'features/controlLayers/konva/CanvasObjectRenderer';
import { CanvasTransformer } from 'features/controlLayers/konva/CanvasTransformer';
@ -23,6 +24,7 @@ export class CanvasLayerAdapter {
};
transformer: CanvasTransformer;
renderer: CanvasObjectRenderer;
filter: CanvasFilter;
isFirstRender: boolean = true;
@ -47,6 +49,7 @@ export class CanvasLayerAdapter {
this.renderer = new CanvasObjectRenderer(this);
this.transformer = new CanvasTransformer(this);
this.filter = new CanvasFilter(this);
}
/**

View File

@ -1,6 +1,6 @@
import type { Store } from '@reduxjs/toolkit';
import type { AppSocket } from 'app/hooks/useSocketIO';
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { AppStore } from 'app/store/store';
import type { JSONObject } from 'common/types';
import { MAX_CANVAS_SCALE, MIN_CANVAS_SCALE } from 'features/controlLayers/konva/constants';
import {
@ -12,7 +12,7 @@ import {
} from 'features/controlLayers/konva/util';
import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker';
import type { CanvasV2State, Coordinate, Dimensions, GenerationMode, Rect } from 'features/controlLayers/store/types';
import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers';
import { isValidLayerWithoutControlAdapter } from 'features/nodes/util/graph/generation/addLayers';
import type Konva from 'konva';
import { clamp } from 'lodash-es';
import { atom } from 'nanostores';
@ -49,8 +49,9 @@ export class CanvasManager {
log: Logger;
workerLog: Logger;
socket: AppSocket;
_store: Store<RootState>;
_store: AppStore;
_prevState: CanvasV2State;
_isFirstRender: boolean = true;
_isDebugging: boolean = false;
@ -58,12 +59,13 @@ export class CanvasManager {
_worker: Worker = new Worker(new URL('./worker.ts', import.meta.url), { type: 'module', name: 'worker' });
_tasks: Map<string, { task: GetBboxTask; onComplete: (extents: Extents | null) => void }> = new Map();
constructor(stage: Konva.Stage, container: HTMLDivElement, store: Store<RootState>) {
constructor(stage: Konva.Stage, container: HTMLDivElement, store: AppStore, socket: AppSocket) {
this.id = getPrefixedId(this.type);
this.path = [this.id];
this.stage = stage;
this.container = container;
this._store = store;
this.socket = socket;
this.stateApi = new CanvasStateApi(this._store, this);
this._prevState = this.stateApi.getState();
@ -547,7 +549,7 @@ export class CanvasManager {
stageClone.x(0);
stageClone.y(0);
const validLayers = layersState.entities.filter(isValidLayer);
const validLayers = layersState.entities.filter(isValidLayerWithoutControlAdapter);
// getLayers() returns the internal `children` array of the stage directly - calling destroy on a layer will
// mutate that array. We need to clone the array to avoid mutating the original.
for (const konvaLayer of stageClone.getLayers().slice()) {

View File

@ -20,7 +20,7 @@ import type {
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import Konva from 'konva';
import type { Logger } from 'roarr';
import { uploadImage } from 'services/api/endpoints/images';
import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
import type { ImageCategory, ImageDTO } from 'services/api/types';
import { assert } from 'tsafe';
@ -62,6 +62,12 @@ export class CanvasObjectRenderer {
*/
renderers: Map<string, AnyObjectRenderer> = new Map();
/**
* A cache of the rasterized image data URL. If the cache is null, the parent has not been rasterized since its last
* change.
*/
rasterizedImageCache: string | null = null;
/**
* A object containing singleton Konva nodes.
*/
@ -162,6 +168,19 @@ export class CanvasObjectRenderer {
didRender = (await this.renderObject(this.buffer)) || didRender;
}
if (didRender && this.rasterizedImageCache) {
const hasOneObject = this.renderers.size === 1;
const firstObject = Array.from(this.renderers.values())[0];
if (
hasOneObject &&
firstObject &&
firstObject.state.type === 'image' &&
firstObject.state.image.image_name !== this.rasterizedImageCache
) {
this.rasterizedImageCache = null;
}
}
return didRender;
};
@ -313,6 +332,18 @@ export class CanvasObjectRenderer {
this.buffer = null;
};
hideObjects = (except: string[] = []) => {
for (const renderer of this.renderers.values()) {
renderer.setVisibility(except.includes(renderer.id));
}
};
showObjects = (except: string[] = []) => {
for (const renderer of this.renderers.values()) {
renderer.setVisibility(!except.includes(renderer.id));
}
};
/**
* Determines if the objects in the renderer require a pixel bbox calculation.
*
@ -345,15 +376,33 @@ export class CanvasObjectRenderer {
return this.renderers.size > 0 || this.buffer !== null;
};
rasterize = async () => {
/**
* Rasterizes the parent entity. If the entity has a rasterization cache, the cached image is returned after
* validating that it exists on the server.
*
* The rasterization cache is reset when the entity's objects change. The buffer object is not considered part of the
* entity's objects for this purpose.
*
* @returns A promise that resolves to the rasterized image DTO.
*/
rasterize = async (): Promise<ImageDTO> => {
this.log.debug('Rasterizing entity');
let imageDTO: ImageDTO | null = null;
if (this.rasterizedImageCache) {
imageDTO = await getImageDTO(this.rasterizedImageCache);
}
if (imageDTO) {
return imageDTO;
}
const rect = this.parent.transformer.getRelativeRect();
const blob = await this.getBlob({ rect });
if (this.manager._isDebugging) {
previewBlob(blob, 'Rasterized entity');
}
const imageDTO = await uploadImage(blob, `${this.id}_rasterized.png`, 'other', true);
imageDTO = await uploadImage(blob, `${this.id}_rasterized.png`, 'other', true);
const imageObject = imageDTOToImageObject(imageDTO);
await this.renderObject(imageObject, true);
this.manager.stateApi.rasterizeEntity({
@ -361,6 +410,10 @@ export class CanvasObjectRenderer {
imageObject,
position: { x: Math.round(rect.x), y: Math.round(rect.y) },
});
this.rasterizedImageCache = imageDTO.image_name;
return imageDTO;
};
getBlob = ({ rect }: { rect?: Rect }): Promise<Blob> => {

View File

@ -1,11 +1,11 @@
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Store } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { AppStore } from 'app/store/store';
import type { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter';
import {
$filteringEntity,
$isDrawing,
$isMouseDown,
$lastAddedPoint,
@ -15,6 +15,7 @@ import {
$shouldShowStagedImage,
$spaceKey,
$stageAttrs,
$transformingEntity,
bboxChanged,
brushWidthChanged,
entityBrushLineAdded,
@ -81,10 +82,10 @@ type EntityStateAndAdapter =
const log = logger('canvas');
export class CanvasStateApi {
_store: Store<RootState>;
_store: AppStore;
manager: CanvasManager;
constructor(store: Store<RootState>, manager: CanvasManager) {
constructor(store: AppStore, manager: CanvasManager) {
this._store = store;
this.manager = manager;
}
@ -188,6 +189,9 @@ export class CanvasStateApi {
getLogLevel = () => {
return this._store.getState().system.consoleLogLevel;
};
getFilterState = () => {
return this._store.getState().canvasV2.filter;
};
getEntity(identifier: CanvasEntityIdentifier): EntityStateAndAdapter | null {
const state = this.getState();
@ -256,7 +260,9 @@ export class CanvasStateApi {
return currentFill;
};
$transformingEntity: WritableAtom<CanvasEntityIdentifier | null> = atom();
$transformingEntity = $transformingEntity;
$filteringEntity = $filteringEntity;
$toolState: WritableAtom<CanvasV2State['tool']> = atom();
$currentFill: WritableAtom<RgbaColor> = atom();
$selectedEntity: WritableAtom<EntityStateAndAdapter | null> = atom();

View File

@ -155,7 +155,7 @@ export class CanvasTool {
const isMouseDown = this.manager.stateApi.$isMouseDown.get();
const tool = toolState.selected;
console.log(selectedEntity);
const isDrawableEntity =
selectedEntity?.state.type === 'regional_guidance' ||
selectedEntity?.state.type === 'layer' ||

View File

@ -33,9 +33,10 @@ import type {
EntityMovedPayload,
EntityRasterizedPayload,
EntityRectAddedPayload,
FilterConfig,
StageAttrs,
} from './types';
import { RGBA_RED } from './types';
import { IMAGE_FILTERS, RGBA_RED } from './types';
const initialState: CanvasV2State = {
_version: 3,
@ -133,6 +134,10 @@ const initialState: CanvasV2State = {
stagedImages: [],
selectedStagedImageIndex: 0,
},
filter: {
autoProcess: true,
config: IMAGE_FILTERS.canny_image_processor.buildDefaults(),
},
};
export function selectEntity(state: CanvasV2State, { id, type }: CanvasEntityIdentifier) {
@ -222,11 +227,12 @@ export const canvasV2Slice = createSlice({
} else if (entity.type === 'layer') {
entity.objects = [imageObject];
entity.position = position;
entity.imageCache = imageObject.image.image_name;
state.layers.imageCache = null;
} else if (entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') {
entity.objects = [imageObject];
entity.position = position;
entity.imageCache = null;
entity.imageCache = imageObject.image.image_name;
} else {
assert(false, 'Not implemented');
}
@ -354,6 +360,12 @@ export const canvasV2Slice = createSlice({
state.ipAdapters.entities = [];
state.controlAdapters.entities = [];
},
filterSelected: (state, action: PayloadAction<{ type: FilterConfig['type'] }>) => {
state.filter.config = IMAGE_FILTERS[action.payload.type].buildDefaults();
},
filterConfigChanged: (state, action: PayloadAction<{ config: FilterConfig }>) => {
state.filter.config = action.payload.config;
},
canvasReset: (state) => {
state.bbox = deepClone(initialState.bbox);
const optimalDimension = getOptimalDimension(state.params.model);
@ -415,6 +427,11 @@ export const {
layerOpacityChanged,
layerAllDeleted,
layerImageCacheChanged,
layerUsedAsControlChanged,
layerControlAdapterModelChanged,
layerControlAdapterControlModeChanged,
layerControlAdapterWeightChanged,
layerControlAdapterBeginEndStepPctChanged,
// IP Adapters
ipaAdded,
ipaRecalled,
@ -513,6 +530,9 @@ export const {
sessionStagingAreaReset,
sessionNextStagedImageSelected,
sessionPrevStagedImageSelected,
// Filter
filterSelected,
filterConfigChanged,
} = canvasV2Slice.actions;
export const selectCanvasV2Slice = (state: RootState) => state.canvasV2;
@ -539,6 +559,8 @@ export const $lastAddedPoint = atom<Coordinate | null>(null);
export const $lastMouseDownPos = atom<Coordinate | null>(null);
export const $lastCursorPos = atom<Coordinate | null>(null);
export const $spaceKey = atom<boolean>(false);
export const $transformingEntity = atom<CanvasEntityIdentifier | null>(null);
export const $filteringEntity = atom<CanvasEntityIdentifier | null>(null);
export const canvasV2PersistConfig: PersistConfig<CanvasV2State> = {
name: canvasV2Slice.name,

View File

@ -13,7 +13,7 @@ import type {
ControlModeV2,
ControlNetConfig,
Filter,
ProcessorConfig,
FilterConfig,
T2IAdapterConfig,
} from './types';
import { buildControlAdapterProcessorV2, imageDTOToImageObject } from './types';
@ -145,7 +145,7 @@ export const controlAdaptersReducers = {
}
ca.controlMode = controlMode;
},
caProcessorConfigChanged: (state, action: PayloadAction<{ id: string; processorConfig: ProcessorConfig | null }>) => {
caProcessorConfigChanged: (state, action: PayloadAction<{ id: string; processorConfig: FilterConfig | null }>) => {
const { id, processorConfig } = action.payload;
const ca = selectCA(state, id);
if (!ca) {

View File

@ -1,10 +1,11 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge } from 'lodash-es';
import type { ImageDTO } from 'services/api/types';
import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type { CanvasLayerState, CanvasV2State } from './types';
import type { CanvasLayerState, CanvasV2State, ControlModeV2, ControlNetConfig, T2IAdapterConfig } from './types';
import { imageDTOToImageWithDims } from './types';
export const selectLayer = (state: CanvasV2State, id: string) => state.layers.entities.find((layer) => layer.id === id);
@ -29,6 +30,7 @@ export const layersReducers = {
opacity: 1,
position: { x: 0, y: 0 },
imageCache: null,
controlAdapter: null,
};
merge(layer, overrides);
state.layers.entities.push(layer);
@ -64,4 +66,76 @@ export const layersReducers = {
const { imageDTO } = action.payload;
state.layers.imageCache = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
layerUsedAsControlChanged: (
state,
action: PayloadAction<{ id: string; controlAdapter: ControlNetConfig | T2IAdapterConfig | null }>
) => {
const { id, controlAdapter } = action.payload;
const layer = selectLayer(state, id);
if (!layer) {
return;
}
layer.controlAdapter = controlAdapter;
},
layerControlAdapterModelChanged: (
state,
action: PayloadAction<{
id: string;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
}>
) => {
const { id, modelConfig } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
if (!modelConfig) {
layer.controlAdapter.model = null;
return;
}
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
// We may need to convert the CA to match the model
if (layer.controlAdapter.type === 't2i_adapter' && layer.controlAdapter.model.type === 'controlnet') {
// Converting from T2I Adapter to ControlNet - add `controlMode`
const controlNetConfig: ControlNetConfig = {
...layer.controlAdapter,
type: 'controlnet',
controlMode: 'balanced',
};
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 'controlnet' && layer.controlAdapter.model.type === 't2i_adapter') {
// Converting from ControlNet to T2I Adapter - remove `controlMode`
const { controlMode: _, ...rest } = layer.controlAdapter;
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
layer.controlAdapter = t2iAdapterConfig;
}
},
layerControlAdapterControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
const { id, controlMode } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter || layer.controlAdapter.type !== 'controlnet') {
return;
}
layer.controlAdapter.controlMode = controlMode;
},
layerControlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.weight = weight;
},
layerControlAdapterBeginEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; beginEndStepPct: [number, number] }>
) => {
const { id, beginEndStepPct } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.beginEndStepPct = beginEndStepPct;
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -1,7 +1,7 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import { isEqual } from 'lodash-es';
@ -99,7 +99,7 @@ export const regionsReducers = {
if (!rg) {
return;
}
rg.imageCache = imageDTOToImageWithDims(imageDTO);
rg.imageCache = imageDTO.image_name;
},
rgAutoNegativeChanged: (state, action: PayloadAction<{ id: string; autoNegative: ParameterAutoNegative }>) => {
const { id, autoNegative } = action.payload;

View File

@ -21,14 +21,14 @@ import type {
MlsdProcessorConfig,
NormalbaeProcessorConfig,
PidiProcessorConfig,
ProcessorConfig,
ProcessorTypeV2,
FilterConfig,
FilterType,
ZoeDepthProcessorConfig,
} from './types';
describe('Control Adapter Types', () => {
test('ProcessorType', () => {
assert<Equals<ProcessorConfig['type'], ProcessorTypeV2>>();
assert<Equals<FilterConfig['type'], FilterType>>();
});
test('IP Adapter Method', () => {
assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>();

View File

@ -1,4 +1,3 @@
import type { JSONObject } from 'common/types';
import type { CanvasControlAdapter } from 'features/controlLayers/konva/CanvasControlAdapter';
import { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter';
import { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter';
@ -36,6 +35,7 @@ import type {
BaseModelType,
ControlNetModelConfig,
ImageDTO,
S,
T2IAdapterModelConfig,
} from 'services/api/types';
import { z } from 'zod';
@ -175,7 +175,7 @@ const zZoeDepthProcessorConfig = z.object({
});
export type ZoeDepthProcessorConfig = z.infer<typeof zZoeDepthProcessorConfig>;
export const zProcessorConfig = z.discriminatedUnion('type', [
export const zFilterConfig = z.discriminatedUnion('type', [
zCannyProcessorConfig,
zColorMapProcessorConfig,
zContentShuffleProcessorConfig,
@ -191,9 +191,9 @@ export const zProcessorConfig = z.discriminatedUnion('type', [
zPidiProcessorConfig,
zZoeDepthProcessorConfig,
]);
export type ProcessorConfig = z.infer<typeof zProcessorConfig>;
export type FilterConfig = z.infer<typeof zFilterConfig>;
const zProcessorTypeV2 = z.enum([
const zFilterType = z.enum([
'canny_image_processor',
'color_map_image_processor',
'content_shuffle_image_processor',
@ -209,22 +209,19 @@ const zProcessorTypeV2 = z.enum([
'pidi_image_processor',
'zoe_depth_image_processor',
]);
export type ProcessorTypeV2 = z.infer<typeof zProcessorTypeV2>;
export const isProcessorTypeV2 = (v: unknown): v is ProcessorTypeV2 => zProcessorTypeV2.safeParse(v).success;
type ProcessorData<T extends ProcessorTypeV2> = {
type: T;
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
buildNode(image: ImageWithDims, config: Extract<ProcessorConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
};
export type FilterType = z.infer<typeof zFilterType>;
export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success;
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);
type CAProcessorsData = {
[key in ProcessorTypeV2]: ProcessorData<key>;
type ImageFilterData<T extends FilterConfig['type']> = {
type: T;
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<FilterConfig, { type: T }>;
buildNode(imageDTO: ImageWithDims, config: Extract<FilterConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
};
/**
* A dict of ControlNet processors, including:
* - label translation key
@ -234,234 +231,243 @@ type CAProcessorsData = {
*
* TODO: Generate from the OpenAPI schema
*/
export const CA_PROCESSOR_DATA: CAProcessorsData = {
export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key> } = {
canny_image_processor: {
type: 'canny_image_processor',
labelTKey: 'controlnet.canny',
descriptionTKey: 'controlnet.cannyDescription',
buildDefaults: () => ({
buildDefaults: (): CannyProcessorConfig => ({
id: 'canny_image_processor',
type: 'canny_image_processor',
low_threshold: 100,
high_threshold: 200,
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: CannyProcessorConfig): S['CannyImageProcessorInvocation'] => ({
...config,
type: 'canny_image_processor',
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
color_map_image_processor: {
type: 'color_map_image_processor',
labelTKey: 'controlnet.colorMap',
descriptionTKey: 'controlnet.colorMapDescription',
buildDefaults: () => ({
buildDefaults: (): ColorMapProcessorConfig => ({
id: 'color_map_image_processor',
type: 'color_map_image_processor',
color_map_tile_size: 64,
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: ColorMapProcessorConfig): S['ColorMapImageProcessorInvocation'] => ({
...config,
type: 'color_map_image_processor',
image: { image_name: image.image_name },
image: { image_name: imageDTO.image_name },
}),
},
content_shuffle_image_processor: {
type: 'content_shuffle_image_processor',
labelTKey: 'controlnet.contentShuffle',
descriptionTKey: 'controlnet.contentShuffleDescription',
buildDefaults: (baseModel) => ({
buildDefaults: (baseModel: BaseModelType): ContentShuffleProcessorConfig => ({
id: 'content_shuffle_image_processor',
type: 'content_shuffle_image_processor',
h: baseModel === 'sdxl' ? 1024 : 512,
w: baseModel === 'sdxl' ? 1024 : 512,
f: baseModel === 'sdxl' ? 512 : 256,
}),
buildNode: (image, config) => ({
buildNode: (
imageDTO: ImageDTO,
config: ContentShuffleProcessorConfig
): S['ContentShuffleImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
depth_anything_image_processor: {
type: 'depth_anything_image_processor',
labelTKey: 'controlnet.depthAnything',
descriptionTKey: 'controlnet.depthAnythingDescription',
buildDefaults: () => ({
buildDefaults: (): DepthAnythingProcessorConfig => ({
id: 'depth_anything_image_processor',
type: 'depth_anything_image_processor',
model_size: 'small_v2',
}),
buildNode: (image, config) => ({
buildNode: (
imageDTO: ImageDTO,
config: DepthAnythingProcessorConfig
): S['DepthAnythingImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
resolution: minDim(image),
image: { image_name: imageDTO.image_name },
resolution: minDim(imageDTO),
}),
},
hed_image_processor: {
type: 'hed_image_processor',
labelTKey: 'controlnet.hed',
descriptionTKey: 'controlnet.hedDescription',
buildDefaults: () => ({
buildDefaults: (): HedProcessorConfig => ({
id: 'hed_image_processor',
type: 'hed_image_processor',
scribble: false,
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: HedProcessorConfig): S['HedImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
lineart_anime_image_processor: {
type: 'lineart_anime_image_processor',
labelTKey: 'controlnet.lineartAnime',
descriptionTKey: 'controlnet.lineartAnimeDescription',
buildDefaults: () => ({
buildDefaults: (): LineartAnimeProcessorConfig => ({
id: 'lineart_anime_image_processor',
type: 'lineart_anime_image_processor',
}),
buildNode: (image, config) => ({
buildNode: (
imageDTO: ImageDTO,
config: LineartAnimeProcessorConfig
): S['LineartAnimeImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
lineart_image_processor: {
type: 'lineart_image_processor',
labelTKey: 'controlnet.lineart',
descriptionTKey: 'controlnet.lineartDescription',
buildDefaults: () => ({
buildDefaults: (): LineartProcessorConfig => ({
id: 'lineart_image_processor',
type: 'lineart_image_processor',
coarse: false,
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: LineartProcessorConfig): S['LineartImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
mediapipe_face_processor: {
type: 'mediapipe_face_processor',
labelTKey: 'controlnet.mediapipeFace',
descriptionTKey: 'controlnet.mediapipeFaceDescription',
buildDefaults: () => ({
buildDefaults: (): MediapipeFaceProcessorConfig => ({
id: 'mediapipe_face_processor',
type: 'mediapipe_face_processor',
max_faces: 1,
min_confidence: 0.5,
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: MediapipeFaceProcessorConfig): S['MediapipeFaceProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
midas_depth_image_processor: {
type: 'midas_depth_image_processor',
labelTKey: 'controlnet.depthMidas',
descriptionTKey: 'controlnet.depthMidasDescription',
buildDefaults: () => ({
buildDefaults: (): MidasDepthProcessorConfig => ({
id: 'midas_depth_image_processor',
type: 'midas_depth_image_processor',
a_mult: 2,
bg_th: 0.1,
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: MidasDepthProcessorConfig): S['MidasDepthImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
mlsd_image_processor: {
type: 'mlsd_image_processor',
labelTKey: 'controlnet.mlsd',
descriptionTKey: 'controlnet.mlsdDescription',
buildDefaults: () => ({
buildDefaults: (): MlsdProcessorConfig => ({
id: 'mlsd_image_processor',
type: 'mlsd_image_processor',
thr_d: 0.1,
thr_v: 0.1,
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: MlsdProcessorConfig): S['MlsdImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
normalbae_image_processor: {
type: 'normalbae_image_processor',
labelTKey: 'controlnet.normalBae',
descriptionTKey: 'controlnet.normalBaeDescription',
buildDefaults: () => ({
buildDefaults: (): NormalbaeProcessorConfig => ({
id: 'normalbae_image_processor',
type: 'normalbae_image_processor',
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: NormalbaeProcessorConfig): S['NormalbaeImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
dw_openpose_image_processor: {
type: 'dw_openpose_image_processor',
labelTKey: 'controlnet.dwOpenpose',
descriptionTKey: 'controlnet.dwOpenposeDescription',
buildDefaults: () => ({
buildDefaults: (): DWOpenposeProcessorConfig => ({
id: 'dw_openpose_image_processor',
type: 'dw_openpose_image_processor',
draw_body: true,
draw_face: false,
draw_hands: false,
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: DWOpenposeProcessorConfig): S['DWOpenposeImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
image_resolution: minDim(imageDTO),
}),
},
pidi_image_processor: {
type: 'pidi_image_processor',
labelTKey: 'controlnet.pidi',
descriptionTKey: 'controlnet.pidiDescription',
buildDefaults: () => ({
buildDefaults: (): PidiProcessorConfig => ({
id: 'pidi_image_processor',
type: 'pidi_image_processor',
scribble: false,
safe: false,
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: PidiProcessorConfig): S['PidiImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
detect_resolution: minDim(image),
image_resolution: minDim(image),
image: { image_name: imageDTO.image_name },
detect_resolution: minDim(imageDTO),
image_resolution: minDim(imageDTO),
}),
},
zoe_depth_image_processor: {
type: 'zoe_depth_image_processor',
labelTKey: 'controlnet.depthZoe',
descriptionTKey: 'controlnet.depthZoeDescription',
buildDefaults: () => ({
buildDefaults: (): ZoeDepthProcessorConfig => ({
id: 'zoe_depth_image_processor',
type: 'zoe_depth_image_processor',
}),
buildNode: (image, config) => ({
buildNode: (imageDTO: ImageDTO, config: ZoeDepthProcessorConfig): S['ZoeDepthImageProcessorInvocation'] => ({
...config,
image: { image_name: image.image_name },
image: { image_name: imageDTO.image_name },
}),
},
};
} as const;
const zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'view', 'bbox']);
export type Tool = z.infer<typeof zTool>;
@ -575,17 +581,6 @@ export function isCanvasBrushLineState(obj: CanvasObjectState): obj is CanvasBru
return obj.type === 'brush_line';
}
export const zCanvasLayerState = z.object({
id: zId,
type: z.literal('layer'),
isEnabled: z.boolean(),
position: zCoordinate,
opacity: zOpacity,
objects: z.array(zCanvasObjectState),
imageCache: z.string().min(1).nullable(),
});
export type CanvasLayerState = z.infer<typeof zCanvasLayerState>;
export const zCanvasIPAdapterState = z.object({
id: zId,
type: z.literal('ip_adapter'),
@ -689,7 +684,7 @@ const zCanvasControlAdapterStateBase = z.object({
weight: z.number().gte(-1).lte(2),
imageObject: zCanvasImageState.nullable(),
processedImageObject: zCanvasImageState.nullable(),
processorConfig: zProcessorConfig.nullable(),
processorConfig: zFilterConfig.nullable(),
processorPendingBatchId: z.string().nullable().default(null),
beginEndStepPct: zBeginEndStepPct,
model: zModelIdentifierField.nullable(),
@ -709,41 +704,55 @@ export const zCanvasControlAdapterState = z.discriminatedUnion('adapterType', [
zCanvasT2IAdapteState,
]);
export type CanvasControlAdapterState = z.infer<typeof zCanvasControlAdapterState>;
export type ControlNetConfig = Pick<
CanvasControlNetState,
| 'adapterType'
| 'weight'
| 'imageObject'
| 'processedImageObject'
| 'processorConfig'
| 'beginEndStepPct'
| 'model'
| 'controlMode'
>;
export type T2IAdapterConfig = Pick<
CanvasT2IAdapterState,
'adapterType' | 'weight' | 'imageObject' | 'processedImageObject' | 'processorConfig' | 'beginEndStepPct' | 'model'
>;
const zControlNetConfig = z.object({
type: z.literal('controlnet'),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
controlMode: zControlModeV2,
});
export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
const zT2IAdapterConfig = z.object({
type: z.literal('t2i_adapter'),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
});
export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
export const zCanvasLayerState = z.object({
id: zId,
type: z.literal('layer'),
isEnabled: z.boolean(),
position: zCoordinate,
opacity: zOpacity,
objects: z.array(zCanvasObjectState),
imageCache: z.string().min(1).nullable(),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]).nullable(),
});
export type CanvasLayerState = z.infer<typeof zCanvasLayerState>;
export type CanvasLayerStateWithValidControlNet = Omit<CanvasLayerState, 'controlAdapter'> & {
controlAdapter: Omit<ControlNetConfig, 'model'> & { model: ControlNetModelConfig };
};
export type CanvasLayerStateWithValidT2IAdapter = Omit<CanvasLayerState, 'controlAdapter'> & {
controlAdapter: Omit<T2IAdapterConfig, 'model'> & { model: T2IAdapterModelConfig };
};
export const initialControlNetV2: ControlNetConfig = {
adapterType: 'controlnet',
type: 'controlnet',
model: null,
weight: 1,
beginEndStepPct: [0, 1],
controlMode: 'balanced',
imageObject: null,
processedImageObject: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
};
export const initialT2IAdapterV2: T2IAdapterConfig = {
adapterType: 't2i_adapter',
type: 't2i_adapter',
model: null,
weight: 1,
beginEndStepPct: [0, 1],
imageObject: null,
processedImageObject: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
};
export const initialIPAdapterV2: IPAdapterConfig = {
@ -757,12 +766,12 @@ export const initialIPAdapterV2: IPAdapterConfig = {
export const buildControlAdapterProcessorV2 = (
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
): ProcessorConfig | null => {
): FilterConfig | null => {
const defaultPreprocessor = modelConfig.default_settings?.preprocessor;
if (!isProcessorTypeV2(defaultPreprocessor)) {
if (!isFilterType(defaultPreprocessor)) {
return null;
}
const processorConfig = CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults(modelConfig.base);
const processorConfig = IMAGE_FILTERS[defaultPreprocessor].buildDefaults(modelConfig.base);
return processorConfig;
};
@ -901,6 +910,10 @@ export type CanvasV2State = {
stagedImages: StagingAreaImage[];
selectedStagedImageIndex: number;
};
filter: {
autoProcess: boolean;
config: FilterConfig;
};
};
export type StageAttrs = {
@ -964,5 +977,3 @@ export function isDrawableEntityType(
): entityType is 'layer' | 'regional_guidance' | 'inpaint_mask' {
return entityType === 'layer' || entityType === 'regional_guidance' || entityType === 'inpaint_mask';
}
export type GetLoggingContext = (extra?: JSONObject) => JSONObject;

View File

@ -2,12 +2,12 @@ import { getCAId, getImageObjectId, getIPAId, getLayerId } from 'features/contro
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers';
import type { CanvasControlAdapterState, CanvasIPAdapterState, CanvasLayerState, LoRA } from 'features/controlLayers/store/types';
import {
CA_PROCESSOR_DATA,
IMAGE_FILTERS,
imageDTOToImageWithDims,
initialControlNetV2,
initialIPAdapterV2,
initialT2IAdapterV2,
isProcessorTypeV2,
isFilterType,
zCanvasLayerState,
} from 'features/controlLayers/store/types';
import type {
@ -559,8 +559,8 @@ const parseControlNetToControlAdapterLayer: MetadataParseFunc<CanvasControlAdapt
.parse(await getProperty(metadataItem, 'control_mode'));
const defaultPreprocessor = controlNetModel.default_settings?.preprocessor;
const processorConfig = isProcessorTypeV2(defaultPreprocessor)
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults()
const processorConfig = isFilterType(defaultPreprocessor)
? IMAGE_FILTERS[defaultPreprocessor].buildDefaults()
: null;
const beginEndStepPct: [number, number] = [
begin_step_percent ?? initialControlNetV2.beginEndStepPct[0],
@ -620,8 +620,8 @@ const parseT2IAdapterToControlAdapterLayer: MetadataParseFunc<CanvasControlAdapt
.parse(await getProperty(metadataItem, 'end_step_percent'));
const defaultPreprocessor = t2iAdapterModel.default_settings?.preprocessor;
const processorConfig = isProcessorTypeV2(defaultPreprocessor)
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults()
const processorConfig = isFilterType(defaultPreprocessor)
? IMAGE_FILTERS[defaultPreprocessor].buildDefaults()
: null;
const beginEndStepPct: [number, number] = [
begin_step_percent ?? initialT2IAdapterV2.beginEndStepPct[0],

View File

@ -1,35 +1,42 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type {
CanvasControlAdapterState,
CanvasControlNetState,
CanvasLayerState,
CanvasLayerStateWithValidControlNet,
CanvasLayerStateWithValidT2IAdapter,
ControlNetConfig,
FilterConfig,
ImageWithDims,
ProcessorConfig,
Rect,
CanvasT2IAdapterState,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import type { ImageField } from 'features/nodes/types/common';
import { CONTROL_NET_COLLECT, T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types';
import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
export const addControlAdapters = async (
manager: CanvasManager,
controlAdapters: CanvasControlAdapterState[],
layers: CanvasLayerState[],
g: Graph,
bbox: Rect,
denoise: Invocation<'denoise_latents'>,
base: BaseModelType
): Promise<CanvasControlAdapterState[]> => {
const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base));
for (const ca of validControlAdapters) {
if (ca.adapterType === 'controlnet') {
await addControlNetToGraph(manager, ca, g, bbox, denoise);
): Promise<(CanvasLayerStateWithValidControlNet | CanvasLayerStateWithValidT2IAdapter)[]> => {
const layersWithValidControlAdapters = layers
.filter((layer) => layer.isEnabled)
.filter((layer) => doesLayerHaveValidControlAdapter(layer, base));
for (const layer of layersWithValidControlAdapters) {
const adapter = manager.layers.get(layer.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.getImageDTO({ rect: bbox, is_intermediate: true, category: 'control' });
if (layer.controlAdapter.type === 'controlnet') {
await addControlNetToGraph(g, layer, imageDTO, denoise);
} else {
await addT2IAdapterToGraph(manager, ca, g, bbox, denoise);
await addT2IAdapterToGraph(g, layer, imageDTO, denoise);
}
}
return validControlAdapters;
return layersWithValidControlAdapters;
};
const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
@ -49,16 +56,15 @@ const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
}
};
const addControlNetToGraph = async (
manager: CanvasManager,
ca: CanvasControlNetState,
const addControlNetToGraph = (
g: Graph,
bbox: Rect,
layer: CanvasLayerStateWithValidControlNet,
imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'>
) => {
const { id, beginEndStepPct, controlMode, model, weight } = ca;
assert(model, 'ControlNet model is required');
const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true });
const { id, controlAdapter } = layer;
const { beginEndStepPct, model, weight, controlMode } = controlAdapter;
const { image_name } = imageDTO;
const controlNetCollect = addControlNetCollectorSafe(g, denoise);
@ -94,16 +100,15 @@ const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
}
};
const addT2IAdapterToGraph = async (
manager: CanvasManager,
ca: CanvasT2IAdapterState,
const addT2IAdapterToGraph = (
g: Graph,
bbox: Rect,
layer: CanvasLayerStateWithValidT2IAdapter,
imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'>
) => {
const { id, beginEndStepPct, model, weight } = ca;
assert(model, 'T2I Adapter model is required');
const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true });
const { id, controlAdapter } = layer;
const { beginEndStepPct, model, weight } = controlAdapter;
const { image_name } = imageDTO;
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
@ -124,7 +129,7 @@ const addT2IAdapterToGraph = async (
const buildControlImage = (
image: ImageWithDims | null,
processedImage: ImageWithDims | null,
processorConfig: ProcessorConfig | null
processorConfig: FilterConfig | null
): ImageField => {
if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image.
@ -140,10 +145,29 @@ const buildControlImage = (
assert(false, 'Attempted to add unprocessed control image');
};
const isValidControlAdapter = (ca: CanvasControlAdapterState, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ca.model);
const modelMatchesBase = ca.model?.base === base;
const hasControlImage = Boolean(ca.imageObject || (ca.processedImageObject && ca.processorConfig));
return hasModel && modelMatchesBase && hasControlImage;
const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {
// Must be have a model
const hasModel = Boolean(controlAdapter.model);
// Model must match the current base model
const modelMatchesBase = controlAdapter.model?.base === base;
return hasModel && modelMatchesBase;
};
const doesLayerHaveValidControlAdapter = (
layer: CanvasLayerState,
base: BaseModelType
): layer is CanvasLayerStateWithValidControlNet | CanvasLayerStateWithValidT2IAdapter => {
if (!layer.controlAdapter) {
// Must have a control adapter
return false;
}
if (!layer.controlAdapter.model) {
// Control adapter must have a model selected
return false;
}
if (layer.controlAdapter.model.base !== base) {
// Selected model must match current base model
return false;
}
return true;
};

View File

@ -1,9 +1,10 @@
import type { CanvasLayerState } from 'features/controlLayers/store/types';
export const isValidLayer = (entity: CanvasLayerState) => {
export const isValidLayerWithoutControlAdapter = (layer: CanvasLayerState) => {
return (
entity.isEnabled &&
layer.isEnabled &&
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
entity.objects.length > 0
layer.objects.length > 0 &&
layer.controlAdapter === null
);
};

View File

@ -215,7 +215,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.controlAdapters.entities,
state.canvasV2.layers.entities,
g,
state.canvasV2.bbox.rect,
denoise,

View File

@ -219,7 +219,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.controlAdapters.entities,
state.canvasV2.layers.entities,
g,
state.canvasV2.bbox.rect,
denoise,

View File

@ -1635,8 +1635,11 @@ export type components = {
* @description The ID of the batch
*/
batch_id?: string;
/** @description The origin of this batch. */
origin?: components["schemas"]["QueueItemOrigin"] | null;
/**
* Origin
* @description The origin of this batch.
*/
origin?: string | null;
/**
* Data
* @description The batch data collection.
@ -1707,10 +1710,11 @@ export type components = {
*/
priority: number;
/**
* Origin
* @description The origin of the batch
* @default null
*/
origin: components["schemas"]["QueueItemOrigin"] | null;
origin: string | null;
};
/** BatchStatus */
BatchStatus: {
@ -1724,8 +1728,11 @@ export type components = {
* @description The ID of the batch
*/
batch_id: string;
/** @description The origin of the batch */
origin: components["schemas"]["QueueItemOrigin"] | null;
/**
* Origin
* @description The origin of the batch
*/
origin: string | null;
/**
* Pending
* @description Number of queue items with status 'pending'
@ -8330,10 +8337,11 @@ export type components = {
*/
batch_id: string;
/**
* Origin
* @description The origin of the batch
* @default null
*/
origin: components["schemas"]["QueueItemOrigin"] | null;
origin: string | null;
/**
* Session Id
* @description The ID of the session (aka graph execution state)
@ -8381,10 +8389,11 @@ export type components = {
*/
batch_id: string;
/**
* Origin
* @description The origin of the batch
* @default null
*/
origin: components["schemas"]["QueueItemOrigin"] | null;
origin: string | null;
/**
* Session Id
* @description The ID of the session (aka graph execution state)
@ -8449,10 +8458,11 @@ export type components = {
*/
batch_id: string;
/**
* Origin
* @description The origin of the batch
* @default null
*/
origin: components["schemas"]["QueueItemOrigin"] | null;
origin: string | null;
/**
* Session Id
* @description The ID of the session (aka graph execution state)
@ -8670,10 +8680,11 @@ export type components = {
*/
batch_id: string;
/**
* Origin
* @description The origin of the batch
* @default null
*/
origin: components["schemas"]["QueueItemOrigin"] | null;
origin: string | null;
/**
* Session Id
* @description The ID of the session (aka graph execution state)
@ -11654,10 +11665,11 @@ export type components = {
*/
batch_id: string;
/**
* Origin
* @description The origin of the batch
* @default null
*/
origin: components["schemas"]["QueueItemOrigin"] | null;
origin: string | null;
/**
* Status
* @description The new status of the queue item
@ -13014,8 +13026,11 @@ export type components = {
* @description The ID of the batch associated with this queue item
*/
batch_id: string;
/** @description The origin of this queue item. */
origin?: components["schemas"]["QueueItemOrigin"] | null;
/**
* Origin
* @description The origin of this queue item.
*/
origin?: string | null;
/**
* Session Id
* @description The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed.
@ -13096,8 +13111,11 @@ export type components = {
* @description The ID of the batch associated with this queue item
*/
batch_id: string;
/** @description The origin of this queue item. */
origin?: components["schemas"]["QueueItemOrigin"] | null;
/**
* Origin
* @description The origin of this queue item.
*/
origin?: string | null;
/**
* Session Id
* @description The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed.