Use metadata ip adapter (#4715)

* add control net to useRecallParams

* got recall controlnets working

* fix metadata viewer controlnet

* fix type errors

* fix controlnet metadata viewer

* add ip adapter to metadata

* added ip adapter to recall parameters

* got ip adapter recall working, still need to fix type errors

* fix type issues

* clean up logs

* python formatting

* cleanup

* fix(ui): only store `image_name` as ip adapter image

* fix(ui): use nullish coalescing operator for numbers

Need to use the nullish coalescing operator `??` instead of false-y coalescing operator `||` when the value being check is a number. This prevents unintended coalescing when the value is zero and therefore false-y.

* feat(ui): fall back on default values for ip adapter metadata

* fix(ui): remove unused schema

* feat(ui): re-use existing schemas in metadata schema

* fix(ui): do not disable invocationCache

---------

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
chainchompa 2023-09-28 05:05:32 -04:00 committed by GitHub
parent 309e2414ce
commit c7f80cd163
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 225 additions and 40 deletions

View File

@ -12,7 +12,9 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from ...version import __version__
@ -25,6 +27,18 @@ class LoRAMetadataField(BaseModelExcludeNull):
weight: float = Field(description="The weight of the LoRA model")
class IPAdapterMetadataField(BaseModelExcludeNull):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
weight: float = Field(description="The weight of the IP-Adapter model")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
class CoreMetadata(BaseModelExcludeNull):
"""Core generation metadata for an image generated in InvokeAI."""
@ -48,6 +62,7 @@ class CoreMetadata(BaseModelExcludeNull):
)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Optional[VAEModelField] = Field(
default=None,
@ -123,6 +138,7 @@ class MetadataAccumulatorInvocation(BaseInvocation):
)
model: MainModelField = InputField(description="The main model used for inference")
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
strength: Optional[float] = InputField(
default=None,

View File

@ -113,7 +113,7 @@ export const addRequestedSingleImageDeletionListener = () => {
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage?.image_name ===
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
@ -238,7 +238,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage?.image_name ===
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));

View File

@ -118,7 +118,7 @@ export const addImageDroppedListener = () => {
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO));
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
return;
}

View File

