feat: Let users pick CLIP Vision model for Checkpoint IP Adapters

This commit is contained in:
blessedcoolant 2024-03-27 20:32:41 +05:30
parent 688a0f30bb
commit 16c366a060
10 changed files with 145 additions and 46 deletions

View File

@ -1,5 +1,5 @@
from builtins import float from builtins import float
from typing import List, Union from typing import List, Literal, Union
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
@ -49,12 +49,15 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2") @invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes.""" """Collects IP-Adapter info to pass to other nodes."""
# Inputs # Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).") image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).", ui_order=1)
ip_adapter_model: ModelIdentifierField = InputField( ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.", description="The IP-Adapter model.",
title="IP-Adapter Model", title="IP-Adapter Model",
@ -62,7 +65,9 @@ class IPAdapterInvocation(BaseInvocation):
ui_order=-1, ui_order=-1,
ui_type=UIType.IPAdapterModel, ui_type=UIType.IPAdapterModel,
) )
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
description="CLIP Vision model to use", default="ViT-H", ui_order=2
)
weight: Union[float, List[float]] = InputField( weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight" default=1, description="The weight given to the IP-Adapter", title="Weight"
) )
@ -89,12 +94,12 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterDiffusersConfig, IPAdapterCheckpointConfig)) assert isinstance(ip_adapter_info, (IPAdapterDiffusersConfig, IPAdapterCheckpointConfig))
image_encoder_model_id = ( if isinstance(ip_adapter_info, IPAdapterDiffusersConfig):
ip_adapter_info.image_encoder_model_id image_encoder_model_id = ip_adapter_info.image_encoder_model_id
if isinstance(ip_adapter_info, IPAdapterDiffusersConfig) image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else "ip_adapter_sd_image_encoder" else:
) image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name) image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
return IPAdapterOutput( return IPAdapterOutput(
@ -109,19 +114,25 @@ class IPAdapterInvocation(BaseInvocation):
) )
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig: def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
found = False image_encoder_models = context.models.search_by_attrs(
while not found: name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
if not len(image_encoder_models) > 0:
context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
Downloading and installing now. This may take a while."
)
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
image_encoder_models = context.models.search_by_attrs( image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
) )
found = len(image_encoder_models) > 0
if not found: if len(image_encoder_models) == 0:
context.logger.warning( context.logger.error("Error while fetching CLIP Vision Image Encoder")
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed." assert len(image_encoder_models) == 1
)
context.logger.warning("Downloading and installing now. This may take a while.")
installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException
assert len(image_encoder_models) == 1
return image_encoder_models[0] return image_encoder_models[0]

View File

@ -2,16 +2,8 @@ from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
BaseInvocation, from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
@ -43,6 +35,7 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.") image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.") ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter") weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)") begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)") end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")

View File

@ -217,6 +217,7 @@
"saveControlImage": "Save Control Image", "saveControlImage": "Save Control Image",
"scribble": "scribble", "scribble": "scribble",
"selectModel": "Select a model", "selectModel": "Select a model",
"selectCLIPVisionModel": "Select a CLIP Vision model",
"setControlImageDimensions": "Set Control Image Dimensions To W/H", "setControlImageDimensions": "Set Control Image Dimensions To W/H",
"showAdvanced": "Show Advanced", "showAdvanced": "Show Advanced",
"small": "Small", "small": "Small",

View File

@ -1,12 +1,18 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled'; import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel'; import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels'; import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType'; import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import {
controlAdapterCLIPVisionModelChanged,
controlAdapterModelChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -29,6 +35,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const { modelConfig } = useControlAdapterModel(id); const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
const mainModel = useAppSelector(selectMainModel); const mainModel = useAppSelector(selectMainModel);
const { t } = useTranslation(); const { t } = useTranslation();
@ -49,6 +56,16 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
[dispatch, id] [dispatch, id]
); );
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v?.value) {
return;
}
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
},
[dispatch, id]
);
const selectedModel = useMemo( const selectedModel = useMemo(
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null), () => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
[controlAdapterType, modelConfig] [controlAdapterType, modelConfig]
@ -71,17 +88,42 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
isLoading, isLoading,
}); });
const clipVisionOptions = useMemo<ComboboxOption[]>(
() => [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
],
[]
);
const clipVisionModel = useMemo(
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
[clipVisionOptions, currentCLIPVisionModel]
);
return ( return (
<Tooltip label={value?.description}> <Tooltip label={value?.description}>
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}> <Flex flexDirection="row" gap={2}>
<Combobox <FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}>
options={options} <Combobox
placeholder={t('controlnet.selectModel')} options={options}
value={value} placeholder={t('controlnet.selectModel')}
onChange={onChange} value={value}
noOptionsMessage={noOptionsMessage} onChange={onChange}
/> noOptionsMessage={noOptionsMessage}
</FormControl> />
</FormControl>
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}>
<Combobox
options={clipVisionOptions}
placeholder={t('controlnet.selectCLIPVisionModel')}
value={clipVisionModel}
onChange={onCLIPVisionModelChange}
/>
</FormControl>
)}
</Flex>
</Tooltip> </Tooltip>
); );
}; };

