feat: add UI for IP Adapter Method

This commit is contained in:
blessedcoolant 2024-04-13 12:06:59 +05:30
parent 6ea183f0d4
commit e9f16ac8c7
11 changed files with 120 additions and 6 deletions

View File

@ -36,6 +36,7 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.") image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.") ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model") clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter") weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)") begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)") end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")

View File

@ -213,6 +213,10 @@
"resize": "Resize", "resize": "Resize",
"resizeSimple": "Resize (Simple)", "resizeSimple": "Resize (Simple)",
"resizeMode": "Resize Mode", "resizeMode": "Resize Mode",
"ipAdapterMethod": "Method",
"full": "Full",
"style": "Style Only",
"composition": "Composition Only",
"safe": "Safe", "safe": "Safe",
"saveControlImage": "Save Control Image", "saveControlImage": "Save Control Image",
"scribble": "scribble", "scribble": "scribble",

View File

@ -21,6 +21,7 @@ import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports'; import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd'; import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode'; import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
import ParamControlAdapterIPMethod from './parameters/ParamControlAdapterIPMethod';
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect'; import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode'; import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight'; import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
@ -111,7 +112,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
<Flex w="full" flexDir="column" gap={4}> <Flex w="full" flexDir="column" gap={4}>
<Flex gap={8} w="full" alignItems="center"> <Flex gap={8} w="full" alignItems="center">
<Flex flexDir="column" gap={2} h={32} w="full"> <Flex flexDir="column" gap={4} h={40} w="full">
<ParamControlAdapterIPMethod id={id} />
<ParamControlAdapterWeight id={id} /> <ParamControlAdapterWeight id={id} />
<ParamControlAdapterBeginEnd id={id} /> <ParamControlAdapterBeginEnd id={id} />
</Flex> </Flex>

View File

@ -0,0 +1,63 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useControlAdapterIPMethod } from 'features/controlAdapters/hooks/useControlAdapterIPMethod';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { controlAdapterIPMethodChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import type { IPMethod } from 'features/controlAdapters/store/types';
import { isIPMethod } from 'features/controlAdapters/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
id: string;
};
const ParamControlAdapterIPMethod = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const method = useControlAdapterIPMethod(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const options: { label: string; value: IPMethod }[] = useMemo(
() => [
{ label: t('controlnet.full'), value: 'full' },
{ label: t('controlnet.style'), value: 'style' },
{ label: t('controlnet.composition'), value: 'composition' },
],
[t]
);
const handleIPMethodChanged = useCallback<ComboboxOnChange>(
(v) => {
if (!isIPMethod(v?.value)) {
return;
}
dispatch(
controlAdapterIPMethodChanged({
id,
method: v.value,
})
);
},
[id, dispatch]
);
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
if (!method) {
return null;
}
return (
<FormControl>
<InformationalPopover feature="controlNetResizeMode">
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
</InformationalPopover>
<Combobox value={value} options={options} isDisabled={!isEnabled} onChange={handleIPMethodChanged} />
</FormControl>
);
};
export default memo(ParamControlAdapterIPMethod);

View File

@ -0,0 +1,24 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectControlAdapterById,
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
export const useControlAdapterIPMethod = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
const cn = selectControlAdapterById(controlAdapters, id);
if (cn && cn?.type === 'ip_adapter') {
return cn.method;
}
}),
[id]
);
const method = useAppSelector(selector);
return method;
};

View File

@ -21,6 +21,7 @@ import type {
ControlAdapterType, ControlAdapterType,
ControlMode, ControlMode,
ControlNetConfig, ControlNetConfig,
IPMethod,
RequiredControlAdapterProcessorNode, RequiredControlAdapterProcessorNode,
ResizeMode, ResizeMode,
T2IAdapterConfig, T2IAdapterConfig,
@ -245,6 +246,10 @@ export const controlAdaptersSlice = createSlice({
} }
caAdapter.updateOne(state, { id, changes: { controlMode } }); caAdapter.updateOne(state, { id, changes: { controlMode } });
}, },
controlAdapterIPMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethod }>) => {
const { id, method } = action.payload;
caAdapter.updateOne(state, { id, changes: { method } });
},
controlAdapterCLIPVisionModelChanged: ( controlAdapterCLIPVisionModelChanged: (
state, state,
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }> action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
@ -390,6 +395,7 @@ export const {
controlAdapterIsEnabledChanged, controlAdapterIsEnabledChanged,
controlAdapterModelChanged, controlAdapterModelChanged,
controlAdapterCLIPVisionModelChanged, controlAdapterCLIPVisionModelChanged,
controlAdapterIPMethodChanged,
controlAdapterWeightChanged, controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged, controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged, controlAdapterEndStepPctChanged,

View File

@ -210,6 +210,10 @@ const zResizeMode = z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_r
export type ResizeMode = z.infer<typeof zResizeMode>; export type ResizeMode = z.infer<typeof zResizeMode>;
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success; export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
const zIPMethod = z.enum(['full', 'style', 'composition']);
export type IPMethod = z.infer<typeof zIPMethod>;
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
export type ControlNetConfig = { export type ControlNetConfig = {
type: 'controlnet'; type: 'controlnet';
id: string; id: string;
@ -253,6 +257,7 @@ export type IPAdapterConfig = {
model: ParameterIPAdapterModel | null; model: ParameterIPAdapterModel | null;
clipVisionModel: CLIPVisionModel; clipVisionModel: CLIPVisionModel;
weight: number; weight: number;
method: IPMethod;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;
}; };

View File

@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
isEnabled: true, isEnabled: true,
controlImage: null, controlImage: null,
model: null, model: null,
method: 'full',
clipVisionModel: 'ViT-H', clipVisionModel: 'ViT-H',
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,

View File

@ -386,6 +386,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
clipVisionModel: 'ViT-H', clipVisionModel: 'ViT-H',
controlImage: image?.image_name ?? null, controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight, weight: weight ?? initialIPAdapter.weight,
method: 'full',
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct, endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
}; };

View File

@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
if (!ipAdapter.model) { if (!ipAdapter.model) {
return; return;
} }
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter; const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required'); assert(controlImage, 'IP Adapter image is required');
@ -57,7 +57,7 @@ export const addIPAdapterToLinearGraph = async (
type: 'ip_adapter', type: 'ip_adapter',
is_intermediate: true, is_intermediate: true,
weight: weight, weight: weight,
method: 'composition', method: method,
ip_adapter_model: model, ip_adapter_model: model,
clip_vision_model: clipVisionModel, clip_vision_model: clipVisionModel,
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
@ -85,7 +85,7 @@ export const addIPAdapterToLinearGraph = async (
}; };
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => { const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter; const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, method, weight } = ipAdapter;
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
@ -103,6 +103,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
ip_adapter_model: model, ip_adapter_model: model,
clip_vision_model: clipVisionModel, clip_vision_model: clipVisionModel,
weight, weight,
method,
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,
image, image,

File diff suppressed because one or more lines are too long