mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: add UI for IP Adapter Method
This commit is contained in:
parent
6ea183f0d4
commit
e9f16ac8c7
@ -36,6 +36,7 @@ class IPAdapterMetadataField(BaseModel):
|
|||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
|
||||||
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
|
||||||
|
method: Literal["full", "style", "composition"] = Field(description="Method to apply IP Weights with")
|
||||||
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
|
||||||
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
|
||||||
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")
|
||||||
|
@ -213,6 +213,10 @@
|
|||||||
"resize": "Resize",
|
"resize": "Resize",
|
||||||
"resizeSimple": "Resize (Simple)",
|
"resizeSimple": "Resize (Simple)",
|
||||||
"resizeMode": "Resize Mode",
|
"resizeMode": "Resize Mode",
|
||||||
|
"ipAdapterMethod": "Method",
|
||||||
|
"full": "Full",
|
||||||
|
"style": "Style Only",
|
||||||
|
"composition": "Composition Only",
|
||||||
"safe": "Safe",
|
"safe": "Safe",
|
||||||
"saveControlImage": "Save Control Image",
|
"saveControlImage": "Save Control Image",
|
||||||
"scribble": "scribble",
|
"scribble": "scribble",
|
||||||
|
@ -21,6 +21,7 @@ import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
|
|||||||
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
|
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
|
||||||
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
|
import { ParamControlAdapterBeginEnd } from './parameters/ParamControlAdapterBeginEnd';
|
||||||
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
|
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
|
||||||
|
import ParamControlAdapterIPMethod from './parameters/ParamControlAdapterIPMethod';
|
||||||
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
|
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
|
||||||
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
|
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
|
||||||
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
|
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
|
||||||
@ -111,7 +112,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
|
|||||||
|
|
||||||
<Flex w="full" flexDir="column" gap={4}>
|
<Flex w="full" flexDir="column" gap={4}>
|
||||||
<Flex gap={8} w="full" alignItems="center">
|
<Flex gap={8} w="full" alignItems="center">
|
||||||
<Flex flexDir="column" gap={2} h={32} w="full">
|
<Flex flexDir="column" gap={4} h={40} w="full">
|
||||||
|
<ParamControlAdapterIPMethod id={id} />
|
||||||
<ParamControlAdapterWeight id={id} />
|
<ParamControlAdapterWeight id={id} />
|
||||||
<ParamControlAdapterBeginEnd id={id} />
|
<ParamControlAdapterBeginEnd id={id} />
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -0,0 +1,63 @@
|
|||||||
|
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
|
||||||
|
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
|
||||||
|
import { useControlAdapterIPMethod } from 'features/controlAdapters/hooks/useControlAdapterIPMethod';
|
||||||
|
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||||
|
import { controlAdapterIPMethodChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import type { IPMethod } from 'features/controlAdapters/store/types';
|
||||||
|
import { isIPMethod } from 'features/controlAdapters/store/types';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
id: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
const ParamControlAdapterIPMethod = ({ id }: Props) => {
|
||||||
|
const isEnabled = useControlAdapterIsEnabled(id);
|
||||||
|
const method = useControlAdapterIPMethod(id);
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const options: { label: string; value: IPMethod }[] = useMemo(
|
||||||
|
() => [
|
||||||
|
{ label: t('controlnet.full'), value: 'full' },
|
||||||
|
{ label: t('controlnet.style'), value: 'style' },
|
||||||
|
{ label: t('controlnet.composition'), value: 'composition' },
|
||||||
|
],
|
||||||
|
[t]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleIPMethodChanged = useCallback<ComboboxOnChange>(
|
||||||
|
(v) => {
|
||||||
|
if (!isIPMethod(v?.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(
|
||||||
|
controlAdapterIPMethodChanged({
|
||||||
|
id,
|
||||||
|
method: v.value,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[id, dispatch]
|
||||||
|
);
|
||||||
|
|
||||||
|
const value = useMemo(() => options.find((o) => o.value === method), [options, method]);
|
||||||
|
|
||||||
|
if (!method) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormControl>
|
||||||
|
<InformationalPopover feature="controlNetResizeMode">
|
||||||
|
<FormLabel>{t('controlnet.ipAdapterMethod')}</FormLabel>
|
||||||
|
</InformationalPopover>
|
||||||
|
<Combobox value={value} options={options} isDisabled={!isEnabled} onChange={handleIPMethodChanged} />
|
||||||
|
</FormControl>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ParamControlAdapterIPMethod);
|
@ -0,0 +1,24 @@
|
|||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import {
|
||||||
|
selectControlAdapterById,
|
||||||
|
selectControlAdaptersSlice,
|
||||||
|
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
|
||||||
|
export const useControlAdapterIPMethod = (id: string) => {
|
||||||
|
const selector = useMemo(
|
||||||
|
() =>
|
||||||
|
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||||
|
const cn = selectControlAdapterById(controlAdapters, id);
|
||||||
|
if (cn && cn?.type === 'ip_adapter') {
|
||||||
|
return cn.method;
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
[id]
|
||||||
|
);
|
||||||
|
|
||||||
|
const method = useAppSelector(selector);
|
||||||
|
|
||||||
|
return method;
|
||||||
|
};
|
@ -21,6 +21,7 @@ import type {
|
|||||||
ControlAdapterType,
|
ControlAdapterType,
|
||||||
ControlMode,
|
ControlMode,
|
||||||
ControlNetConfig,
|
ControlNetConfig,
|
||||||
|
IPMethod,
|
||||||
RequiredControlAdapterProcessorNode,
|
RequiredControlAdapterProcessorNode,
|
||||||
ResizeMode,
|
ResizeMode,
|
||||||
T2IAdapterConfig,
|
T2IAdapterConfig,
|
||||||
@ -245,6 +246,10 @@ export const controlAdaptersSlice = createSlice({
|
|||||||
}
|
}
|
||||||
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
||||||
},
|
},
|
||||||
|
controlAdapterIPMethodChanged: (state, action: PayloadAction<{ id: string; method: IPMethod }>) => {
|
||||||
|
const { id, method } = action.payload;
|
||||||
|
caAdapter.updateOne(state, { id, changes: { method } });
|
||||||
|
},
|
||||||
controlAdapterCLIPVisionModelChanged: (
|
controlAdapterCLIPVisionModelChanged: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
||||||
@ -390,6 +395,7 @@ export const {
|
|||||||
controlAdapterIsEnabledChanged,
|
controlAdapterIsEnabledChanged,
|
||||||
controlAdapterModelChanged,
|
controlAdapterModelChanged,
|
||||||
controlAdapterCLIPVisionModelChanged,
|
controlAdapterCLIPVisionModelChanged,
|
||||||
|
controlAdapterIPMethodChanged,
|
||||||
controlAdapterWeightChanged,
|
controlAdapterWeightChanged,
|
||||||
controlAdapterBeginStepPctChanged,
|
controlAdapterBeginStepPctChanged,
|
||||||
controlAdapterEndStepPctChanged,
|
controlAdapterEndStepPctChanged,
|
||||||
|
@ -210,6 +210,10 @@ const zResizeMode = z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_r
|
|||||||
export type ResizeMode = z.infer<typeof zResizeMode>;
|
export type ResizeMode = z.infer<typeof zResizeMode>;
|
||||||
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
|
export const isResizeMode = (v: unknown): v is ResizeMode => zResizeMode.safeParse(v).success;
|
||||||
|
|
||||||
|
const zIPMethod = z.enum(['full', 'style', 'composition']);
|
||||||
|
export type IPMethod = z.infer<typeof zIPMethod>;
|
||||||
|
export const isIPMethod = (v: unknown): v is IPMethod => zIPMethod.safeParse(v).success;
|
||||||
|
|
||||||
export type ControlNetConfig = {
|
export type ControlNetConfig = {
|
||||||
type: 'controlnet';
|
type: 'controlnet';
|
||||||
id: string;
|
id: string;
|
||||||
@ -253,6 +257,7 @@ export type IPAdapterConfig = {
|
|||||||
model: ParameterIPAdapterModel | null;
|
model: ParameterIPAdapterModel | null;
|
||||||
clipVisionModel: CLIPVisionModel;
|
clipVisionModel: CLIPVisionModel;
|
||||||
weight: number;
|
weight: number;
|
||||||
|
method: IPMethod;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
endStepPct: number;
|
endStepPct: number;
|
||||||
};
|
};
|
||||||
|
@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
|
|||||||
isEnabled: true,
|
isEnabled: true,
|
||||||
controlImage: null,
|
controlImage: null,
|
||||||
model: null,
|
model: null,
|
||||||
|
method: 'full',
|
||||||
clipVisionModel: 'ViT-H',
|
clipVisionModel: 'ViT-H',
|
||||||
weight: 1,
|
weight: 1,
|
||||||
beginStepPct: 0,
|
beginStepPct: 0,
|
||||||
|
@ -386,6 +386,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
|
|||||||
clipVisionModel: 'ViT-H',
|
clipVisionModel: 'ViT-H',
|
||||||
controlImage: image?.image_name ?? null,
|
controlImage: image?.image_name ?? null,
|
||||||
weight: weight ?? initialIPAdapter.weight,
|
weight: weight ?? initialIPAdapter.weight,
|
||||||
|
method: 'full',
|
||||||
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,
|
||||||
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
endStepPct: end_step_percent ?? initialIPAdapter.endStepPct,
|
||||||
};
|
};
|
||||||
|
@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
if (!ipAdapter.model) {
|
if (!ipAdapter.model) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
const { id, weight, model, clipVisionModel, method, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||||
|
|
||||||
assert(controlImage, 'IP Adapter image is required');
|
assert(controlImage, 'IP Adapter image is required');
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
type: 'ip_adapter',
|
type: 'ip_adapter',
|
||||||
is_intermediate: true,
|
is_intermediate: true,
|
||||||
weight: weight,
|
weight: weight,
|
||||||
method: 'composition',
|
method: method,
|
||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
clip_vision_model: clipVisionModel,
|
clip_vision_model: clipVisionModel,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
@ -85,7 +85,7 @@ export const addIPAdapterToLinearGraph = async (
|
|||||||
};
|
};
|
||||||
|
|
||||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
||||||
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
|
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, method, weight } = ipAdapter;
|
||||||
|
|
||||||
assert(model, 'IP Adapter model is required');
|
assert(model, 'IP Adapter model is required');
|
||||||
|
|
||||||
@ -103,6 +103,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
|
|||||||
ip_adapter_model: model,
|
ip_adapter_model: model,
|
||||||
clip_vision_model: clipVisionModel,
|
clip_vision_model: clipVisionModel,
|
||||||
weight,
|
weight,
|
||||||
|
method,
|
||||||
begin_step_percent: beginStepPct,
|
begin_step_percent: beginStepPct,
|
||||||
end_step_percent: endStepPct,
|
end_step_percent: endStepPct,
|
||||||
image,
|
image,
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user