View File

@ -0,0 +1,24 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectControlAdapterById,
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
export const useControlAdapterCLIPVisionModel = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
const cn = selectControlAdapterById(controlAdapters, id);
if (cn && cn?.type === 'ip_adapter') {
return cn.clipVisionModel;
}
}),
[id]
);
const clipVisionModel = useAppSelector(selector);
return clipVisionModel;
};

View File

@ -13,6 +13,7 @@ import { v4 as uuidv4 } from 'uuid';
import { controlAdapterImageProcessed } from './actions'; import { controlAdapterImageProcessed } from './actions';
import { CONTROLNET_PROCESSORS } from './constants'; import { CONTROLNET_PROCESSORS } from './constants';
import type { import type {
CLIPVisionModel,
ControlAdapterConfig, ControlAdapterConfig,
ControlAdapterProcessorType, ControlAdapterProcessorType,
ControlAdaptersState, ControlAdaptersState,
@ -243,6 +244,13 @@ export const controlAdaptersSlice = createSlice({
} }
caAdapter.updateOne(state, { id, changes: { controlMode } }); caAdapter.updateOne(state, { id, changes: { controlMode } });
}, },
controlAdapterCLIPVisionModelChanged: (
state,
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
) => {
const { id, clipVisionModel } = action.payload;
caAdapter.updateOne(state, { id, changes: { clipVisionModel } });
},
controlAdapterResizeModeChanged: ( controlAdapterResizeModeChanged: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -380,6 +388,7 @@ export const {
controlAdapterProcessedImageChanged, controlAdapterProcessedImageChanged,
controlAdapterIsEnabledChanged, controlAdapterIsEnabledChanged,
controlAdapterModelChanged, controlAdapterModelChanged,
controlAdapterCLIPVisionModelChanged,
controlAdapterWeightChanged, controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged, controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged, controlAdapterEndStepPctChanged,

View File

@ -243,12 +243,15 @@ export type T2IAdapterConfig = {
shouldAutoConfig: boolean; shouldAutoConfig: boolean;
}; };
export type CLIPVisionModel = 'ViT-H' | 'ViT-G';
export type IPAdapterConfig = { export type IPAdapterConfig = {
type: 'ip_adapter'; type: 'ip_adapter';
id: string; id: string;
isEnabled: boolean; isEnabled: boolean;
controlImage: string | null; controlImage: string | null;
model: ParameterIPAdapterModel | null; model: ParameterIPAdapterModel | null;
clipVisionModel: CLIPVisionModel;
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;

View File

@ -45,6 +45,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
isEnabled: true, isEnabled: true,
controlImage: null, controlImage: null,
model: null, model: null,
clipVisionModel: 'ViT-H',
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,

View File

@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
if (!ipAdapter.model) { if (!ipAdapter.model) {
return; return;
} }
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter; const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required'); assert(controlImage, 'IP Adapter image is required');
@ -58,6 +58,7 @@ export const addIPAdapterToLinearGraph = async (
is_intermediate: true, is_intermediate: true,
weight: weight, weight: weight,
ip_adapter_model: model, ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,
image: { image: {
@ -83,7 +84,7 @@ export const addIPAdapterToLinearGraph = async (
}; };
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => { const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter; const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
@ -99,6 +100,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
return { return {
ip_adapter_model: model, ip_adapter_model: model,
clip_vision_model: clipVisionModel,
weight, weight,
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,

File diff suppressed because one or more lines are too long