feat(ui, app): use layer as control (wip)

This commit is contained in:
psychedelicious 2024-08-14 18:10:51 +10:00
parent 3b36eb0223
commit 636d9a7209
65 changed files with 1734 additions and 292 deletions

View File

@ -10,7 +10,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS, QUEUE_ITEM_STATUS,
BatchStatus, BatchStatus,
EnqueueBatchResult, EnqueueBatchResult,
QueueItemOrigin,
SessionQueueItem, SessionQueueItem,
SessionQueueStatus, SessionQueueStatus,
) )
@ -89,7 +88,7 @@ class QueueItemEventBase(QueueEventBase):
item_id: int = Field(description="The ID of the queue item") item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch") batch_id: str = Field(description="The ID of the queue batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch") origin: str | None = Field(default=None, description="The origin of the batch")
class InvocationEventBase(QueueItemEventBase): class InvocationEventBase(QueueItemEventBase):
@ -284,7 +283,7 @@ class BatchEnqueuedEvent(QueueEventBase):
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)" description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
) )
priority: int = Field(description="The priority of the batch") priority: int = Field(description="The priority of the batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of the batch") origin: str | None = Field(default=None, description="The origin of the batch")
@classmethod @classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":

View File

@ -86,7 +86,7 @@ BatchDataCollection: TypeAlias = list[list[BatchDatum]]
class Batch(BaseModel): class Batch(BaseModel):
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch") batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this batch.") origin: str | None = Field(default=None, description="The origin of this batch.")
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.") data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
graph: Graph = Field(description="The graph to initialize the session with") graph: Graph = Field(description="The graph to initialize the session with")
workflow: Optional[WorkflowWithoutID] = Field( workflow: Optional[WorkflowWithoutID] = Field(
@ -205,7 +205,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item") status: QUEUE_ITEM_STATUS = Field(default="pending", description="The status of this queue item")
priority: int = Field(default=0, description="The priority of this queue item") priority: int = Field(default=0, description="The priority of this queue item")
batch_id: str = Field(description="The ID of the batch associated with this queue item") batch_id: str = Field(description="The ID of the batch associated with this queue item")
origin: QueueItemOrigin | None = Field(default=None, description="The origin of this queue item. ") origin: str | None = Field(default=None, description="The origin of this queue item. ")
session_id: str = Field( session_id: str = Field(
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed." description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
) )
@ -305,7 +305,7 @@ class SessionQueueStatus(BaseModel):
class BatchStatus(BaseModel): class BatchStatus(BaseModel):
queue_id: str = Field(..., description="The ID of the queue") queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch") batch_id: str = Field(..., description="The ID of the batch")
origin: QueueItemOrigin | None = Field(..., description="The origin of the batch") origin: str | None = Field(..., description="The origin of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'") pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'") in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'") completed: int = Field(..., description="Number of queue items with status 'complete'")
@ -451,7 +451,7 @@ class SessionQueueValueToInsert(NamedTuple):
field_values: Optional[str] # field_values json field_values: Optional[str] # field_values json
priority: int # priority priority: int # priority
workflow: Optional[str] # workflow json workflow: Optional[str] # workflow json
origin: QueueItemOrigin | None origin: str | None
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert] ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]

View File

@ -10,6 +10,7 @@ import { setEventListeners } from 'services/events/setEventListeners';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types'; import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client'; import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
import { io } from 'socket.io-client'; import { io } from 'socket.io-client';
import { assert } from 'tsafe';
// Inject socket options and url into window for debugging // Inject socket options and url into window for debugging
declare global { declare global {
@ -18,6 +19,14 @@ declare global {
} }
} }
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;
export const $socket = atom<AppSocket | null>(null);
export const getSocket = () => {
const socket = $socket.get();
assert(socket !== null, 'Socket is not initialized');
return socket;
};
export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({}); export const $socketOptions = map<Partial<ManagerOptions & SocketOptions>>({});
const $isSocketInitialized = atom<boolean>(false); const $isSocketInitialized = atom<boolean>(false);
@ -61,7 +70,8 @@ export const useSocketIO = () => {
return; return;
} }
const socket: Socket<ServerToClientEvents, ClientToServerEvents> = io(socketUrl, socketOptions); const socket: AppSocket = io(socketUrl, socketOptions);
$socket.set(socket);
setEventListeners({ dispatch, socket }); setEventListeners({ dispatch, socket });
socket.connect(); socket.connect();

View File

@ -12,7 +12,7 @@ import {
caRecalled, caRecalled,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import { selectCA } from 'features/controlLayers/store/controlAdaptersReducers'; import { selectCA } from 'features/controlLayers/store/controlAdaptersReducers';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types'; import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { toast } from 'features/toast/toast'; import { toast } from 'features/toast/toast';
import { t } from 'i18next'; import { t } from 'i18next';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -95,7 +95,7 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
} }
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now // TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image.image, config as never); const processorNode = IMAGE_FILTERS[config.type].buildNode(image.image, config as never);
const enqueueBatchArg: BatchConfig = { const enqueueBatchArg: BatchConfig = {
prepend: true, prepend: true,
batch: { batch: {

View File

@ -1,4 +1,5 @@
import type { createStore } from 'app/store/store'; import { useStore } from '@nanostores/react';
import type { AppStore } from 'app/store/store';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
// Inject socket options and url into window for debugging // Inject socket options and url into window for debugging
@ -22,7 +23,7 @@ class ReduxStoreNotInitialized extends Error {
} }
} }
export const $store = atom<Readonly<ReturnType<typeof createStore>> | undefined>(); export const $store = atom<Readonly<AppStore | undefined>>();
export const getStore = () => { export const getStore = () => {
const store = $store.get(); const store = $store.get();
@ -31,3 +32,11 @@ export const getStore = () => {
} }
return store; return store;
}; };
export const useAppStore = () => {
const store = useStore($store);
if (!store) {
throw new ReduxStoreNotInitialized();
}
return store;
};

View File

@ -180,7 +180,8 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
}, },
}); });
export type RootState = ReturnType<ReturnType<typeof createStore>['getState']>; export type AppStore = ReturnType<typeof createStore>;
export type RootState = ReturnType<AppStore['getState']>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any // eslint-disable-next-line @typescript-eslint/no-explicit-any
export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>; export type AppThunkDispatch = ThunkDispatch<RootState, any, UnknownAction>;
export type AppDispatch = ReturnType<typeof createStore>['dispatch']; export type AppDispatch = ReturnType<typeof createStore>['dispatch'];

View File

@ -1,8 +1,8 @@
import type { AppThunkDispatch, RootState } from 'app/store/store'; import type { AppStore, AppThunkDispatch, RootState } from 'app/store/store';
import type { TypedUseSelectorHook } from 'react-redux'; import type { TypedUseSelectorHook } from 'react-redux';
import { useDispatch, useSelector, useStore } from 'react-redux'; import {useDispatch, useSelector, useStore } from 'react-redux';
// Use throughout your app instead of plain `useDispatch` and `useSelector` // Use throughout your app instead of plain `useDispatch` and `useSelector`
export const useAppDispatch = () => useDispatch<AppThunkDispatch>(); export const useAppDispatch = () => useDispatch<AppThunkDispatch>();
export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector; export const useAppSelector: TypedUseSelectorHook<RootState> = useSelector;
export const useAppStore = () => useStore<RootState>(); export const useAppStore = () => useStore<AppStore>();

View File

