mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Let users pick CLIP Vision model for Checkpoint IP Adapters
This commit is contained in:
@ -217,6 +217,7 @@
|
||||
"saveControlImage": "Save Control Image",
|
||||
"scribble": "scribble",
|
||||
"selectModel": "Select a model",
|
||||
"selectCLIPVisionModel": "Select a CLIP Vision model",
|
||||
"setControlImageDimensions": "Set Control Image Dimensions To W/H",
|
||||
"showAdvanced": "Show Advanced",
|
||||
"small": "Small",
|
||||
|
@ -1,12 +1,18 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
|
||||
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
|
||||
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
|
||||
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
|
||||
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
|
||||
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import {
|
||||
controlAdapterCLIPVisionModelChanged,
|
||||
controlAdapterModelChanged,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
|
||||
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@ -29,6 +35,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
const { modelConfig } = useControlAdapterModel(id);
|
||||
const dispatch = useAppDispatch();
|
||||
const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
|
||||
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
|
||||
const mainModel = useAppSelector(selectMainModel);
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -49,6 +56,16 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
[dispatch, id]
|
||||
);
|
||||
|
||||
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!v?.value) {
|
||||
return;
|
||||
}
|
||||
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
|
||||
},
|
||||
[dispatch, id]
|
||||
);
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
|
||||
[controlAdapterType, modelConfig]
|
||||
@ -71,17 +88,42 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
|
||||
isLoading,
|
||||
});
|
||||
|
||||
const clipVisionOptions = useMemo<ComboboxOption[]>(
|
||||
() => [
|
||||
{ label: 'ViT-H', value: 'ViT-H' },
|
||||
{ label: 'ViT-G', value: 'ViT-G' },
|
||||
],
|
||||
[]
|
||||
);
|
||||
|
||||
const clipVisionModel = useMemo(
|
||||
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
|
||||
[clipVisionOptions, currentCLIPVisionModel]
|
||||
);
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}>
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
<Flex flexDirection="row" gap={2}>
|
||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}>
|
||||
<Combobox
|
||||
options={options}
|
||||
placeholder={t('controlnet.selectModel')}
|
||||
value={value}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
|
||||
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}>
|
||||
<Combobox
|
||||
options={clipVisionOptions}
|
||||
placeholder={t('controlnet.selectCLIPVisionModel')}
|
||||
value={clipVisionModel}
|
||||
onChange={onCLIPVisionModelChange}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
@ -0,0 +1,24 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectControlAdapterById,
|
||||
selectControlAdaptersSlice,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useControlAdapterCLIPVisionModel = (id: string) => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
|
||||
const cn = selectControlAdapterById(controlAdapters, id);
|
||||
if (cn && cn?.type === 'ip_adapter') {
|
||||
return cn.clipVisionModel;
|
||||
}
|
||||
}),
|
||||
[id]
|
||||
);
|
||||
|
||||
const clipVisionModel = useAppSelector(selector);
|
||||
|
||||
return clipVisionModel;
|
||||
};
|
@ -13,6 +13,7 @@ import { v4 as uuidv4 } from 'uuid';
|
||||
import { controlAdapterImageProcessed } from './actions';
|
||||
import { CONTROLNET_PROCESSORS } from './constants';
|
||||
import type {
|
||||
CLIPVisionModel,
|
||||
ControlAdapterConfig,
|
||||
ControlAdapterProcessorType,
|
||||
ControlAdaptersState,
|
||||
@ -243,6 +244,13 @@ export const controlAdaptersSlice = createSlice({
|
||||
}
|
||||
caAdapter.updateOne(state, { id, changes: { controlMode } });
|
||||
},
|
||||
controlAdapterCLIPVisionModelChanged: (
|
||||
state,
|
||||
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
|
||||
) => {
|
||||
const { id, clipVisionModel } = action.payload;
|
||||
caAdapter.updateOne(state, { id, changes: { clipVisionModel } });
|
||||
},
|
||||
controlAdapterResizeModeChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@ -380,6 +388,7 @@ export const {
|
||||
controlAdapterProcessedImageChanged,
|
||||
controlAdapterIsEnabledChanged,
|
||||
controlAdapterModelChanged,
|
||||
controlAdapterCLIPVisionModelChanged,
|
||||
controlAdapterWeightChanged,
|
||||
controlAdapterBeginStepPctChanged,
|
||||
controlAdapterEndStepPctChanged,
|
||||
|
@ -243,12 +243,15 @@ export type T2IAdapterConfig = {
|
||||
shouldAutoConfig: boolean;
|
||||
};
|
||||
|
||||
export type CLIPVisionModel = 'ViT-H' | 'ViT-G';
|
||||
|
||||
export type IPAdapterConfig = {
|
||||
type: 'ip_adapter';
|
||||
id: string;
|
||||
isEnabled: boolean;
|
||||
controlImage: string | null;
|
||||
model: ParameterIPAdapterModel | null;
|
||||
clipVisionModel: CLIPVisionModel;
|
||||
weight: number;
|
||||
beginStepPct: number;
|
||||
endStepPct: number;
|
||||
|
@ -45,6 +45,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
|
||||
isEnabled: true,
|
||||
controlImage: null,
|
||||
model: null,
|
||||
clipVisionModel: 'ViT-H',
|
||||
weight: 1,
|
||||
beginStepPct: 0,
|
||||
endStepPct: 1,
|
||||
|
@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
if (!ipAdapter.model) {
|
||||
return;
|
||||
}
|
||||
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||
const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
|
||||
|
||||
assert(controlImage, 'IP Adapter image is required');
|
||||
|
||||
@ -58,6 +58,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
is_intermediate: true,
|
||||
weight: weight,
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
image: {
|
||||
@ -83,7 +84,7 @@ export const addIPAdapterToLinearGraph = async (
|
||||
};
|
||||
|
||||
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
|
||||
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter;
|
||||
const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
|
||||
|
||||
assert(model, 'IP Adapter model is required');
|
||||
|
||||
@ -99,6 +100,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
|
||||
|
||||
return {
|
||||
ip_adapter_model: model,
|
||||
clip_vision_model: clipVisionModel,
|
||||
weight,
|
||||
begin_step_percent: beginStepPct,
|
||||
end_step_percent: endStepPct,
|
||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user