mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add controlnet field to nodes
This commit is contained in:
parent
29b2e59e65
commit
5ac114576f
@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
|
||||
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
|
||||
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
|
||||
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
|
||||
import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
|
||||
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
|
||||
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
|
||||
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
|
||||
@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'controlnet_model' && template.type === 'controlnet_model') {
|
||||
return (
|
||||
<ControlNetModelInputFieldComponent
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
template={template}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (type === 'array' && template.type === 'array') {
|
||||
return (
|
||||
<ArrayInputFieldComponent
|
||||
|
@ -0,0 +1,102 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
ControlNetModelInputFieldTemplate,
|
||||
ControlNetModelInputFieldValue,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/components/ModelSelect';
|
||||
import { forEach, isString } from 'lodash-es';
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetControlNetModelsQuery } from 'services/api/endpoints/models';
|
||||
import { FieldComponentProps } from './types';
|
||||
|
||||
const ControlNetModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
ControlNetModelInputFieldValue,
|
||||
ControlNetModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { data: controlNetModels } = useGetControlNetModelsQuery();
|
||||
|
||||
const selectedModel = useMemo(
|
||||
() => controlNetModels?.entities[field.value ?? controlNetModels.ids[0]],
|
||||
[controlNetModels?.entities, controlNetModels?.ids, field.value]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!controlNetModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(controlNetModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.name,
|
||||
group: BASE_MODEL_NAME_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [controlNetModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: v,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (field.value && controlNetModels?.ids.includes(field.value)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const firstLora = controlNetModels?.ids[0];
|
||||
|
||||
if (!isString(firstLora)) {
|
||||
return;
|
||||
}
|
||||
|
||||
handleValueChanged(firstLora);
|
||||
}, [field.value, handleValueChanged, controlNetModels?.ids]);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
tooltip={selectedModel?.description}
|
||||
label={
|
||||
selectedModel?.base_model &&
|
||||
BASE_MODEL_NAME_MAP[selectedModel?.base_model]
|
||||
}
|
||||
value={field.value}
|
||||
placeholder="Pick one"
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ControlNetModelInputFieldComponent);
|
@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record<string, FieldType> = {
|
||||
model: 'model',
|
||||
vae_model: 'vae_model',
|
||||
lora_model: 'lora_model',
|
||||
controlnet_model: 'controlnet_model',
|
||||
ControlNetModelField: 'controlnet_model',
|
||||
array: 'array',
|
||||
item: 'item',
|
||||
ColorField: 'color',
|
||||
@ -130,6 +132,12 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
title: 'LoRA',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
controlnet_model: {
|
||||
color: 'teal',
|
||||
colorCssVar: getColorTokenCssVariable('teal'),
|
||||
title: 'ControlNet',
|
||||
description: 'Models are models.',
|
||||
},
|
||||
array: {
|
||||
color: 'gray',
|
||||
colorCssVar: getColorTokenCssVariable('gray'),
|
||||
|
@ -71,6 +71,7 @@ export type FieldType =
|
||||
| 'model'
|
||||
| 'vae_model'
|
||||
| 'lora_model'
|
||||
| 'controlnet_model'
|
||||
| 'array'
|
||||
| 'item'
|
||||
| 'color'
|
||||
@ -100,6 +101,7 @@ export type InputFieldValue =
|
||||
| MainModelInputFieldValue
|
||||
| VaeModelInputFieldValue
|
||||
| LoRAModelInputFieldValue
|
||||
| ControlNetModelInputFieldValue
|
||||
| ArrayInputFieldValue
|
||||
| ItemInputFieldValue
|
||||
| ColorInputFieldValue
|
||||
@ -127,6 +129,7 @@ export type InputFieldTemplate =
|
||||
| ModelInputFieldTemplate
|
||||
| VaeModelInputFieldTemplate
|
||||
| LoRAModelInputFieldTemplate
|
||||
| ControlNetModelInputFieldTemplate
|
||||
| ArrayInputFieldTemplate
|
||||
| ItemInputFieldTemplate
|
||||
| ColorInputFieldTemplate
|
||||
@ -249,6 +252,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
|
||||
value?: LoRAModelParam;
|
||||
};
|
||||
|
||||
export type ControlNetModelInputFieldValue = FieldValueBase & {
|
||||
type: 'controlnet_model';
|
||||
value?: string;
|
||||
};
|
||||
|
||||
export type ArrayInputFieldValue = FieldValueBase & {
|
||||
type: 'array';
|
||||
value?: (string | number)[];
|
||||
@ -368,6 +376,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'lora_model';
|
||||
};
|
||||
|
||||
export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'controlnet_model';
|
||||
};
|
||||
|
||||
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'array';
|
||||
|
@ -9,6 +9,7 @@ import {
|
||||
ColorInputFieldTemplate,
|
||||
ConditioningInputFieldTemplate,
|
||||
ControlInputFieldTemplate,
|
||||
ControlNetModelInputFieldTemplate,
|
||||
EnumInputFieldTemplate,
|
||||
FieldType,
|
||||
FloatInputFieldTemplate,
|
||||
@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildControlNetModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
|
||||
const template: ControlNetModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'controlnet_model',
|
||||
inputRequirement: 'always',
|
||||
inputKind: 'direct',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
|
||||
if (['lora_model'].includes(fieldType)) {
|
||||
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['controlnet_model'].includes(fieldType)) {
|
||||
return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
if (['enum'].includes(fieldType)) {
|
||||
return buildEnumInputFieldTemplate({ schemaObject, baseField });
|
||||
}
|
||||
|
@ -83,6 +83,10 @@ export const buildInputFieldValue = (
|
||||
if (template.type === 'lora_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
|
||||
if (template.type === 'controlnet_model') {
|
||||
fieldValue.value = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
return fieldValue;
|
||||
|
@ -0,0 +1,14 @@
|
||||
import { BaseModelType, ControlNetModelField } from 'services/api/types';
|
||||
|
||||
export const modelIdToControlNetModelField = (
|
||||
controlNetModelId: string
|
||||
): ControlNetModelField => {
|
||||
const [base_model, model_type, model_name] = controlNetModelId.split('/');
|
||||
|
||||
const field: ControlNetModelField = {
|
||||
base_model: base_model as BaseModelType,
|
||||
model_name,
|
||||
};
|
||||
|
||||
return field;
|
||||
};
|
@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
|
||||
export type MainModelField = components['schemas']['MainModelField'];
|
||||
export type VAEModelField = components['schemas']['VAEModelField'];
|
||||
export type LoRAModelField = components['schemas']['LoRAModelField'];
|
||||
export type ControlNetModelField =
|
||||
components['schemas']['ControlNetModelField'];
|
||||
export type ModelsList = components['schemas']['ModelsList'];
|
||||
export type ControlField = components['schemas']['ControlField'];
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user