@ -111,7 +111,7 @@ export const addImageUploadedFulfilledListener = () => {
}
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
dispatch(ipAdapterImageChanged(imageDTO));
dispatch(ipAdapterImageChanged(imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
dispatch(
addToast({

View File

@ -33,7 +33,7 @@ const ParamIPAdapterImage = () => {
const { t } = useTranslation();
const { currentData: imageDTO } = useGetImageDTOQuery(
ipAdapterInfo.adapterImage?.image_name ?? skipToken
ipAdapterInfo.adapterImage ?? skipToken
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {

View File

@ -6,7 +6,6 @@ import {
import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema';
import { ImageDTO } from 'services/api/types';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
@ -60,7 +59,7 @@ export type ControlNetConfig = {
};
export type IPAdapterConfig = {
adapterImage: ImageDTO | null;
adapterImage: string | null;
model: IPAdapterModelParam | null;
weight: number;
beginStepPct: number;
@ -388,7 +387,10 @@ export const controlNetSlice = createSlice({
isIPAdapterEnabledChanged: (state, action: PayloadAction<boolean>) => {
state.isIPAdapterEnabled = action.payload;
},
ipAdapterImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
ipAdapterRecalled: (state, action: PayloadAction<IPAdapterConfig>) => {
state.ipAdapterInfo = action.payload;
},
ipAdapterImageChanged: (state, action: PayloadAction<string | null>) => {
state.ipAdapterInfo.adapterImage = action.payload;
},
ipAdapterWeightChanged: (state, action: PayloadAction<number>) => {
@ -471,6 +473,7 @@ export const {
controlNetReset,
controlNetAutoConfigToggled,
isIPAdapterEnabledChanged,
ipAdapterRecalled,
ipAdapterImageChanged,
ipAdapterWeightChanged,
ipAdapterModelChanged,

View File

@ -27,8 +27,7 @@ export const getImageUsage = (state: RootState, image_name: string) => {
c.controlImage === image_name || c.processedControlImage === image_name
);
const isIPAdapterImage =
controlNet.ipAdapterInfo.adapterImage?.image_name === image_name;
const isIPAdapterImage = controlNet.ipAdapterInfo.adapterImage === image_name;
const imageUsage: ImageUsage = {
isInitialImage,

View File

@ -2,6 +2,7 @@ import {
ControlNetMetadataItem,
CoreMetadata,
LoRAMetadataItem,
IPAdapterMetadataItem,
} from 'features/nodes/types/types';
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
import { memo, useMemo, useCallback } from 'react';
@ -34,6 +35,7 @@ const ImageMetadataActions = (props: Props) => {
recallStrength,
recallLoRA,
recallControlNet,
recallIPAdapter,
} = useRecallParameters();
const handleRecallPositivePrompt = useCallback(() => {
@ -90,6 +92,13 @@ const ImageMetadataActions = (props: Props) => {
[recallControlNet]
);
const handleRecallIPAdapter = useCallback(
(ipAdapter: IPAdapterMetadataItem) => {
recallIPAdapter(ipAdapter);
},
[recallIPAdapter]
);
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
return metadata?.controlnets
? metadata.controlnets.filter((controlnet) =>
@ -98,6 +107,14 @@ const ImageMetadataActions = (props: Props) => {
: [];
}, [metadata?.controlnets]);
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
return metadata?.ipAdapters
? metadata.ipAdapters.filter((ipAdapter) =>
isValidControlNetModel(ipAdapter.ip_adapter_model)
)
: [];
}, [metadata?.ipAdapters]);
if (!metadata || Object.keys(metadata).length === 0) {
return null;
}
@ -211,6 +228,14 @@ const ImageMetadataActions = (props: Props) => {
onClick={() => handleRecallControlNet(controlnet)}
/>
))}
{validIPAdapters.map((ipAdapter, index) => (
<ImageMetadataItem
key={index}
label="IP Adapter"
value={`${ipAdapter.ip_adapter_model?.model_name} - ${ipAdapter.weight}`}
onClick={() => handleRecallIPAdapter(ipAdapter)}
/>
))}
</>
);
};

View File

@ -412,8 +412,9 @@ export type IPAdapterModel = z.infer<typeof zIPAdapterModel>;
export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zIPAdapterModel,
image_encoder_model: z.string().trim().min(1),
weight: z.number(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
});
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
@ -1145,6 +1146,10 @@ const zControlNetMetadataItem = zControlField.deepPartial();
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
export const zCoreMetadata = z
.object({
app_version: z.string().nullish().catch(null),
@ -1164,16 +1169,9 @@ export const zCoreMetadata = z
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
.nullish()
.catch(null),
controlnets: z.array(zControlField.deepPartial()).nullish().catch(null),
loras: z
.array(
z.object({
lora: zLoRAModelField.deepPartial(),
weight: z.number(),
})
)
.nullish()
.catch(null),
controlnets: z.array(zControlNetMetadataItem).nullish().catch(null),
ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null),
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
vae: zVaeModelField.nullish().catch(null),
strength: z.number().nullish().catch(null),
init_image: z.string().nullish().catch(null),

View File

@ -1,7 +1,14 @@
import { RootState } from 'app/store/store';
import { IPAdapterInvocation } from 'services/api/types';
import {
IPAdapterInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { CANVAS_COHERENCE_DENOISE_LATENTS, IP_ADAPTER } from './constants';
import {
CANVAS_COHERENCE_DENOISE_LATENTS,
IP_ADAPTER,
METADATA_ACCUMULATOR,
} from './constants';
export const addIPAdapterToLinearGraph = (
state: RootState,
@ -10,9 +17,9 @@ export const addIPAdapterToLinearGraph = (
): void => {
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet;
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
// | MetadataAccumulatorInvocation
// | undefined;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (isIPAdapterEnabled && ipAdapterInfo.model) {
const ipAdapterNode: IPAdapterInvocation = {
@ -30,23 +37,29 @@ export const addIPAdapterToLinearGraph = (
if (ipAdapterInfo.adapterImage) {
ipAdapterNode.image = {
image_name: ipAdapterInfo.adapterImage.image_name,
image_name: ipAdapterInfo.adapterImage,
};
} else {
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
if (metadataAccumulator?.ipAdapters) {
const ipAdapterField = {
image: {
image_name: ipAdapterInfo.adapterImage,
},
ip_adapter_model: {
base_model: ipAdapterInfo.model?.base_model,
model_name: ipAdapterInfo.model?.model_name,
},
weight: ipAdapterInfo.weight,
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
};
// if (metadataAccumulator?.ip_adapters) {
// // metadata accumulator only needs the ip_adapter field - not the whole node
// // extract what we need and add to the accumulator
// const ipAdapterField = omit(ipAdapterNode, [
// 'id',
// 'type',
// ]) as IPAdapterField;
// metadataAccumulator.ip_adapters.push(ipAdapterField);
// }
metadataAccumulator.ipAdapters.push(ipAdapterField);
}
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },

View File

@ -327,6 +327,7 @@ export const buildCanvasImageToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.image_name,

View File

@ -338,6 +338,7 @@ export const buildCanvasSDXLImageToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
strength,
init_image: initialImage.image_name,
};

View File

@ -320,6 +320,7 @@ export const buildCanvasSDXLTextToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
};
graph.edges.push({

View File

@ -308,6 +308,7 @@ export const buildCanvasTextToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
clip_skip: clipSkip,
};

View File

@ -328,6 +328,7 @@ export const buildLinearImageToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
clip_skip: clipSkip,
strength,
init_image: initialImage.imageName,

View File

@ -348,6 +348,7 @@ export const buildLinearSDXLImageToImageGraph = (
vae: undefined,
controlnets: [],
loras: [],
ipAdapters: [],
strength: strength,
init_image: initialImage.imageName,
positive_style_prompt: positiveStylePrompt,

View File

@ -242,6 +242,7 @@ export const buildLinearSDXLTextToImageGraph = (
vae: undefined,
controlnets: [],
loras: [],
ipAdapters: [],
positive_style_prompt: positiveStylePrompt,
negative_style_prompt: negativeStylePrompt,
};

View File

@ -250,6 +250,7 @@ export const buildLinearTextToImageGraph = (
vae: undefined, // option; set in addVAEToGraph
controlnets: [], // populated in addControlNetToLinearGraph
loras: [], // populated in addLoRAsToGraph
ipAdapters: [], // populated in addIPAdapterToLinearGraph
clip_skip: clipSkip,
};

View File

@ -6,6 +6,7 @@ import {
CoreMetadata,
LoRAMetadataItem,
ControlNetMetadataItem,
IPAdapterMetadataItem,
} from 'features/nodes/types/types';
import {
refinerModelChanged,
@ -23,16 +24,22 @@ import { useTranslation } from 'react-i18next';
import { ImageDTO } from 'services/api/types';
import {
controlNetModelsAdapter,
ipAdapterModelsAdapter,
useGetIPAdapterModelsQuery,
loraModelsAdapter,
useGetControlNetModelsQuery,
useGetLoRAModelsQuery,
} from '../../../services/api/endpoints/models';
import {
ControlNetConfig,
IPAdapterConfig,
controlNetEnabled,
controlNetRecalled,
controlNetReset,
initialControlNet,
initialIPAdapterState,
ipAdapterRecalled,
isIPAdapterEnabledChanged,
} from '../../controlNet/store/controlNetSlice';
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
import { initialImageSelected, modelSelected } from '../store/actions';
@ -52,6 +59,7 @@ import {
isValidHeight,
isValidLoRAModel,
isValidControlNetModel,
isValidIPAdapterModel,
isValidMainModel,
isValidNegativePrompt,
isValidPositivePrompt,
@ -512,8 +520,6 @@ export const useRecallParameters = () => {
})
);
dispatch(controlNetEnabled());
parameterSetToast();
},
[
@ -524,6 +530,92 @@ export const useRecallParameters = () => {
]
);
/**
* Recall IP Adapter with toast
*/
const { ipAdapters } = useGetIPAdapterModelsQuery(undefined, {
selectFromResult: (result) => ({
ipAdapters: result.data
? ipAdapterModelsAdapter.getSelectors().selectAll(result.data)
: [],
}),
});
const prepareIPAdapterMetadataItem = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
if (!isValidIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
}
const {
image,
ip_adapter_model,
weight,
begin_step_percent,
end_step_percent,
} = ipAdapterMetadataItem;
const matchingIPAdapterModel = ipAdapters.find(
(c) =>
c.base_model === ip_adapter_model?.base_model &&
c.model_name === ip_adapter_model?.model_name
);
if (!matchingIPAdapterModel) {
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
}
const isCompatibleBaseModel =
matchingIPAdapterModel?.base_model === model?.base_model;
if (!isCompatibleBaseModel) {
return {
ipAdapter: null,
error: 'IP Adapter incompatible with currently-selected model',
};
}
const ipAdapter: IPAdapterConfig = {
adapterImage: image?.image_name ?? null,
model: matchingIPAdapterModel,
weight: weight ?? initialIPAdapterState.weight,
beginStepPct: begin_step_percent ?? initialIPAdapterState.beginStepPct,
endStepPct: end_step_percent ?? initialIPAdapterState.endStepPct,
};
return { ipAdapter, error: null };
},
[ipAdapters, model?.base_model]
);
const recallIPAdapter = useCallback(
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem);
if (!result.ipAdapter) {
parameterNotSetToast(result.error);
return;
}
dispatch(
ipAdapterRecalled({
...result.ipAdapter,
})
);
dispatch(isIPAdapterEnabledChanged(true));
parameterSetToast();
},
[
prepareIPAdapterMetadataItem,
dispatch,
parameterSetToast,
parameterNotSetToast,
]
);
/*
* Sets image as initial image with toast
*/
@ -563,6 +655,7 @@ export const useRecallParameters = () => {
refiner_start,
loras,
controlnets,
ipAdapters,
} = metadata;
if (isValidCfgScale(cfg_scale)) {
@ -653,7 +746,9 @@ export const useRecallParameters = () => {
});
dispatch(controlNetReset());
if (controlnets?.length) {
dispatch(controlNetEnabled());
}
controlnets?.forEach((controlnet) => {
const result = prepareControlNetMetadataItem(controlnet);
if (result.controlnet) {
@ -661,6 +756,16 @@ export const useRecallParameters = () => {
}
});
if (ipAdapters?.length) {
dispatch(isIPAdapterEnabledChanged(true));
}
ipAdapters?.forEach((ipAdapter) => {
const result = prepareIPAdapterMetadataItem(ipAdapter);
if (result.ipAdapter) {
dispatch(ipAdapterRecalled(result.ipAdapter));
}
});
allParameterSetToast();
},
[
@ -669,6 +774,7 @@ export const useRecallParameters = () => {
dispatch,
prepareLoRAMetadataItem,
prepareControlNetMetadataItem,
prepareIPAdapterMetadataItem,
]
);
@ -688,6 +794,7 @@ export const useRecallParameters = () => {
recallStrength,
recallLoRA,
recallControlNet,
recallIPAdapter,
recallAllParameters,
sendToImageToImage,
};

View File

@ -343,6 +343,12 @@ export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
/**
* Zod schema for l2l strength parameter
*/
/**
* Validates/type-guards a value as a model parameter
*/
export const isValidIPAdapterModel = (
val: unknown
): val is IPAdapterModelParam => zIPAdapterModel.safeParse(val).success;
export const zStrength = z.number().min(0).max(1);
/**
* Type alias for l2l strength parameter, inferred from its zod schema

View File

@ -2115,6 +2115,11 @@ export type components = {
* @description The ControlNets used for inference
*/
controlnets: components["schemas"]["ControlField"][];
/**
* Loras
* @description The LoRAs used for inference
*/
ipAdapters: components["schemas"]["IPAdapterField"][];
/**
* Loras
* @description The LoRAs used for inference
@ -3178,7 +3183,7 @@ export type components = {
* Image Encoder Model
* @description The name of the CLIP image encoder model.
*/
image_encoder_model: components["schemas"]["CLIPVisionModelField"];
image_encoder_model?: components["schemas"]["CLIPVisionModelField"];
/**
* Weight
* @description The weight given to the ControlNet
@ -5814,6 +5819,11 @@ export type components = {
* @description The LoRAs used for inference
*/
loras?: components["schemas"]["LoRAMetadataField"][];
/**
* Strength
* @description The strength used for latents-to-latents
*/
ipAdapters?: components["schemas"]["IPAdapterField"][];
/**
* Strength
* @description The strength used for latents-to-latents