mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add controlnet field to nodes
This commit is contained in:
parent
29b2e59e65
commit
5ac114576f
@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
|||||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||||
|
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
|
||||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||||
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
|
||||||
|
return (
|
||||||
|
<ControlNetModelInputFieldComponent
|
||||||
|
nodeId={nodeId}
|
||||||
|
field={field}
|
||||||
|
template={template}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
if (type === 'array' && template.type === 'array') {
|
if (type === 'array' && template.type === 'array') {
|
||||||
return (
|
return (
|
||||||
<ArrayInputFieldComponent
|
<ArrayInputFieldComponent
|
||||||
|
@ -0,0 +1,102 @@
|
|||||||
|
import { SelectItem } from '@mantine/core';
|
||||||
|
import { useAppDispatch } from 'app/store/storeHooks';
|
||||||
|
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||||
|
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||||
|
import {
|
||||||
|
ControlNetModelInputFieldTemplate,
|
||||||
|
ControlNetModelInputFieldValue,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||||
|
import { forEach, isString } from 'lodash-es';
|
||||||
|
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
import { FieldComponentProps } from './types';
|
||||||
|
|
||||||
|
const ControlNetModelInputFieldComponent = (
|
||||||
|
props: FieldComponentProps<
|
||||||
|
ControlNetModelInputFieldValue,
|
||||||
|
ControlNetModelInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
const { nodeId, field } = props;
|
||||||
|
|
||||||
|
const dispatch = useAppDispatch();
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||||
|
|
||||||
|
const selectedModel = useMemo(
|
||||||
|
() => controlNetModels?.entities[field.value ?? controlNetModels.ids[0]],
|
||||||
|
[controlNetModels?.entities, controlNetModels?.ids, field.value]
|
||||||
|
);
|
||||||
|
|
||||||
|
const data = useMemo(() => {
|
||||||
|
if (!controlNetModels) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const data: SelectItem[] = [];
|
||||||
|
|
||||||
|
forEach(controlNetModels.entities, (model, id) => {
|
||||||
|
if (!model) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
data.push({
|
||||||
|
value: id,
|
||||||
|
label: model.name,
|
||||||
|
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
return data;
|
||||||
|
}, [controlNetModels]);
|
||||||
|
|
||||||
|
const handleValueChanged = useCallback(
|
||||||
|
(v: string | null) => {
|
||||||
|
if (!v) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch(
|
||||||
|
fieldValueChanged({
|
||||||
|
nodeId,
|
||||||
|
fieldName: field.name,
|
||||||
|
value: v,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
},
|
||||||
|
[dispatch, field.name, nodeId]
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (field.value && controlNetModels?.ids.includes(field.value)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstLora = controlNetModels?.ids[0];
|
||||||
|
|
||||||
|
if (!isString(firstLora)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
handleValueChanged(firstLora);
|
||||||
|
}, [field.value, handleValueChanged, controlNetModels?.ids]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<IAIMantineSelect
|
||||||
|
tooltip={selectedModel?.description}
|
||||||
|
label={
|
||||||
|
selectedModel?.base_model &&
|
||||||
|
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||||
|
}
|
||||||
|
value={field.value}
|
||||||
|
placeholder="Pick one"
|
||||||
|
data={data}
|
||||||
|
onChange={handleValueChanged}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(ControlNetModelInputFieldComponent);
|
@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
|||||||
model: 'model',
|
model: 'model',
|
||||||
vae_model: 'vae_model',
|
vae_model: 'vae_model',
|
||||||
lora_model: 'lora_model',
|
lora_model: 'lora_model',
|
||||||
|
controlnet_model: 'controlnet_model',
|
||||||
|
ControlNetModelField: 'controlnet_model',
|
||||||
array: 'array',
|
array: 'array',
|
||||||
item: 'item',
|
item: 'item',
|
||||||
ColorField: 'color',
|
ColorField: 'color',
|
||||||
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
title: 'LoRA',
|
title: 'LoRA',
|
||||||
description: 'Models are models.',
|
description: 'Models are models.',
|
||||||
},
|
},
|
||||||
|
controlnet_model: {
|
||||||
|
color: 'teal',
|
||||||
|
colorCssVar: getColorTokenCssVariable('teal'),
|
||||||
|
title: 'ControlNet',
|
||||||
|
description: 'Models are models.',
|
||||||
|
},
|
||||||
array: {
|
array: {
|
||||||
color: 'gray',
|
color: 'gray',
|
||||||
colorCssVar: getColorTokenCssVariable('gray'),
|
colorCssVar: getColorTokenCssVariable('gray'),
|
||||||
|
@ -71,6 +71,7 @@ export type FieldType =
|
|||||||
| 'model'
|
| 'model'
|
||||||
| 'vae_model'
|
| 'vae_model'
|
||||||
| 'lora_model'
|
| 'lora_model'
|
||||||
|
| 'controlnet_model'
|
||||||
| 'array'
|
| 'array'
|
||||||
| 'item'
|
| 'item'
|
||||||
| 'color'
|
| 'color'
|
||||||
@ -100,6 +101,7 @@ export type InputFieldValue =
|
|||||||
| MainModelInputFieldValue
|
| MainModelInputFieldValue
|
||||||
| VaeModelInputFieldValue
|
| VaeModelInputFieldValue
|
||||||
| LoRAModelInputFieldValue
|
| LoRAModelInputFieldValue
|
||||||
|
| ControlNetModelInputFieldValue
|
||||||
| ArrayInputFieldValue
|
| ArrayInputFieldValue
|
||||||
| ItemInputFieldValue
|
| ItemInputFieldValue
|
||||||
| ColorInputFieldValue
|
| ColorInputFieldValue
|
||||||
@ -127,6 +129,7 @@ export type InputFieldTemplate =
|
|||||||
| ModelInputFieldTemplate
|
| ModelInputFieldTemplate
|
||||||
| VaeModelInputFieldTemplate
|
| VaeModelInputFieldTemplate
|
||||||
| LoRAModelInputFieldTemplate
|
| LoRAModelInputFieldTemplate
|
||||||
|
| ControlNetModelInputFieldTemplate
|
||||||
| ArrayInputFieldTemplate
|
| ArrayInputFieldTemplate
|
||||||
| ItemInputFieldTemplate
|
| ItemInputFieldTemplate
|
||||||
| ColorInputFieldTemplate
|
| ColorInputFieldTemplate
|
||||||
@ -249,6 +252,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
|
|||||||
value?: LoRAModelParam;
|
value?: LoRAModelParam;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ControlNetModelInputFieldValue = FieldValueBase & {
|
||||||
|
type: 'controlnet_model';
|
||||||
|
value?: string;
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldValue = FieldValueBase & {
|
export type ArrayInputFieldValue = FieldValueBase & {
|
||||||
type: 'array';
|
type: 'array';
|
||||||
value?: (string | number)[];
|
value?: (string | number)[];
|
||||||
@ -368,6 +376,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
|
|||||||
type: 'lora_model';
|
type: 'lora_model';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: string;
|
||||||
|
type: 'controlnet_model';
|
||||||
|
};
|
||||||
|
|
||||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: [];
|
default: [];
|
||||||
type: 'array';
|
type: 'array';
|
||||||
|
@ -9,6 +9,7 @@ import {
|
|||||||
ColorInputFieldTemplate,
|
ColorInputFieldTemplate,
|
||||||
ConditioningInputFieldTemplate,
|
ConditioningInputFieldTemplate,
|
||||||
ControlInputFieldTemplate,
|
ControlInputFieldTemplate,
|
||||||
|
ControlNetModelInputFieldTemplate,
|
||||||
EnumInputFieldTemplate,
|
EnumInputFieldTemplate,
|
||||||
FieldType,
|
FieldType,
|
||||||
FloatInputFieldTemplate,
|
FloatInputFieldTemplate,
|
||||||
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildControlNetModelInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
|
||||||
|
const template: ControlNetModelInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'controlnet_model',
|
||||||
|
inputRequirement: 'always',
|
||||||
|
inputKind: 'direct',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildImageInputFieldTemplate = ({
|
const buildImageInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
|
|||||||
if (['lora_model'].includes(fieldType)) {
|
if (['lora_model'].includes(fieldType)) {
|
||||||
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
|
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
if (['controlnet_model'].includes(fieldType)) {
|
||||||
|
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
|
||||||
|
}
|
||||||
if (['enum'].includes(fieldType)) {
|
if (['enum'].includes(fieldType)) {
|
||||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||||
}
|
}
|
||||||
|
@ -83,6 +83,10 @@ export const buildInputFieldValue = (
|
|||||||
if (template.type === 'lora_model') {
|
if (template.type === 'lora_model') {
|
||||||
fieldValue.value = undefined;
|
fieldValue.value = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (template.type === 'controlnet_model') {
|
||||||
|
fieldValue.value = undefined;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fieldValue;
|
return fieldValue;
|
||||||
|
@ -0,0 +1,14 @@
|
|||||||
|
import { BaseModelType, ControlNetModelField } from 'services/api/types';
|
||||||
|
|
||||||
|
export const modelIdToControlNetModelField = (
|
||||||
|
controlNetModelId: string
|
||||||
|
): ControlNetModelField => {
|
||||||
|
const [base_model, model_type, model_name] = controlNetModelId.split('/');
|
||||||
|
|
||||||
|
const field: ControlNetModelField = {
|
||||||
|
base_model: base_model as BaseModelType,
|
||||||
|
model_name,
|
||||||
|
};
|
||||||
|
|
||||||
|
return field;
|
||||||
|
};
|
@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
|
|||||||
export type MainModelField = components['schemas']['MainModelField'];
|
export type MainModelField = components['schemas']['MainModelField'];
|
||||||
export type VAEModelField = components['schemas']['VAEModelField'];
|
export type VAEModelField = components['schemas']['VAEModelField'];
|
||||||
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
||||||
|
export type ControlNetModelField =
|
||||||
|
components['schemas']['ControlNetModelField'];
|
||||||
export type ModelsList = components['schemas']['ModelsList'];
|
export type ModelsList = components['schemas']['ModelsList'];
|
||||||
export type ControlField = components['schemas']['ControlField'];
|
export type ControlField = components['schemas']['ControlField'];
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user