fix(ui): ip adapters work

This commit is contained in:
psychedelicious 2024-08-15 19:39:47 +10:00
parent 4adb2eabf5
commit 7178fc6253
33 changed files with 307 additions and 255 deletions

View File

@ -5,6 +5,7 @@ import type { JSONObject } from 'common/types';
import {
bboxHeightChanged,
bboxWidthChanged,
controlLayerModelChanged,
ipaModelChanged,
loraDeleted,
modelChanged,
@ -20,6 +21,7 @@ import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isControlNetOrT2IAdapterModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
@ -31,7 +33,7 @@ import {
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
startAppListening({
predicate: modelsApi.endpoints.getModelConfigs.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
effect: (action, { getState, dispatch }) => {
// models loaded, we need to ensure the selected model is available and if not, select the first one
const log = logger('models');
log.info({ models: action.payload.entities }, `Models loaded (${action.payload.ids.length})`);
@ -169,24 +171,24 @@ const handleLoRAModels: ModelHandler = (models, state, dispatch, _log) => {
};
const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
// const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
// state.canvasV2.controlAdapters.entities.forEach((ca) => {
// const isModelAvailable = caModels.some((m) => m.key === ca.model?.key);
// if (isModelAvailable) {
// return;
// }
// dispatch(caModelChanged({ id: ca.id, modelConfig: null }));
// });
const caModels = models.filter(isControlNetOrT2IAdapterModelConfig);
state.canvasV2.controlLayers.entities.forEach((entity) => {
const isModelAvailable = caModels.some((m) => m.key === entity.controlAdapter.model?.key);
if (isModelAvailable) {
return;
}
dispatch(controlLayerModelChanged({ id: entity.id, modelConfig: null }));
});
};
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, _log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
state.canvasV2.ipAdapters.entities.forEach(({ id, model }) => {
const isModelAvailable = ipaModels.some((m) => m.key === model?.key);
state.canvasV2.ipAdapters.entities.forEach((entity) => {
const isModelAvailable = ipaModels.some((m) => m.key === entity.ipAdapter.model?.key);
if (isModelAvailable) {
return;
}
dispatch(ipaModelChanged({ id, modelConfig: null }));
dispatch(ipaModelChanged({ id: entity.id, modelConfig: null }));
});
state.canvasV2.regions.entities.forEach(({ id, ipAdapters }) => {

View File

@ -154,24 +154,24 @@ const createSelector = (templates: Templates) =>
});
canvasV2.ipAdapters.entities
.filter((ipa) => ipa.isEnabled)
.forEach((ipa, i) => {
.filter((entity) => entity.isEnabled)
.forEach((entity, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[ipa.type]);
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
// Must have model
if (!ipa.model) {
if (!entity.ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (ipa.model?.base !== model?.base) {
if (entity.ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipa.imageObject) {
if (!entity.ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
@ -182,22 +182,22 @@ const createSelector = (templates: Templates) =>
});
canvasV2.regions.entities
.filter((rg) => rg.isEnabled)
.forEach((rg, i) => {
.filter((entity) => entity.isEnabled)
.forEach((entity, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[rg.type]);
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
// Must have a region
if (rg.objects.length === 0) {
if (entity.objects.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
}
// Must have at least 1 prompt or IP Adapter
if (rg.positivePrompt === null && rg.negativePrompt === null && rg.ipAdapters.length === 0) {
if (entity.positivePrompt === null && entity.negativePrompt === null && entity.ipAdapters.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
rg.ipAdapters.forEach((ipAdapter) => {
entity.ipAdapters.forEach((ipAdapter) => {
// Must have model
if (!ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
@ -207,7 +207,7 @@ const createSelector = (templates: Templates) =>
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipAdapter.imageObject) {
if (!ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
});
@ -219,12 +219,11 @@ const createSelector = (templates: Templates) =>
});
canvasV2.rasterLayers.entities
.filter((l) => l.isEnabled)
.filter((l) => l.type === 'raster_layer')
.forEach((l, i) => {
.filter((entity) => entity.isEnabled)
.forEach((entity, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];

View File

@ -21,7 +21,7 @@ export const AddLayerButton = memo(() => {
dispatch(controlLayerAdded({ isSelected: true, overrides: { controlAdapter: defaultControlAdapter } }));
}, [defaultControlAdapter, dispatch]);
const addIPAdapter = useCallback(() => {
dispatch(ipaAdded({ config: defaultIPAdapter }));
dispatch(ipaAdded({ ipAdapter: defaultIPAdapter }));
}, [defaultIPAdapter, dispatch]);
return (

View File

@ -1,4 +1,5 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library';
import { Spacer } from '@invoke-ai/ui-library';
import { useBoolean } from 'common/hooks/useBoolean';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
@ -17,14 +18,14 @@ type Props = {
export const ControlLayer = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'control_layer' }), [id]);
const editing = useDisclosure({ defaultIsOpen: false });
const editing = useBoolean(false);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={editing.onOpen}>
<CanvasEntityHeader onDoubleClick={editing.setTrue}>
<CanvasEntityEnabledToggle />
{editing.isOpen ? <CanvasEntityTitleEdit onStopEditing={editing.onClose} /> : <CanvasEntityTitle />}
{editing.isTrue ? <CanvasEntityTitleEdit onStopEditing={editing.setFalse} /> : <CanvasEntityTitle />}
<Spacer />
<CanvasEntityDeleteButton />
</CanvasEntityHeader>

View File

@ -1,9 +1,11 @@
import { Spacer } from '@invoke-ai/ui-library';
import { useBoolean } from 'common/hooks/useBoolean';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
import { CanvasEntityHeader } from 'features/controlLayers/components/common/CanvasEntityHeader';
import { CanvasEntityTitle } from 'features/controlLayers/components/common/CanvasEntityTitle';
import { CanvasEntityTitleEdit } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
import { IPAdapterSettings } from 'features/controlLayers/components/IPAdapter/IPAdapterSettings';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
@ -15,13 +17,14 @@ type Props = {
export const IPAdapter = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'ip_adapter' }), [id]);
const editing = useBoolean(false);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader>
<CanvasEntityHeader onDoubleClick={editing.setTrue}>
<CanvasEntityEnabledToggle />
<CanvasEntityTitle />
{editing.isTrue ? <CanvasEntityTitleEdit onStopEditing={editing.setFalse} /> : <CanvasEntityTitle />}
<Spacer />
<CanvasEntityDeleteButton />
</CanvasEntityHeader>

View File

@ -13,7 +13,7 @@ import {
ipaModelChanged,
ipaWeightChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import { selectIPAOrThrow } from 'features/controlLayers/store/ipAdaptersReducers';
import { selectIPAdapterEntityOrThrow } from 'features/controlLayers/store/ipAdaptersReducers';
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { IPAImageDropData } from 'features/dnd/types';
import { memo, useCallback, useMemo } from 'react';
@ -25,7 +25,7 @@ import { IPAdapterModel } from './IPAdapterModel';
export const IPAdapterSettings = memo(() => {
const dispatch = useAppDispatch();
const { id } = useEntityIdentifierContext();
const ipAdapter = useAppSelector((s) => selectIPAOrThrow(s.canvasV2, id));
const ipAdapter = useAppSelector((s) => selectIPAdapterEntityOrThrow(s.canvasV2, id).ipAdapter);
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
@ -93,9 +93,9 @@ export const IPAdapterSettings = memo(() => {
</Flex>
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<IPAdapterImagePreview
image={ipAdapter.imageObject?.image ?? null}
image={ipAdapter.image ?? null}
onChangeImage={onChangeImage}
ipAdapterId={ipAdapter.id}
ipAdapterId={id}
droppableData={droppableData}
postUploadAction={postUploadAction}
/>

View File

@ -1,4 +1,5 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library';
import { Spacer } from '@invoke-ai/ui-library';
import { useBoolean } from 'common/hooks/useBoolean';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
@ -15,14 +16,14 @@ type Props = {
export const RasterLayer = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'raster_layer' }), [id]);
const editing = useDisclosure({ defaultIsOpen: false });
const editing = useBoolean(false);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={editing.onOpen}>
<CanvasEntityHeader onDoubleClick={editing.setTrue}>
<CanvasEntityEnabledToggle />
{editing.isOpen ? <CanvasEntityTitleEdit onStopEditing={editing.onClose} /> : <CanvasEntityTitle />}
{editing.isTrue ? <CanvasEntityTitleEdit onStopEditing={editing.setFalse} /> : <CanvasEntityTitle />}
<Spacer />
<CanvasEntityDeleteButton />
</CanvasEntityHeader>

View File

@ -1,4 +1,5 @@
import { Spacer, useDisclosure } from '@invoke-ai/ui-library';
import { Spacer } from '@invoke-ai/ui-library';
import { useBoolean } from 'common/hooks/useBoolean';
import { CanvasEntityContainer } from 'features/controlLayers/components/common/CanvasEntityContainer';
import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton';
import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle';
@ -20,14 +21,14 @@ type Props = {
export const RegionalGuidance = memo(({ id }: Props) => {
const entityIdentifier = useMemo<CanvasEntityIdentifier>(() => ({ id, type: 'regional_guidance' }), [id]);
const editing = useDisclosure({ defaultIsOpen: false });
const editing = useBoolean(false);
return (
<EntityIdentifierContext.Provider value={entityIdentifier}>
<CanvasEntityContainer>
<CanvasEntityHeader onDoubleClick={editing.onOpen}>
<CanvasEntityHeader onDoubleClick={editing.setTrue}>
<CanvasEntityEnabledToggle />
{editing.isOpen ? <CanvasEntityTitleEdit onStopEditing={editing.onClose} /> : <CanvasEntityTitle />}
{editing.isTrue ? <CanvasEntityTitleEdit onStopEditing={editing.setFalse} /> : <CanvasEntityTitle />}
<Spacer />
<RegionalGuidanceBadges />
<RegionalGuidanceMaskFillColorPicker />

View File

@ -1,14 +1,14 @@
import { Badge } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { selectRegionalGuidanceEntityOrThrow } from 'features/controlLayers/store/regionsReducers';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const RegionalGuidanceBadges = memo(() => {
const { id } = useEntityIdentifierContext();
const { t } = useTranslation();
const autoNegative = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).autoNegative);
const autoNegative = useAppSelector((s) => selectRegionalGuidanceEntityOrThrow(s.canvasV2, id).autoNegative);
return (
<>

View File

@ -14,7 +14,7 @@ import {
rgIPAdapterModelChanged,
rgIPAdapterWeightChanged,
} from 'features/controlLayers/store/canvasV2Slice';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { selectRegionalGuidanceEntityOrThrow } from 'features/controlLayers/store/regionsReducers';
import type { CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import type { RGIPAdapterImageDropData } from 'features/dnd/types';
import { memo, useCallback, useMemo } from 'react';
@ -34,7 +34,7 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ id, ipAdapterId, ipAdap
dispatch(rgIPAdapterDeleted({ id, ipAdapterId }));
}, [dispatch, ipAdapterId, id]);
const ipAdapter = useAppSelector((s) => {
const ipa = selectRGOrThrow(s.canvasV2, id).ipAdapters.find((ipa) => ipa.id === ipAdapterId);
const ipa = selectRegionalGuidanceEntityOrThrow(s.canvasV2, id).ipAdapters.find((ipa) => ipa.id === ipAdapterId);
assert(ipa, `Regional GuidanceIP Adapter with id ${ipAdapterId} not found`);
return ipa;
});
@ -123,7 +123,7 @@ export const RegionalGuidanceIPAdapterSettings = memo(({ id, ipAdapterId, ipAdap
</Flex>
<Flex alignItems="center" justifyContent="center" h={36} w={36} aspectRatio="1/1">
<IPAdapterImagePreview
image={ipAdapter.imageObject?.image ?? null}
image={ipAdapter.image ?? null}
onChangeImage={onChangeImage}
ipAdapterId={ipAdapter.id}
droppableData={droppableData}

View File

@ -1,15 +1,30 @@
import { Divider, Flex } from '@invoke-ai/ui-library';
import { Divider } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { RegionalGuidanceIPAdapterSettings } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettings';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { memo } from 'react';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { selectRegionalGuidanceEntityOrThrow } from 'features/controlLayers/store/regionsReducers';
import { Fragment, memo, useMemo } from 'react';
type Props = {
id: string;
};
export const RegionalGuidanceIPAdapters = memo(({ id }: Props) => {
const ipAdapterIds = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).ipAdapters.map(({ id }) => id));
const selectIPAdapterIds = useMemo(
() =>
createMemoizedSelector(selectCanvasV2Slice, (canvasV2) => {
const ipAdapterIds = selectRegionalGuidanceEntityOrThrow(canvasV2, id).ipAdapters.map(({ id }) => id);
if (ipAdapterIds.length === 0) {
return EMPTY_ARRAY;
}
return ipAdapterIds;
}),
[id]
);
const ipAdapterIds = useAppSelector(selectIPAdapterIds);
if (ipAdapterIds.length === 0) {
return null;
@ -17,15 +32,11 @@ export const RegionalGuidanceIPAdapters = memo(({ id }: Props) => {
return (
<>
{ipAdapterIds.map((id, index) => (
<Flex flexDir="column" key={id}>
{index > 0 && (
<Flex pb={3}>
<Divider />
</Flex>
)}
<RegionalGuidanceIPAdapterSettings id={id} ipAdapterId={id} ipAdapterNumber={index + 1} />
</Flex>
{ipAdapterIds.map((ipAdapterId, index) => (
<Fragment key={ipAdapterId}>
{index > 0 && <Divider />}
<RegionalGuidanceIPAdapterSettings id={id} ipAdapterId={ipAdapterId} ipAdapterNumber={index + 1} />
</Fragment>
))}
</>
);

View File

@ -5,7 +5,7 @@ import { rgbColorToString } from 'common/util/colorCodeTransformers';
import { stopPropagation } from 'common/util/stopPropagation';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { rgFillChanged } from 'features/controlLayers/store/canvasV2Slice';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { selectRegionalGuidanceEntityOrThrow } from 'features/controlLayers/store/regionsReducers';
import { memo, useCallback } from 'react';
import type { RgbColor } from 'react-colorful';
import { useTranslation } from 'react-i18next';
@ -14,7 +14,7 @@ export const RegionalGuidanceMaskFillColorPicker = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const { t } = useTranslation();
const dispatch = useAppDispatch();
const fill = useAppSelector((s) => selectRGOrThrow(s.canvasV2, entityIdentifier.id).fill);
const fill = useAppSelector((s) => selectRegionalGuidanceEntityOrThrow(s.canvasV2, entityIdentifier.id).fill);
const onChange = useCallback(
(fill: RgbColor) => {
dispatch(rgFillChanged({ id: entityIdentifier.id, fill }));

View File

@ -37,9 +37,7 @@ export const RegionalGuidanceMenuItemsAddPromptsAndIPAdapter = memo(() => {
dispatch(rgNegativePromptChanged({ id: id, prompt: '' }));
}, [dispatch, id]);
const addIPAdapter = useCallback(() => {
dispatch(
rgIPAdapterAdded({ id, ipAdapter: { ...defaultIPAdapter, id: nanoid(), type: 'ip_adapter', isEnabled: true } })
);
dispatch(rgIPAdapterAdded({ id, ipAdapter: { ...defaultIPAdapter, id: nanoid() } }));
}, [defaultIPAdapter, dispatch, id]);
return (

View File

@ -2,7 +2,7 @@ import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { RegionalGuidanceDeletePromptButton } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton';
import { rgNegativePromptChanged } from 'features/controlLayers/store/canvasV2Slice';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { selectRegionalGuidanceEntityOrThrow } from 'features/controlLayers/store/regionsReducers';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
@ -15,7 +15,7 @@ type Props = {
};
export const RegionalGuidanceNegativePrompt = memo(({ id }: Props) => {
const prompt = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).negativePrompt ?? '');
const prompt = useAppSelector((s) => selectRegionalGuidanceEntityOrThrow(s.canvasV2, id).negativePrompt ?? '');
const dispatch = useAppDispatch();
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();

View File

@ -2,7 +2,7 @@ import { Box, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { RegionalGuidanceDeletePromptButton } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton';
import { rgPositivePromptChanged } from 'features/controlLayers/store/canvasV2Slice';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { selectRegionalGuidanceEntityOrThrow } from 'features/controlLayers/store/regionsReducers';
import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper';
import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton';
import { PromptPopover } from 'features/prompt/PromptPopover';
@ -15,7 +15,7 @@ type Props = {
};
export const RegionalGuidancePositivePrompt = memo(({ id }: Props) => {
const prompt = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).positivePrompt ?? '');
const prompt = useAppSelector((s) => selectRegionalGuidanceEntityOrThrow(s.canvasV2, id).positivePrompt ?? '');
const dispatch = useAppDispatch();
const textareaRef = useRef<HTMLTextAreaElement>(null);
const { t } = useTranslation();

View File

@ -1,8 +1,9 @@
import { Divider } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { AddPromptButtons } from 'features/controlLayers/components/AddPromptButtons';
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { selectRegionalGuidanceEntityOrThrow } from 'features/controlLayers/store/regionsReducers';
import { memo } from 'react';
import { RegionalGuidanceIPAdapters } from './RegionalGuidanceIPAdapters';
@ -11,15 +12,31 @@ import { RegionalGuidancePositivePrompt } from './RegionalGuidancePositivePrompt
export const RegionalGuidanceSettings = memo(() => {
const { id } = useEntityIdentifierContext();
const hasPositivePrompt = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).positivePrompt !== null);
const hasNegativePrompt = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).negativePrompt !== null);
const hasIPAdapters = useAppSelector((s) => selectRGOrThrow(s.canvasV2, id).ipAdapters.length > 0);
const hasPositivePrompt = useAppSelector(
(s) => selectRegionalGuidanceEntityOrThrow(s.canvasV2, id).positivePrompt !== null
);
const hasNegativePrompt = useAppSelector(
(s) => selectRegionalGuidanceEntityOrThrow(s.canvasV2, id).negativePrompt !== null
);
const hasIPAdapters = useAppSelector(
(s) => selectRegionalGuidanceEntityOrThrow(s.canvasV2, id).ipAdapters.length > 0
);
return (
<CanvasEntitySettingsWrapper>
{!hasPositivePrompt && !hasNegativePrompt && !hasIPAdapters && <AddPromptButtons id={id} />}
{hasPositivePrompt && <RegionalGuidancePositivePrompt id={id} />}
{hasNegativePrompt && <RegionalGuidanceNegativePrompt id={id} />}
{hasPositivePrompt && (
<>
<RegionalGuidancePositivePrompt id={id} />
{(hasNegativePrompt || hasIPAdapters) && <Divider />}
</>
)}
{hasNegativePrompt && (
<>
<RegionalGuidanceNegativePrompt id={id} />
{hasIPAdapters && <Divider />}
</>
)}
{hasIPAdapters && <RegionalGuidanceIPAdapters id={id} />}
</CanvasEntitySettingsWrapper>
);

View File

@ -14,7 +14,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { stopPropagation } from 'common/util/stopPropagation';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { rgAutoNegativeChanged } from 'features/controlLayers/store/canvasV2Slice';
import { selectRGOrThrow } from 'features/controlLayers/store/regionsReducers';
import { selectRegionalGuidanceEntityOrThrow } from 'features/controlLayers/store/regionsReducers';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
@ -24,7 +24,7 @@ export const RegionalGuidanceSettingsPopover = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoNegative = useAppSelector((s) => selectRGOrThrow(s.canvasV2, entityIdentifier.id).autoNegative);
const autoNegative = useAppSelector((s) => selectRegionalGuidanceEntityOrThrow(s.canvasV2, entityIdentifier.id).autoNegative);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(rgAutoNegativeChanged({ id: entityIdentifier.id, autoNegative: e.target.checked ? 'invert' : 'off' }));

View File

@ -49,7 +49,15 @@ export const CanvasEntityTitleEdit = memo(({ onStopEditing }: Props) => {
}, []);
return (
<Input ref={ref} value={localTitle} onChange={onChange} onBlur={onBlur} onKeyDown={onKeyDown} variant="outline" />
<Input
ref={ref}
value={localTitle}
onChange={onChange}
onBlur={onBlur}
onKeyDown={onKeyDown}
variant="outline"
_focusVisible={{ borderWidth: 1, borderColor: 'invokeBlueAlpha.400', borderRadius: 'base' }}
/>
);
});

View File

@ -38,7 +38,7 @@ export const useEntityTitle = (entityIdentifier: CanvasEntityIdentifier) => {
} else if (entityIdentifier.type === 'raster_layer') {
parts.push(t('controlLayers.rasterLayer'));
} else if (entityIdentifier.type === 'ip_adapter') {
parts.push(t('controlLayers.ipAdapter'));
parts.push(t('common.ipAdapter'));
} else if (entityIdentifier.type === 'regional_guidance') {
parts.push(t('controlLayers.regionalGuidance'));
} else {

View File

@ -3,7 +3,12 @@ import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { selectCanvasV2Slice } from 'features/controlLayers/store/canvasV2Slice';
import { selectControlLayerOrThrow } from 'features/controlLayers/store/controlLayersReducers';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import type {
CanvasEntityIdentifier,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import { initialControlNetV2, initialIPAdapterV2, initialT2IAdapterV2 } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { useMemo } from 'react';
@ -22,7 +27,7 @@ export const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIden
return controlAdapter;
};
export const useDefaultControlAdapter = () => {
export const useDefaultControlAdapter = (): ControlNetConfig | T2IAdapterConfig => {
const [modelConfigs] = useControlNetAndT2IAdapterModels();
const baseModel = useAppSelector((s) => s.canvasV2.params.model?.base);
@ -43,7 +48,7 @@ export const useDefaultControlAdapter = () => {
return defaultControlAdapter;
};
export const useDefaultIPAdapter = () => {
export const useDefaultIPAdapter = (): IPAdapterConfig => {
const [modelConfigs] = useIPAdapterModels();
const baseModel = useAppSelector((s) => s.canvasV2.params.model?.base);

View File

@ -38,8 +38,8 @@ export class CanvasFilter {
const { config } = this.manager.stateApi.getFilterState();
this.log.trace({ config }, 'Previewing filter');
const dispatch = this.manager.stateApi._store.dispatch;
const imageDTO = await this.parent.renderer.rasterize();
const rect = this.parent.transformer.getRelativeRect()
const imageDTO = await this.parent.renderer.rasterize(rect, false);
// TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now
const filterNode = IMAGE_FILTERS[config.type].buildNode(imageDTO, config as never);
const enqueueBatchArg: BatchConfig = {
@ -106,6 +106,7 @@ export class CanvasFilter {
width: this.imageState.image.height,
height: this.imageState.image.width,
},
replaceObjects: true,
});
this.parent.renderer.showObjects();
this.manager.stateApi.$filteringEntity.set(null);

View File

@ -20,7 +20,6 @@ import type {
ImageCache,
Rect,
} from 'features/controlLayers/store/types';
import { isValidLayerWithoutControlAdapter } from 'features/nodes/util/graph/generation/addLayers';
import type Konva from 'konva';
import { clamp, isEqual } from 'lodash-es';
import { atom } from 'nanostores';
@ -538,7 +537,8 @@ export class CanvasManager {
stageClone.x(0);
stageClone.y(0);
const validLayers = layersState.entities.filter(isValidLayerWithoutControlAdapter);
const validLayers = layersState.entities.filter((entity) => entity.isEnabled && entity.objects.length > 0);
// getLayers() returns the internal `children` array of the stage directly - calling destroy on a layer will
// mutate that array. We need to clone the array to avoid mutating the original.
for (const konvaLayer of stageClone.getLayers().slice()) {

View File

@ -374,8 +374,7 @@ export class CanvasObjectRenderer {
* @param rect The rect to rasterize. If omitted, the entity's full rect will be used.
* @returns A promise that resolves to the rasterized image DTO.
*/
rasterize = async (rect?: Rect): Promise<ImageDTO> => {
rect = rect ?? this.parent.transformer.getRelativeRect();
rasterize = async (rect: Rect, replaceObjects: boolean = false): Promise<ImageDTO> => {
let imageDTO: ImageDTO | null = null;
const rasterizedImageCache = this.getRasterizedImageCache(rect);
@ -400,6 +399,7 @@ export class CanvasObjectRenderer {
entityIdentifier: this.parent.getEntityIdentifier(),
imageObject,
rect: { x: Math.round(rect.x), y: Math.round(rect.y), width: imageDTO.width, height: imageDTO.height },
replaceObjects,
});
return imageDTO;

View File

@ -58,7 +58,7 @@ export class CanvasTransformer {
/**
* Whether the transformer is currently calculating the rect of the parent.
*/
isPendingRectCalculation: boolean = false;
isPendingRectCalculation: boolean = true;
/**
* A set of subscriptions that should be cleaned up when the transformer is destroyed.
@ -506,7 +506,8 @@ export class CanvasTransformer {
*/
applyTransform = async () => {
this.log.debug('Applying transform');
await this.parent.renderer.rasterize();
const rect = this.getRelativeRect();
await this.parent.renderer.rasterize(rect, true);
this.requestRectCalculation();
this.stopTransform();
};
@ -589,7 +590,7 @@ export class CanvasTransformer {
};
updateBbox = () => {
this.log.trace('Updating bbox');
this.log.trace({ nodeRect: this.nodeRect, pixelRect: this.pixelRect }, 'Updating bbox');
if (this.isPendingRectCalculation) {
this.syncInteractionState();
@ -600,10 +601,8 @@ export class CanvasTransformer {
// eraser lines, fully clipped brush lines or if it has been fully erased.
if (this.pixelRect.width === 0 || this.pixelRect.height === 0) {
// We shouldn't reset on the first render - the bbox will be calculated on the next render
if (!this.parent.renderer.hasObjects()) {
// The layer is fully transparent but has objects - reset it
this.manager.stateApi.resetEntity({ entityIdentifier: this.parent.getEntityIdentifier() });
}
// The layer is fully transparent but has objects - reset it
this.manager.stateApi.resetEntity({ entityIdentifier: this.parent.getEntityIdentifier() });
this.syncInteractionState();
return;
}

View File

@ -154,6 +154,8 @@ export function selectEntity(state: CanvasV2State, { id, type }: CanvasEntityIde
return state.inpaintMask;
case 'regional_guidance':
return state.regions.entities.find((rg) => rg.id === id);
case 'ip_adapter':
return state.ipAdapters.entities.find((ipa) => ipa.id === id);
default:
return;
}
@ -246,19 +248,22 @@ export const canvasV2Slice = createSlice({
}
},
entityRasterized: (state, action: PayloadAction<EntityRasterizedPayload>) => {
const { entityIdentifier, imageObject, rect } = action.payload;
const { entityIdentifier, imageObject, rect, replaceObjects } = action.payload;
const entity = selectEntity(state, entityIdentifier);
if (!entity) {
return;
}
if (isDrawableEntity(entity)) {
entity.objects = [imageObject];
entity.position = { x: rect.x, y: rect.y };
// Remove the cache for the given rect. This should never happen, because we should never rasterize the same
// rect twice. Just in case, we remove the old cache.
entity.rasterizationCache = entity.rasterizationCache.filter((cache) => !isEqual(cache.rect, rect));
entity.rasterizationCache.push({ imageName: imageObject.image.image_name, rect });
if (replaceObjects) {
entity.objects = [imageObject];
entity.position = { x: rect.x, y: rect.y };
}
}
},
entityBrushLineAdded: (state, action: PayloadAction<EntityBrushLineAddedPayload>) => {
@ -328,6 +333,13 @@ export const canvasV2Slice = createSlice({
if (region) {
selectedEntityIdentifier = { type: region.type, id: region.id };
}
} else if (entityIdentifier.type === 'ip_adapter') {
const index = state.ipAdapters.entities.findIndex((layer) => layer.id === entityIdentifier.id);
state.ipAdapters.entities = state.ipAdapters.entities.filter((rg) => rg.id !== entityIdentifier.id);
const entity = state.ipAdapters.entities[index];
if (entity) {
selectedEntityIdentifier = { type: entity.type, id: entity.id };
}
} else {
assert(false, 'Not implemented');
}

View File

@ -4,30 +4,32 @@ import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
import type { CanvasV2State, CLIPVisionModelV2, IPAdapterConfig, CanvasIPAdapterState, IPMethodV2 } from './types';
import { imageDTOToImageObject } from './types';
import type { CanvasIPAdapterState, CanvasV2State, CLIPVisionModelV2, IPAdapterConfig, IPMethodV2 } from './types';
import { imageDTOToImageWithDims } from './types';
export const selectIPA = (state: CanvasV2State, id: string) => state.ipAdapters.entities.find((ipa) => ipa.id === id);
export const selectIPAOrThrow = (state: CanvasV2State, id: string) => {
const ipa = selectIPA(state, id);
assert(ipa, `IP Adapter with id ${id} not found`);
return ipa;
export const selectIPAdapterEntity = (state: CanvasV2State, id: string) =>
state.ipAdapters.entities.find((ipa) => ipa.id === id);
export const selectIPAdapterEntityOrThrow = (state: CanvasV2State, id: string) => {
const entity = selectIPAdapterEntity(state, id);
assert(entity, `IP Adapter with id ${id} not found`);
return entity;
};
export const ipAdaptersReducers = {
ipaAdded: {
reducer: (state, action: PayloadAction<{ id: string; config: IPAdapterConfig }>) => {
const { id, config } = action.payload;
reducer: (state, action: PayloadAction<{ id: string; ipAdapter: IPAdapterConfig }>) => {
const { id, ipAdapter } = action.payload;
const layer: CanvasIPAdapterState = {
id,
type: 'ip_adapter',
name: null,
isEnabled: true,
...config,
ipAdapter,
};
state.ipAdapters.entities.push(layer);
state.selectedEntityIdentifier = { type: 'ip_adapter', id };
},
prepare: (payload: { config: IPAdapterConfig }) => ({ payload: { id: uuidv4(), ...payload } }),
prepare: (payload: { ipAdapter: IPAdapterConfig }) => ({ payload: { id: uuidv4(), ...payload } }),
},
ipaRecalled: (state, action: PayloadAction<{ data: CanvasIPAdapterState }>) => {
const { data } = action.payload;
@ -36,7 +38,7 @@ export const ipAdaptersReducers = {
},
ipaIsEnabledToggled: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
const ipa = selectIPA(state, id);
const ipa = selectIPAdapterEntity(state, id);
if (ipa) {
ipa.isEnabled = !ipa.isEnabled;
}
@ -49,64 +51,54 @@ export const ipAdaptersReducers = {
state.ipAdapters.entities = [];
},
ipaImageChanged: {
reducer: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null; objectId: string }>) => {
const { id, imageDTO, objectId } = action.payload;
const ipa = selectIPA(state, id);
if (!ipa) {
reducer: (state, action: PayloadAction<{ id: string; imageDTO: ImageDTO | null }>) => {
const { id, imageDTO } = action.payload;
const entity = selectIPAdapterEntity(state, id);
if (!entity) {
return;
}
ipa.imageObject = imageDTO ? imageDTOToImageObject(imageDTO) : null;
entity.ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
prepare: (payload: { id: string; imageDTO: ImageDTO | null }) => ({ payload: { ...payload, objectId: uuidv4() } }),
},
ipaMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethodV2 }>) => {
const { id, method } = action.payload;
const ipa = selectIPA(state, id);
if (!ipa) {
const entity = selectIPAdapterEntity(state, id);
if (!entity) {
return;
}
ipa.method = method;
entity.ipAdapter.method = method;
},
ipaModelChanged: (
state,
action: PayloadAction<{
id: string;
modelConfig: IPAdapterModelConfig | null;
}>
) => {
ipaModelChanged: (state, action: PayloadAction<{ id: string; modelConfig: IPAdapterModelConfig | null }>) => {
const { id, modelConfig } = action.payload;
const ipa = selectIPA(state, id);
if (!ipa) {
const entity = selectIPAdapterEntity(state, id);
if (!entity) {
return;
}
if (modelConfig) {
ipa.model = zModelIdentifierField.parse(modelConfig);
} else {
ipa.model = null;
}
entity.ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
},
ipaCLIPVisionModelChanged: (state, action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModelV2 }>) => {
const { id, clipVisionModel } = action.payload;
const ipa = selectIPA(state, id);
if (!ipa) {
const entity = selectIPAdapterEntity(state, id);
if (!entity) {
return;
}
ipa.clipVisionModel = clipVisionModel;
entity.ipAdapter.clipVisionModel = clipVisionModel;
},
ipaWeightChanged: (state, action: PayloadAction<{ id: string; weight: number }>) => {
const { id, weight } = action.payload;
const ipa = selectIPA(state, id);
if (!ipa) {
const entity = selectIPAdapterEntity(state, id);
if (!entity) {
return;
}
ipa.weight = weight;
entity.ipAdapter.weight = weight;
},
ipaBeginEndStepPctChanged: (state, action: PayloadAction<{ id: string; beginEndStepPct: [number, number] }>) => {
const { id, beginEndStepPct } = action.payload;
const ipa = selectIPA(state, id);
if (!ipa) {
const entity = selectIPAdapterEntity(state, id);
if (!entity) {
return;
}
ipa.beginEndStepPct = beginEndStepPct;
entity.ipAdapter.beginEndStepPct = beginEndStepPct;
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -1,18 +1,32 @@
import type { PayloadAction, SliceCaseReducers } from '@reduxjs/toolkit';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasV2State, CLIPVisionModelV2, IPMethodV2 } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/types';
import type {
CanvasV2State,
CLIPVisionModelV2,
IPMethodV2,
RegionalGuidanceIPAdapterConfig,
} from 'features/controlLayers/store/types';
import { imageDTOToImageWithDims } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { ParameterAutoNegative } from 'features/parameters/types/parameterSchemas';
import { isEqual } from 'lodash-es';
import type { ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type { CanvasIPAdapterState, CanvasRegionalGuidanceState, RgbColor } from './types';
import type { CanvasRegionalGuidanceState, RgbColor } from './types';
export const selectRG = (state: CanvasV2State, id: string) => state.regions.entities.find((rg) => rg.id === id);
export const selectRGOrThrow = (state: CanvasV2State, id: string) => {
const rg = selectRG(state, id);
export const selectRegionalGuidanceEntity = (state: CanvasV2State, id: string) => {
return state.regions.entities.find((rg) => rg.id === id);
};
export const selectRegionalGuidanceIPAdapter = (state: CanvasV2State, id: string, ipAdapterId: string) => {
const entity = state.regions.entities.find((rg) => rg.id === id);
if (!entity) {
return;
}
return entity.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
};
export const selectRegionalGuidanceEntityOrThrow = (state: CanvasV2State, id: string) => {
const rg = selectRegionalGuidanceEntity(state, id);
assert(rg, `Region with id ${id} not found`);
return rg;
};
@ -72,105 +86,89 @@ export const regionsReducers = {
},
rgPositivePromptChanged: (state, action: PayloadAction<{ id: string; prompt: string | null }>) => {
const { id, prompt } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const entity = selectRegionalGuidanceEntity(state, id);
if (!entity) {
return;
}
rg.positivePrompt = prompt;
entity.positivePrompt = prompt;
},
rgNegativePromptChanged: (state, action: PayloadAction<{ id: string; prompt: string | null }>) => {
const { id, prompt } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const entity = selectRegionalGuidanceEntity(state, id);
if (!entity) {
return;
}
rg.negativePrompt = prompt;
entity.negativePrompt = prompt;
},
rgFillChanged: (state, action: PayloadAction<{ id: string; fill: RgbColor }>) => {
const { id, fill } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const entity = selectRegionalGuidanceEntity(state, id);
if (!entity) {
return;
}
rg.fill = fill;
entity.fill = fill;
},
rgAutoNegativeChanged: (state, action: PayloadAction<{ id: string; autoNegative: ParameterAutoNegative }>) => {
const { id, autoNegative } = action.payload;
const rg = selectRG(state, id);
const rg = selectRegionalGuidanceEntity(state, id);
if (!rg) {
return;
}
rg.autoNegative = autoNegative;
},
rgIPAdapterAdded: (state, action: PayloadAction<{ id: string; ipAdapter: CanvasIPAdapterState }>) => {
rgIPAdapterAdded: (state, action: PayloadAction<{ id: string; ipAdapter: RegionalGuidanceIPAdapterConfig }>) => {
const { id, ipAdapter } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const entity = selectRegionalGuidanceEntity(state, id);
if (!entity) {
return;
}
rg.ipAdapters.push(ipAdapter);
entity.ipAdapters.push(ipAdapter);
},
rgIPAdapterDeleted: (state, action: PayloadAction<{ id: string; ipAdapterId: string }>) => {
const { id, ipAdapterId } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const entity = selectRegionalGuidanceEntity(state, id);
if (!entity) {
return;
}
rg.ipAdapters = rg.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId);
entity.ipAdapters = entity.ipAdapters.filter((ipAdapter) => ipAdapter.id !== ipAdapterId);
},
rgIPAdapterImageChanged: (
state,
action: PayloadAction<{ id: string; ipAdapterId: string; imageDTO: ImageDTO | null; objectId: string }>
action: PayloadAction<{ id: string; ipAdapterId: string; imageDTO: ImageDTO | null }>
) => {
const { id, ipAdapterId, imageDTO } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const ipAdapter = selectRegionalGuidanceIPAdapter(state, id, ipAdapterId);
if (!ipAdapter) {
return;
}
const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
if (!ipa) {
return;
}
ipa.imageObject = imageDTO ? imageDTOToImageObject(imageDTO) : null;
ipAdapter.image = imageDTO ? imageDTOToImageWithDims(imageDTO) : null;
},
rgIPAdapterWeightChanged: (state, action: PayloadAction<{ id: string; ipAdapterId: string; weight: number }>) => {
const { id, ipAdapterId, weight } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const ipAdapter = selectRegionalGuidanceIPAdapter(state, id, ipAdapterId);
if (!ipAdapter) {
return;
}
const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
if (!ipa) {
return;
}
ipa.weight = weight;
ipAdapter.weight = weight;
},
rgIPAdapterBeginEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; ipAdapterId: string; beginEndStepPct: [number, number] }>
) => {
const { id, ipAdapterId, beginEndStepPct } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const ipAdapter = selectRegionalGuidanceIPAdapter(state, id, ipAdapterId);
if (!ipAdapter) {
return;
}
const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
if (!ipa) {
return;
}
ipa.beginEndStepPct = beginEndStepPct;
ipAdapter.beginEndStepPct = beginEndStepPct;
},
rgIPAdapterMethodChanged: (state, action: PayloadAction<{ id: string; ipAdapterId: string; method: IPMethodV2 }>) => {
const { id, ipAdapterId, method } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const ipAdapter = selectRegionalGuidanceIPAdapter(state, id, ipAdapterId);
if (!ipAdapter) {
return;
}
const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
if (!ipa) {
return;
}
ipa.method = method;
ipAdapter.method = method;
},
rgIPAdapterModelChanged: (
state,
@ -181,33 +179,21 @@ export const regionsReducers = {
}>
) => {
const { id, ipAdapterId, modelConfig } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const ipAdapter = selectRegionalGuidanceIPAdapter(state, id, ipAdapterId);
if (!ipAdapter) {
return;
}
const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
if (!ipa) {
return;
}
if (modelConfig) {
ipa.model = zModelIdentifierField.parse(modelConfig);
} else {
ipa.model = null;
}
ipAdapter.model = modelConfig ? zModelIdentifierField.parse(modelConfig) : null;
},
rgIPAdapterCLIPVisionModelChanged: (
state,
action: PayloadAction<{ id: string; ipAdapterId: string; clipVisionModel: CLIPVisionModelV2 }>
) => {
const { id, ipAdapterId, clipVisionModel } = action.payload;
const rg = selectRG(state, id);
if (!rg) {
const ipAdapter = selectRegionalGuidanceIPAdapter(state, id, ipAdapterId);
if (!ipAdapter) {
return;
}
const ipa = rg.ipAdapters.find((ipa) => ipa.id === ipAdapterId);
if (!ipa) {
return;
}
ipa.clipVisionModel = clipVisionModel;
ipAdapter.clipVisionModel = clipVisionModel;
},
} satisfies SliceCaseReducers<CanvasV2State>;

View File

@ -581,22 +581,24 @@ export function isCanvasBrushLineState(obj: CanvasObjectState): obj is CanvasBru
return obj.type === 'brush_line';
}
const zIPAdapterConfig = z.object({
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
clipVisionModel: zCLIPVisionModelV2,
});
export type IPAdapterConfig = z.infer<typeof zIPAdapterConfig>;
export const zCanvasIPAdapterState = z.object({
id: zId,
name: z.string().nullable(),
type: z.literal('ip_adapter'),
isEnabled: z.boolean(),
weight: z.number().gte(-1).lte(2),
method: zIPMethodV2,
imageObject: zCanvasImageState.nullable(),
model: zModelIdentifierField.nullable(),
clipVisionModel: zCLIPVisionModelV2,
beginEndStepPct: zBeginEndStepPct,
ipAdapter: zIPAdapterConfig,
});
export type CanvasIPAdapterState = z.infer<typeof zCanvasIPAdapterState>;
export type IPAdapterConfig = Pick<
CanvasIPAdapterState,
'weight' | 'imageObject' | 'beginEndStepPct' | 'model' | 'clipVisionModel' | 'method'
>;
const zMaskObject = z
.discriminatedUnion('type', [
@ -645,6 +647,17 @@ const zImageCache = z.object({
});
export type ImageCache = z.infer<typeof zImageCache>;
const zRegionalGuidanceIPAdapterConfig = z.object({
id: zId,
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
weight: z.number().gte(-1).lte(2),
beginEndStepPct: zBeginEndStepPct,
method: zIPMethodV2,
clipVisionModel: zCLIPVisionModelV2,
});
export type RegionalGuidanceIPAdapterConfig = z.infer<typeof zRegionalGuidanceIPAdapterConfig>;
export const zCanvasRegionalGuidanceState = z.object({
id: zId,
name: z.string().nullable(),
@ -655,7 +668,7 @@ export const zCanvasRegionalGuidanceState = z.object({
fill: zRgbColor,
positivePrompt: zParameterPositivePrompt.nullable(),
negativePrompt: zParameterNegativePrompt.nullable(),
ipAdapters: z.array(zCanvasIPAdapterState),
ipAdapters: z.array(zRegionalGuidanceIPAdapterConfig),
autoNegative: zAutoNegative,
rasterizationCache: z.array(zImageCache),
});
@ -763,7 +776,7 @@ export const initialT2IAdapterV2: T2IAdapterConfig = {
};
export const initialIPAdapterV2: IPAdapterConfig = {
imageObject: null,
image: null,
model: null,
beginEndStepPct: [0, 1],
method: 'full',
@ -943,6 +956,7 @@ export type EntityRasterizedPayload = {
entityIdentifier: CanvasEntityIdentifier;
imageObject: CanvasImageState;
rect: Rect;
replaceObjects: boolean;
};
export type ImageObjectAddedArg = { id: string; imageDTO: ImageDTO; position?: Coordinate };

View File

@ -1,4 +1,4 @@
import type { CanvasIPAdapterState } from 'features/controlLayers/store/types';
import type { CanvasIPAdapterState, IPAdapterConfig } from 'features/controlLayers/store/types';
import { IP_ADAPTER_COLLECT } from 'features/nodes/util/graph/constants';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { BaseModelType, Invocation } from 'services/api/types';
@ -10,7 +10,7 @@ export const addIPAdapters = (
denoise: Invocation<'denoise_latents'>,
base: BaseModelType
): CanvasIPAdapterState[] => {
const validIPAdapters = ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity.ipAdapter, base));
for (const ipa of validIPAdapters) {
addIPAdapter(ipa, g, denoise);
}
@ -33,13 +33,14 @@ export const addIPAdapterCollectorSafe = (g: Graph, denoise: Invocation<'denoise
}
};
const addIPAdapter = (ipa: CanvasIPAdapterState, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject } = ipa;
assert(imageObject, 'IP Adapter image is required');
const addIPAdapter = (entity: CanvasIPAdapterState, g: Graph, denoise: Invocation<'denoise_latents'>) => {
const { id, ipAdapter } = entity;
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(image, 'IP Adapter image is required');
assert(model, 'IP Adapter model is required');
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
const ipAdapter = g.addNode({
const ipAdapterNode = g.addNode({
id: `ip_adapter_${id}`,
type: 'ip_adapter',
weight,
@ -49,16 +50,16 @@ const addIPAdapter = (ipa: CanvasIPAdapterState, g: Graph, denoise: Invocation<'
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: imageObject.image.image_name,
image_name: image.image_name,
},
});
g.addEdge(ipAdapter, 'ip_adapter', ipAdapterCollect, 'item');
g.addEdge(ipAdapterNode, 'ip_adapter', ipAdapterCollect, 'item');
};
export const isValidIPAdapter = (ipa: CanvasIPAdapterState, base: BaseModelType): boolean => {
export const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
// Must be have a model that matches the current base and must have a control image
const hasModel = Boolean(ipa.model);
const modelMatchesBase = ipa.model?.base === base;
const hasImage = Boolean(ipa.imageObject);
const hasModel = Boolean(ipAdapter.model);
const modelMatchesBase = ipAdapter.model?.base === base;
const hasImage = Boolean(ipAdapter.image);
return hasModel && modelMatchesBase && hasImage;
};

View File

@ -1,10 +1,5 @@
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
export const isValidLayerWithoutControlAdapter = (layer: CanvasRasterLayerState) => {
return (
layer.isEnabled &&
// Boolean(entity.bbox) && TODO(psyche): Re-enable this check when we have a way to calculate bbox for all layers
layer.objects.length > 0 &&
layer.controlAdapter === null
);
export const isValidLayer = (layer: CanvasRasterLayerState) => {
return layer.isEnabled && layer.objects.length > 0;
};

View File

@ -1,6 +1,10 @@
import { deepClone } from 'common/util/deepClone';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import type { CanvasIPAdapterState, CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
import type {
CanvasRegionalGuidanceState,
Rect,
RegionalGuidanceIPAdapterConfig,
} from 'features/controlLayers/store/types';
import {
PROMPT_REGION_INVERT_TENSOR_MASK_PREFIX,
PROMPT_REGION_MASK_TO_TENSOR_PREFIX,
@ -174,13 +178,15 @@ export const addRegions = async (
}
}
const validRGIPAdapters: CanvasIPAdapterState[] = region.ipAdapters.filter((ipa) => isValidIPAdapter(ipa, base));
const validRGIPAdapters: RegionalGuidanceIPAdapterConfig[] = region.ipAdapters.filter((ipAdapter) =>
isValidIPAdapter(ipAdapter, base)
);
for (const ipa of validRGIPAdapters) {
const ipAdapterCollect = addIPAdapterCollectorSafe(g, denoise);
const { id, weight, model, clipVisionModel, method, beginEndStepPct, imageObject } = ipa;
const { id, weight, model, clipVisionModel, method, beginEndStepPct, image } = ipa;
assert(model, 'IP Adapter model is required');
assert(imageObject, 'IP Adapter image is required');
assert(image, 'IP Adapter image is required');
const ipAdapter = g.addNode({
id: `ip_adapter_${id}`,
@ -192,7 +198,7 @@ export const addRegions = async (
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
image: {
image_name: imageObject.image.image_name,
image_name: image.image_name,
},
});

View File

@ -215,7 +215,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager): P
const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.rasterLayers.entities,
state.canvasV2.controlLayers.entities,
g,
state.canvasV2.bbox.rect,
denoise,

View File

@ -219,7 +219,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager):
const _addedCAs = await addControlAdapters(
manager,
state.canvasV2.rasterLayers.entities,
state.canvasV2.controlLayers.entities,
g,
state.canvasV2.bbox.rect,
denoise,