From 5ac114576fc57f05c7885a326bf1d955487338c0 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Sat, 8 Jul 2023 19:47:52 +1000
Subject: [PATCH] feat(ui): add controlnet field to nodes
---
.../nodes/components/InputFieldComponent.tsx | 11 ++
.../ControlNetModelInputFieldComponent.tsx | 102 ++++++++++++++++++
.../web/src/features/nodes/types/constants.ts | 8 ++
.../web/src/features/nodes/types/types.ts | 13 +++
.../nodes/util/fieldTemplateBuilders.ts | 19 ++++
.../features/nodes/util/fieldValueBuilders.ts | 4 +
.../util/modelIdToControlNetModelField.ts | 14 +++
.../frontend/web/src/services/api/types.d.ts | 2 +
8 files changed, 173 insertions(+)
create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/ControlNetModelInputFieldComponent.tsx
create mode 100644 invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts
diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
index b179adff23..23effc5375 100644
--- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx
@@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent';
import ColorInputFieldComponent from './fields/ColorInputFieldComponent';
import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent';
import ControlInputFieldComponent from './fields/ControlInputFieldComponent';
+import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent';
import EnumInputFieldComponent from './fields/EnumInputFieldComponent';
import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent';
import ImageInputFieldComponent from './fields/ImageInputFieldComponent';
@@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => {
);
}
+ if (type === 'controlnet_model' && template.type === 'controlnet_model') {
+ return (
+
+ );
+ }
+
if (type === 'array' && template.type === 'array') {
return (
+) => {
+ const { nodeId, field } = props;
+
+ const dispatch = useAppDispatch();
+ const { t } = useTranslation();
+
+ const { data: controlNetModels } = useGetControlNetModelsQuery();
+
+ const selectedModel = useMemo(
+ () => controlNetModels?.entities[field.value ?? controlNetModels.ids[0]],
+ [controlNetModels?.entities, controlNetModels?.ids, field.value]
+ );
+
+ const data = useMemo(() => {
+ if (!controlNetModels) {
+ return [];
+ }
+
+ const data: SelectItem[] = [];
+
+ forEach(controlNetModels.entities, (model, id) => {
+ if (!model) {
+ return;
+ }
+
+ data.push({
+ value: id,
+ label: model.name,
+ group: BASE_MODEL_NAME_MAP[model.base_model],
+ });
+ });
+
+ return data;
+ }, [controlNetModels]);
+
+ const handleValueChanged = useCallback(
+ (v: string | null) => {
+ if (!v) {
+ return;
+ }
+
+ dispatch(
+ fieldValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value: v,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
+
+ useEffect(() => {
+ if (field.value && controlNetModels?.ids.includes(field.value)) {
+ return;
+ }
+
+ const firstLora = controlNetModels?.ids[0];
+
+ if (!isString(firstLora)) {
+ return;
+ }
+
+ handleValueChanged(firstLora);
+ }, [field.value, handleValueChanged, controlNetModels?.ids]);
+
+ return (
+
+ );
+};
+
+export default memo(ControlNetModelInputFieldComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts
index 5fe780a286..3a70e52ee5 100644
--- a/invokeai/frontend/web/src/features/nodes/types/constants.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts
@@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record = {
model: 'model',
vae_model: 'vae_model',
lora_model: 'lora_model',
+ controlnet_model: 'controlnet_model',
+ ControlNetModelField: 'controlnet_model',
array: 'array',
item: 'item',
ColorField: 'color',
@@ -130,6 +132,12 @@ export const FIELDS: Record = {
title: 'LoRA',
description: 'Models are models.',
},
+ controlnet_model: {
+ color: 'teal',
+ colorCssVar: getColorTokenCssVariable('teal'),
+ title: 'ControlNet',
+ description: 'Models are models.',
+ },
array: {
color: 'gray',
colorCssVar: getColorTokenCssVariable('gray'),
diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts
index 4c47c63068..18b837a98e 100644
--- a/invokeai/frontend/web/src/features/nodes/types/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/types.ts
@@ -71,6 +71,7 @@ export type FieldType =
| 'model'
| 'vae_model'
| 'lora_model'
+ | 'controlnet_model'
| 'array'
| 'item'
| 'color'
@@ -100,6 +101,7 @@ export type InputFieldValue =
| MainModelInputFieldValue
| VaeModelInputFieldValue
| LoRAModelInputFieldValue
+ | ControlNetModelInputFieldValue
| ArrayInputFieldValue
| ItemInputFieldValue
| ColorInputFieldValue
@@ -127,6 +129,7 @@ export type InputFieldTemplate =
| ModelInputFieldTemplate
| VaeModelInputFieldTemplate
| LoRAModelInputFieldTemplate
+ | ControlNetModelInputFieldTemplate
| ArrayInputFieldTemplate
| ItemInputFieldTemplate
| ColorInputFieldTemplate
@@ -249,6 +252,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & {
value?: LoRAModelParam;
};
+export type ControlNetModelInputFieldValue = FieldValueBase & {
+ type: 'controlnet_model';
+ value?: string;
+};
+
export type ArrayInputFieldValue = FieldValueBase & {
type: 'array';
value?: (string | number)[];
@@ -368,6 +376,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'lora_model';
};
+export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
+ default: string;
+ type: 'controlnet_model';
+};
+
export type ArrayInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'array';
diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
index 1c2dbc0c3e..eaa7fe66fc 100644
--- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts
@@ -9,6 +9,7 @@ import {
ColorInputFieldTemplate,
ConditioningInputFieldTemplate,
ControlInputFieldTemplate,
+ ControlNetModelInputFieldTemplate,
EnumInputFieldTemplate,
FieldType,
FloatInputFieldTemplate,
@@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({
return template;
};
+const buildControlNetModelInputFieldTemplate = ({
+ schemaObject,
+ baseField,
+}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => {
+ const template: ControlNetModelInputFieldTemplate = {
+ ...baseField,
+ type: 'controlnet_model',
+ inputRequirement: 'always',
+ inputKind: 'direct',
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
+
const buildImageInputFieldTemplate = ({
schemaObject,
baseField,
@@ -479,6 +495,9 @@ export const buildInputFieldTemplate = (
if (['lora_model'].includes(fieldType)) {
return buildLoRAModelInputFieldTemplate({ schemaObject, baseField });
}
+ if (['controlnet_model'].includes(fieldType)) {
+ return buildControlNetModelInputFieldTemplate({ schemaObject, baseField });
+ }
if (['enum'].includes(fieldType)) {
return buildEnumInputFieldTemplate({ schemaObject, baseField });
}
diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
index 950038b691..f54a7640bd 100644
--- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts
@@ -83,6 +83,10 @@ export const buildInputFieldValue = (
if (template.type === 'lora_model') {
fieldValue.value = undefined;
}
+
+ if (template.type === 'controlnet_model') {
+ fieldValue.value = undefined;
+ }
}
return fieldValue;
diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts
new file mode 100644
index 0000000000..655d5cd5df
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts
@@ -0,0 +1,14 @@
+import { BaseModelType, ControlNetModelField } from 'services/api/types';
+
+export const modelIdToControlNetModelField = (
+ controlNetModelId: string
+): ControlNetModelField => {
+ const [base_model, model_type, model_name] = controlNetModelId.split('/');
+
+ const field: ControlNetModelField = {
+ base_model: base_model as BaseModelType,
+ model_name,
+ };
+
+ return field;
+};
diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts
index fcbbd1a6a0..37faae592f 100644
--- a/invokeai/frontend/web/src/services/api/types.d.ts
+++ b/invokeai/frontend/web/src/services/api/types.d.ts
@@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType'];
export type MainModelField = components['schemas']['MainModelField'];
export type VAEModelField = components['schemas']['VAEModelField'];
export type LoRAModelField = components['schemas']['LoRAModelField'];
+export type ControlNetModelField =
+ components['schemas']['ControlNetModelField'];
export type ModelsList = components['schemas']['ModelsList'];
export type ControlField = components['schemas']['ControlField'];