@ -1,4 +1,4 @@
import type { ProcessorTypeV2 } from 'features/controlLayers/store/types'; import type { FilterType } from 'features/controlLayers/store/types';
import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import type { ParameterPrecision, ParameterScheduler } from 'features/parameters/types/parameterSchemas';
import type { InvokeTabName } from 'features/ui/store/tabMap'; import type { InvokeTabName } from 'features/ui/store/tabMap';
import type { O } from 'ts-toolbelt'; import type { O } from 'ts-toolbelt';
@ -83,7 +83,7 @@ export type AppConfig = {
sd: { sd: {
defaultModel?: string; defaultModel?: string;
disabledControlNetModels: string[]; disabledControlNetModels: string[];
disabledControlNetProcessors: ProcessorTypeV2[]; disabledControlNetProcessors: FilterType[];
// Core parameters // Core parameters
iterations: NumericalParameterConfig; iterations: NumericalParameterConfig;
width: NumericalParameterConfig; // initial value comes from model width: NumericalParameterConfig; // initial value comes from model

View File

@ -9,12 +9,12 @@ import { MediapipeFaceProcessor } from 'features/controlLayers/components/Contro
import { MidasDepthProcessor } from 'features/controlLayers/components/ControlAdapter/processors/MidasDepthProcessor'; import { MidasDepthProcessor } from 'features/controlLayers/components/ControlAdapter/processors/MidasDepthProcessor';
import { MlsdImageProcessor } from 'features/controlLayers/components/ControlAdapter/processors/MlsdImageProcessor'; import { MlsdImageProcessor } from 'features/controlLayers/components/ControlAdapter/processors/MlsdImageProcessor';
import { PidiProcessor } from 'features/controlLayers/components/ControlAdapter/processors/PidiProcessor'; import { PidiProcessor } from 'features/controlLayers/components/ControlAdapter/processors/PidiProcessor';
import type { ProcessorConfig } from 'features/controlLayers/store/types'; import type { FilterConfig } from 'features/controlLayers/store/types';
import { memo } from 'react'; import { memo } from 'react';
type Props = { type Props = {
config: ProcessorConfig | null; config: FilterConfig | null;
onChange: (config: ProcessorConfig | null) => void; onChange: (config: FilterConfig | null) => void;
}; };
export const ControlAdapterProcessorConfig = memo(({ config, onChange }: Props) => { export const ControlAdapterProcessorConfig = memo(({ config, onChange }: Props) => {

View File

@ -3,8 +3,8 @@ import { Combobox, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/u
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import type {ProcessorConfig } from 'features/controlLayers/store/types'; import type {FilterConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA, isProcessorTypeV2 } from 'features/controlLayers/store/types'; import { IMAGE_FILTERS, isFilterType } from 'features/controlLayers/store/types';
import { configSelector } from 'features/system/store/configSelectors'; import { configSelector } from 'features/system/store/configSelectors';
import { includes, map } from 'lodash-es'; import { includes, map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
@ -13,8 +13,8 @@ import { PiXBold } from 'react-icons/pi';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
type Props = { type Props = {
config: ProcessorConfig | null; config: FilterConfig | null;
onChange: (config: ProcessorConfig | null) => void; onChange: (config: FilterConfig | null) => void;
}; };
const selectDisabledProcessors = createMemoizedSelector( const selectDisabledProcessors = createMemoizedSelector(
@ -26,7 +26,7 @@ export const ControlAdapterProcessorTypeSelect = memo(({ config, onChange }: Pro
const { t } = useTranslation(); const { t } = useTranslation();
const disabledProcessors = useAppSelector(selectDisabledProcessors); const disabledProcessors = useAppSelector(selectDisabledProcessors);
const options = useMemo(() => { const options = useMemo(() => {
return map(CA_PROCESSOR_DATA, ({ labelTKey }, type) => ({ value: type, label: t(labelTKey) })).filter( return map(IMAGE_FILTERS, ({ labelTKey }, type) => ({ value: type, label: t(labelTKey) })).filter(
(o) => !includes(disabledProcessors, o.value) (o) => !includes(disabledProcessors, o.value)
); );
}, [disabledProcessors, t]); }, [disabledProcessors, t]);
@ -36,8 +36,8 @@ export const ControlAdapterProcessorTypeSelect = memo(({ config, onChange }: Pro
if (!v) { if (!v) {
onChange(null); onChange(null);
} else { } else {
assert(isProcessorTypeV2(v.value)); assert(isFilterType(v.value));
onChange(CA_PROCESSOR_DATA[v.value].buildDefaults()); onChange(IMAGE_FILTERS[v.value].buildDefaults());
} }
}, },
[onChange] [onChange]

View File

@ -19,7 +19,7 @@ import {
caWeightChanged, caWeightChanged,
} from 'features/controlLayers/store/canvasV2Slice'; } from 'features/controlLayers/store/canvasV2Slice';
import { selectCAOrThrow } from 'features/controlLayers/store/controlAdaptersReducers'; import { selectCAOrThrow } from 'features/controlLayers/store/controlAdaptersReducers';
import type { ControlModeV2, ProcessorConfig } from 'features/controlLayers/store/types'; import type { ControlModeV2, FilterConfig } from 'features/controlLayers/store/types';
import type { CAImageDropData } from 'features/dnd/types'; import type { CAImageDropData } from 'features/dnd/types';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -62,7 +62,7 @@ export const ControlAdapterSettings = memo(() => {
); );
const onChangeProcessorConfig = useCallback( const onChangeProcessorConfig = useCallback(
(processorConfig: ProcessorConfig | null) => { (processorConfig: FilterConfig | null) => {
dispatch(caProcessorConfigChanged({ id, processorConfig })); dispatch(caProcessorConfigChanged({ id, processorConfig }));
}, },
[dispatch, id] [dispatch, id]

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types'; import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { CannyProcessorConfig } from 'features/controlLayers/store/types'; import type { CannyProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types'; import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper'; import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<CannyProcessorConfig>; type Props = ProcessorComponentProps<CannyProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['canny_image_processor'].buildDefaults(); const DEFAULTS = IMAGE_FILTERS['canny_image_processor'].buildDefaults();
export const CannyProcessor = ({ onChange, config }: Props) => { export const CannyProcessor = ({ onChange, config }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types'; import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { ColorMapProcessorConfig } from 'features/controlLayers/store/types'; import type { ColorMapProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types'; import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper'; import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<ColorMapProcessorConfig>; type Props = ProcessorComponentProps<ColorMapProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['color_map_image_processor'].buildDefaults(); const DEFAULTS = IMAGE_FILTERS['color_map_image_processor'].buildDefaults();
export const ColorMapProcessor = memo(({ onChange, config }: Props) => { export const ColorMapProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types'; import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { ContentShuffleProcessorConfig } from 'features/controlLayers/store/types'; import type { ContentShuffleProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types'; import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper'; import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<ContentShuffleProcessorConfig>; type Props = ProcessorComponentProps<ContentShuffleProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['content_shuffle_image_processor'].buildDefaults(); const DEFAULTS = IMAGE_FILTERS['content_shuffle_image_processor'].buildDefaults();
export const ContentShuffleProcessor = memo(({ onChange, config }: Props) => { export const ContentShuffleProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -1,7 +1,7 @@
import { Flex, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library'; import { Flex, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types'; import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { DWOpenposeProcessorConfig } from 'features/controlLayers/store/types'; import type { DWOpenposeProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types'; import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react'; import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper'; import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<DWOpenposeProcessorConfig>; type Props = ProcessorComponentProps<DWOpenposeProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['dw_openpose_image_processor'].buildDefaults(); const DEFAULTS = IMAGE_FILTERS['dw_openpose_image_processor'].buildDefaults();
export const DWOpenposeProcessor = memo(({ onChange, config }: Props) => { export const DWOpenposeProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types'; import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { MediapipeFaceProcessorConfig } from 'features/controlLayers/store/types'; import type { MediapipeFaceProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types'; import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper'; import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<MediapipeFaceProcessorConfig>; type Props = ProcessorComponentProps<MediapipeFaceProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['mediapipe_face_processor'].buildDefaults(); const DEFAULTS = IMAGE_FILTERS['mediapipe_face_processor'].buildDefaults();
export const MediapipeFaceProcessor = memo(({ onChange, config }: Props) => { export const MediapipeFaceProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types'; import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { MidasDepthProcessorConfig } from 'features/controlLayers/store/types'; import type { MidasDepthProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types'; import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper'; import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<MidasDepthProcessorConfig>; type Props = ProcessorComponentProps<MidasDepthProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['midas_depth_image_processor'].buildDefaults(); const DEFAULTS = IMAGE_FILTERS['midas_depth_image_processor'].buildDefaults();
export const MidasDepthProcessor = memo(({ onChange, config }: Props) => { export const MidasDepthProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -1,14 +1,14 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types'; import type { ProcessorComponentProps } from 'features/controlLayers/components/ControlAdapter/processors/types';
import type { MlsdProcessorConfig } from 'features/controlLayers/store/types'; import type { MlsdProcessorConfig } from 'features/controlLayers/store/types';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/store/types'; import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import ProcessorWrapper from './ProcessorWrapper'; import ProcessorWrapper from './ProcessorWrapper';
type Props = ProcessorComponentProps<MlsdProcessorConfig>; type Props = ProcessorComponentProps<MlsdProcessorConfig>;
const DEFAULTS = CA_PROCESSOR_DATA['mlsd_image_processor'].buildDefaults(); const DEFAULTS = IMAGE_FILTERS['mlsd_image_processor'].buildDefaults();
export const MlsdImageProcessor = memo(({ onChange, config }: Props) => { export const MlsdImageProcessor = memo(({ onChange, config }: Props) => {
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -1,6 +1,6 @@
import type { ProcessorConfig } from 'features/controlLayers/store/types'; import type { FilterConfig } from 'features/controlLayers/store/types';
export type ProcessorComponentProps<T extends ProcessorConfig> = { export type ProcessorComponentProps<T extends FilterConfig> = {
onChange: (config: T) => void; onChange: (config: T) => void;
config: T; config: T;
}; };

View File

@ -1,19 +1,37 @@
/* eslint-disable i18next/no-literal-string */ /* eslint-disable i18next/no-literal-string */
import { Flex } from '@invoke-ai/ui-library'; import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { AddLayerButton } from 'features/controlLayers/components/AddLayerButton'; import { AddLayerButton } from 'features/controlLayers/components/AddLayerButton';
import { CanvasEntityList } from 'features/controlLayers/components/CanvasEntityList'; import { CanvasEntityList } from 'features/controlLayers/components/CanvasEntityList';
import { DeleteAllLayersButton } from 'features/controlLayers/components/DeleteAllLayersButton'; import { DeleteAllLayersButton } from 'features/controlLayers/components/DeleteAllLayersButton';
import { Filter } from 'features/controlLayers/components/Filters/Filter';
import { $filteringEntity } from 'features/controlLayers/store/canvasV2Slice';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { memo } from 'react'; import { memo } from 'react';
import { Panel, PanelGroup } from 'react-resizable-panels';
export const ControlLayersPanelContent = memo(() => { export const ControlLayersPanelContent = memo(() => {
const filteringEntity = useStore($filteringEntity);
return ( return (
<Flex flexDir="column" gap={2} w="full" h="full"> <PanelGroup direction="vertical">
<Flex justifyContent="space-around"> <Panel id="canvas-entity-list-panel" order={0}>
<AddLayerButton /> <Flex flexDir="column" gap={2} w="full" h="full">
<DeleteAllLayersButton /> <Flex justifyContent="space-around">
</Flex> <AddLayerButton />
<CanvasEntityList /> <DeleteAllLayersButton />
</Flex> </Flex>
<CanvasEntityList />
</Flex>
</Panel>
{Boolean(filteringEntity) && (
<>
<ResizeHandle orientation="horizontal" />
<Panel id="filter-panel" order={1}>
<Filter />
</Panel>
</>
)}
</PanelGroup>
); );
}); });

View File

@ -12,11 +12,25 @@ import { ResetCanvasButton } from 'features/controlLayers/components/ResetCanvas
import { ToolChooser } from 'features/controlLayers/components/ToolChooser'; import { ToolChooser } from 'features/controlLayers/components/ToolChooser';
import { UndoRedoButtonGroup } from 'features/controlLayers/components/UndoRedoButtonGroup'; import { UndoRedoButtonGroup } from 'features/controlLayers/components/UndoRedoButtonGroup';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager'; import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { nanoid } from 'features/controlLayers/konva/util';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton'; import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { ViewerToggleMenu } from 'features/gallery/components/ImageViewer/ViewerToggleMenu'; import { ViewerToggleMenu } from 'features/gallery/components/ImageViewer/ViewerToggleMenu';
import type { ChangeEvent } from 'react'; import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react'; import { memo, useCallback } from 'react';
const filter = () => {
const entity = $canvasManager.get()?.stateApi.getSelectedEntity();
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.previewFilter({
type: 'canny_image_processor',
id: nanoid(),
low_threshold: 50,
high_threshold: 50,
});
};
export const ControlLayersToolbar = memo(() => { export const ControlLayersToolbar = memo(() => {
const tool = useAppSelector((s) => s.canvasV2.tool.selected); const tool = useAppSelector((s) => s.canvasV2.tool.selected);
const canvasManager = useStore($canvasManager); const canvasManager = useStore($canvasManager);
@ -47,6 +61,7 @@ export const ControlLayersToolbar = memo(() => {
<Flex gap={2} marginInlineEnd="auto" alignItems="center"> <Flex gap={2} marginInlineEnd="auto" alignItems="center">
<ToggleProgressButton /> <ToggleProgressButton />
<ToolChooser /> <ToolChooser />
<Button onClick={filter}>Filter</Button>
</Flex> </Flex>
</Flex> </Flex>
<Flex flex={1} gap={2} justifyContent="center" alignItems="center"> <Flex flex={1} gap={2} justifyContent="center" alignItems="center">

View File

@ -0,0 +1,76 @@
import { Button, ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { FilterSettings } from 'features/controlLayers/components/Filters/FilterSettings';
import { FilterTypeSelect } from 'features/controlLayers/components/Filters/FilterTypeSelect';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { $filteringEntity } from 'features/controlLayers/store/canvasV2Slice';
import { memo, useCallback } from 'react';
export const Filter = memo(() => {
const filteringEntity = useStore($filteringEntity);
const preview = useCallback(() => {
if (!filteringEntity) {
return;
}
const canvasManager = $canvasManager.get();
if (!canvasManager) {
return;
}
const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.previewFilter();
}, [filteringEntity]);
const apply = useCallback(() => {
if (!filteringEntity) {
return;
}
const canvasManager = $canvasManager.get();
if (!canvasManager) {
return;
}
const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.applyFilter();
}, [filteringEntity]);
const cancel = useCallback(() => {
if (!filteringEntity) {
return;
}
const canvasManager = $canvasManager.get();
if (!canvasManager) {
return;
}
const entity = canvasManager.stateApi.getEntity(filteringEntity);
if (!entity || entity.type !== 'layer') {
return;
}
entity.adapter.filter.cancelFilter();
}, [filteringEntity]);
return (
<Flex flexDir="column" gap={3} w="full" h="full">
<FilterTypeSelect />
<ButtonGroup isAttached={false}>
<Button onClick={preview} isDisabled={!filteringEntity}>
Preview
</Button>
<Button onClick={apply} isDisabled={!filteringEntity}>
Apply
</Button>
<Button onClick={cancel} isDisabled={!filteringEntity}>
Cancel
</Button>
</ButtonGroup>
<FilterSettings />
</Flex>
);
});
Filter.displayName = 'Filter';

View File

@ -0,0 +1,67 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { CannyProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<CannyProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['canny_image_processor'].buildDefaults();
export const FilterCanny = ({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleLowThresholdChanged = useCallback(
(v: number) => {
onChange({ ...config, low_threshold: v });
},
[onChange, config]
);
const handleHighThresholdChanged = useCallback(
(v: number) => {
onChange({ ...config, high_threshold: v });
},
[onChange, config]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.lowThreshold')}</FormLabel>
<CompositeSlider
value={config.low_threshold}
onChange={handleLowThresholdChanged}
defaultValue={DEFAULTS.low_threshold}
min={0}
max={255}
/>
<CompositeNumberInput
value={config.low_threshold}
onChange={handleLowThresholdChanged}
defaultValue={DEFAULTS.low_threshold}
min={0}
max={255}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.highThreshold')}</FormLabel>
<CompositeSlider
value={config.high_threshold}
onChange={handleHighThresholdChanged}
defaultValue={DEFAULTS.high_threshold}
min={0}
max={255}
/>
<CompositeNumberInput
value={config.high_threshold}
onChange={handleHighThresholdChanged}
defaultValue={DEFAULTS.high_threshold}
min={0}
max={255}
/>
</FormControl>
</>
);
};
FilterCanny.displayName = 'FilterCanny';

View File

@ -0,0 +1,47 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ColorMapProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<ColorMapProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['color_map_image_processor'].buildDefaults();
export const FilterColorMap = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleColorMapTileSizeChanged = useCallback(
(v: number) => {
onChange({ ...config, color_map_tile_size: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.colorMapTileSize')}</FormLabel>
<CompositeSlider
value={config.color_map_tile_size}
defaultValue={DEFAULTS.color_map_tile_size}
onChange={handleColorMapTileSizeChanged}
min={1}
max={256}
step={1}
marks
/>
<CompositeNumberInput
value={config.color_map_tile_size}
defaultValue={DEFAULTS.color_map_tile_size}
onChange={handleColorMapTileSizeChanged}
min={1}
max={4096}
step={1}
/>
</FormControl>
</>
);
});
FilterColorMap.displayName = 'FilterColorMap';

View File

@ -0,0 +1,78 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { ContentShuffleProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<ContentShuffleProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['content_shuffle_image_processor'].buildDefaults();
export const FilterContentShuffle = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleWChanged = useCallback(
(v: number) => {
onChange({ ...config, w: v });
},
[config, onChange]
);
const handleHChanged = useCallback(
(v: number) => {
onChange({ ...config, h: v });
},
[config, onChange]
);
const handleFChanged = useCallback(
(v: number) => {
onChange({ ...config, f: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.w')}</FormLabel>
<CompositeSlider
value={config.w}
defaultValue={DEFAULTS.w}
onChange={handleWChanged}
min={0}
max={4096}
marks
/>
<CompositeNumberInput value={config.w} defaultValue={DEFAULTS.w} onChange={handleWChanged} min={0} max={4096} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.h')}</FormLabel>
<CompositeSlider
value={config.h}
defaultValue={DEFAULTS.h}
onChange={handleHChanged}
min={0}
max={4096}
marks
/>
<CompositeNumberInput value={config.h} defaultValue={DEFAULTS.h} onChange={handleHChanged} min={0} max={4096} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.f')}</FormLabel>
<CompositeSlider
value={config.f}
defaultValue={DEFAULTS.f}
onChange={handleFChanged}
min={0}
max={4096}
marks
/>
<CompositeNumberInput value={config.f} defaultValue={DEFAULTS.f} onChange={handleFChanged} min={0} max={4096} />
</FormControl>
</>
);
});
FilterContentShuffle.displayName = 'FilterContentShuffle';

View File

@ -0,0 +1,61 @@
import { Flex, FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { DWOpenposeProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<DWOpenposeProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['dw_openpose_image_processor'].buildDefaults();
export const FilterDWOpenpose = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleDrawBodyChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, draw_body: e.target.checked });
},
[config, onChange]
);
const handleDrawFaceChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, draw_face: e.target.checked });
},
[config, onChange]
);
const handleDrawHandsChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, draw_hands: e.target.checked });
},
[config, onChange]
);
return (
<>
<Flex sx={{ flexDir: 'row', gap: 6 }}>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlnet.body')}</FormLabel>
<Switch defaultChecked={DEFAULTS.draw_body} isChecked={config.draw_body} onChange={handleDrawBodyChanged} />
</FormControl>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlnet.face')}</FormLabel>
<Switch defaultChecked={DEFAULTS.draw_face} isChecked={config.draw_face} onChange={handleDrawFaceChanged} />
</FormControl>
<FormControl w="max-content">
<FormLabel m={0}>{t('controlnet.hands')}</FormLabel>
<Switch
defaultChecked={DEFAULTS.draw_hands}
isChecked={config.draw_hands}
onChange={handleDrawHandsChanged}
/>
</FormControl>
</Flex>
</>
);
});
FilterDWOpenpose.displayName = 'FilterDWOpenpose';

View File

@ -0,0 +1,46 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { DepthAnythingModelSize, DepthAnythingProcessorConfig } from 'features/controlLayers/store/types';
import { isDepthAnythingModelSize } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<DepthAnythingProcessorConfig>;
export const FilterDepthAnything = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleModelSizeChange = useCallback<ComboboxOnChange>(
(v) => {
if (!isDepthAnythingModelSize(v?.value)) {
return;
}
onChange({ ...config, model_size: v.value });
},
[config, onChange]
);
const options: { label: string; value: DepthAnythingModelSize }[] = useMemo(
() => [
{ label: t('controlnet.depthAnythingSmallV2'), value: 'small_v2' },
{ label: t('controlnet.small'), value: 'small' },
{ label: t('controlnet.base'), value: 'base' },
{ label: t('controlnet.large'), value: 'large' },
],
[t]
);
const value = useMemo(() => options.filter((o) => o.value === config.model_size)[0], [options, config.model_size]);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.modelSize')}</FormLabel>
<Combobox value={value} options={options} onChange={handleModelSizeChange} isSearchable={false} />
</FormControl>
</>
);
});
FilterDepthAnything.displayName = 'FilterDepthAnything';

View File

@ -0,0 +1,31 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { HedProcessorConfig } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<HedProcessorConfig>;
export const FilterHed = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleScribbleChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, scribble: e.target.checked });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.scribble')}</FormLabel>
<Switch isChecked={config.scribble} onChange={handleScribbleChanged} />
</FormControl>
</>
);
});
FilterHed.displayName = 'FilterHed';

View File

@ -0,0 +1,31 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { LineartProcessorConfig } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<LineartProcessorConfig>;
export const FilterLineart = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleCoarseChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, coarse: e.target.checked });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.coarse')}</FormLabel>
<Switch isChecked={config.coarse} onChange={handleCoarseChanged} />
</FormControl>
</>
);
});
FilterLineart.displayName = 'FilterLineart';

View File

@ -0,0 +1,73 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { MediapipeFaceProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<MediapipeFaceProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['mediapipe_face_processor'].buildDefaults();
export const FilterMediapipeFace = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleMaxFacesChanged = useCallback(
(v: number) => {
onChange({ ...config, max_faces: v });
},
[config, onChange]
);
const handleMinConfidenceChanged = useCallback(
(v: number) => {
onChange({ ...config, min_confidence: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.maxFaces')}</FormLabel>
<CompositeSlider
value={config.max_faces}
onChange={handleMaxFacesChanged}
defaultValue={DEFAULTS.max_faces}
min={1}
max={20}
marks
/>
<CompositeNumberInput
value={config.max_faces}
onChange={handleMaxFacesChanged}
defaultValue={DEFAULTS.max_faces}
min={1}
max={20}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.minConfidence')}</FormLabel>
<CompositeSlider
value={config.min_confidence}
onChange={handleMinConfidenceChanged}
defaultValue={DEFAULTS.min_confidence}
min={0}
max={1}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.min_confidence}
onChange={handleMinConfidenceChanged}
defaultValue={DEFAULTS.min_confidence}
min={0}
max={1}
step={0.01}
/>
</FormControl>
</>
);
});
FilterMediapipeFace.displayName = 'FilterMediapipeFace';

View File

@ -0,0 +1,75 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { MidasDepthProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<MidasDepthProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['midas_depth_image_processor'].buildDefaults();
export const FilterMidasDepth = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleAMultChanged = useCallback(
(v: number) => {
onChange({ ...config, a_mult: v });
},
[config, onChange]
);
const handleBgThChanged = useCallback(
(v: number) => {
onChange({ ...config, bg_th: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.amult')}</FormLabel>
<CompositeSlider
value={config.a_mult}
onChange={handleAMultChanged}
defaultValue={DEFAULTS.a_mult}
min={0}
max={20}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.a_mult}
onChange={handleAMultChanged}
defaultValue={DEFAULTS.a_mult}
min={0}
max={20}
step={0.01}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.bgth')}</FormLabel>
<CompositeSlider
value={config.bg_th}
onChange={handleBgThChanged}
defaultValue={DEFAULTS.bg_th}
min={0}
max={20}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.bg_th}
onChange={handleBgThChanged}
defaultValue={DEFAULTS.bg_th}
min={0}
max={20}
step={0.01}
/>
</FormControl>
</>
);
});
FilterMidasDepth.displayName = 'FilterMidasDepth';

View File

@ -0,0 +1,75 @@
import { CompositeNumberInput, CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import type { MlsdProcessorConfig } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<MlsdProcessorConfig>;
const DEFAULTS = IMAGE_FILTERS['mlsd_image_processor'].buildDefaults();
export const FilterMlsdImage = memo(({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleThrDChanged = useCallback(
(v: number) => {
onChange({ ...config, thr_d: v });
},
[config, onChange]
);
const handleThrVChanged = useCallback(
(v: number) => {
onChange({ ...config, thr_v: v });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.w')} </FormLabel>
<CompositeSlider
value={config.thr_d}
onChange={handleThrDChanged}
defaultValue={DEFAULTS.thr_d}
min={0}
max={1}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.thr_d}
onChange={handleThrDChanged}
defaultValue={DEFAULTS.thr_d}
min={0}
max={1}
step={0.01}
/>
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.h')} </FormLabel>
<CompositeSlider
value={config.thr_v}
onChange={handleThrVChanged}
defaultValue={DEFAULTS.thr_v}
min={0}
max={1}
step={0.01}
marks
/>
<CompositeNumberInput
value={config.thr_v}
onChange={handleThrVChanged}
defaultValue={DEFAULTS.thr_v}
min={0}
max={1}
step={0.01}
/>
</FormControl>
</>
);
});
FilterMlsdImage.displayName = 'FilterMlsdImage';

View File

@ -0,0 +1,42 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import type { PidiProcessorConfig } from 'features/controlLayers/store/types';
import type { ChangeEvent } from 'react';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import type { FilterComponentProps } from './types';
type Props = FilterComponentProps<PidiProcessorConfig>;
export const FilterPidi = ({ onChange, config }: Props) => {
const { t } = useTranslation();
const handleScribbleChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, scribble: e.target.checked });
},
[config, onChange]
);
const handleSafeChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChange({ ...config, safe: e.target.checked });
},
[config, onChange]
);
return (
<>
<FormControl>
<FormLabel m={0}>{t('controlnet.scribble')}</FormLabel>
<Switch isChecked={config.scribble} onChange={handleScribbleChanged} />
</FormControl>
<FormControl>
<FormLabel m={0}>{t('controlnet.safe')}</FormLabel>
<Switch isChecked={config.safe} onChange={handleSafeChanged} />
</FormControl>
</>
);
};
FilterPidi.displayName = 'FilterPidi';

View File

@ -0,0 +1,77 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { FilterCanny } from 'features/controlLayers/components/Filters/FilterCanny';
import { FilterColorMap } from 'features/controlLayers/components/Filters/FilterColorMap';
import { FilterContentShuffle } from 'features/controlLayers/components/Filters/FilterContentShuffle';
import { FilterDepthAnything } from 'features/controlLayers/components/Filters/FilterDepthAnything';
import { FilterDWOpenpose } from 'features/controlLayers/components/Filters/FilterDWOpenpose';
import { FilterHed } from 'features/controlLayers/components/Filters/FilterHed';
import { FilterLineart } from 'features/controlLayers/components/Filters/FilterLineart';
import { FilterMediapipeFace } from 'features/controlLayers/components/Filters/FilterMediapipeFace';
import { FilterMidasDepth } from 'features/controlLayers/components/Filters/FilterMidasDepth';
import { FilterMlsdImage } from 'features/controlLayers/components/Filters/FilterMlsdImage';
import { FilterPidi } from 'features/controlLayers/components/Filters/FilterPidi';
import { filterConfigChanged } from 'features/controlLayers/store/canvasV2Slice';
import { type FilterConfig, IMAGE_FILTERS } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const FilterSettings = memo(() => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const config = useAppSelector((s) => s.canvasV2.filter.config);
const updateFilter = useCallback(
(config: FilterConfig) => {
dispatch(filterConfigChanged({ config }));
},
[dispatch]
);
if (config.type === 'canny_image_processor') {
return <FilterCanny config={config} onChange={updateFilter} />;
}
if (config.type === 'color_map_image_processor') {
return <FilterColorMap config={config} onChange={updateFilter} />;
}
if (config.type === 'content_shuffle_image_processor') {
return <FilterContentShuffle config={config} onChange={updateFilter} />;
}
if (config.type === 'depth_anything_image_processor') {
return <FilterDepthAnything config={config} onChange={updateFilter} />;
}
if (config.type === 'dw_openpose_image_processor') {
return <FilterDWOpenpose config={config} onChange={updateFilter} />;
}
if (config.type === 'hed_image_processor') {
return <FilterHed config={config} onChange={updateFilter} />;
}
if (config.type === 'lineart_image_processor') {
return <FilterLineart config={config} onChange={updateFilter} />;
}
if (config.type === 'mediapipe_face_processor') {
return <FilterMediapipeFace config={config} onChange={updateFilter} />;
}
if (config.type === 'midas_depth_image_processor') {
return <FilterMidasDepth config={config} onChange={updateFilter} />;
}
if (config.type === 'mlsd_image_processor') {
return <FilterMlsdImage config={config} onChange={updateFilter} />;
}
if (config.type === 'pidi_image_processor') {
return <FilterPidi config={config} onChange={updateFilter} />;
}
return <IAINoContentFallback label={`${t(IMAGE_FILTERS[config.type].labelTKey)} has no settings`} icon={null} />;
});
FilterSettings.displayName = 'Filter';

View File

@ -0,0 +1,54 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { filterSelected } from 'features/controlLayers/store/canvasV2Slice';
import { IMAGE_FILTERS, isFilterType } from 'features/controlLayers/store/types';
import { configSelector } from 'features/system/store/configSelectors';
import { includes, map } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { assert } from 'tsafe';
const selectDisabledProcessors = createMemoizedSelector(
configSelector,
(config) => config.sd.disabledControlNetProcessors
);
export const FilterTypeSelect = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const filterType = useAppSelector((s) => s.canvasV2.filter.config.type);
const disabledProcessors = useAppSelector(selectDisabledProcessors);
const options = useMemo(() => {
return map(IMAGE_FILTERS, ({ labelTKey }, type) => ({ value: type, label: t(labelTKey) })).filter(
(o) => !includes(disabledProcessors, o.value)
);
}, [disabledProcessors, t]);
const _onChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v) {
return;
}
assert(isFilterType(v.value));
dispatch(filterSelected({ type: v.value }));
},
[dispatch]
);
const value = useMemo(() => options.find((o) => o.value === filterType) ?? null, [options, filterType]);
return (
<Flex gap={2}>
<FormControl>
<InformationalPopover feature="controlNetProcessor">
<FormLabel m={0}>{t('controlLayers.filter')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={options} onChange={_onChange} isSearchable={false} isClearable={false} />
</FormControl>
</Flex>
);
});
FilterTypeSelect.displayName = 'FilterTypeSelect';

View File

@ -0,0 +1,17 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useFilter } from 'features/controlLayers/components/Filters/Filter';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
const FilterWrapper = (props: PropsWithChildren) => {
const isPreviewDisabled = useAppSelector((s) => s.canvasV2.selectedEntityIdentifier?.type !== 'layer');
const filter = useFilter();
return (
<Flex flexDir="column" gap={3} w="full" h="full">
{props.children}
</Flex>
);
};
export default memo(FilterWrapper);

View File

@ -0,0 +1,6 @@
import type { FilterConfig } from 'features/controlLayers/store/types';
export type FilterComponentProps<T extends FilterConfig> = {
onChange: (config: T) => void;
config: T;
};

View File

@ -1,4 +1,4 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library'; import { Spacer } from '@invoke-ai/ui-library';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer'; import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton'; import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle'; import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
@ -18,12 +18,11 @@ type Props = {
export const Layer = memo(({ id }: Props) => { export const Layer = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'layer' }), [id]); const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'layer' }), [id]);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: false });
return ( return (
<EntityIdentifierContext.Provider value={entityIdentifier}> <EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer> <CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={onToggle}> <CanvasEntityHeader>
<CanvasEntityEnabledToggle /> <CanvasEntityEnabledToggle />
<CanvasEntityTitle /> <CanvasEntityTitle />
<Spacer /> <Spacer />
@ -31,7 +30,7 @@ export const Layer = memo(({ id }: Props) => {
<LayerActionsMenu /> <LayerActionsMenu />
<CanvasEntityDeleteButton /> <CanvasEntityDeleteButton />
</CanvasEntityHeader> </CanvasEntityHeader>
{isOpen && <LayerSettings />} <LayerSettings />
</CanvasEntityContainer> </CanvasEntityContainer>
</EntityIdentifierContext.Provider> </EntityIdentifierContext.Provider>
); );

View File

@ -0,0 +1,66 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginEndStepPct';
import { Weight } from 'features/controlLayers/components/common/Weight';
import { ControlAdapterControlModeSelect } from 'features/controlLayers/components/ControlAdapter/ControlAdapterControlModeSelect';
import { ControlAdapterModel } from 'features/controlLayers/components/ControlAdapter/ControlAdapterModel';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import {
layerControlAdapterBeginEndStepPctChanged,
layerControlAdapterControlModeChanged,
layerControlAdapterModelChanged,
layerControlAdapterWeightChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import type { ControlModeV2, ControlNetConfig, T2IAdapterConfig } from 'features/controlLayers/store/types';
import { memo, useCallback } from 'react';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
type Props = {
controlAdapter: ControlNetConfig | T2IAdapterConfig;
};
export const LayerControlAdapter = memo(({ controlAdapter }: Props) => {
const dispatch = useAppDispatch();
const { id } = useEntityIdentifierContext();
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
dispatch(layerControlAdapterBeginEndStepPctChanged({ id, beginEndStepPct }));
},
[dispatch, id]
);
const onChangeControlMode = useCallback(
(controlMode: ControlModeV2) => {
dispatch(layerControlAdapterControlModeChanged({ id, controlMode }));
},
[dispatch, id]
);
const onChangeWeight = useCallback(
(weight: number) => {
dispatch(layerControlAdapterWeightChanged({ id, weight }));
},
[dispatch, id]
);
const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
dispatch(layerControlAdapterModelChanged({ id, modelConfig }));
},
[dispatch, id]
);
return (
<Flex flexDir="column" gap={3} position="relative" w="full">
<ControlAdapterModel modelKey={controlAdapter.model?.key ?? null} onChange={onChangeModel} />
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
{controlAdapter.type === 'controlnet' && (
<ControlAdapterControlModeSelect controlMode={controlAdapter.controlMode} onChange={onChangeControlMode} />
)}
</Flex>
);
});
LayerControlAdapter.displayName = 'LayerControlAdapter';

View File

@ -1,10 +1,22 @@
import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings'; import { CanvasEntitySettings } from 'features/controlLayers/components/common/CanvasEntitySettings';
import { LayerControlAdapter } from 'features/controlLayers/components/Layer/LayerControlAdapter';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useLayerControlAdapter } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { memo } from 'react'; import { memo } from 'react';
export const LayerSettings = memo(() => { export const LayerSettings = memo(() => {
const entityIdentifier = useEntityIdentifierContext(); const entityIdentifier = useEntityIdentifierContext();
return <CanvasEntitySettings>PLACEHOLDER</CanvasEntitySettings>; const controlAdapter = useLayerControlAdapter(entityIdentifier);
if (!controlAdapter) {
return null;
}
return (
<CanvasEntitySettings>
<LayerControlAdapter controlAdapter={controlAdapter} />
</CanvasEntitySettings>
);
}); });
LayerSettings.displayName = 'LayerSettings'; LayerSettings.displayName = 'LayerSettings';

View File

@ -1,6 +1,8 @@
import { Flex } from '@invoke-ai/ui-library'; import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $socket } from 'app/hooks/useSocketIO';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { useAppStore } from 'app/store/storeHooks'; import { useAppStore } from 'app/store/nanostores/store';
import { HeadsUpDisplay } from 'features/controlLayers/components/HeadsUpDisplay'; import { HeadsUpDisplay } from 'features/controlLayers/components/HeadsUpDisplay';
import { $canvasManager, CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { $canvasManager, CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import Konva from 'konva'; import Konva from 'konva';
@ -17,6 +19,7 @@ Konva.showWarnings = false;
const useStageRenderer = (stage: Konva.Stage, container: HTMLDivElement | null, asPreview: boolean) => { const useStageRenderer = (stage: Konva.Stage, container: HTMLDivElement | null, asPreview: boolean) => {
const store = useAppStore(); const store = useAppStore();
const socket = useStore($socket);
const dpr = useDevicePixelRatio({ round: false }); const dpr = useDevicePixelRatio({ round: false });
useLayoutEffect(() => { useLayoutEffect(() => {
@ -27,12 +30,17 @@ const useStageRenderer = (stage: Konva.Stage, container: HTMLDivElement | null,
return () => {}; return () => {};
} }
const manager = new CanvasManager(stage, container, store); if (!socket) {
log.debug('Socket not connected, skipping initialization');
return () => {};
}
const manager = new CanvasManager(stage, container, store, socket);
$canvasManager.set(manager); $canvasManager.set(manager);
console.log(manager); console.log(manager);
const cleanup = manager.initialize(); const cleanup = manager.initialize();
return cleanup; return cleanup;
}, [asPreview, container, stage, store]); }, [asPreview, container, socket, stage, store]);
useLayoutEffect(() => { useLayoutEffect(() => {
Konva.pixelRatio = dpr; Konva.pixelRatio = dpr;

View File

@ -2,7 +2,8 @@ import { Button, IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager'; import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { memo, useCallback, useEffect, useState } from 'react'; import { $transformingEntity } from 'features/controlLayers/store/canvasV2Slice';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook'; import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { PiResizeBold } from 'react-icons/pi'; import { PiResizeBold } from 'react-icons/pi';
@ -10,20 +11,11 @@ import { PiResizeBold } from 'react-icons/pi';
export const TransformToolButton = memo(() => { export const TransformToolButton = memo(() => {
const { t } = useTranslation(); const { t } = useTranslation();
const canvasManager = useStore($canvasManager); const canvasManager = useStore($canvasManager);
const [isTransforming, setIsTransforming] = useState(false); const transformingEntity = useStore($transformingEntity);
const isDisabled = useAppSelector( const isDisabled = useAppSelector(
(s) => s.canvasV2.selectedEntityIdentifier === null || s.canvasV2.session.isStaging (s) => s.canvasV2.selectedEntityIdentifier === null || s.canvasV2.session.isStaging
); );
useEffect(() => {
if (!canvasManager) {
return;
}
return canvasManager.stateApi.$transformingEntity.listen((newValue) => {
setIsTransforming(Boolean(newValue));
});
}, [canvasManager]);
const onTransform = useCallback(() => { const onTransform = useCallback(() => {
if (!canvasManager) { if (!canvasManager) {
return; return;
@ -47,7 +39,7 @@ export const TransformToolButton = memo(() => {
useHotkeys(['ctrl+t', 'meta+t'], onTransform, { enabled: !isDisabled }, [isDisabled, onTransform]); useHotkeys(['ctrl+t', 'meta+t'], onTransform, { enabled: !isDisabled }, [isDisabled, onTransform]);
if (isTransforming) { if (transformingEntity) {
return ( return (
<> <>
<Button onClick={onApplyTransformation}>{t('common.apply')}</Button> <Button onClick={onApplyTransformation}>{t('common.apply')}</Button>

View File

@ -1,8 +1,12 @@
import { MenuItem } from '@invoke-ai/ui-library'; import { MenuDivider, MenuItem } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useLayerUseAsControl } from 'features/controlLayers/hooks/useLayerControlAdapter';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import { import {
$filteringEntity,
entityArrangedBackwardOne, entityArrangedBackwardOne,
entityArrangedForwardOne, entityArrangedForwardOne,
entityArrangedToBack, entityArrangedToBack,
@ -20,7 +24,11 @@ import {
PiArrowLineDownBold, PiArrowLineDownBold,
PiArrowLineUpBold, PiArrowLineUpBold,
PiArrowUpBold, PiArrowUpBold,
PiCheckBold,
PiQuestionMarkBold,
PiStarHalfBold,
PiTrashSimpleBold, PiTrashSimpleBold,
PiXBold,
} from 'react-icons/pi'; } from 'react-icons/pi';
const getIndexAndCount = ( const getIndexAndCount = (
@ -52,18 +60,15 @@ const getIndexAndCount = (
export const CanvasEntityActionMenuItems = memo(() => { export const CanvasEntityActionMenuItems = memo(() => {
const { t } = useTranslation(); const { t } = useTranslation();
const canvasManager = useStore($canvasManager);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext(); const entityIdentifier = useEntityIdentifierContext();
const useAsControl = useLayerUseAsControl(entityIdentifier);
const selectValidActions = useMemo( const selectValidActions = useMemo(
() => () =>
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => { createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
const { index, count } = getIndexAndCount(canvasV2, entityIdentifier); const { index, count } = getIndexAndCount(canvasV2, entityIdentifier);
return { return {
isArrangeable:
entityIdentifier.type === 'layer' ||
entityIdentifier.type === 'control_adapter' ||
entityIdentifier.type === 'regional_guidance',
isDeleteable: entityIdentifier.type !== 'inpaint_mask',
canMoveForwardOne: index < count - 1, canMoveForwardOne: index < count - 1,
canMoveBackwardOne: index > 0, canMoveBackwardOne: index > 0,
canMoveToFront: index < count - 1, canMoveToFront: index < count - 1,
@ -75,6 +80,18 @@ export const CanvasEntityActionMenuItems = memo(() => {
const validActions = useAppSelector(selectValidActions); const validActions = useAppSelector(selectValidActions);
const isArrangeable = useMemo(
() => entityIdentifier.type === 'layer' || entityIdentifier.type === 'regional_guidance',
[entityIdentifier.type]
);
const isDeleteable = useMemo(
() => entityIdentifier.type === 'layer' || entityIdentifier.type === 'regional_guidance',
[entityIdentifier.type]
);
const isFilterable = useMemo(() => entityIdentifier.type === 'layer', [entityIdentifier.type]);
const isUseAsControlable = useMemo(() => entityIdentifier.type === 'layer', [entityIdentifier.type]);
const deleteEntity = useCallback(() => { const deleteEntity = useCallback(() => {
dispatch(entityDeleted({ entityIdentifier })); dispatch(entityDeleted({ entityIdentifier }));
}, [dispatch, entityIdentifier]); }, [dispatch, entityIdentifier]);
@ -93,10 +110,23 @@ export const CanvasEntityActionMenuItems = memo(() => {
const moveToBack = useCallback(() => { const moveToBack = useCallback(() => {
dispatch(entityArrangedToBack({ entityIdentifier })); dispatch(entityArrangedToBack({ entityIdentifier }));
}, [dispatch, entityIdentifier]); }, [dispatch, entityIdentifier]);
const filter = useCallback(() => {
$filteringEntity.set(entityIdentifier);
}, [entityIdentifier]);
const debug = useCallback(() => {
if (!canvasManager) {
return;
}
const entity = canvasManager.stateApi.getEntity(entityIdentifier);
if (!entity) {
return;
}
console.debug(entity);
}, [canvasManager, entityIdentifier]);
return ( return (
<> <>
{validActions.isArrangeable && ( {isArrangeable && (
<> <>
<MenuItem onClick={moveToFront} isDisabled={!validActions.canMoveToFront} icon={<PiArrowLineUpBold />}> <MenuItem onClick={moveToFront} isDisabled={!validActions.canMoveToFront} icon={<PiArrowLineUpBold />}>
{t('controlLayers.moveToFront')} {t('controlLayers.moveToFront')}
@ -112,14 +142,29 @@ export const CanvasEntityActionMenuItems = memo(() => {
</MenuItem> </MenuItem>
</> </>
)} )}
{isFilterable && (
<MenuItem onClick={filter} icon={<PiStarHalfBold />}>
{t('common.filter')}
</MenuItem>
)}
{isUseAsControlable && (
<MenuItem onClick={useAsControl.toggle} icon={useAsControl.hasControlAdapter ? <PiXBold /> : <PiCheckBold />}>
{useAsControl.hasControlAdapter ? t('common.removeControl') : t('common.useAsControl')}
</MenuItem>
)}
<MenuDivider />
<MenuItem onClick={resetEntity} icon={<PiArrowCounterClockwiseBold />}> <MenuItem onClick={resetEntity} icon={<PiArrowCounterClockwiseBold />}>
{t('accessibility.reset')} {t('accessibility.reset')}
</MenuItem> </MenuItem>
{validActions.isDeleteable && ( {isDeleteable && (
<MenuItem onClick={deleteEntity} icon={<PiTrashSimpleBold />} color="error.300"> <MenuItem onClick={deleteEntity} icon={<PiTrashSimpleBold />} color="error.300">
{t('common.delete')} {t('common.delete')}
</MenuItem> </MenuItem>
)} )}
<MenuDivider />
<MenuItem onClick={debug} icon={<PiQuestionMarkBold />} color="warn.300">
{t('common.debug')}
</MenuItem>
</> </>
); );
}); });

View File

@ -11,6 +11,7 @@ import { PiCheckBold } from 'react-icons/pi';
export const CanvasEntityEnabledToggle = memo(() => { export const CanvasEntityEnabledToggle = memo(() => {
const { t } = useTranslation(); const { t } = useTranslation();
const entityIdentifier = useEntityIdentifierContext(); const entityIdentifier = useEntityIdentifierContext();
const isEnabled = useEntityIsEnabled(entityIdentifier); const isEnabled = useEntityIsEnabled(entityIdentifier);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const onClick = useCallback(() => { const onClick = useCallback(() => {

View File

@ -2,11 +2,11 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { caAdded, ipaAdded, rgIPAdapterAdded } from 'features/controlLayers/store/canvasV2Slice'; import { caAdded, ipaAdded, rgIPAdapterAdded } from 'features/controlLayers/store/canvasV2Slice';
import { import {
CA_PROCESSOR_DATA, IMAGE_FILTERS,
initialControlNetV2, initialControlNetV2,
initialIPAdapterV2, initialIPAdapterV2,
initialT2IAdapterV2, initialT2IAdapterV2,
isProcessorTypeV2, isFilterType,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
@ -30,8 +30,8 @@ export const useAddCALayer = () => {
} }
const defaultPreprocessor = model.default_settings?.preprocessor; const defaultPreprocessor = model.default_settings?.preprocessor;
const processorConfig = isProcessorTypeV2(defaultPreprocessor) const processorConfig = isFilterType(defaultPreprocessor)
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults(baseModel) ? IMAGE_FILTERS[defaultPreprocessor].buildDefaults(baseModel)
: null; : null;
const initialConfig = deepClone(model.type === 'controlnet' ? initialControlNetV2 : initialT2IAdapterV2); const initialConfig = deepClone(model.type === 'controlnet' ? initialControlNetV2 : initialT2IAdapterV2);

View File

@ -0,0 +1,8 @@
import { useStore } from '@nanostores/react';
import { $canvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
export const useEntityAdapter = (entityIdentifier: CanvasEntityIdentifier) => {
const canvasManager = useStore($canvasManager);
console.log(canvasManager);
};

View File

@ -0,0 +1,57 @@
import { createMemoizedAppSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { layerUsedAsControlChanged, selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { selectLayer } from 'features/controlLayers/store/layersReducers';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { initialControlNetV2, initialT2IAdapterV2 } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useCallback, useMemo } from 'react';
import { useControlNetAndT2IAdapterModels } from 'services/api/hooks/modelsByType';
import type { ControlNetModelConfig, T2IAdapterModelConfig } from 'services/api/types';
export const useLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier) => {
const selectControlAdapter = useMemo(
() =>
createMemoizedAppSelector(selectCanvasV2Slice, (canvasV2) => {
const layer = selectLayer(canvasV2, entityIdentifier.id);
if (!layer) {
return null;
}
return layer.controlAdapter;
}),
[entityIdentifier]
);
const controlAdapter = useAppSelector(selectControlAdapter);
return controlAdapter;
};
export const useLayerUseAsControl = (entityIdentifier: CanvasEntityIdentifier) => {
const dispatch = useAppDispatch();
const [modelConfigs] = useControlNetAndT2IAdapterModels();
const baseModel = useAppSelector((s) => s.canvasV2.params.model?.base);
const controlAdapter = useLayerControlAdapter(entityIdentifier);
const model: ControlNetModelConfig | T2IAdapterModelConfig | null = useMemo(() => {
// prefer to use a model that matches the base model
const compatibleModels = modelConfigs.filter((m) => (baseModel ? m.base === baseModel : true));
return compatibleModels[0] ?? modelConfigs[0] ?? null;
}, [baseModel, modelConfigs]);
const toggle = useCallback(() => {
if (controlAdapter) {
dispatch(layerUsedAsControlChanged({ id: entityIdentifier.id, controlAdapter: null }));
return;
}
const newControlAdapter = deepClone(model?.type === 't2i_adapter' ? initialT2IAdapterV2 : initialControlNetV2);
if (model) {
newControlAdapter.model = zModelIdentifierField.parse(model);
}
dispatch(layerUsedAsControlChanged({ id: entityIdentifier.id, controlAdapter: newControlAdapter }));
}, [controlAdapter, dispatch, entityIdentifier.id, model]);
return { hasControlAdapter: Boolean(controlAdapter), toggle };
};

View File

@ -0,0 +1,132 @@
import type { JSONObject } from 'common/types';
import { parseify } from 'common/util/serialize';
import type { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasImageState } from 'features/controlLayers/store/types';
import { IMAGE_FILTERS, imageDTOToImageObject } from 'features/controlLayers/store/types';
import type { Logger } from 'roarr';
import { getImageDTO } from 'services/api/endpoints/images';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import type { InvocationCompleteEvent } from 'services/events/types';
import { assert } from 'tsafe';
const TYPE = 'entity_filter_preview';
export class CanvasFilter {
readonly type = TYPE;
id: string;
path: string[];
parent: CanvasLayerAdapter;
manager: CanvasManager;
log: Logger;
imageState: CanvasImageState | null = null;
constructor(parent: CanvasLayerAdapter) {
this.id = getPrefixedId(this.type);
this.parent = parent;
this.manager = parent.manager;
this.path = this.parent.path.concat(this.id);
this.log = this.manager.buildLogger(this.getLoggingContext);
this.log.trace('Creating filter');
}
previewFilter = async () => {
const { config } = this.manager.stateApi.getFilterState();
this.log.trace({ config }, 'Previewing filter');
const dispatch = this.manager.stateApi._store.dispatch;
const imageDTO = await this.parent.renderer.rasterize();
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const filterNode = IMAGE_FILTERS[config.type].buildNode(imageDTO, config as never);
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: {
nodes: {
[filterNode.id]: {
...filterNode,
// Control images are always intermediate - do not save to gallery
// is_intermediate: true,
is_intermediate: false, // false for testing
},
},
edges: [],
},
origin: this.id,
runs: 1,
},
};
// Listen for the filter processing completion event
const listener = async (event: InvocationCompleteEvent) => {
if (event.origin !== this.id || event.invocation_source_id !== filterNode.id) {
return;
}
this.log.trace({ event: parseify(event) }, 'Handling filter processing completion');
const { result } = event;
assert(result.type === 'image_output', `Processor did not return an image output, got: ${result}`);
const imageDTO = await getImageDTO(result.image.image_name);
assert(imageDTO, "Failed to fetch processor output's image DTO");
this.imageState = imageDTOToImageObject(imageDTO);
this.parent.renderer.clearBuffer();
await this.parent.renderer.setBuffer(this.imageState);
this.parent.renderer.hideObjects([this.imageState.id]);
this.manager.socket.off('invocation_complete', listener);
};
this.manager.socket.on('invocation_complete', listener);
this.log.trace({ enqueueBatchArg: parseify(enqueueBatchArg) }, 'Enqueuing filter batch');
dispatch(
queueApi.endpoints.enqueueBatch.initiate(enqueueBatchArg, {
fixedCacheKey: 'enqueueBatch',
})
);
};
applyFilter = () => {
this.log.trace('Applying filter');
if (!this.imageState) {
this.log.warn('No image state to apply filter to');
return;
}
this.parent.renderer.commitBuffer();
const rect = this.parent.transformer.getRelativeRect();
this.manager.stateApi.rasterizeEntity({
entityIdentifier: this.parent.getEntityIdentifier(),
imageObject: this.imageState,
position: { x: Math.round(rect.x), y: Math.round(rect.y) },
});
this.parent.renderer.showObjects();
this.manager.stateApi.$filteringEntity.set(null);
this.imageState = null;
};
cancelFilter = () => {
this.log.trace('Cancelling filter');
this.parent.renderer.clearBuffer();
this.parent.renderer.showObjects();
this.manager.stateApi.$filteringEntity.set(null);
this.imageState = null;
};
destroy = () => {
this.log.trace('Destroying filter');
};
repr = () => {
return {
id: this.id,
type: this.type,
};
};
getLoggingContext = (): JSONObject => {
return { ...this.parent.getLoggingContext(), path: this.path.join('.') };
};
}

View File

@ -1,5 +1,6 @@
import type { JSONObject } from 'common/types'; import type { JSONObject } from 'common/types';
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { CanvasFilter } from 'features/controlLayers/konva/CanvasFilter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasObjectRenderer } from 'features/controlLayers/konva/CanvasObjectRenderer'; import { CanvasObjectRenderer } from 'features/controlLayers/konva/CanvasObjectRenderer';
import { CanvasTransformer } from 'features/controlLayers/konva/CanvasTransformer'; import { CanvasTransformer } from 'features/controlLayers/konva/CanvasTransformer';
@ -23,6 +24,7 @@ export class CanvasLayerAdapter {
}; };
transformer: CanvasTransformer; transformer: CanvasTransformer;
renderer: CanvasObjectRenderer; renderer: CanvasObjectRenderer;
filter: CanvasFilter;
isFirstRender: boolean = true; isFirstRender: boolean = true;
@ -47,6 +49,7 @@ export class CanvasLayerAdapter {
this.renderer = new CanvasObjectRenderer(this); this.renderer = new CanvasObjectRenderer(this);
this.transformer = new CanvasTransformer(this); this.transformer = new CanvasTransformer(this);
this.filter = new CanvasFilter(this);
} }
/** /**

View File

@ -1,6 +1,6 @@
import type { Store } from '@reduxjs/toolkit'; import type { AppSocket } from 'app/hooks/useSocketIO';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { AppStore } from 'app/store/store';
import type { JSONObject } from 'common/types'; import type { JSONObject } from 'common/types';
import { MAX_CANVAS_SCALE, MIN_CANVAS_SCALE } from 'features/controlLayers/konva/constants'; import { MAX_CANVAS_SCALE, MIN_CANVAS_SCALE } from 'features/controlLayers/konva/constants';
import { import {
@ -12,7 +12,7 @@ import {
} from 'features/controlLayers/konva/util'; } from 'features/controlLayers/konva/util';
import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker'; import type { Extents, ExtentsResult, GetBboxTask, WorkerLogMessage } from 'features/controlLayers/konva/worker';
import type { CanvasV2State, Coordinate, Dimensions, GenerationMode, Rect } from 'features/controlLayers/store/types'; import type { CanvasV2State, Coordinate, Dimensions, GenerationMode, Rect } from 'features/controlLayers/store/types';
import { isValidLayer } from 'features/nodes/util/graph/generation/addLayers'; import { isValidLayerWithoutControlAdapter } from 'features/nodes/util/graph/generation/addLayers';
import type Konva from 'konva'; import type Konva from 'konva';
import { clamp } from 'lodash-es'; import { clamp } from 'lodash-es';
import { atom } from 'nanostores'; import { atom } from 'nanostores';
@ -49,8 +49,9 @@ export class CanvasManager {
log: Logger; log: Logger;
workerLog: Logger; workerLog: Logger;
socket: AppSocket;
_store: Store<RootState>; _store: AppStore;
_prevState: CanvasV2State; _prevState: CanvasV2State;
_isFirstRender: boolean = true; _isFirstRender: boolean = true;
_isDebugging: boolean = false; _isDebugging: boolean = false;
@ -58,12 +59,13 @@ export class CanvasManager {
_worker: Worker = new Worker(new URL('./worker.ts', import.meta.url), { type: 'module', name: 'worker' }); _worker: Worker = new Worker(new URL('./worker.ts', import.meta.url), { type: 'module', name: 'worker' });
_tasks: Map<string, { task: GetBboxTask; onComplete: (extents: Extents | null) => void }> = new Map(); _tasks: Map<string, { task: GetBboxTask; onComplete: (extents: Extents | null) => void }> = new Map();
constructor(stage: Konva.Stage, container: HTMLDivElement, store: Store<RootState>) { constructor(stage: Konva.Stage, container: HTMLDivElement, store: AppStore, socket: AppSocket) {
this.id = getPrefixedId(this.type); this.id = getPrefixedId(this.type);
this.path = [this.id]; this.path = [this.id];
this.stage = stage; this.stage = stage;
this.container = container; this.container = container;
this._store = store; this._store = store;
this.socket = socket;
this.stateApi = new CanvasStateApi(this._store, this); this.stateApi = new CanvasStateApi(this._store, this);
this._prevState = this.stateApi.getState(); this._prevState = this.stateApi.getState();
@ -547,7 +549,7 @@ export class CanvasManager {
stageClone.x(0); stageClone.x(0);
stageClone.y(0); stageClone.y(0);
const validLayers = layersState.entities.filter(isValidLayer); const validLayers = layersState.entities.filter(isValidLayerWithoutControlAdapter);
// getLayers() returns the internal `children` array of the stage directly - calling destroy on a layer will // getLayers() returns the internal `children` array of the stage directly - calling destroy on a layer will
// mutate that array. We need to clone the array to avoid mutating the original. // mutate that array. We need to clone the array to avoid mutating the original.
for (const konvaLayer of stageClone.getLayers().slice()) { for (const konvaLayer of stageClone.getLayers().slice()) {

View File

@ -20,7 +20,7 @@ import type {
import { imageDTOToImageObject } from 'features/controlLayers/store/types'; import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import Konva from 'konva'; import Konva from 'konva';
import type { Logger } from 'roarr'; import type { Logger } from 'roarr';
import { uploadImage } from 'services/api/endpoints/images'; import { getImageDTO, uploadImage } from 'services/api/endpoints/images';
import type { ImageCategory, ImageDTO } from 'services/api/types'; import type { ImageCategory, ImageDTO } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
@ -62,6 +62,12 @@ export class CanvasObjectRenderer {
*/ */
renderers: Map<string, AnyObjectRenderer> = new Map(); renderers: Map<string, AnyObjectRenderer> = new Map();
/**
* A cache of the rasterized image data URL. If the cache is null, the parent has not been rasterized since its last
* change.
*/
rasterizedImageCache: string | null = null;
/** /**
* A object containing singleton Konva nodes. * A object containing singleton Konva nodes.
*/ */
@ -162,6 +168,19 @@ export class CanvasObjectRenderer {
didRender = (await this.renderObject(this.buffer)) || didRender; didRender = (await this.renderObject(this.buffer)) || didRender;
} }
if (didRender && this.rasterizedImageCache) {
const hasOneObject = this.renderers.size === 1;
const firstObject = Array.from(this.renderers.values())[0];
if (
hasOneObject &&
firstObject &&
firstObject.state.type === 'image' &&
firstObject.state.image.image_name !== this.rasterizedImageCache
) {
this.rasterizedImageCache = null;
}
}
return didRender; return didRender;
}; };
@ -313,6 +332,18 @@ export class CanvasObjectRenderer {
this.buffer = null; this.buffer = null;
}; };
hideObjects = (except: string[] = []) => {
for (const renderer of this.renderers.values()) {
renderer.setVisibility(except.includes(renderer.id));
}
};
showObjects = (except: string[] = []) => {
for (const renderer of this.renderers.values()) {
renderer.setVisibility(!except.includes(renderer.id));
}
};
/** /**
* Determines if the objects in the renderer require a pixel bbox calculation. * Determines if the objects in the renderer require a pixel bbox calculation.
* *
@ -345,15 +376,33 @@ export class CanvasObjectRenderer {
return this.renderers.size > 0 || this.buffer !== null; return this.renderers.size > 0 || this.buffer !== null;
}; };
rasterize = async () => { /**
* Rasterizes the parent entity. If the entity has a rasterization cache, the cached image is returned after
* validating that it exists on the server.
*
* The rasterization cache is reset when the entity's objects change. The buffer object is not considered part of the
* entity's objects for this purpose.
*
* @returns A promise that resolves to the rasterized image DTO.
*/
rasterize = async (): Promise<ImageDTO> => {
this.log.debug('Rasterizing entity'); this.log.debug('Rasterizing entity');
let imageDTO: ImageDTO | null = null;
if (this.rasterizedImageCache) {
imageDTO = await getImageDTO(this.rasterizedImageCache);
}
if (imageDTO) {
return imageDTO;
}
const rect = this.parent.transformer.getRelativeRect(); const rect = this.parent.transformer.getRelativeRect();
const blob = await this.getBlob({ rect }); const blob = await this.getBlob({ rect });
if (this.manager._isDebugging) { if (this.manager._isDebugging) {
previewBlob(blob, 'Rasterized entity'); previewBlob(blob, 'Rasterized entity');
} }
const imageDTO = await uploadImage(blob, `${this.id}_rasterized.png`, 'other', true); imageDTO = await uploadImage(blob, `${this.id}_rasterized.png`, 'other', true);
const imageObject = imageDTOToImageObject(imageDTO); const imageObject = imageDTOToImageObject(imageDTO);
await this.renderObject(imageObject, true); await this.renderObject(imageObject, true);
this.manager.stateApi.rasterizeEntity({ this.manager.stateApi.rasterizeEntity({
@ -361,6 +410,10 @@ export class CanvasObjectRenderer {
imageObject, imageObject,
position: { x: Math.round(rect.x), y: Math.round(rect.y) }, position: { x: Math.round(rect.x), y: Math.round(rect.y) },
}); });
this.rasterizedImageCache = imageDTO.image_name;
return imageDTO;
}; };
getBlob = ({ rect }: { rect?: Rect }): Promise<Blob> => { getBlob = ({ rect }: { rect?: Rect }): Promise<Blob> => {

View File

@ -1,11 +1,11 @@
import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library'; import { $alt, $ctrl, $meta, $shift } from '@invoke-ai/ui-library';
import type { Store } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store'; import type { AppStore } from 'app/store/store';
import type { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter'; import type { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter'; import type { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter';
import { import {
$filteringEntity,
$isDrawing, $isDrawing,
$isMouseDown, $isMouseDown,
$lastAddedPoint, $lastAddedPoint,
@ -15,6 +15,7 @@ import {
$shouldShowStagedImage, $shouldShowStagedImage,
$spaceKey, $spaceKey,
$stageAttrs, $stageAttrs,
$transformingEntity,
bboxChanged, bboxChanged,
brushWidthChanged, brushWidthChanged,
entityBrushLineAdded, entityBrushLineAdded,
@ -81,10 +82,10 @@ type EntityStateAndAdapter =
const log = logger('canvas'); const log = logger('canvas');
export class CanvasStateApi { export class CanvasStateApi {
_store: Store<RootState>; _store: AppStore;
manager: CanvasManager; manager: CanvasManager;
constructor(store: Store<RootState>, manager: CanvasManager) { constructor(store: AppStore, manager: CanvasManager) {
this._store = store; this._store = store;
this.manager = manager; this.manager = manager;
} }
@ -188,6 +189,9 @@ export class CanvasStateApi {
getLogLevel = () => { getLogLevel = () => {
return this._store.getState().system.consoleLogLevel; return this._store.getState().system.consoleLogLevel;
}; };
getFilterState = () => {
return this._store.getState().canvasV2.filter;
};
getEntity(identifier: CanvasEntityIdentifier): EntityStateAndAdapter | null { getEntity(identifier: CanvasEntityIdentifier): EntityStateAndAdapter | null {
const state = this.getState(); const state = this.getState();
@ -256,7 +260,9 @@ export class CanvasStateApi {
return currentFill; return currentFill;
}; };
$transformingEntity: WritableAtom<CanvasEntityIdentifier | null> = atom(); $transformingEntity = $transformingEntity;
$filteringEntity = $filteringEntity;
$toolState: WritableAtom<CanvasV2State['tool']> = atom(); $toolState: WritableAtom<CanvasV2State['tool']> = atom();
$currentFill: WritableAtom<RgbaColor> = atom(); $currentFill: WritableAtom<RgbaColor> = atom();
$selectedEntity: WritableAtom<EntityStateAndAdapter | null> = atom(); $selectedEntity: WritableAtom<EntityStateAndAdapter | null> = atom();

View File

@ -155,7 +155,7 @@ export class CanvasTool {
const isMouseDown = this.manager.stateApi.$isMouseDown.get(); const isMouseDown = this.manager.stateApi.$isMouseDown.get();
const tool = toolState.selected; const tool = toolState.selected;
console.log(selectedEntity);
const isDrawableEntity = const isDrawableEntity =
selectedEntity?.state.type === 'regional_guidance' || selectedEntity?.state.type === 'regional_guidance' ||
selectedEntity?.state.type === 'layer' || selectedEntity?.state.type === 'layer' ||

View File

@ -33,9 +33,10 @@ import type {
EntityMovedPayload, EntityMovedPayload,
EntityRasterizedPayload, EntityRasterizedPayload,
EntityRectAddedPayload, EntityRectAddedPayload,
FilterConfig,
StageAttrs, StageAttrs,
} from './types'; } from './types';
import { RGBA_RED } from './types'; import { IMAGE_FILTERS, RGBA_RED } from './types';
const initialState: CanvasV2State = { const initialState: CanvasV2State = {
_version: 3, _version: 3,
@ -133,6 +134,10 @@ const initialState: CanvasV2State = {
stagedImages: [], stagedImages: [],
selectedStagedImageIndex: 0, selectedStagedImageIndex: 0,
}, },
filter: {
autoProcess: true,
config: IMAGE_FILTERS.canny_image_processor.buildDefaults(),
},
}; };
export function selectEntity(state: CanvasV2State, { id, type }: CanvasEntityIdentifier) { export function selectEntity(state: CanvasV2State, { id, type }: CanvasEntityIdentifier) {
@ -222,11 +227,12 @@ export const canvasV2Slice = createSlice({
} else if (entity.type === 'layer') { } else if (entity.type === 'layer') {
entity.objects = [imageObject]; entity.objects = [imageObject];
entity.position = position; entity.position = position;
entity.imageCache = imageObject.image.image_name;
state.layers.imageCache = null; state.layers.imageCache = null;
} else if (entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') { } else if (entity.type === 'inpaint_mask' || entity.type === 'regional_guidance') {
entity.objects = [imageObject]; entity.objects = [imageObject];
entity.position = position; entity.position = position;
entity.imageCache = null; entity.imageCache = imageObject.image.image_name;
} else { } else {
assert(false, 'Not implemented'); assert(false, 'Not implemented');
} }
@ -354,6 +360,12 @@ export const canvasV2Slice = createSlice({
state.ipAdapters.entities = []; state.ipAdapters.entities = [];
state.controlAdapters.entities = []; state.controlAdapters.entities = [];
}, },
filterSelected: (state, action: PayloadAction<{ type: FilterConfig['type'] }>) => {
state.filter.config = IMAGE_FILTERS[action.payload.type].buildDefaults();
},
filterConfigChanged: (state, action: PayloadAction<{ config: FilterConfig }>) => {
state.filter.config = action.payload.config;
},
canvasReset: (state) => { canvasReset: (state) => {
state.bbox = deepClone(initialState.bbox); state.bbox = deepClone(initialState.bbox);
const optimalDimension = getOptimalDimension(state.params.model); const optimalDimension = getOptimalDimension(state.params.model);
@ -415,6 +427,11 @@ export const {
layerOpacityChanged, layerOpacityChanged,
layerAllDeleted, layerAllDeleted,
layerImageCacheChanged, layerImageCacheChanged,
layerUsedAsControlChanged,
layerControlAdapterModelChanged,
layerControlAdapterControlModeChanged,
layerControlAdapterWeightChanged,
layerControlAdapterBeginEndStepPctChanged,
// IP Adapters // IP Adapters
ipaAdded, ipaAdded,
ipaRecalled, ipaRecalled,
@ -513,6 +530,9 @@ export const {
sessionStagingAreaReset, sessionStagingAreaReset,
sessionNextStagedImageSelected, sessionNextStagedImageSelected,
sessionPrevStagedImageSelected, sessionPrevStagedImageSelected,
// Filter
filterSelected,
filterConfigChanged,
} = canvasV2Slice.actions; } = canvasV2Slice.actions;
export const selectCanvasV2Slice = (state: RootState) => state.canvasV2; export const selectCanvasV2Slice = (state: RootState) => state.canvasV2;
@ -539,6 +559,8 @@ export const $lastAddedPoint = atom<Coordinate | null>(null);
export const $lastMouseDownPos = atom<Coordinate | null>(null); export const $lastMouseDownPos = atom<Coordinate | null>(null);
export const $lastCursorPos = atom<Coordinate | null>(null); export const $lastCursorPos = atom<Coordinate | null>(null);
export const $spaceKey = atom<boolean>(false); export const $spaceKey = atom<boolean>(false);
export const $transformingEntity = atom<CanvasEntityIdentifier | null>(null);
export const $filteringEntity = atom<CanvasEntityIdentifier | null>(null);
export const canvasV2PersistConfig: PersistConfig<CanvasV2State> = { export const canvasV2PersistConfig: PersistConfig<CanvasV2State> = {
name: canvasV2Slice.name, name: canvasV2Slice.name,

View File

@ -13,7 +13,7 @@ import type {
ControlModeV2, ControlModeV2,
ControlNetConfig, ControlNetConfig,
Filter, Filter,
ProcessorConfig, FilterConfig,
T2IAdapterConfig, T2IAdapterConfig,
} from './types'; } from './types';
import { buildControlAdapterProcessorV2, imageDTOToImageObject } from './types'; import { buildControlAdapterProcessorV2, imageDTOToImageObject } from './types';
@ -145,7 +145,7 @@ export const controlAdaptersReducers = {
} }
ca.controlMode = controlMode; ca.controlMode = controlMode;
}, },
caProcessorConfigChanged: (state, action: PayloadAction<{ id: string; processorConfig: ProcessorConfig | null }>) => { caProcessorConfigChanged: (state, action: PayloadAction<{ id: string; processorConfig: FilterConfig | null }>) => {
const { id, processorConfig } = action.payload; const { id, processorConfig } = action.payload;
const ca = selectCA(state, id); const ca = selectCA(state, id);
if (!ca) { if (!ca) {

View File

@ -1,10 +1,11 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge } from 'lodash-es'; import { merge } from 'lodash-es';
import type { ImageDTO } from 'services/api/types'; import type { ControlNetModelConfig, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import type { CanvasLayerState, CanvasV2State } from './types'; import type { CanvasLayerState, CanvasV2State, ControlModeV2, ControlNetConfig, T2IAdapterConfig } from './types';
import { imageDTOToImageWithDims } from './types'; import { imageDTOToImageWithDims } from './types';
export const selectLayer = (state: CanvasV2State, id: string) => state.layers.entities.find((layer) => layer.id === id); export const selectLayer = (state: CanvasV2State, id: string) => state.layers.entities.find((layer) => layer.id === id);
@ -29,6 +30,7 @@ export const layersReducers = {
opacity: 1, opacity: 1,
position: { x: 0, y: 0 }, position: { x: 0, y: 0 },
imageCache: null, imageCache: null,
controlAdapter: null,
}; };
merge(layer, overrides); merge(layer, overrides);
state.layers.entities.push(layer); state.layers.entities.push(layer);
@ -64,4 +66,76 @@ export const layersReducers = {
const { imageDTO } = action.payload; const { imageDTO } = action.payload;
state.layers.imageCache = imageDTO ? imageDTOToImageWithDims(imageDTO) : null; state.layers.imageCache = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
}, },
layerUsedAsControlChanged: (
state,
action: PayloadAction<{ id: string; controlAdapter: ControlNetConfig | T2IAdapterConfig | null }>
) => {
const { id, controlAdapter } = action.payload;
const layer = selectLayer(state, id);
if (!layer) {
return;
}
layer.controlAdapter = controlAdapter;
},
layerControlAdapterModelChanged: (
state,
action: PayloadAction<{
id: string;
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null;
}>
) => {
const { id, modelConfig } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
if (!modelConfig) {
layer.controlAdapter.model = null;
return;
}
layer.controlAdapter.model = zModelIdentifierField.parse(modelConfig);
// We may need to convert the CA to match the model
if (layer.controlAdapter.type === 't2i_adapter' && layer.controlAdapter.model.type === 'controlnet') {
// Converting from T2I Adapter to ControlNet - add `controlMode`
const controlNetConfig: ControlNetConfig = {
...layer.controlAdapter,
type: 'controlnet',
controlMode: 'balanced',
};
layer.controlAdapter = controlNetConfig;
} else if (layer.controlAdapter.type === 'controlnet' && layer.controlAdapter.model.type === 't2i_adapter') {
// Converting from ControlNet to T2I Adapter - remove `controlMode`
const { controlMode: _, ...rest } = layer.controlAdapter;
const t2iAdapterConfig: T2IAdapterConfig = { ...rest, type: 't2i_adapter' };
layer.controlAdapter = t2iAdapterConfig;
}
},
layerControlAdapterControlModeChanged: (state, action: PayloadAction<{ id: string; controlMode: ControlModeV2 }>) => {
const { id, controlMode } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter || layer.controlAdapter.type !== 'controlnet') {
return;
}
layer.controlAdapter.controlMode = controlMode;
},
layerControlAdapterWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.weight = weight;
},
layerControlAdapterBeginEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; beginEndStepPct: [number, number] }>
) => {
const { id, beginEndStepPct } = action.payload;
const layer = selectLayer(state, id);
if (!layer || !layer.controlAdapter) {
return;
}
layer.controlAdapter.beginEndStepPct = beginEndStepPct;
},
} satisfies SliceCaseReducers<CanvasV2State>; } satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -1,7 +1,7 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit'; import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util'; import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types'; import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/types'; import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas'; import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import { isEqual } from 'lodash-es'; import { isEqual } from 'lodash-es';
@ -99,7 +99,7 @@ export const regionsReducers = {
if (!rg) { if (!rg) {
return; return;
} }
rg.imageCache = imageDTOToImageWithDims(imageDTO); rg.imageCache = imageDTO.image_name;
}, },
rgAutoNegativeChanged: (state, action: PayloadAction<{ id: string; autoNegative: ParameterAutoNegative }>) => { rgAutoNegativeChanged: (state, action: PayloadAction<{ id: string; autoNegative: ParameterAutoNegative }>) => {
const { id, autoNegative } = action.payload; const { id, autoNegative } = action.payload;

View File

@ -21,14 +21,14 @@ import type {
MlsdProcessorConfig, MlsdProcessorConfig,
NormalbaeProcessorConfig, NormalbaeProcessorConfig,
PidiProcessorConfig, PidiProcessorConfig,
ProcessorConfig, FilterConfig,
ProcessorTypeV2, FilterType,
ZoeDepthProcessorConfig, ZoeDepthProcessorConfig,
} from './types'; } from './types';
describe('Control Adapter Types', () => { describe('Control Adapter Types', () => {
test('ProcessorType', () => { test('ProcessorType', () => {
assert<Equals<ProcessorConfig['type'], ProcessorTypeV2>>(); assert<Equals<FilterConfig['type'], FilterType>>();
}); });
test('IP Adapter Method', () => { test('IP Adapter Method', () => {
assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>(); assert<Equals<NonNullable<Invocation<'ip_adapter'>['method']>, IPMethodV2>>();

View File

@ -1,4 +1,3 @@
import type { JSONObject } from 'common/types';
import type { CanvasControlAdapter } from 'features/controlLayers/konva/CanvasControlAdapter'; import type { CanvasControlAdapter } from 'features/controlLayers/konva/CanvasControlAdapter';
import { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter'; import { CanvasLayerAdapter } from 'features/controlLayers/konva/CanvasLayerAdapter';
import { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter'; import { CanvasMaskAdapter } from 'features/controlLayers/konva/CanvasMaskAdapter';
@ -36,6 +35,7 @@ import type {
BaseModelType, BaseModelType,
ControlNetModelConfig, ControlNetModelConfig,
ImageDTO, ImageDTO,
S,
T2IAdapterModelConfig, T2IAdapterModelConfig,
} from 'services/api/types'; } from 'services/api/types';
import { z } from 'zod'; import { z } from 'zod';
@ -175,7 +175,7 @@ const zZoeDepthProcessorConfig = z.object({
}); });
export type ZoeDepthProcessorConfig = z.infer<typeof zZoeDepthProcessorConfig>; export type ZoeDepthProcessorConfig = z.infer<typeof zZoeDepthProcessorConfig>;
export const zProcessorConfig = z.discriminatedUnion('type', [ export const zFilterConfig = z.discriminatedUnion('type', [
zCannyProcessorConfig, zCannyProcessorConfig,
zColorMapProcessorConfig, zColorMapProcessorConfig,
zContentShuffleProcessorConfig, zContentShuffleProcessorConfig,
@ -191,9 +191,9 @@ export const zProcessorConfig = z.discriminatedUnion('type', [
zPidiProcessorConfig, zPidiProcessorConfig,
zZoeDepthProcessorConfig, zZoeDepthProcessorConfig,
]); ]);
export type ProcessorConfig = z.infer<typeof zProcessorConfig>; export type FilterConfig = z.infer<typeof zFilterConfig>;
const zProcessorTypeV2 = z.enum([ const zFilterType = z.enum([
'canny_image_processor', 'canny_image_processor',
'color_map_image_processor', 'color_map_image_processor',
'content_shuffle_image_processor', 'content_shuffle_image_processor',
@ -209,22 +209,19 @@ const zProcessorTypeV2 = z.enum([
'pidi_image_processor', 'pidi_image_processor',
'zoe_depth_image_processor', 'zoe_depth_image_processor',
]); ]);
export type ProcessorTypeV2 = z.infer<typeof zProcessorTypeV2>; export type FilterType = z.infer<typeof zFilterType>;
export const isProcessorTypeV2 = (v: unknown): v is ProcessorTypeV2 => zProcessorTypeV2.safeParse(v).success; export const isFilterType = (v: unknown): v is FilterType => zFilterType.safeParse(v).success;
type ProcessorData<T extends ProcessorTypeV2> = {
type: T;
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
buildNode(image: ImageWithDims, config: Extract<ProcessorConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
};
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height); const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);
type CAProcessorsData = { type ImageFilterData<T extends FilterConfig['type']> = {
[key in ProcessorTypeV2]: ProcessorData<key>; type: T;
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<FilterConfig, { type: T }>;
buildNode(imageDTO: ImageWithDims, config: Extract<FilterConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
}; };
/** /**
* A dict of ControlNet processors, including: * A dict of ControlNet processors, including:
* - label translation key * - label translation key
@ -234,234 +231,243 @@ type CAProcessorsData = {
* *
* TODO: Generate from the OpenAPI schema * TODO: Generate from the OpenAPI schema
*/ */
export const CA_PROCESSOR_DATA: CAProcessorsData = { export const IMAGE_FILTERS: { [key in FilterConfig['type']]: ImageFilterData<key> } = {
canny_image_processor: { canny_image_processor: {
type: 'canny_image_processor', type: 'canny_image_processor',
labelTKey: 'controlnet.canny', labelTKey: 'controlnet.canny',
descriptionTKey: 'controlnet.cannyDescription', descriptionTKey: 'controlnet.cannyDescription',
buildDefaults: () => ({ buildDefaults: (): CannyProcessorConfig => ({
id: 'canny_image_processor', id: 'canny_image_processor',
type: 'canny_image_processor', type: 'canny_image_processor',
low_threshold: 100, low_threshold: 100,
high_threshold: 200, high_threshold: 200,
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: CannyProcessorConfig): S['CannyImageProcessorInvocation'] => ({
...config, ...config,
type: 'canny_image_processor', type: 'canny_image_processor',
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
color_map_image_processor: { color_map_image_processor: {
type: 'color_map_image_processor', type: 'color_map_image_processor',
labelTKey: 'controlnet.colorMap', labelTKey: 'controlnet.colorMap',
descriptionTKey: 'controlnet.colorMapDescription', descriptionTKey: 'controlnet.colorMapDescription',
buildDefaults: () => ({ buildDefaults: (): ColorMapProcessorConfig => ({
id: 'color_map_image_processor', id: 'color_map_image_processor',
type: 'color_map_image_processor', type: 'color_map_image_processor',
color_map_tile_size: 64, color_map_tile_size: 64,
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: ColorMapProcessorConfig): S['ColorMapImageProcessorInvocation'] => ({
...config, ...config,
type: 'color_map_image_processor', type: 'color_map_image_processor',
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
}), }),
}, },
content_shuffle_image_processor: { content_shuffle_image_processor: {
type: 'content_shuffle_image_processor', type: 'content_shuffle_image_processor',
labelTKey: 'controlnet.contentShuffle', labelTKey: 'controlnet.contentShuffle',
descriptionTKey: 'controlnet.contentShuffleDescription', descriptionTKey: 'controlnet.contentShuffleDescription',
buildDefaults: (baseModel) => ({ buildDefaults: (baseModel: BaseModelType): ContentShuffleProcessorConfig => ({
id: 'content_shuffle_image_processor', id: 'content_shuffle_image_processor',
type: 'content_shuffle_image_processor', type: 'content_shuffle_image_processor',
h: baseModel === 'sdxl' ? 1024 : 512, h: baseModel === 'sdxl' ? 1024 : 512,
w: baseModel === 'sdxl' ? 1024 : 512, w: baseModel === 'sdxl' ? 1024 : 512,
f: baseModel === 'sdxl' ? 512 : 256, f: baseModel === 'sdxl' ? 512 : 256,
}), }),
buildNode: (image, config) => ({ buildNode: (
imageDTO: ImageDTO,
config: ContentShuffleProcessorConfig
): S['ContentShuffleImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
depth_anything_image_processor: { depth_anything_image_processor: {
type: 'depth_anything_image_processor', type: 'depth_anything_image_processor',
labelTKey: 'controlnet.depthAnything', labelTKey: 'controlnet.depthAnything',
descriptionTKey: 'controlnet.depthAnythingDescription', descriptionTKey: 'controlnet.depthAnythingDescription',
buildDefaults: () => ({ buildDefaults: (): DepthAnythingProcessorConfig => ({
id: 'depth_anything_image_processor', id: 'depth_anything_image_processor',
type: 'depth_anything_image_processor', type: 'depth_anything_image_processor',
model_size: 'small_v2', model_size: 'small_v2',
}), }),
buildNode: (image, config) => ({ buildNode: (
imageDTO: ImageDTO,
config: DepthAnythingProcessorConfig
): S['DepthAnythingImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
resolution: minDim(image), resolution: minDim(imageDTO),
}), }),
}, },
hed_image_processor: { hed_image_processor: {
type: 'hed_image_processor', type: 'hed_image_processor',
labelTKey: 'controlnet.hed', labelTKey: 'controlnet.hed',
descriptionTKey: 'controlnet.hedDescription', descriptionTKey: 'controlnet.hedDescription',
buildDefaults: () => ({ buildDefaults: (): HedProcessorConfig => ({
id: 'hed_image_processor', id: 'hed_image_processor',
type: 'hed_image_processor', type: 'hed_image_processor',
scribble: false, scribble: false,
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: HedProcessorConfig): S['HedImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
lineart_anime_image_processor: { lineart_anime_image_processor: {
type: 'lineart_anime_image_processor', type: 'lineart_anime_image_processor',
labelTKey: 'controlnet.lineartAnime', labelTKey: 'controlnet.lineartAnime',
descriptionTKey: 'controlnet.lineartAnimeDescription', descriptionTKey: 'controlnet.lineartAnimeDescription',
buildDefaults: () => ({ buildDefaults: (): LineartAnimeProcessorConfig => ({
id: 'lineart_anime_image_processor', id: 'lineart_anime_image_processor',
type: 'lineart_anime_image_processor', type: 'lineart_anime_image_processor',
}), }),
buildNode: (image, config) => ({ buildNode: (
imageDTO: ImageDTO,
config: LineartAnimeProcessorConfig
): S['LineartAnimeImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
lineart_image_processor: { lineart_image_processor: {
type: 'lineart_image_processor', type: 'lineart_image_processor',
labelTKey: 'controlnet.lineart', labelTKey: 'controlnet.lineart',
descriptionTKey: 'controlnet.lineartDescription', descriptionTKey: 'controlnet.lineartDescription',
buildDefaults: () => ({ buildDefaults: (): LineartProcessorConfig => ({
id: 'lineart_image_processor', id: 'lineart_image_processor',
type: 'lineart_image_processor', type: 'lineart_image_processor',
coarse: false, coarse: false,
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: LineartProcessorConfig): S['LineartImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
mediapipe_face_processor: { mediapipe_face_processor: {
type: 'mediapipe_face_processor', type: 'mediapipe_face_processor',
labelTKey: 'controlnet.mediapipeFace', labelTKey: 'controlnet.mediapipeFace',
descriptionTKey: 'controlnet.mediapipeFaceDescription', descriptionTKey: 'controlnet.mediapipeFaceDescription',
buildDefaults: () => ({ buildDefaults: (): MediapipeFaceProcessorConfig => ({
id: 'mediapipe_face_processor', id: 'mediapipe_face_processor',
type: 'mediapipe_face_processor', type: 'mediapipe_face_processor',
max_faces: 1, max_faces: 1,
min_confidence: 0.5, min_confidence: 0.5,
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: MediapipeFaceProcessorConfig): S['MediapipeFaceProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
midas_depth_image_processor: { midas_depth_image_processor: {
type: 'midas_depth_image_processor', type: 'midas_depth_image_processor',
labelTKey: 'controlnet.depthMidas', labelTKey: 'controlnet.depthMidas',
descriptionTKey: 'controlnet.depthMidasDescription', descriptionTKey: 'controlnet.depthMidasDescription',
buildDefaults: () => ({ buildDefaults: (): MidasDepthProcessorConfig => ({
id: 'midas_depth_image_processor', id: 'midas_depth_image_processor',
type: 'midas_depth_image_processor', type: 'midas_depth_image_processor',
a_mult: 2, a_mult: 2,
bg_th: 0.1, bg_th: 0.1,
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: MidasDepthProcessorConfig): S['MidasDepthImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
mlsd_image_processor: { mlsd_image_processor: {
type: 'mlsd_image_processor', type: 'mlsd_image_processor',
labelTKey: 'controlnet.mlsd', labelTKey: 'controlnet.mlsd',
descriptionTKey: 'controlnet.mlsdDescription', descriptionTKey: 'controlnet.mlsdDescription',
buildDefaults: () => ({ buildDefaults: (): MlsdProcessorConfig => ({
id: 'mlsd_image_processor', id: 'mlsd_image_processor',
type: 'mlsd_image_processor', type: 'mlsd_image_processor',
thr_d: 0.1, thr_d: 0.1,
thr_v: 0.1, thr_v: 0.1,
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: MlsdProcessorConfig): S['MlsdImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
normalbae_image_processor: { normalbae_image_processor: {
type: 'normalbae_image_processor', type: 'normalbae_image_processor',
labelTKey: 'controlnet.normalBae', labelTKey: 'controlnet.normalBae',
descriptionTKey: 'controlnet.normalBaeDescription', descriptionTKey: 'controlnet.normalBaeDescription',
buildDefaults: () => ({ buildDefaults: (): NormalbaeProcessorConfig => ({
id: 'normalbae_image_processor', id: 'normalbae_image_processor',
type: 'normalbae_image_processor', type: 'normalbae_image_processor',
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: NormalbaeProcessorConfig): S['NormalbaeImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
dw_openpose_image_processor: { dw_openpose_image_processor: {
type: 'dw_openpose_image_processor', type: 'dw_openpose_image_processor',
labelTKey: 'controlnet.dwOpenpose', labelTKey: 'controlnet.dwOpenpose',
descriptionTKey: 'controlnet.dwOpenposeDescription', descriptionTKey: 'controlnet.dwOpenposeDescription',
buildDefaults: () => ({ buildDefaults: (): DWOpenposeProcessorConfig => ({
id: 'dw_openpose_image_processor', id: 'dw_openpose_image_processor',
type: 'dw_openpose_image_processor', type: 'dw_openpose_image_processor',
draw_body: true, draw_body: true,
draw_face: false, draw_face: false,
draw_hands: false, draw_hands: false,
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: DWOpenposeProcessorConfig): S['DWOpenposeImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
pidi_image_processor: { pidi_image_processor: {
type: 'pidi_image_processor', type: 'pidi_image_processor',
labelTKey: 'controlnet.pidi', labelTKey: 'controlnet.pidi',
descriptionTKey: 'controlnet.pidiDescription', descriptionTKey: 'controlnet.pidiDescription',
buildDefaults: () => ({ buildDefaults: (): PidiProcessorConfig => ({
id: 'pidi_image_processor', id: 'pidi_image_processor',
type: 'pidi_image_processor', type: 'pidi_image_processor',
scribble: false, scribble: false,
safe: false, safe: false,
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: PidiProcessorConfig): S['PidiImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
detect_resolution: minDim(image), detect_resolution: minDim(imageDTO),
image_resolution: minDim(image), image_resolution: minDim(imageDTO),
}), }),
}, },
zoe_depth_image_processor: { zoe_depth_image_processor: {
type: 'zoe_depth_image_processor', type: 'zoe_depth_image_processor',
labelTKey: 'controlnet.depthZoe', labelTKey: 'controlnet.depthZoe',
descriptionTKey: 'controlnet.depthZoeDescription', descriptionTKey: 'controlnet.depthZoeDescription',
buildDefaults: () => ({ buildDefaults: (): ZoeDepthProcessorConfig => ({
id: 'zoe_depth_image_processor', id: 'zoe_depth_image_processor',
type: 'zoe_depth_image_processor', type: 'zoe_depth_image_processor',
}), }),
buildNode: (image, config) => ({ buildNode: (imageDTO: ImageDTO, config: ZoeDepthProcessorConfig): S['ZoeDepthImageProcessorInvocation'] => ({
...config, ...config,
image: { image_name: image.image_name }, image: { image_name: imageDTO.image_name },
}), }),
}, },
}; } as const;
const zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'view', 'bbox']); const zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'view', 'bbox']);
export type Tool = z.infer<typeof zTool>; export type Tool = z.infer<typeof zTool>;
@ -575,17 +581,6 @@ export function isCanvasBrushLineState(obj: CanvasObjectState): obj is CanvasBru
return obj.type === 'brush_line'; return obj.type === 'brush_line';
} }
export const zCanvasLayerState = z.object({
id: zId,
type: z.literal('layer'),
isEnabled: z.boolean(),
position: zCoordinate,
opacity: zOpacity,
objects: z.array(zCanvasObjectState),
imageCache: z.string().min(1).nullable(),
});
export type CanvasLayerState = z.infer<typeof zCanvasLayerState>;
export const zCanvasIPAdapterState = z.object({ export const zCanvasIPAdapterState = z.object({
id: zId, id: zId,
type: z.literal('ip_adapter'), type: z.literal('ip_adapter'),
@ -689,7 +684,7 @@ const zCanvasControlAdapterStateBase = z.object({
weight: z.number().gte(-1).lte(2), weight: z.number().gte(-1).lte(2),
imageObject: zCanvasImageState.nullable(), imageObject: zCanvasImageState.nullable(),
processedImageObject: zCanvasImageState.nullable(), processedImageObject: zCanvasImageState.nullable(),
processorConfig: zProcessorConfig.nullable(), processorConfig: zFilterConfig.nullable(),
processorPendingBatchId: z.string().nullable().default(null), processorPendingBatchId: z.string().nullable().default(null),
beginEndStepPct: zBeginEndStepPct, beginEndStepPct: zBeginEndStepPct,
model: zModelIdentifierField.nullable(), model: zModelIdentifierField.nullable(),
@ -709,41 +704,55 @@ export const zCanvasControlAdapterState = z.discriminatedUnion('adapterType', [
zCanvasT2IAdapteState, zCanvasT2IAdapteState,
]); ]);
export type CanvasControlAdapterState = z.infer<typeof zCanvasControlAdapterState>; export type CanvasControlAdapterState = z.infer<typeof zCanvasControlAdapterState>;
export type ControlNetConfig = Pick<
CanvasControlNetState, const zControlNetConfig = z.object({
| 'adapterType' type: z.literal('controlnet'),
| 'weight' model: zModelIdentifierField.nullable(),
| 'imageObject' weight: z.number().gte(-1).lte(2),
| 'processedImageObject' beginEndStepPct: zBeginEndStepPct,
| 'processorConfig' controlMode: zControlModeV2,
| 'beginEndStepPct' });
| 'model' export type ControlNetConfig = z.infer<typeof zControlNetConfig>;
| 'controlMode'
>; const zT2IAdapterConfig = z.object({
export type T2IAdapterConfig = Pick< type: z.literal('t2i_adapter'),
CanvasT2IAdapterState, model: zModelIdentifierField.nullable(),
'adapterType' | 'weight' | 'imageObject' | 'processedImageObject' | 'processorConfig' | 'beginEndStepPct' | 'model' weight: z.number().gte(-1).lte(2),
>; beginEndStepPct: zBeginEndStepPct,
});
export type T2IAdapterConfig = z.infer<typeof zT2IAdapterConfig>;
export const zCanvasLayerState = z.object({
id: zId,
type: z.literal('layer'),
isEnabled: z.boolean(),
position: zCoordinate,
opacity: zOpacity,
objects: z.array(zCanvasObjectState),
imageCache: z.string().min(1).nullable(),
controlAdapter: z.discriminatedUnion('type', [zControlNetConfig, zT2IAdapterConfig]).nullable(),
});
export type CanvasLayerState = z.infer<typeof zCanvasLayerState>;
export type CanvasLayerStateWithValidControlNet = Omit<CanvasLayerState, 'controlAdapter'> & {
controlAdapter: Omit<ControlNetConfig, 'model'> & { model: ControlNetModelConfig };
};
export type CanvasLayerStateWithValidT2IAdapter = Omit<CanvasLayerState, 'controlAdapter'> & {
controlAdapter: Omit<T2IAdapterConfig, 'model'> & { model: T2IAdapterModelConfig };
};
export const initialControlNetV2: ControlNetConfig = { export const initialControlNetV2: ControlNetConfig = {
adapterType: 'controlnet', type: 'controlnet',
model: null, model: null,
weight: 1, weight: 1,
beginEndStepPct: [0, 1], beginEndStepPct: [0, 1],
controlMode: 'balanced', controlMode: 'balanced',
imageObject: null,
processedImageObject: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
}; };
export const initialT2IAdapterV2: T2IAdapterConfig = { export const initialT2IAdapterV2: T2IAdapterConfig = {
adapterType: 't2i_adapter', type: 't2i_adapter',
model: null, model: null,
weight: 1, weight: 1,
beginEndStepPct: [0, 1], beginEndStepPct: [0, 1],
imageObject: null,
processedImageObject: null,
processorConfig: CA_PROCESSOR_DATA.canny_image_processor.buildDefaults(),
}; };
export const initialIPAdapterV2: IPAdapterConfig = { export const initialIPAdapterV2: IPAdapterConfig = {
@ -757,12 +766,12 @@ export const initialIPAdapterV2: IPAdapterConfig = {
export const buildControlAdapterProcessorV2 = ( export const buildControlAdapterProcessorV2 = (
modelConfig: ControlNetModelConfig | T2IAdapterModelConfig modelConfig: ControlNetModelConfig | T2IAdapterModelConfig
): ProcessorConfig | null => { ): FilterConfig | null => {
const defaultPreprocessor = modelConfig.default_settings?.preprocessor; const defaultPreprocessor = modelConfig.default_settings?.preprocessor;
if (!isProcessorTypeV2(defaultPreprocessor)) { if (!isFilterType(defaultPreprocessor)) {
return null; return null;
} }
const processorConfig = CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults(modelConfig.base); const processorConfig = IMAGE_FILTERS[defaultPreprocessor].buildDefaults(modelConfig.base);
return processorConfig; return processorConfig;
}; };
@ -901,6 +910,10 @@ export type CanvasV2State = {
stagedImages: StagingAreaImage[]; stagedImages: StagingAreaImage[];
selectedStagedImageIndex: number; selectedStagedImageIndex: number;
}; };
filter: {
autoProcess: boolean;
config: FilterConfig;
};
}; };
export type StageAttrs = { export type StageAttrs = {
@ -964,5 +977,3 @@ export function isDrawableEntityType(
): entityType is 'layer' | 'regional_guidance' | 'inpaint_mask' { ): entityType is 'layer' | 'regional_guidance' | 'inpaint_mask' {
return entityType === 'layer' || entityType === 'regional_guidance' || entityType === 'inpaint_mask'; return entityType === 'layer' || entityType === 'regional_guidance' || entityType === 'inpaint_mask';
} }
export type GetLoggingContext = (extra?: JSONObject) => JSONObject;

View File

@ -2,12 +2,12 @@ import { getCAId, getImageObjectId, getIPAId, getLayerId } from 'features/contro
import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers'; import { defaultLoRAConfig } from 'features/controlLayers/store/lorasReducers';
import type { CanvasControlAdapterState, CanvasIPAdapterState, CanvasLayerState, LoRA } from 'features/controlLayers/store/types'; import type { CanvasControlAdapterState, CanvasIPAdapterState, CanvasLayerState, LoRA } from 'features/controlLayers/store/types';
import { import {
CA_PROCESSOR_DATA, IMAGE_FILTERS,
imageDTOToImageWithDims, imageDTOToImageWithDims,
initialControlNetV2, initialControlNetV2,
initialIPAdapterV2, initialIPAdapterV2,
initialT2IAdapterV2, initialT2IAdapterV2,
isProcessorTypeV2, isFilterType,
zCanvasLayerState, zCanvasLayerState,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import type { import type {
@ -559,8 +559,8 @@ const parseControlNetToControlAdapterLayer: MetadataParseFunc<CanvasControlAdapt
.parse(await getProperty(metadataItem, 'control_mode')); .parse(await getProperty(metadataItem, 'control_mode'));
const defaultPreprocessor = controlNetModel.default_settings?.preprocessor; const defaultPreprocessor = controlNetModel.default_settings?.preprocessor;
const processorConfig = isProcessorTypeV2(defaultPreprocessor) const processorConfig = isFilterType(defaultPreprocessor)
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults() ? IMAGE_FILTERS[defaultPreprocessor].buildDefaults()
: null; : null;
const beginEndStepPct: [number, number] = [ const beginEndStepPct: [number, number] = [
begin_step_percent ?? initialControlNetV2.beginEndStepPct[0], begin_step_percent ?? initialControlNetV2.beginEndStepPct[0],
@ -620,8 +620,8 @@ const parseT2IAdapterToControlAdapterLayer: MetadataParseFunc<CanvasControlAdapt
.parse(await getProperty(metadataItem, 'end_step_percent')); .parse(await getProperty(metadataItem, 'end_step_percent'));
const defaultPreprocessor = t2iAdapterModel.default_settings?.preprocessor; const defaultPreprocessor = t2iAdapterModel.default_settings?.preprocessor;
const processorConfig = isProcessorTypeV2(defaultPreprocessor) const processorConfig = isFilterType(defaultPreprocessor)
? CA_PROCESSOR_DATA[defaultPreprocessor].buildDefaults() ? IMAGE_FILTERS[defaultPreprocessor].buildDefaults()
: null; : null;
const beginEndStepPct: [number, number] = [ const beginEndStepPct: [number, number] = [
begin_step_percent ?? initialT2IAdapterV2.beginEndStepPct[0], begin_step_percent ?? initialT2IAdapterV2.beginEndStepPct[0],

View File

@ -1,35 +1,42 @@
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { import type {
CanvasControlAdapterState, CanvasLayerState,
CanvasControlNetState, CanvasLayerStateWithValidControlNet,
CanvasLayerStateWithValidT2IAdapter,
ControlNetConfig,
FilterConfig,
ImageWithDims, ImageWithDims,
ProcessorConfig,
Rect, Rect,
CanvasT2IAdapterState, T2IAdapterConfig,
} from 'features/controlLayers/store/types'; } from 'features/controlLayers/store/types';
import type { ImageField } from 'features/nodes/types/common'; import type { ImageField } from 'features/nodes/types/common';
import { CONTROL_NET_COLLECT, T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants'; import { CONTROL_NET_COLLECT, T2I_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types'; import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
export const addControlAdapters = async ( export const addControlAdapters = async (
manager: CanvasManager, manager: CanvasManager,
controlAdapters: CanvasControlAdapterState[], layers: CanvasLayerState[],
g: Graph, g: Graph,
bbox: Rect, bbox: Rect,
denoise: Invocation<'denoise_latents'>, denoise: Invocation<'denoise_latents'>,
base: BaseModelType base: BaseModelType
): Promise<CanvasControlAdapterState[]> => { ): Promise<(CanvasLayerStateWithValidControlNet | CanvasLayerStateWithValidT2IAdapter)[]> => {
const validControlAdapters = controlAdapters.filter((ca) => isValidControlAdapter(ca, base)); const layersWithValidControlAdapters = layers
for (const ca of validControlAdapters) { .filter((layer) => layer.isEnabled)
if (ca.adapterType === 'controlnet') { .filter((layer) => doesLayerHaveValidControlAdapter(layer, base));
await addControlNetToGraph(manager, ca, g, bbox, denoise); for (const layer of layersWithValidControlAdapters) {
const adapter = manager.layers.get(layer.id);
assert(adapter, 'Adapter not found');
const imageDTO = await adapter.renderer.getImageDTO({ rect: bbox, is_intermediate: true, category: 'control' });
if (layer.controlAdapter.type === 'controlnet') {
await addControlNetToGraph(g, layer, imageDTO, denoise);
} else { } else {
await addT2IAdapterToGraph(manager, ca, g, bbox, denoise); await addT2IAdapterToGraph(g, layer, imageDTO, denoise);
} }
} }
return validControlAdapters; return layersWithValidControlAdapters;
}; };
const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => { const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_latents'>): Invocation<'collect'> => {
@ -49,16 +56,15 @@ const addControlNetCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
} }
}; };
const addControlNetToGraph = async ( const addControlNetToGraph = (
manager: CanvasManager,
ca: CanvasControlNetState,
g: Graph, g: Graph,
bbox: Rect, layer: CanvasLayerStateWithValidControlNet,
imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'> denoise: Invocation<'denoise_latents'>
) => { ) => {
const { id, beginEndStepPct, controlMode, model, weight } = ca; const { id, controlAdapter } = layer;
assert(model, 'ControlNet model is required'); const { beginEndStepPct, model, weight, controlMode } = controlAdapter;
const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true }); const { image_name } = imageDTO;
const controlNetCollect = addControlNetCollectorSafe(g, denoise); const controlNetCollect = addControlNetCollectorSafe(g, denoise);
@ -94,16 +100,15 @@ const addT2IAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise_laten
} }
}; };
const addT2IAdapterToGraph = async ( const addT2IAdapterToGraph = (
manager: CanvasManager,
ca: CanvasT2IAdapterState,
g: Graph, g: Graph,
bbox: Rect, layer: CanvasLayerStateWithValidT2IAdapter,
imageDTO: ImageDTO,
denoise: Invocation<'denoise_latents'> denoise: Invocation<'denoise_latents'>
) => { ) => {
const { id, beginEndStepPct, model, weight } = ca; const { id, controlAdapter } = layer;
assert(model, 'T2I Adapter model is required'); const { beginEndStepPct, model, weight } = controlAdapter;
const { image_name } = await manager.getControlAdapterImage({ id: ca.id, bbox, preview: true }); const { image_name } = imageDTO;
const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise); const t2iAdapterCollect = addT2IAdapterCollectorSafe(g, denoise);
@ -124,7 +129,7 @@ const addT2IAdapterToGraph = async (
const buildControlImage = ( const buildControlImage = (
image: ImageWithDims | null, image: ImageWithDims | null,
processedImage: ImageWithDims | null, processedImage: ImageWithDims | null,
processorConfig: ProcessorConfig | null processorConfig: FilterConfig | null
): ImageField => { ): ImageField => {
if (processedImage && processorConfig) { if (processedImage && processorConfig) {
// We've processed the image in the app - use it for the control image. // We've processed the image in the app - use it for the control image.
@ -140,10 +145,29 @@ const buildControlImage = (
assert(false, 'Attempted to add unprocessed control image'); assert(false, 'Attempted to add unprocessed control image');
}; };
const isValidControlAdapter = (ca: CanvasControlAdapterState, base: BaseModelType): boolean => { const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image // Must be have a model
const hasModel = Boolean(ca.model); const hasModel = Boolean(controlAdapter.model);
const modelMatchesBase = ca.model?.base === base; // Model must match the current base model
const hasControlImage = Boolean(ca.imageObject || (ca.processedImageObject && ca.processorConfig)); const modelMatchesBase = controlAdapter.model?.base === base;
return hasModel && modelMatchesBase && hasControlImage; return hasModel && modelMatchesBase;
};
const doesLayerHaveValidControlAdapter = (
layer: CanvasLayerState,
base: BaseModelType
): layer is CanvasLayerStateWithValidControlNet | CanvasLayerStateWithValidT2IAdapter => {
if (!layer.controlAdapter) {
// Must have a control adapter
return false;
}
if (!layer.controlAdapter.model) {
// Control adapter must have a model selected
return false;
}
if (layer.controlAdapter.model.base !== base) {
// Selected model must match current base model
return false;
}
return true;
}; };

View File

@ -1,9 +1,10 @@
import type { CanvasLayerState } from 'features/controlLayers/store/types'; import type { CanvasLayerState } from 'features/controlLayers/store/types';
export const isValidLayer = (entity: CanvasLayerState) => { export const isValidLayerWithoutControlAdapter = (layer: CanvasLayerState) => {
return ( return (
entity.isEnabled && layer.isEnabled &&
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers // Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
entity.objects.length > 0 layer.objects.length > 0 &&
layer.controlAdapter === null
); );
}; };

View File

@ -215,7 +215,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
const _addedCAs = await addControlAdapters( const _addedCAs = await addControlAdapters(
manager, manager,
state.canvasV2.controlAdapters.entities, state.canvasV2.layers.entities,
g, g,
state.canvasV2.bbox.rect, state.canvasV2.bbox.rect,
denoise, denoise,

View File

@ -219,7 +219,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
const _addedCAs = await addControlAdapters( const _addedCAs = await addControlAdapters(
manager, manager,
state.canvasV2.controlAdapters.entities, state.canvasV2.layers.entities,
g, g,
state.canvasV2.bbox.rect, state.canvasV2.bbox.rect,
denoise, denoise,

View File

@ -1635,8 +1635,11 @@ export type components = {
* @description The ID of the batch * @description The ID of the batch
*/ */
batch_id?: string; batch_id?: string;
/** @description The origin of this batch. */ /**
origin?: components["schemas"]["QueueItemOrigin"] | null; * Origin
* @description The origin of this batch.
*/
origin?: string | null;
/** /**
* Data * Data
* @description The batch data collection. * @description The batch data collection.
@ -1707,10 +1710,11 @@ export type components = {
*/ */
priority: number; priority: number;
/** /**
* Origin
* @description The origin of the batch * @description The origin of the batch
* @default null * @default null
*/ */
origin: components["schemas"]["QueueItemOrigin"] | null; origin: string | null;
}; };
/** BatchStatus */ /** BatchStatus */
BatchStatus: { BatchStatus: {
@ -1724,8 +1728,11 @@ export type components = {
* @description The ID of the batch * @description The ID of the batch
*/ */
batch_id: string; batch_id: string;
/** @description The origin of the batch */ /**
origin: components["schemas"]["QueueItemOrigin"] | null; * Origin
* @description The origin of the batch
*/
origin: string | null;
/** /**
* Pending * Pending
* @description Number of queue items with status 'pending' * @description Number of queue items with status 'pending'
@ -8330,10 +8337,11 @@ export type components = {
*/ */
batch_id: string; batch_id: string;
/** /**
* Origin
* @description The origin of the batch * @description The origin of the batch
* @default null * @default null
*/ */
origin: components["schemas"]["QueueItemOrigin"] | null; origin: string | null;
/** /**
* Session Id * Session Id
* @description The ID of the session (aka graph execution state) * @description The ID of the session (aka graph execution state)
@ -8381,10 +8389,11 @@ export type components = {
*/ */
batch_id: string; batch_id: string;
/** /**
* Origin
* @description The origin of the batch * @description The origin of the batch
* @default null * @default null
*/ */
origin: components["schemas"]["QueueItemOrigin"] | null; origin: string | null;
/** /**
* Session Id * Session Id
* @description The ID of the session (aka graph execution state) * @description The ID of the session (aka graph execution state)
@ -8449,10 +8458,11 @@ export type components = {
*/ */
batch_id: string; batch_id: string;
/** /**
* Origin
* @description The origin of the batch * @description The origin of the batch
* @default null * @default null
*/ */
origin: components["schemas"]["QueueItemOrigin"] | null; origin: string | null;
/** /**
* Session Id * Session Id
* @description The ID of the session (aka graph execution state) * @description The ID of the session (aka graph execution state)
@ -8670,10 +8680,11 @@ export type components = {
*/ */
batch_id: string; batch_id: string;
/** /**
* Origin
* @description The origin of the batch * @description The origin of the batch
* @default null * @default null
*/ */
origin: components["schemas"]["QueueItemOrigin"] | null; origin: string | null;
/** /**
* Session Id * Session Id
* @description The ID of the session (aka graph execution state) * @description The ID of the session (aka graph execution state)
@ -11654,10 +11665,11 @@ export type components = {
*/ */
batch_id: string; batch_id: string;
/** /**
* Origin
* @description The origin of the batch * @description The origin of the batch
* @default null * @default null
*/ */
origin: components["schemas"]["QueueItemOrigin"] | null; origin: string | null;
/** /**
* Status * Status
* @description The new status of the queue item * @description The new status of the queue item
@ -13014,8 +13026,11 @@ export type components = {
* @description The ID of the batch associated with this queue item * @description The ID of the batch associated with this queue item
*/ */
batch_id: string; batch_id: string;
/** @description The origin of this queue item. */ /**
origin?: components["schemas"]["QueueItemOrigin"] | null; * Origin
* @description The origin of this queue item.
*/
origin?: string | null;
/** /**
* Session Id * Session Id
* @description The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed. * @description The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed.
@ -13096,8 +13111,11 @@ export type components = {
* @description The ID of the batch associated with this queue item * @description The ID of the batch associated with this queue item
*/ */
batch_id: string; batch_id: string;
/** @description The origin of this queue item. */ /**
origin?: components["schemas"]["QueueItemOrigin"] | null; * Origin
* @description The origin of this queue item.
*/
origin?: string | null;
/** /**
* Session Id * Session Id
* @description The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed. * @description The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed.