From 5ac114576fc57f05c7885a326bf1d955487338c0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 8 Jul 2023 19:47:52 +1000 Subject: [PATCH] feat(ui): add controlnet field to nodes --- .../nodes/components/InputFieldComponent.tsx | 11 ++ .../ControlNetModelInputFieldComponent.tsx | 102 ++++++++++++++++++ .../web/src/features/nodes/types/constants.ts | 8 ++ .../web/src/features/nodes/types/types.ts | 13 +++ .../nodes/util/fieldTemplateBuilders.ts | 19 ++++ .../features/nodes/util/fieldValueBuilders.ts | 4 + .../util/modelIdToControlNetModelField.ts | 14 +++ .../frontend/web/src/services/api/types.d.ts | 2 + 8 files changed, 173 insertions(+) create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/ControlNetModelInputFieldComponent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index b179adff23..23effc5375 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -7,6 +7,7 @@ import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; +import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent'; import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; @@ -174,6 +175,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'controlnet_model' && template.type === 'controlnet_model') { + return ( + + ); + } + if (type === 'array' && template.type === 'array') { return ( +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: controlNetModels } = useGetControlNetModelsQuery(); + + const selectedModel = useMemo( + () => controlNetModels?.entities[field.value ?? controlNetModels.ids[0]], + [controlNetModels?.entities, controlNetModels?.ids, field.value] + ); + + const data = useMemo(() => { + if (!controlNetModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(controlNetModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [controlNetModels]); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && controlNetModels?.ids.includes(field.value)) { + return; + } + + const firstLora = controlNetModels?.ids[0]; + + if (!isString(firstLora)) { + return; + } + + handleValueChanged(firstLora); + }, [field.value, handleValueChanged, controlNetModels?.ids]); + + return ( + + ); +}; + +export default memo(ControlNetModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 5fe780a286..3a70e52ee5 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -19,6 +19,8 @@ export const FIELD_TYPE_MAP: Record = { model: 'model', vae_model: 'vae_model', lora_model: 'lora_model', + controlnet_model: 'controlnet_model', + ControlNetModelField: 'controlnet_model', array: 'array', item: 'item', ColorField: 'color', @@ -130,6 +132,12 @@ export const FIELDS: Record = { title: 'LoRA', description: 'Models are models.', }, + controlnet_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'ControlNet', + description: 'Models are models.', + }, array: { color: 'gray', colorCssVar: getColorTokenCssVariable('gray'), diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index 4c47c63068..18b837a98e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -71,6 +71,7 @@ export type FieldType = | 'model' | 'vae_model' | 'lora_model' + | 'controlnet_model' | 'array' | 'item' | 'color' @@ -100,6 +101,7 @@ export type InputFieldValue = | MainModelInputFieldValue | VaeModelInputFieldValue | LoRAModelInputFieldValue + | ControlNetModelInputFieldValue | ArrayInputFieldValue | ItemInputFieldValue | ColorInputFieldValue @@ -127,6 +129,7 @@ export type InputFieldTemplate = | ModelInputFieldTemplate | VaeModelInputFieldTemplate | LoRAModelInputFieldTemplate + | ControlNetModelInputFieldTemplate | ArrayInputFieldTemplate | ItemInputFieldTemplate | ColorInputFieldTemplate @@ -249,6 +252,11 @@ export type LoRAModelInputFieldValue = FieldValueBase & { value?: LoRAModelParam; }; +export type ControlNetModelInputFieldValue = FieldValueBase & { + type: 'controlnet_model'; + value?: string; +}; + export type ArrayInputFieldValue = FieldValueBase & { type: 'array'; value?: (string | number)[]; @@ -368,6 +376,11 @@ export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & { type: 'lora_model'; }; +export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'controlnet_model'; +}; + export type ArrayInputFieldTemplate = InputFieldTemplateBase & { default: []; type: 'array'; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index 1c2dbc0c3e..eaa7fe66fc 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -9,6 +9,7 @@ import { ColorInputFieldTemplate, ConditioningInputFieldTemplate, ControlInputFieldTemplate, + ControlNetModelInputFieldTemplate, EnumInputFieldTemplate, FieldType, FloatInputFieldTemplate, @@ -207,6 +208,21 @@ const buildLoRAModelInputFieldTemplate = ({ return template; }; +const buildControlNetModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => { + const template: ControlNetModelInputFieldTemplate = { + ...baseField, + type: 'controlnet_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildImageInputFieldTemplate = ({ schemaObject, baseField, @@ -479,6 +495,9 @@ export const buildInputFieldTemplate = ( if (['lora_model'].includes(fieldType)) { return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); } + if (['controlnet_model'].includes(fieldType)) { + return buildControlNetModelInputFieldTemplate({ schemaObject, baseField }); + } if (['enum'].includes(fieldType)) { return buildEnumInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index 950038b691..f54a7640bd 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -83,6 +83,10 @@ export const buildInputFieldValue = ( if (template.type === 'lora_model') { fieldValue.value = undefined; } + + if (template.type === 'controlnet_model') { + fieldValue.value = undefined; + } } return fieldValue; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts new file mode 100644 index 0000000000..655d5cd5df --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToControlNetModelField.ts @@ -0,0 +1,14 @@ +import { BaseModelType, ControlNetModelField } from 'services/api/types'; + +export const modelIdToControlNetModelField = ( + controlNetModelId: string +): ControlNetModelField => { + const [base_model, model_type, model_name] = controlNetModelId.split('/'); + + const field: ControlNetModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index fcbbd1a6a0..37faae592f 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -32,6 +32,8 @@ export type BaseModelType = components['schemas']['BaseModelType']; export type MainModelField = components['schemas']['MainModelField']; export type VAEModelField = components['schemas']['VAEModelField']; export type LoRAModelField = components['schemas']['LoRAModelField']; +export type ControlNetModelField = + components['schemas']['ControlNetModelField']; export type ModelsList = components['schemas']['ModelsList']; export type ControlField = components['schemas']['ControlField'];