mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add a IPAdapterModelField for passing passing IP-Adapter models between nodes.
This commit is contained in:
parent
468253aa14
commit
a2777decd4
@ -154,6 +154,7 @@ class UIType(str, Enum):
|
|||||||
VaeModel = "VaeModelField"
|
VaeModel = "VaeModelField"
|
||||||
LoRAModel = "LoRAModelField"
|
LoRAModel = "LoRAModelField"
|
||||||
ControlNetModel = "ControlNetModelField"
|
ControlNetModel = "ControlNetModelField"
|
||||||
|
IPAdapterModel = "IPAdapterModelField"
|
||||||
UNet = "UNetField"
|
UNet = "UNetField"
|
||||||
Vae = "VaeField"
|
Vae = "VaeField"
|
||||||
CLIP = "ClipField"
|
CLIP = "ClipField"
|
||||||
|
@ -6,6 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
@ -14,24 +15,22 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
|
from invokeai.backend.model_management.models.base import BaseModelType
|
||||||
IP_ADAPTER_MODELS = Literal[
|
|
||||||
"ip-adapter_sd15",
|
|
||||||
"ip-adapter-plus_sd15",
|
|
||||||
"ip-adapter-plus-face_sd15",
|
|
||||||
"ip-adapter_sdxl",
|
|
||||||
]
|
|
||||||
|
|
||||||
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
|
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
|
||||||
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
|
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterModelField(BaseModel):
|
||||||
|
model_name: str = Field(description="Name of the IP-Adapter model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
|
|
||||||
# TODO(ryand): Create and use a custom `IpAdapterModelField`.
|
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||||
ip_adapter_model: str = Field(description="The name of the IP-Adapter model.")
|
|
||||||
|
|
||||||
# TODO(ryand): Create and use a `CLIPImageEncoderField` instead that is analogous to the `ClipField` used elsewhere.
|
# TODO(ryand): Create and use a `CLIPImageEncoderField` instead that is analogous to the `ClipField` used elsewhere.
|
||||||
image_encoder_model: str = Field(description="The name of the CLIP image encoder model.")
|
image_encoder_model: str = Field(description="The name of the CLIP image encoder model.")
|
||||||
@ -51,10 +50,10 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||||
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
|
ip_adapter_model: IPAdapterModelField = InputField(
|
||||||
default="ip-adapter_sd15.bin",
|
description="The IP-Adapter model.",
|
||||||
description="The name of the IP-Adapter model.",
|
|
||||||
title="IP-Adapter Model",
|
title="IP-Adapter Model",
|
||||||
|
input=Input.Direct,
|
||||||
)
|
)
|
||||||
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
|
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
|
||||||
default="models/core/ip_adapters/sd-1/image_encoder/", description="The name of the CLIP image encoder model."
|
default="models/core/ip_adapters/sd-1/image_encoder/", description="The name of the CLIP image encoder model."
|
||||||
|
@ -412,9 +412,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
ip_adapter_model = exit_stack.enter_context(
|
ip_adapter_model = exit_stack.enter_context(
|
||||||
context.services.model_manager.get_model(
|
context.services.model_manager.get_model(
|
||||||
model_name=ip_adapter.ip_adapter_model,
|
model_name=ip_adapter.ip_adapter_model.model_name,
|
||||||
model_type=ModelType.IPAdapter,
|
model_type=ModelType.IPAdapter,
|
||||||
base_model=BaseModelType.StableDiffusion1, # HACK(ryand): Pass this in properly
|
base_model=ip_adapter.ip_adapter_model.base_model,
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -15,6 +15,7 @@ import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
|||||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||||
import StringInputField from './inputs/StringInputField';
|
import StringInputField from './inputs/StringInputField';
|
||||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||||
|
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
|
||||||
|
|
||||||
type InputFieldProps = {
|
type InputFieldProps = {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -147,6 +148,19 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
field?.type === 'IPAdapterModelField' &&
|
||||||
|
fieldTemplate?.type === 'IPAdapterModelField'
|
||||||
|
) {
|
||||||
|
return (
|
||||||
|
<IPAdapterModelInputField
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
fieldTemplate={fieldTemplate}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||||
return (
|
return (
|
||||||
<ColorInputField
|
<ColorInputField
|
||||||
|
@ -0,0 +1,100 @@
|
|||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
IPAdapterModelInputFieldTemplate,
|
||||||
|
IPAdapterModelInputFieldValue,
|
||||||
|
FieldComponentProps,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
|
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
|
const IPAdapterModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
IPAdapterModelInputFieldValue,
|
||||||
|
IPAdapterModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
const ipAdapterModel = field.value;
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
|
||||||
|
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
||||||
|
|
||||||
|
// grab the full model entity from the RTK Query cache
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() =>
|
||||||
|
ipAdapterModels?.entities[
|
||||||
|
`${ipAdapterModel?.base_model}/ip_adapter/${ipAdapterModel?.model_name}`
|
||||||
|
] ?? null,
|
||||||
|
[
|
||||||
|
ipAdapterModel?.base_model,
|
||||||
|
ipAdapterModel?.model_name,
|
||||||
|
ipAdapterModels?.entities,
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!ipAdapterModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(ipAdapterModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.model_name,
|
||||||
|
group: MODEL_TYPE_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [ipAdapterModels]);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const newIPAdapterModel = modelIdToIPAdapterModelParam(v);
|
||||||
|
|
||||||
|
if (!newIPAdapterModel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldIPAdapterModelValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: newIPAdapterModel,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
className="nowheel nodrag"
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
value={selectedModel?.id ?? null}
|
||||||
|
placeholder="Pick one"
|
||||||
|
error={!selectedModel}
|
||||||
|
data={data}
|
||||||
|
onChange={handleValueChanged}
|
||||||
|
sx={{ width: '100%' }}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IPAdapterModelInputFieldComponent);
|
@ -41,6 +41,7 @@ import {
|
|||||||
IntegerInputFieldValue,
|
IntegerInputFieldValue,
|
||||||
InvocationNodeData,
|
InvocationNodeData,
|
||||||
InvocationTemplate,
|
InvocationTemplate,
|
||||||
|
IPAdapterModelInputFieldValue,
|
||||||
isInvocationNode,
|
isInvocationNode,
|
||||||
isNotesNode,
|
isNotesNode,
|
||||||
LoRAModelInputFieldValue,
|
LoRAModelInputFieldValue,
|
||||||
@ -520,6 +521,12 @@ const nodesSlice = createSlice({
|
|||||||
) => {
|
) => {
|
||||||
fieldValueReducer(state, action);
|
fieldValueReducer(state, action);
|
||||||
},
|
},
|
||||||
|
fieldIPAdapterModelValueChanged: (
|
||||||
|
state,
|
||||||
|
action: FieldValueAction<IPAdapterModelInputFieldValue>
|
||||||
|
) => {
|
||||||
|
fieldValueReducer(state, action);
|
||||||
|
},
|
||||||
fieldEnumModelValueChanged: (
|
fieldEnumModelValueChanged: (
|
||||||
state,
|
state,
|
||||||
action: FieldValueAction<EnumInputFieldValue>
|
action: FieldValueAction<EnumInputFieldValue>
|
||||||
@ -866,6 +873,7 @@ export const {
|
|||||||
fieldLoRAModelValueChanged,
|
fieldLoRAModelValueChanged,
|
||||||
fieldEnumModelValueChanged,
|
fieldEnumModelValueChanged,
|
||||||
fieldControlNetModelValueChanged,
|
fieldControlNetModelValueChanged,
|
||||||
|
fieldIPAdapterModelValueChanged,
|
||||||
fieldRefinerModelValueChanged,
|
fieldRefinerModelValueChanged,
|
||||||
fieldSchedulerValueChanged,
|
fieldSchedulerValueChanged,
|
||||||
nodeIsOpenChanged,
|
nodeIsOpenChanged,
|
||||||
|
@ -40,6 +40,7 @@ export const POLYMORPHIC_TYPES = [
|
|||||||
];
|
];
|
||||||
|
|
||||||
export const MODEL_TYPES = [
|
export const MODEL_TYPES = [
|
||||||
|
'IPAdapterModelField',
|
||||||
'ControlNetModelField',
|
'ControlNetModelField',
|
||||||
'LoRAModelField',
|
'LoRAModelField',
|
||||||
'MainModelField',
|
'MainModelField',
|
||||||
@ -240,6 +241,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
description: 'IP-Adapter info passed between nodes.',
|
description: 'IP-Adapter info passed between nodes.',
|
||||||
title: 'IP-Adapter',
|
title: 'IP-Adapter',
|
||||||
},
|
},
|
||||||
|
IPAdapterModelField: {
|
||||||
|
color: 'teal.500', // TODO(ryand): Pick a color
|
||||||
|
description: 'IP-Adapter model',
|
||||||
|
title: 'IP-Adapter Model',
|
||||||
|
},
|
||||||
LatentsCollection: {
|
LatentsCollection: {
|
||||||
color: 'pink.500',
|
color: 'pink.500',
|
||||||
description: 'Latents may be passed between nodes.',
|
description: 'Latents may be passed between nodes.',
|
||||||
|
@ -94,6 +94,7 @@ export const zFieldType = z.enum([
|
|||||||
'IntegerCollection',
|
'IntegerCollection',
|
||||||
'IntegerPolymorphic',
|
'IntegerPolymorphic',
|
||||||
'IPAdapterField',
|
'IPAdapterField',
|
||||||
|
'IPAdapterModelField',
|
||||||
'LatentsCollection',
|
'LatentsCollection',
|
||||||
'LatentsField',
|
'LatentsField',
|
||||||
'LatentsPolymorphic',
|
'LatentsPolymorphic',
|
||||||
@ -389,9 +390,12 @@ export type ControlCollectionInputFieldValue = z.infer<
|
|||||||
typeof zControlCollectionInputFieldValue
|
typeof zControlCollectionInputFieldValue
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export const zIPAdapterModel = zModelIdentifier;
|
||||||
|
export type IPAdapterModel = z.infer<typeof zIPAdapterModel>;
|
||||||
|
|
||||||
export const zIPAdapterField = z.object({
|
export const zIPAdapterField = z.object({
|
||||||
image: zImageField,
|
image: zImageField,
|
||||||
ip_adapter_model: z.string().trim().min(1),
|
ip_adapter_model: zIPAdapterModel,
|
||||||
image_encoder_model: z.string().trim().min(1),
|
image_encoder_model: z.string().trim().min(1),
|
||||||
weight: z.number(),
|
weight: z.number(),
|
||||||
});
|
});
|
||||||
@ -554,6 +558,17 @@ export type ControlNetModelInputFieldValue = z.infer<
|
|||||||
typeof zControlNetModelInputFieldValue
|
typeof zControlNetModelInputFieldValue
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export const zIPAdapterModelField = zModelIdentifier;
|
||||||
|
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
|
||||||
|
|
||||||
|
export const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('IPAdapterModelField'),
|
||||||
|
value: zIPAdapterModelField.optional(),
|
||||||
|
});
|
||||||
|
export type IPAdapterModelInputFieldValue = z.infer<
|
||||||
|
typeof zIPAdapterModelInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
|
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||||
type: z.literal('Collection'),
|
type: z.literal('Collection'),
|
||||||
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
|
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
|
||||||
@ -637,6 +652,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
|||||||
zIntegerPolymorphicInputFieldValue,
|
zIntegerPolymorphicInputFieldValue,
|
||||||
zIntegerInputFieldValue,
|
zIntegerInputFieldValue,
|
||||||
zIPAdapterInputFieldValue,
|
zIPAdapterInputFieldValue,
|
||||||
|
zIPAdapterModelInputFieldValue,
|
||||||
zLatentsInputFieldValue,
|
zLatentsInputFieldValue,
|
||||||
zLatentsCollectionInputFieldValue,
|
zLatentsCollectionInputFieldValue,
|
||||||
zLatentsPolymorphicInputFieldValue,
|
zLatentsPolymorphicInputFieldValue,
|
||||||
@ -881,6 +897,11 @@ export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'ControlNetModelField';
|
type: 'ControlNetModelField';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'IPAdapterModelField';
|
||||||
|
};
|
||||||
|
|
||||||
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
|
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: [];
|
default: [];
|
||||||
type: 'Collection';
|
type: 'Collection';
|
||||||
@ -953,6 +974,7 @@ export type InputFieldTemplate =
|
|||||||
| IntegerPolymorphicInputFieldTemplate
|
| IntegerPolymorphicInputFieldTemplate
|
||||||
| IntegerInputFieldTemplate
|
| IntegerInputFieldTemplate
|
||||||
| IPAdapterInputFieldTemplate
|
| IPAdapterInputFieldTemplate
|
||||||
|
| IPAdapterModelInputFieldTemplate
|
||||||
| LatentsInputFieldTemplate
|
| LatentsInputFieldTemplate
|
||||||
| LatentsCollectionInputFieldTemplate
|
| LatentsCollectionInputFieldTemplate
|
||||||
| LatentsPolymorphicInputFieldTemplate
|
| LatentsPolymorphicInputFieldTemplate
|
||||||
|
@ -61,6 +61,7 @@ import {
|
|||||||
LatentsField,
|
LatentsField,
|
||||||
ConditioningField,
|
ConditioningField,
|
||||||
IPAdapterInputFieldTemplate,
|
IPAdapterInputFieldTemplate,
|
||||||
|
IPAdapterModelInputFieldTemplate,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
import { ControlField } from 'services/api/types';
|
import { ControlField } from 'services/api/types';
|
||||||
|
|
||||||
@ -436,6 +437,19 @@ const buildControlNetModelInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildIPAdapterModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): IPAdapterModelInputFieldTemplate => {
|
||||||
|
const template: IPAdapterModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'IPAdapterModelField',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildImageInputFieldTemplate = ({
|
const buildImageInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -866,6 +880,7 @@ const TEMPLATE_BUILDER_MAP = {
|
|||||||
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||||
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||||
IPAdapterField: buildIPAdapterInputFieldTemplate,
|
IPAdapterField: buildIPAdapterInputFieldTemplate,
|
||||||
|
IPAdapterModelField: buildIPAdapterModelInputFieldTemplate,
|
||||||
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||||
LatentsField: buildLatentsInputFieldTemplate,
|
LatentsField: buildLatentsInputFieldTemplate,
|
||||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||||
|
@ -30,6 +30,7 @@ const FIELD_VALUE_FALLBACK_MAP = {
|
|||||||
IntegerCollection: [],
|
IntegerCollection: [],
|
||||||
IntegerPolymorphic: 0,
|
IntegerPolymorphic: 0,
|
||||||
IPAdapterField: undefined,
|
IPAdapterField: undefined,
|
||||||
|
IPAdapterModelField: undefined,
|
||||||
LatentsCollection: [],
|
LatentsCollection: [],
|
||||||
LatentsField: undefined,
|
LatentsField: undefined,
|
||||||
LatentsPolymorphic: undefined,
|
LatentsPolymorphic: undefined,
|
||||||
|
@ -323,7 +323,17 @@ export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
|
|||||||
export const isValidControlNetModel = (
|
export const isValidControlNetModel = (
|
||||||
val: unknown
|
val: unknown
|
||||||
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
|
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
|
||||||
|
/**
|
||||||
|
* Zod schema for IP-Adapter models
|
||||||
|
*/
|
||||||
|
export const zIPAdapterModel = z.object({
|
||||||
|
model_name: z.string().min(1),
|
||||||
|
base_model: zBaseModel,
|
||||||
|
});
|
||||||
|
/**
|
||||||
|
* Type alias for model parameter, inferred from its zod schema
|
||||||
|
*/
|
||||||
|
export type zIPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
|
||||||
/**
|
/**
|
||||||
* Zod schema for l2l strength parameter
|
* Zod schema for l2l strength parameter
|
||||||
*/
|
*/
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
import { logger } from 'app/logging/logger';
|
||||||
|
import { zIPAdapterModel } from 'features/parameters/types/parameterSchemas';
|
||||||
|
import { IPAdapterModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
export const modelIdToIPAdapterModelParam = (
|
||||||
|
ipAdapterModelId: string
|
||||||
|
): IPAdapterModelField | undefined => {
|
||||||
|
const log = logger('models');
|
||||||
|
const [base_model, _model_type, model_name] = ipAdapterModelId.split('/');
|
||||||
|
|
||||||
|
const result = zIPAdapterModel.safeParse({
|
||||||
|
base_model,
|
||||||
|
model_name,
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!result.success) {
|
||||||
|
log.error(
|
||||||
|
{
|
||||||
|
ipAdapterModelId,
|
||||||
|
errors: result.error.format(),
|
||||||
|
},
|
||||||
|
'Failed to parse IP-Adapter model id'
|
||||||
|
);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.data;
|
||||||
|
};
|
@ -5,6 +5,7 @@ import {
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
CheckpointModelConfig,
|
CheckpointModelConfig,
|
||||||
ControlNetModelConfig,
|
ControlNetModelConfig,
|
||||||
|
IPAdapterModelConfig,
|
||||||
DiffusersModelConfig,
|
DiffusersModelConfig,
|
||||||
ImportModelConfig,
|
ImportModelConfig,
|
||||||
LoRAModelConfig,
|
LoRAModelConfig,
|
||||||
@ -36,6 +37,10 @@ export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
|||||||
id: string;
|
id: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
|
||||||
|
id: string;
|
||||||
|
};
|
||||||
|
|
||||||
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
||||||
id: string;
|
id: string;
|
||||||
};
|
};
|
||||||
@ -47,6 +52,7 @@ type AnyModelConfigEntity =
|
|||||||
| OnnxModelConfigEntity
|
| OnnxModelConfigEntity
|
||||||
| LoRAModelConfigEntity
|
| LoRAModelConfigEntity
|
||||||
| ControlNetModelConfigEntity
|
| ControlNetModelConfigEntity
|
||||||
|
| IPAdapterModelConfigEntity
|
||||||
| TextualInversionModelConfigEntity
|
| TextualInversionModelConfigEntity
|
||||||
| VaeModelConfigEntity;
|
| VaeModelConfigEntity;
|
||||||
|
|
||||||
@ -135,6 +141,10 @@ export const controlNetModelsAdapter =
|
|||||||
createEntityAdapter<ControlNetModelConfigEntity>({
|
createEntityAdapter<ControlNetModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
});
|
});
|
||||||
|
export const ipAdapterModelsAdapter =
|
||||||
|
createEntityAdapter<IPAdapterModelConfigEntity>({
|
||||||
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
|
});
|
||||||
export const textualInversionModelsAdapter =
|
export const textualInversionModelsAdapter =
|
||||||
createEntityAdapter<TextualInversionModelConfigEntity>({
|
createEntityAdapter<TextualInversionModelConfigEntity>({
|
||||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||||
@ -435,6 +445,37 @@ export const modelsApi = api.injectEndpoints({
|
|||||||
);
|
);
|
||||||
},
|
},
|
||||||
}),
|
}),
|
||||||
|
getIPAdapterModels: build.query<
|
||||||
|
EntityState<IPAdapterModelConfigEntity>,
|
||||||
|
void
|
||||||
|
>({
|
||||||
|
query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }),
|
||||||
|
providesTags: (result) => {
|
||||||
|
const tags: ApiFullTagDescription[] = [
|
||||||
|
{ type: 'IPAdapterModel', id: LIST_TAG },
|
||||||
|
];
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
tags.push(
|
||||||
|
...result.ids.map((id) => ({
|
||||||
|
type: 'IPAdapterModel' as const,
|
||||||
|
id,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return tags;
|
||||||
|
},
|
||||||
|
transformResponse: (response: { models: IPAdapterModelConfig[] }) => {
|
||||||
|
const entities = createModelEntities<IPAdapterModelConfigEntity>(
|
||||||
|
response.models
|
||||||
|
);
|
||||||
|
return ipAdapterModelsAdapter.setAll(
|
||||||
|
ipAdapterModelsAdapter.getInitialState(),
|
||||||
|
entities
|
||||||
|
);
|
||||||
|
},
|
||||||
|
}),
|
||||||
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
||||||
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
||||||
providesTags: (result) => {
|
providesTags: (result) => {
|
||||||
@ -533,6 +574,7 @@ export const {
|
|||||||
useGetMainModelsQuery,
|
useGetMainModelsQuery,
|
||||||
useGetOnnxModelsQuery,
|
useGetOnnxModelsQuery,
|
||||||
useGetControlNetModelsQuery,
|
useGetControlNetModelsQuery,
|
||||||
|
useGetIPAdapterModelsQuery,
|
||||||
useGetLoRAModelsQuery,
|
useGetLoRAModelsQuery,
|
||||||
useGetTextualInversionModelsQuery,
|
useGetTextualInversionModelsQuery,
|
||||||
useGetVaeModelsQuery,
|
useGetVaeModelsQuery,
|
||||||
|
@ -2429,9 +2429,9 @@ export type components = {
|
|||||||
image: components["schemas"]["ImageField"];
|
image: components["schemas"]["ImageField"];
|
||||||
/**
|
/**
|
||||||
* Ip Adapter Model
|
* Ip Adapter Model
|
||||||
* @description The name of the IP-Adapter model.
|
* @description The IP-Adapter model to use.
|
||||||
*/
|
*/
|
||||||
ip_adapter_model: string;
|
ip_adapter_model: components["schemas"]["IPAdapterModelField"];
|
||||||
/**
|
/**
|
||||||
* Image Encoder Model
|
* Image Encoder Model
|
||||||
* @description The name of the CLIP image encoder model.
|
* @description The name of the CLIP image encoder model.
|
||||||
@ -2472,11 +2472,9 @@ export type components = {
|
|||||||
image?: components["schemas"]["ImageField"];
|
image?: components["schemas"]["ImageField"];
|
||||||
/**
|
/**
|
||||||
* IP-Adapter Model
|
* IP-Adapter Model
|
||||||
* @description The name of the IP-Adapter model.
|
* @description The IP-Adapter model.
|
||||||
* @default ip-adapter_sd15.bin
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
*/
|
||||||
ip_adapter_model?: "ip-adapter_sd15" | "ip-adapter-plus_sd15" | "ip-adapter-plus-face_sd15" | "ip-adapter_sdxl";
|
ip_adapter_model: components["schemas"]["IPAdapterModelField"];
|
||||||
/**
|
/**
|
||||||
* Image Encoder Model
|
* Image Encoder Model
|
||||||
* @description The name of the CLIP image encoder model.
|
* @description The name of the CLIP image encoder model.
|
||||||
@ -2518,6 +2516,16 @@ export type components = {
|
|||||||
model_format: "checkpoint";
|
model_format: "checkpoint";
|
||||||
error?: components["schemas"]["ModelError"];
|
error?: components["schemas"]["ModelError"];
|
||||||
};
|
};
|
||||||
|
/** IPAdapterModelField */
|
||||||
|
IPAdapterModelField: {
|
||||||
|
/**
|
||||||
|
* Model Name
|
||||||
|
* @description Name of the IP-Adapter model
|
||||||
|
*/
|
||||||
|
model_name: string;
|
||||||
|
/** @description Base model */
|
||||||
|
base_model: components["schemas"]["BaseModelType"];
|
||||||
|
};
|
||||||
/**
|
/**
|
||||||
* IPAdapterOutput
|
* IPAdapterOutput
|
||||||
* @description Base class for all invocation outputs.
|
* @description Base class for all invocation outputs.
|
||||||
@ -7188,7 +7196,7 @@ export type components = {
|
|||||||
* If a field should be provided a data type that does not exactly match the python type of the field, use this to provide the type that should be used instead. See the node development docs for detail on adding a new field type, which involves client-side changes.
|
* If a field should be provided a data type that does not exactly match the python type of the field, use this to provide the type that should be used instead. See the node development docs for detail on adding a new field type, which involves client-side changes.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
UIType: "boolean" | "ColorField" | "ConditioningField" | "ControlField" | "float" | "ImageField" | "integer" | "LatentsField" | "string" | "BooleanCollection" | "ColorCollection" | "ConditioningCollection" | "ControlCollection" | "FloatCollection" | "ImageCollection" | "IntegerCollection" | "LatentsCollection" | "StringCollection" | "BooleanPolymorphic" | "ColorPolymorphic" | "ConditioningPolymorphic" | "ControlPolymorphic" | "FloatPolymorphic" | "ImagePolymorphic" | "IntegerPolymorphic" | "LatentsPolymorphic" | "StringPolymorphic" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "UNetField" | "VaeField" | "ClipField" | "Collection" | "CollectionItem" | "enum" | "Scheduler" | "WorkflowField" | "IsIntermediate" | "MetadataField";
|
UIType: "boolean" | "ColorField" | "ConditioningField" | "ControlField" | "float" | "ImageField" | "integer" | "LatentsField" | "string" | "BooleanCollection" | "ColorCollection" | "ConditioningCollection" | "ControlCollection" | "FloatCollection" | "ImageCollection" | "IntegerCollection" | "LatentsCollection" | "StringCollection" | "BooleanPolymorphic" | "ColorPolymorphic" | "ConditioningPolymorphic" | "ControlPolymorphic" | "FloatPolymorphic" | "ImagePolymorphic" | "IntegerPolymorphic" | "LatentsPolymorphic" | "StringPolymorphic" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "UNetField" | "VaeField" | "ClipField" | "Collection" | "CollectionItem" | "enum" | "Scheduler" | "WorkflowField" | "IsIntermediate" | "MetadataField";
|
||||||
/**
|
/**
|
||||||
* UIComponent
|
* UIComponent
|
||||||
* @description The type of UI component to use for a field, used to override the default components, which are inferred from the field type.
|
* @description The type of UI component to use for a field, used to override the default components, which are inferred from the field type.
|
||||||
@ -7227,12 +7235,6 @@ export type components = {
|
|||||||
/** Ui Order */
|
/** Ui Order */
|
||||||
ui_order?: number;
|
ui_order?: number;
|
||||||
};
|
};
|
||||||
/**
|
|
||||||
* StableDiffusion2ModelFormat
|
|
||||||
* @description An enumeration.
|
|
||||||
* @enum {string}
|
|
||||||
*/
|
|
||||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
|
||||||
/**
|
/**
|
||||||
* ControlNetModelFormat
|
* ControlNetModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -7246,11 +7248,17 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
IPAdapterModelFormat: "checkpoint";
|
IPAdapterModelFormat: "checkpoint";
|
||||||
/**
|
/**
|
||||||
* StableDiffusionXLModelFormat
|
* StableDiffusion2ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||||
|
/**
|
||||||
|
* StableDiffusionOnnxModelFormat
|
||||||
|
* @description An enumeration.
|
||||||
|
* @enum {string}
|
||||||
|
*/
|
||||||
|
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
||||||
/**
|
/**
|
||||||
* StableDiffusion1ModelFormat
|
* StableDiffusion1ModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
@ -7258,11 +7266,11 @@ export type components = {
|
|||||||
*/
|
*/
|
||||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||||
/**
|
/**
|
||||||
* StableDiffusionOnnxModelFormat
|
* StableDiffusionXLModelFormat
|
||||||
* @description An enumeration.
|
* @description An enumeration.
|
||||||
* @enum {string}
|
* @enum {string}
|
||||||
*/
|
*/
|
||||||
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||||
};
|
};
|
||||||
responses: never;
|
responses: never;
|
||||||
parameters: never;
|
parameters: never;
|
||||||
|
@ -60,6 +60,7 @@ export type OnnxModelField = s['OnnxModelField'];
|
|||||||
export type VAEModelField = s['VAEModelField'];
|
export type VAEModelField = s['VAEModelField'];
|
||||||
export type LoRAModelField = s['LoRAModelField'];
|
export type LoRAModelField = s['LoRAModelField'];
|
||||||
export type ControlNetModelField = s['ControlNetModelField'];
|
export type ControlNetModelField = s['ControlNetModelField'];
|
||||||
|
export type IPAdapterModelField = s['IPAdapterModelField'];
|
||||||
export type ModelsList = s['ModelsList'];
|
export type ModelsList = s['ModelsList'];
|
||||||
export type ControlField = s['ControlField'];
|
export type ControlField = s['ControlField'];
|
||||||
|
|
||||||
@ -73,6 +74,9 @@ export type ControlNetModelDiffusersConfig =
|
|||||||
export type ControlNetModelConfig =
|
export type ControlNetModelConfig =
|
||||||
| ControlNetModelCheckpointConfig
|
| ControlNetModelCheckpointConfig
|
||||||
| ControlNetModelDiffusersConfig;
|
| ControlNetModelDiffusersConfig;
|
||||||
|
export type IPAdapterModelCheckpointConfig =
|
||||||
|
s['IPAdapterModelCheckpointConfig'];
|
||||||
|
export type IPAdapterModelConfig = IPAdapterModelCheckpointConfig;
|
||||||
export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
|
export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
|
||||||
export type DiffusersModelConfig =
|
export type DiffusersModelConfig =
|
||||||
| s['StableDiffusion1ModelDiffusersConfig']
|
| s['StableDiffusion1ModelDiffusersConfig']
|
||||||
@ -88,6 +92,7 @@ export type AnyModelConfig =
|
|||||||
| LoRAModelConfig
|
| LoRAModelConfig
|
||||||
| VaeModelConfig
|
| VaeModelConfig
|
||||||
| ControlNetModelConfig
|
| ControlNetModelConfig
|
||||||
|
| IPAdapterModelConfig
|
||||||
| TextualInversionModelConfig
|
| TextualInversionModelConfig
|
||||||
| MainModelConfig
|
| MainModelConfig
|
||||||
| OnnxModelConfig;
|
| OnnxModelConfig;
|
||||||
|
Loading…
Reference in New Issue
Block a user