mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use metadata ip adapter (#4715)
* add control net to useRecallParams * got recall controlnets working * fix metadata viewer controlnet * fix type errors * fix controlnet metadata viewer * add ip adapter to metadata * added ip adapter to recall parameters * got ip adapter recall working, still need to fix type errors * fix type issues * clean up logs * python formatting * cleanup * fix(ui): only store `image_name` as ip adapter image * fix(ui): use nullish coalescing operator for numbers Need to use the nullish coalescing operator `??` instead of false-y coalescing operator `||` when the value being check is a number. This prevents unintended coalescing when the value is zero and therefore false-y. * feat(ui): fall back on default values for ip adapter metadata * fix(ui): remove unused schema * feat(ui): re-use existing schemas in metadata schema * fix(ui): do not disable invocationCache --------- Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
parent
309e2414ce
commit
c7f80cd163
@ -12,7 +12,9 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
|
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
from ...version import __version__
|
from ...version import __version__
|
||||||
@ -25,6 +27,18 @@ class LoRAMetadataField(BaseModelExcludeNull):
|
|||||||
weight: float = Field(description="The weight of the LoRA model")
|
weight: float = Field(description="The weight of the LoRA model")
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterMetadataField(BaseModelExcludeNull):
|
||||||
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
|
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||||
|
weight: float = Field(description="The weight of the IP-Adapter model")
|
||||||
|
begin_step_percent: float = Field(
|
||||||
|
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
|
||||||
|
)
|
||||||
|
end_step_percent: float = Field(
|
||||||
|
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CoreMetadata(BaseModelExcludeNull):
|
class CoreMetadata(BaseModelExcludeNull):
|
||||||
"""Core generation metadata for an image generated in InvokeAI."""
|
"""Core generation metadata for an image generated in InvokeAI."""
|
||||||
|
|
||||||
@ -48,6 +62,7 @@ class CoreMetadata(BaseModelExcludeNull):
|
|||||||
)
|
)
|
||||||
model: MainModelField = Field(description="The main model used for inference")
|
model: MainModelField = Field(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
|
||||||
|
ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference")
|
||||||
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
|
||||||
vae: Optional[VAEModelField] = Field(
|
vae: Optional[VAEModelField] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -123,6 +138,7 @@ class MetadataAccumulatorInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
model: MainModelField = InputField(description="The main model used for inference")
|
model: MainModelField = InputField(description="The main model used for inference")
|
||||||
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
|
||||||
|
ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference")
|
||||||
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
|
||||||
strength: Optional[float] = InputField(
|
strength: Optional[float] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -113,7 +113,7 @@ export const addRequestedSingleImageDeletionListener = () => {
|
|||||||
|
|
||||||
// Remove IP Adapter Set Image if image is deleted.
|
// Remove IP Adapter Set Image if image is deleted.
|
||||||
if (
|
if (
|
||||||
getState().controlNet.ipAdapterInfo.adapterImage?.image_name ===
|
getState().controlNet.ipAdapterInfo.adapterImage ===
|
||||||
imageDTO.image_name
|
imageDTO.image_name
|
||||||
) {
|
) {
|
||||||
dispatch(ipAdapterImageChanged(null));
|
dispatch(ipAdapterImageChanged(null));
|
||||||
@ -238,7 +238,7 @@ export const addRequestedMultipleImageDeletionListener = () => {
|
|||||||
|
|
||||||
// Remove IP Adapter Set Image if image is deleted.
|
// Remove IP Adapter Set Image if image is deleted.
|
||||||
if (
|
if (
|
||||||
getState().controlNet.ipAdapterInfo.adapterImage?.image_name ===
|
getState().controlNet.ipAdapterInfo.adapterImage ===
|
||||||
imageDTO.image_name
|
imageDTO.image_name
|
||||||
) {
|
) {
|
||||||
dispatch(ipAdapterImageChanged(null));
|
dispatch(ipAdapterImageChanged(null));
|
||||||
|
@ -118,7 +118,7 @@ export const addImageDroppedListener = () => {
|
|||||||
activeData.payloadType === 'IMAGE_DTO' &&
|
activeData.payloadType === 'IMAGE_DTO' &&
|
||||||
activeData.payload.imageDTO
|
activeData.payload.imageDTO
|
||||||
) {
|
) {
|
||||||
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO));
|
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO.image_name));
|
||||||
dispatch(isIPAdapterEnabledChanged(true));
|
dispatch(isIPAdapterEnabledChanged(true));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -111,7 +111,7 @@ export const addImageUploadedFulfilledListener = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
|
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
|
||||||
dispatch(ipAdapterImageChanged(imageDTO));
|
dispatch(ipAdapterImageChanged(imageDTO.image_name));
|
||||||
dispatch(isIPAdapterEnabledChanged(true));
|
dispatch(isIPAdapterEnabledChanged(true));
|
||||||
dispatch(
|
dispatch(
|
||||||
addToast({
|
addToast({
|
||||||
|
@ -33,7 +33,7 @@ const ParamIPAdapterImage = () => {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
const { currentData: imageDTO } = useGetImageDTOQuery(
|
const { currentData: imageDTO } = useGetImageDTOQuery(
|
||||||
ipAdapterInfo.adapterImage?.image_name ?? skipToken
|
ipAdapterInfo.adapterImage ?? skipToken
|
||||||
);
|
);
|
||||||
|
|
||||||
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
||||||
|
@ -6,7 +6,6 @@ import {
|
|||||||
import { cloneDeep, forEach } from 'lodash-es';
|
import { cloneDeep, forEach } from 'lodash-es';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
import { components } from 'services/api/schema';
|
import { components } from 'services/api/schema';
|
||||||
import { ImageDTO } from 'services/api/types';
|
|
||||||
import { appSocketInvocationError } from 'services/events/actions';
|
import { appSocketInvocationError } from 'services/events/actions';
|
||||||
import { controlNetImageProcessed } from './actions';
|
import { controlNetImageProcessed } from './actions';
|
||||||
import {
|
import {
|
||||||
@ -60,7 +59,7 @@ export type ControlNetConfig = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type IPAdapterConfig = {
|
export type IPAdapterConfig = {
|
||||||
adapterImage: ImageDTO | null;
|
adapterImage: string | null;
|
||||||
model: IPAdapterModelParam | null;
|
model: IPAdapterModelParam | null;
|
||||||
weight: number;
|
weight: number;
|
||||||
beginStepPct: number;
|
beginStepPct: number;
|
||||||
@ -388,7 +387,10 @@ export const controlNetSlice = createSlice({
|
|||||||
isIPAdapterEnabledChanged: (state, action: PayloadAction<boolean>) => {
|
isIPAdapterEnabledChanged: (state, action: PayloadAction<boolean>) => {
|
||||||
state.isIPAdapterEnabled = action.payload;
|
state.isIPAdapterEnabled = action.payload;
|
||||||
},
|
},
|
||||||
ipAdapterImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
|
ipAdapterRecalled: (state, action: PayloadAction<IPAdapterConfig>) => {
|
||||||
|
state.ipAdapterInfo = action.payload;
|
||||||
|
},
|
||||||
|
ipAdapterImageChanged: (state, action: PayloadAction<string | null>) => {
|
||||||
state.ipAdapterInfo.adapterImage = action.payload;
|
state.ipAdapterInfo.adapterImage = action.payload;
|
||||||
},
|
},
|
||||||
ipAdapterWeightChanged: (state, action: PayloadAction<number>) => {
|
ipAdapterWeightChanged: (state, action: PayloadAction<number>) => {
|
||||||
@ -471,6 +473,7 @@ export const {
|
|||||||
controlNetReset,
|
controlNetReset,
|
||||||
controlNetAutoConfigToggled,
|
controlNetAutoConfigToggled,
|
||||||
isIPAdapterEnabledChanged,
|
isIPAdapterEnabledChanged,
|
||||||
|
ipAdapterRecalled,
|
||||||
ipAdapterImageChanged,
|
ipAdapterImageChanged,
|
||||||
ipAdapterWeightChanged,
|
ipAdapterWeightChanged,
|
||||||
ipAdapterModelChanged,
|
ipAdapterModelChanged,
|
||||||
|
@ -27,8 +27,7 @@ export const getImageUsage = (state: RootState, image_name: string) => {
|
|||||||
c.controlImage === image_name || c.processedControlImage === image_name
|
c.controlImage === image_name || c.processedControlImage === image_name
|
||||||
);
|
);
|
||||||
|
|
||||||
const isIPAdapterImage =
|
const isIPAdapterImage = controlNet.ipAdapterInfo.adapterImage === image_name;
|
||||||
controlNet.ipAdapterInfo.adapterImage?.image_name === image_name;
|
|
||||||
|
|
||||||
const imageUsage: ImageUsage = {
|
const imageUsage: ImageUsage = {
|
||||||
isInitialImage,
|
isInitialImage,
|
||||||
|
@ -2,6 +2,7 @@ import {
|
|||||||
ControlNetMetadataItem,
|
ControlNetMetadataItem,
|
||||||
CoreMetadata,
|
CoreMetadata,
|
||||||
LoRAMetadataItem,
|
LoRAMetadataItem,
|
||||||
|
IPAdapterMetadataItem,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters';
|
||||||
import { memo, useMemo, useCallback } from 'react';
|
import { memo, useMemo, useCallback } from 'react';
|
||||||
@ -34,6 +35,7 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
recallStrength,
|
recallStrength,
|
||||||
recallLoRA,
|
recallLoRA,
|
||||||
recallControlNet,
|
recallControlNet,
|
||||||
|
recallIPAdapter,
|
||||||
} = useRecallParameters();
|
} = useRecallParameters();
|
||||||
|
|
||||||
const handleRecallPositivePrompt = useCallback(() => {
|
const handleRecallPositivePrompt = useCallback(() => {
|
||||||
@ -90,6 +92,13 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
[recallControlNet]
|
[recallControlNet]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleRecallIPAdapter = useCallback(
|
||||||
|
(ipAdapter: IPAdapterMetadataItem) => {
|
||||||
|
recallIPAdapter(ipAdapter);
|
||||||
|
},
|
||||||
|
[recallIPAdapter]
|
||||||
|
);
|
||||||
|
|
||||||
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
const validControlNets: ControlNetMetadataItem[] = useMemo(() => {
|
||||||
return metadata?.controlnets
|
return metadata?.controlnets
|
||||||
? metadata.controlnets.filter((controlnet) =>
|
? metadata.controlnets.filter((controlnet) =>
|
||||||
@ -98,6 +107,14 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
: [];
|
: [];
|
||||||
}, [metadata?.controlnets]);
|
}, [metadata?.controlnets]);
|
||||||
|
|
||||||
|
const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => {
|
||||||
|
return metadata?.ipAdapters
|
||||||
|
? metadata.ipAdapters.filter((ipAdapter) =>
|
||||||
|
isValidControlNetModel(ipAdapter.ip_adapter_model)
|
||||||
|
)
|
||||||
|
: [];
|
||||||
|
}, [metadata?.ipAdapters]);
|
||||||
|
|
||||||
if (!metadata || Object.keys(metadata).length === 0) {
|
if (!metadata || Object.keys(metadata).length === 0) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@ -211,6 +228,14 @@ const ImageMetadataActions = (props: Props) => {
|
|||||||
onClick={() => handleRecallControlNet(controlnet)}
|
onClick={() => handleRecallControlNet(controlnet)}
|
||||||
/>
|
/>
|
||||||
))}
|
))}
|
||||||
|
{validIPAdapters.map((ipAdapter, index) => (
|
||||||
|
<ImageMetadataItem
|
||||||
|
key={index}
|
||||||
|
label="IP Adapter"
|
||||||
|
value={`${ipAdapter.ip_adapter_model?.model_name} - ${ipAdapter.weight}`}
|
||||||
|
onClick={() => handleRecallIPAdapter(ipAdapter)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -412,8 +412,9 @@ export type IPAdapterModel = z.infer<typeof zIPAdapterModel>;
|
|||||||
export const zIPAdapterField = z.object({
|
export const zIPAdapterField = z.object({
|
||||||
image: zImageField,
|
image: zImageField,
|
||||||
ip_adapter_model: zIPAdapterModel,
|
ip_adapter_model: zIPAdapterModel,
|
||||||
image_encoder_model: z.string().trim().min(1),
|
|
||||||
weight: z.number(),
|
weight: z.number(),
|
||||||
|
begin_step_percent: z.number().optional(),
|
||||||
|
end_step_percent: z.number().optional(),
|
||||||
});
|
});
|
||||||
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
||||||
|
|
||||||
@ -1145,6 +1146,10 @@ const zControlNetMetadataItem = zControlField.deepPartial();
|
|||||||
|
|
||||||
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
export type ControlNetMetadataItem = z.infer<typeof zControlNetMetadataItem>;
|
||||||
|
|
||||||
|
const zIPAdapterMetadataItem = zIPAdapterField.deepPartial();
|
||||||
|
|
||||||
|
export type IPAdapterMetadataItem = z.infer<typeof zIPAdapterMetadataItem>;
|
||||||
|
|
||||||
export const zCoreMetadata = z
|
export const zCoreMetadata = z
|
||||||
.object({
|
.object({
|
||||||
app_version: z.string().nullish().catch(null),
|
app_version: z.string().nullish().catch(null),
|
||||||
@ -1164,16 +1169,9 @@ export const zCoreMetadata = z
|
|||||||
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
.union([zMainModel.deepPartial(), zOnnxModel.deepPartial()])
|
||||||
.nullish()
|
.nullish()
|
||||||
.catch(null),
|
.catch(null),
|
||||||
controlnets: z.array(zControlField.deepPartial()).nullish().catch(null),
|
controlnets: z.array(zControlNetMetadataItem).nullish().catch(null),
|
||||||
loras: z
|
ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null),
|
||||||
.array(
|
loras: z.array(zLoRAMetadataItem).nullish().catch(null),
|
||||||
z.object({
|
|
||||||
lora: zLoRAModelField.deepPartial(),
|
|
||||||
weight: z.number(),
|
|
||||||
})
|
|
||||||
)
|
|
||||||
.nullish()
|
|
||||||
.catch(null),
|
|
||||||
vae: zVaeModelField.nullish().catch(null),
|
vae: zVaeModelField.nullish().catch(null),
|
||||||
strength: z.number().nullish().catch(null),
|
strength: z.number().nullish().catch(null),
|
||||||
init_image: z.string().nullish().catch(null),
|
init_image: z.string().nullish().catch(null),
|
||||||
|
@ -1,7 +1,14 @@
|
|||||||
import { RootState } from 'app/store/store';
|
import { RootState } from 'app/store/store';
|
||||||
import { IPAdapterInvocation } from 'services/api/types';
|
import {
|
||||||
|
IPAdapterInvocation,
|
||||||
|
MetadataAccumulatorInvocation,
|
||||||
|
} from 'services/api/types';
|
||||||
import { NonNullableGraph } from '../../types/types';
|
import { NonNullableGraph } from '../../types/types';
|
||||||
import { CANVAS_COHERENCE_DENOISE_LATENTS, IP_ADAPTER } from './constants';
|
import {
|
||||||
|
CANVAS_COHERENCE_DENOISE_LATENTS,
|
||||||
|
IP_ADAPTER,
|
||||||
|
METADATA_ACCUMULATOR,
|
||||||
|
} from './constants';
|
||||||
|
|
||||||
export const addIPAdapterToLinearGraph = (
|
export const addIPAdapterToLinearGraph = (
|
||||||
state: RootState,
|
state: RootState,
|
||||||
@ -10,9 +17,9 @@ export const addIPAdapterToLinearGraph = (
|
|||||||
): void => {
|
): void => {
|
||||||
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet;
|
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet;
|
||||||
|
|
||||||
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
|
||||||
// | MetadataAccumulatorInvocation
|
| MetadataAccumulatorInvocation
|
||||||
// | undefined;
|
| undefined;
|
||||||
|
|
||||||
if (isIPAdapterEnabled && ipAdapterInfo.model) {
|
if (isIPAdapterEnabled && ipAdapterInfo.model) {
|
||||||
const ipAdapterNode: IPAdapterInvocation = {
|
const ipAdapterNode: IPAdapterInvocation = {
|
||||||
@ -30,23 +37,29 @@ export const addIPAdapterToLinearGraph = (
|
|||||||
|
|
||||||
if (ipAdapterInfo.adapterImage) {
|
if (ipAdapterInfo.adapterImage) {
|
||||||
ipAdapterNode.image = {
|
ipAdapterNode.image = {
|
||||||
image_name: ipAdapterInfo.adapterImage.image_name,
|
image_name: ipAdapterInfo.adapterImage,
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
|
||||||
|
if (metadataAccumulator?.ipAdapters) {
|
||||||
|
const ipAdapterField = {
|
||||||
|
image: {
|
||||||
|
image_name: ipAdapterInfo.adapterImage,
|
||||||
|
},
|
||||||
|
ip_adapter_model: {
|
||||||
|
base_model: ipAdapterInfo.model?.base_model,
|
||||||
|
model_name: ipAdapterInfo.model?.model_name,
|
||||||
|
},
|
||||||
|
weight: ipAdapterInfo.weight,
|
||||||
|
begin_step_percent: ipAdapterInfo.beginStepPct,
|
||||||
|
end_step_percent: ipAdapterInfo.endStepPct,
|
||||||
|
};
|
||||||
|
|
||||||
// if (metadataAccumulator?.ip_adapters) {
|
metadataAccumulator.ipAdapters.push(ipAdapterField);
|
||||||
// // metadata accumulator only needs the ip_adapter field - not the whole node
|
}
|
||||||
// // extract what we need and add to the accumulator
|
|
||||||
// const ipAdapterField = omit(ipAdapterNode, [
|
|
||||||
// 'id',
|
|
||||||
// 'type',
|
|
||||||
// ]) as IPAdapterField;
|
|
||||||
// metadataAccumulator.ip_adapters.push(ipAdapterField);
|
|
||||||
// }
|
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
|
||||||
|
@ -327,6 +327,7 @@ export const buildCanvasImageToImageGraph = (
|
|||||||
vae: undefined, // option; set in addVAEToGraph
|
vae: undefined, // option; set in addVAEToGraph
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
controlnets: [], // populated in addControlNetToLinearGraph
|
||||||
loras: [], // populated in addLoRAsToGraph
|
loras: [], // populated in addLoRAsToGraph
|
||||||
|
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||||
clip_skip: clipSkip,
|
clip_skip: clipSkip,
|
||||||
strength,
|
strength,
|
||||||
init_image: initialImage.image_name,
|
init_image: initialImage.image_name,
|
||||||
|
@ -338,6 +338,7 @@ export const buildCanvasSDXLImageToImageGraph = (
|
|||||||
vae: undefined, // option; set in addVAEToGraph
|
vae: undefined, // option; set in addVAEToGraph
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
controlnets: [], // populated in addControlNetToLinearGraph
|
||||||
loras: [], // populated in addLoRAsToGraph
|
loras: [], // populated in addLoRAsToGraph
|
||||||
|
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||||
strength,
|
strength,
|
||||||
init_image: initialImage.image_name,
|
init_image: initialImage.image_name,
|
||||||
};
|
};
|
||||||
|
@ -320,6 +320,7 @@ export const buildCanvasSDXLTextToImageGraph = (
|
|||||||
vae: undefined, // option; set in addVAEToGraph
|
vae: undefined, // option; set in addVAEToGraph
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
controlnets: [], // populated in addControlNetToLinearGraph
|
||||||
loras: [], // populated in addLoRAsToGraph
|
loras: [], // populated in addLoRAsToGraph
|
||||||
|
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||||
};
|
};
|
||||||
|
|
||||||
graph.edges.push({
|
graph.edges.push({
|
||||||
|
@ -308,6 +308,7 @@ export const buildCanvasTextToImageGraph = (
|
|||||||
vae: undefined, // option; set in addVAEToGraph
|
vae: undefined, // option; set in addVAEToGraph
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
controlnets: [], // populated in addControlNetToLinearGraph
|
||||||
loras: [], // populated in addLoRAsToGraph
|
loras: [], // populated in addLoRAsToGraph
|
||||||
|
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||||
clip_skip: clipSkip,
|
clip_skip: clipSkip,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -328,6 +328,7 @@ export const buildLinearImageToImageGraph = (
|
|||||||
vae: undefined, // option; set in addVAEToGraph
|
vae: undefined, // option; set in addVAEToGraph
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
controlnets: [], // populated in addControlNetToLinearGraph
|
||||||
loras: [], // populated in addLoRAsToGraph
|
loras: [], // populated in addLoRAsToGraph
|
||||||
|
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||||
clip_skip: clipSkip,
|
clip_skip: clipSkip,
|
||||||
strength,
|
strength,
|
||||||
init_image: initialImage.imageName,
|
init_image: initialImage.imageName,
|
||||||
|
@ -348,6 +348,7 @@ export const buildLinearSDXLImageToImageGraph = (
|
|||||||
vae: undefined,
|
vae: undefined,
|
||||||
controlnets: [],
|
controlnets: [],
|
||||||
loras: [],
|
loras: [],
|
||||||
|
ipAdapters: [],
|
||||||
strength: strength,
|
strength: strength,
|
||||||
init_image: initialImage.imageName,
|
init_image: initialImage.imageName,
|
||||||
positive_style_prompt: positiveStylePrompt,
|
positive_style_prompt: positiveStylePrompt,
|
||||||
|
@ -242,6 +242,7 @@ export const buildLinearSDXLTextToImageGraph = (
|
|||||||
vae: undefined,
|
vae: undefined,
|
||||||
controlnets: [],
|
controlnets: [],
|
||||||
loras: [],
|
loras: [],
|
||||||
|
ipAdapters: [],
|
||||||
positive_style_prompt: positiveStylePrompt,
|
positive_style_prompt: positiveStylePrompt,
|
||||||
negative_style_prompt: negativeStylePrompt,
|
negative_style_prompt: negativeStylePrompt,
|
||||||
};
|
};
|
||||||
|
@ -250,6 +250,7 @@ export const buildLinearTextToImageGraph = (
|
|||||||
vae: undefined, // option; set in addVAEToGraph
|
vae: undefined, // option; set in addVAEToGraph
|
||||||
controlnets: [], // populated in addControlNetToLinearGraph
|
controlnets: [], // populated in addControlNetToLinearGraph
|
||||||
loras: [], // populated in addLoRAsToGraph
|
loras: [], // populated in addLoRAsToGraph
|
||||||
|
ipAdapters: [], // populated in addIPAdapterToLinearGraph
|
||||||
clip_skip: clipSkip,
|
clip_skip: clipSkip,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import {
|
|||||||
CoreMetadata,
|
CoreMetadata,
|
||||||
LoRAMetadataItem,
|
LoRAMetadataItem,
|
||||||
ControlNetMetadataItem,
|
ControlNetMetadataItem,
|
||||||
|
IPAdapterMetadataItem,
|
||||||
} from 'features/nodes/types/types';
|
} from 'features/nodes/types/types';
|
||||||
import {
|
import {
|
||||||
refinerModelChanged,
|
refinerModelChanged,
|
||||||
@ -23,16 +24,22 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import { ImageDTO } from 'services/api/types';
|
import { ImageDTO } from 'services/api/types';
|
||||||
import {
|
import {
|
||||||
controlNetModelsAdapter,
|
controlNetModelsAdapter,
|
||||||
|
ipAdapterModelsAdapter,
|
||||||
|
useGetIPAdapterModelsQuery,
|
||||||
loraModelsAdapter,
|
loraModelsAdapter,
|
||||||
useGetControlNetModelsQuery,
|
useGetControlNetModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
} from '../../../services/api/endpoints/models';
|
} from '../../../services/api/endpoints/models';
|
||||||
import {
|
import {
|
||||||
ControlNetConfig,
|
ControlNetConfig,
|
||||||
|
IPAdapterConfig,
|
||||||
controlNetEnabled,
|
controlNetEnabled,
|
||||||
controlNetRecalled,
|
controlNetRecalled,
|
||||||
controlNetReset,
|
controlNetReset,
|
||||||
initialControlNet,
|
initialControlNet,
|
||||||
|
initialIPAdapterState,
|
||||||
|
ipAdapterRecalled,
|
||||||
|
isIPAdapterEnabledChanged,
|
||||||
} from '../../controlNet/store/controlNetSlice';
|
} from '../../controlNet/store/controlNetSlice';
|
||||||
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
import { loraRecalled, lorasCleared } from '../../lora/store/loraSlice';
|
||||||
import { initialImageSelected, modelSelected } from '../store/actions';
|
import { initialImageSelected, modelSelected } from '../store/actions';
|
||||||
@ -52,6 +59,7 @@ import {
|
|||||||
isValidHeight,
|
isValidHeight,
|
||||||
isValidLoRAModel,
|
isValidLoRAModel,
|
||||||
isValidControlNetModel,
|
isValidControlNetModel,
|
||||||
|
isValidIPAdapterModel,
|
||||||
isValidMainModel,
|
isValidMainModel,
|
||||||
isValidNegativePrompt,
|
isValidNegativePrompt,
|
||||||
isValidPositivePrompt,
|
isValidPositivePrompt,
|
||||||
@ -512,8 +520,6 @@ export const useRecallParameters = () => {
|
|||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
dispatch(controlNetEnabled());
|
|
||||||
|
|
||||||
parameterSetToast();
|
parameterSetToast();
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
@ -524,6 +530,92 @@ export const useRecallParameters = () => {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Recall IP Adapter with toast
|
||||||
|
*/
|
||||||
|
|
||||||
|
const { ipAdapters } = useGetIPAdapterModelsQuery(undefined, {
|
||||||
|
selectFromResult: (result) => ({
|
||||||
|
ipAdapters: result.data
|
||||||
|
? ipAdapterModelsAdapter.getSelectors().selectAll(result.data)
|
||||||
|
: [],
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const prepareIPAdapterMetadataItem = useCallback(
|
||||||
|
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
||||||
|
if (!isValidIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) {
|
||||||
|
return { ipAdapter: null, error: 'Invalid IP Adapter model' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const {
|
||||||
|
image,
|
||||||
|
ip_adapter_model,
|
||||||
|
weight,
|
||||||
|
begin_step_percent,
|
||||||
|
end_step_percent,
|
||||||
|
} = ipAdapterMetadataItem;
|
||||||
|
|
||||||
|
const matchingIPAdapterModel = ipAdapters.find(
|
||||||
|
(c) =>
|
||||||
|
c.base_model === ip_adapter_model?.base_model &&
|
||||||
|
c.model_name === ip_adapter_model?.model_name
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!matchingIPAdapterModel) {
|
||||||
|
return { ipAdapter: null, error: 'IP Adapter model is not installed' };
|
||||||
|
}
|
||||||
|
|
||||||
|
const isCompatibleBaseModel =
|
||||||
|
matchingIPAdapterModel?.base_model === model?.base_model;
|
||||||
|
|
||||||
|
if (!isCompatibleBaseModel) {
|
||||||
|
return {
|
||||||
|
ipAdapter: null,
|
||||||
|
error: 'IP Adapter incompatible with currently-selected model',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const ipAdapter: IPAdapterConfig = {
|
||||||
|
adapterImage: image?.image_name ?? null,
|
||||||
|
model: matchingIPAdapterModel,
|
||||||
|
weight: weight ?? initialIPAdapterState.weight,
|
||||||
|
beginStepPct: begin_step_percent ?? initialIPAdapterState.beginStepPct,
|
||||||
|
endStepPct: end_step_percent ?? initialIPAdapterState.endStepPct,
|
||||||
|
};
|
||||||
|
|
||||||
|
return { ipAdapter, error: null };
|
||||||
|
},
|
||||||
|
[ipAdapters, model?.base_model]
|
||||||
|
);
|
||||||
|
|
||||||
|
const recallIPAdapter = useCallback(
|
||||||
|
(ipAdapterMetadataItem: IPAdapterMetadataItem) => {
|
||||||
|
const result = prepareIPAdapterMetadataItem(ipAdapterMetadataItem);
|
||||||
|
|
||||||
|
if (!result.ipAdapter) {
|
||||||
|
parameterNotSetToast(result.error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
ipAdapterRecalled({
|
||||||
|
...result.ipAdapter,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
dispatch(isIPAdapterEnabledChanged(true));
|
||||||
|
|
||||||
|
parameterSetToast();
|
||||||
|
},
|
||||||
|
[
|
||||||
|
prepareIPAdapterMetadataItem,
|
||||||
|
dispatch,
|
||||||
|
parameterSetToast,
|
||||||
|
parameterNotSetToast,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Sets image as initial image with toast
|
* Sets image as initial image with toast
|
||||||
*/
|
*/
|
||||||
@ -563,6 +655,7 @@ export const useRecallParameters = () => {
|
|||||||
refiner_start,
|
refiner_start,
|
||||||
loras,
|
loras,
|
||||||
controlnets,
|
controlnets,
|
||||||
|
ipAdapters,
|
||||||
} = metadata;
|
} = metadata;
|
||||||
|
|
||||||
if (isValidCfgScale(cfg_scale)) {
|
if (isValidCfgScale(cfg_scale)) {
|
||||||
@ -653,7 +746,9 @@ export const useRecallParameters = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
dispatch(controlNetReset());
|
dispatch(controlNetReset());
|
||||||
|
if (controlnets?.length) {
|
||||||
dispatch(controlNetEnabled());
|
dispatch(controlNetEnabled());
|
||||||
|
}
|
||||||
controlnets?.forEach((controlnet) => {
|
controlnets?.forEach((controlnet) => {
|
||||||
const result = prepareControlNetMetadataItem(controlnet);
|
const result = prepareControlNetMetadataItem(controlnet);
|
||||||
if (result.controlnet) {
|
if (result.controlnet) {
|
||||||
@ -661,6 +756,16 @@ export const useRecallParameters = () => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (ipAdapters?.length) {
|
||||||
|
dispatch(isIPAdapterEnabledChanged(true));
|
||||||
|
}
|
||||||
|
ipAdapters?.forEach((ipAdapter) => {
|
||||||
|
const result = prepareIPAdapterMetadataItem(ipAdapter);
|
||||||
|
if (result.ipAdapter) {
|
||||||
|
dispatch(ipAdapterRecalled(result.ipAdapter));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
allParameterSetToast();
|
allParameterSetToast();
|
||||||
},
|
},
|
||||||
[
|
[
|
||||||
@ -669,6 +774,7 @@ export const useRecallParameters = () => {
|
|||||||
dispatch,
|
dispatch,
|
||||||
prepareLoRAMetadataItem,
|
prepareLoRAMetadataItem,
|
||||||
prepareControlNetMetadataItem,
|
prepareControlNetMetadataItem,
|
||||||
|
prepareIPAdapterMetadataItem,
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -688,6 +794,7 @@ export const useRecallParameters = () => {
|
|||||||
recallStrength,
|
recallStrength,
|
||||||
recallLoRA,
|
recallLoRA,
|
||||||
recallControlNet,
|
recallControlNet,
|
||||||
|
recallIPAdapter,
|
||||||
recallAllParameters,
|
recallAllParameters,
|
||||||
sendToImageToImage,
|
sendToImageToImage,
|
||||||
};
|
};
|
||||||
|
@ -343,6 +343,12 @@ export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
|
|||||||
/**
|
/**
|
||||||
* Zod schema for l2l strength parameter
|
* Zod schema for l2l strength parameter
|
||||||
*/
|
*/
|
||||||
|
/**
|
||||||
|
* Validates/type-guards a value as a model parameter
|
||||||
|
*/
|
||||||
|
export const isValidIPAdapterModel = (
|
||||||
|
val: unknown
|
||||||
|
): val is IPAdapterModelParam => zIPAdapterModel.safeParse(val).success;
|
||||||
export const zStrength = z.number().min(0).max(1);
|
export const zStrength = z.number().min(0).max(1);
|
||||||
/**
|
/**
|
||||||
* Type alias for l2l strength parameter, inferred from its zod schema
|
* Type alias for l2l strength parameter, inferred from its zod schema
|
||||||
|
@ -2115,6 +2115,11 @@ export type components = {
|
|||||||
* @description The ControlNets used for inference
|
* @description The ControlNets used for inference
|
||||||
*/
|
*/
|
||||||
controlnets: components["schemas"]["ControlField"][];
|
controlnets: components["schemas"]["ControlField"][];
|
||||||
|
/**
|
||||||
|
* Loras
|
||||||
|
* @description The LoRAs used for inference
|
||||||
|
*/
|
||||||
|
ipAdapters: components["schemas"]["IPAdapterField"][];
|
||||||
/**
|
/**
|
||||||
* Loras
|
* Loras
|
||||||
* @description The LoRAs used for inference
|
* @description The LoRAs used for inference
|
||||||
@ -3178,7 +3183,7 @@ export type components = {
|
|||||||
* Image Encoder Model
|
* Image Encoder Model
|
||||||
* @description The name of the CLIP image encoder model.
|
* @description The name of the CLIP image encoder model.
|
||||||
*/
|
*/
|
||||||
image_encoder_model: components["schemas"]["CLIPVisionModelField"];
|
image_encoder_model?: components["schemas"]["CLIPVisionModelField"];
|
||||||
/**
|
/**
|
||||||
* Weight
|
* Weight
|
||||||
* @description The weight given to the ControlNet
|
* @description The weight given to the ControlNet
|
||||||
@ -5814,6 +5819,11 @@ export type components = {
|
|||||||
* @description The LoRAs used for inference
|
* @description The LoRAs used for inference
|
||||||
*/
|
*/
|
||||||
loras?: components["schemas"]["LoRAMetadataField"][];
|
loras?: components["schemas"]["LoRAMetadataField"][];
|
||||||
|
/**
|
||||||
|
* Strength
|
||||||
|
* @description The strength used for latents-to-latents
|
||||||
|
*/
|
||||||
|
ipAdapters?: components["schemas"]["IPAdapterField"][];
|
||||||
/**
|
/**
|
||||||
* Strength
|
* Strength
|
||||||
* @description The strength used for latents-to-latents
|
* @description The strength used for latents-to-latents
|
||||||
|
Loading…
Reference in New Issue
Block a user