mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add a IPAdapterModelField for passing passing IP-Adapter models between nodes.
This commit is contained in:
@ -15,6 +15,7 @@ import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||
import StringInputField from './inputs/StringInputField';
|
||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
|
||||
|
||||
type InputFieldProps = {
|
||||
nodeId: string;
|
||||
@ -147,6 +148,19 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'IPAdapterModelField' &&
|
||||
fieldTemplate?.type === 'IPAdapterModelField'
|
||||
) {
|
||||
return (
|
||||
<IPAdapterModelInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||
return (
|
||||
<ColorInputField
|
||||
|
@ -0,0 +1,100 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
IPAdapterModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const IPAdapterModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
IPAdapterModelInputFieldValue,
|
||||
IPAdapterModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const ipAdapterModel = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
ipAdapterModels?.entities[
|
||||
`${ipAdapterModel?.base_model}/ip_adapter/${ipAdapterModel?.model_name}`
|
||||
] ?? null,
|
||||
[
|
||||
ipAdapterModel?.base_model,
|
||||
ipAdapterModel?.model_name,
|
||||
ipAdapterModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!ipAdapterModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(ipAdapterModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [ipAdapterModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newIPAdapterModel = modelIdToIPAdapterModelParam(v);
|
||||
|
||||
if (!newIPAdapterModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldIPAdapterModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: newIPAdapterModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
className="nowheel nodrag"
|
||||
tooltip={selectedModel?.description}
|
||||
value={selectedModel?.id ?? null}
|
||||
placeholder="Pick one"
|
||||
error={!selectedModel}
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
sx={{ width: '100%' }}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IPAdapterModelInputFieldComponent);
|
@ -41,6 +41,7 @@ import {
|
||||
IntegerInputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
IPAdapterModelInputFieldValue,
|
||||
isInvocationNode,
|
||||
isNotesNode,
|
||||
LoRAModelInputFieldValue,
|
||||
@ -520,6 +521,12 @@ const nodesSlice = createSlice({
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldIPAdapterModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<IPAdapterModelInputFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldEnumModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<EnumInputFieldValue>
|
||||
@ -866,6 +873,7 @@ export const {
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldControlNetModelValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
nodeIsOpenChanged,
|
||||
|
@ -40,6 +40,7 @@ export const POLYMORPHIC_TYPES = [
|
||||
];
|
||||
|
||||
export const MODEL_TYPES = [
|
||||
'IPAdapterModelField',
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
@ -240,6 +241,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
description: 'IP-Adapter info passed between nodes.',
|
||||
title: 'IP-Adapter',
|
||||
},
|
||||
IPAdapterModelField: {
|
||||
color: 'teal.500', // TODO(ryand): Pick a color
|
||||
description: 'IP-Adapter model',
|
||||
title: 'IP-Adapter Model',
|
||||
},
|
||||
LatentsCollection: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
|
@ -94,6 +94,7 @@ export const zFieldType = z.enum([
|
||||
'IntegerCollection',
|
||||
'IntegerPolymorphic',
|
||||
'IPAdapterField',
|
||||
'IPAdapterModelField',
|
||||
'LatentsCollection',
|
||||
'LatentsField',
|
||||
'LatentsPolymorphic',
|
||||
@ -389,9 +390,12 @@ export type ControlCollectionInputFieldValue = z.infer<
|
||||
typeof zControlCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zIPAdapterModel = zModelIdentifier;
|
||||
export type IPAdapterModel = z.infer<typeof zIPAdapterModel>;
|
||||
|
||||
export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: z.string().trim().min(1),
|
||||
ip_adapter_model: zIPAdapterModel,
|
||||
image_encoder_model: z.string().trim().min(1),
|
||||
weight: z.number(),
|
||||
});
|
||||
@ -554,6 +558,17 @@ export type ControlNetModelInputFieldValue = z.infer<
|
||||
typeof zControlNetModelInputFieldValue
|
||||
>;
|
||||
|
||||
export const zIPAdapterModelField = zModelIdentifier;
|
||||
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
|
||||
|
||||
export const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IPAdapterModelField'),
|
||||
value: zIPAdapterModelField.optional(),
|
||||
});
|
||||
export type IPAdapterModelInputFieldValue = z.infer<
|
||||
typeof zIPAdapterModelInputFieldValue
|
||||
>;
|
||||
|
||||
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('Collection'),
|
||||
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
|
||||
@ -637,6 +652,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zIntegerPolymorphicInputFieldValue,
|
||||
zIntegerInputFieldValue,
|
||||
zIPAdapterInputFieldValue,
|
||||
zIPAdapterModelInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zLatentsCollectionInputFieldValue,
|
||||
zLatentsPolymorphicInputFieldValue,
|
||||
@ -881,6 +897,11 @@ export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ControlNetModelField';
|
||||
};
|
||||
|
||||
export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'IPAdapterModelField';
|
||||
};
|
||||
|
||||
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'Collection';
|
||||
@ -953,6 +974,7 @@ export type InputFieldTemplate =
|
||||
| IntegerPolymorphicInputFieldTemplate
|
||||
| IntegerInputFieldTemplate
|
||||
| IPAdapterInputFieldTemplate
|
||||
| IPAdapterModelInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| LatentsCollectionInputFieldTemplate
|
||||
| LatentsPolymorphicInputFieldTemplate
|
||||
|
@ -61,6 +61,7 @@ import {
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
IPAdapterInputFieldTemplate,
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
} from '../types/types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
|
||||
@ -436,6 +437,19 @@ const buildControlNetModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIPAdapterModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IPAdapterModelInputFieldTemplate => {
|
||||
const template: IPAdapterModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'IPAdapterModelField',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -866,6 +880,7 @@ const TEMPLATE_BUILDER_MAP = {
|
||||
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||
IPAdapterField: buildIPAdapterInputFieldTemplate,
|
||||
IPAdapterModelField: buildIPAdapterModelInputFieldTemplate,
|
||||
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||
LatentsField: buildLatentsInputFieldTemplate,
|
||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||
|
@ -30,6 +30,7 @@ const FIELD_VALUE_FALLBACK_MAP = {
|
||||
IntegerCollection: [],
|
||||
IntegerPolymorphic: 0,
|
||||
IPAdapterField: undefined,
|
||||
IPAdapterModelField: undefined,
|
||||
LatentsCollection: [],
|
||||
LatentsField: undefined,
|
||||
LatentsPolymorphic: undefined,
|
||||
|
@ -323,7 +323,17 @@ export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
|
||||
export const isValidControlNetModel = (
|
||||
val: unknown
|
||||
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for IP-Adapter models
|
||||
*/
|
||||
export const zIPAdapterModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type zIPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
|
||||
/**
|
||||
* Zod schema for l2l strength parameter
|
||||
*/
|
||||
|
@ -0,0 +1,29 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { zIPAdapterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { IPAdapterModelField } from 'services/api/types';
|
||||
|
||||
export const modelIdToIPAdapterModelParam = (
|
||||
ipAdapterModelId: string
|
||||
): IPAdapterModelField | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, _model_type, model_name] = ipAdapterModelId.split('/');
|
||||
|
||||
const result = zIPAdapterModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
ipAdapterModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse IP-Adapter model id'
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
Reference in New Issue
Block a user