From 29c8ddfb884e5ab9562b27145fe926b3feaff798 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 28 Jun 2024 18:03:09 -0400 Subject: [PATCH] WIP - A bunch of boilerplate to support Spandrel Image-to-Image models throughout the model manager and the frontend. --- invokeai/app/invocations/fields.py | 2 + invokeai/backend/model_manager/config.py | 12 ++ invokeai/backend/model_manager/probe.py | 19 ++- .../Invocation/fields/InputFieldRenderer.tsx | 8 + ...elImageToImageModelFieldInputComponent.tsx | 56 +++++++ .../src/features/nodes/store/nodesSlice.ts | 6 + .../web/src/features/nodes/types/common.ts | 1 + .../web/src/features/nodes/types/constants.ts | 2 + .../web/src/features/nodes/types/field.ts | 33 ++++ .../util/schema/buildFieldInputInstance.ts | 1 + .../util/schema/buildFieldInputTemplate.ts | 13 ++ .../nodes/util/workflow/validateWorkflow.ts | 1 + .../src/services/api/hooks/modelsByType.ts | 2 + .../frontend/web/src/services/api/schema.ts | 144 ++++++++++++++++-- .../frontend/web/src/services/api/types.ts | 6 + 15 files changed, 287 insertions(+), 19 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index b792453b47..f341039fe0 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum): ControlNetModel = "ControlNetModelField" IPAdapterModel = "IPAdapterModelField" T2IAdapterModel = "T2IAdapterModelField" + SpandrelImageToImageModel = "SpandrelImageToImageModelField" # endregion # region Misc Field Types @@ -134,6 +135,7 @@ class FieldDescriptions: sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" + spandrel_image_to_image_model = "Spandrel Image-to-Image model" lora_weight = "The weight at which the LoRA is applied to each model" compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" raw_prompt = "Raw prompt text (no parsing)" diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 9a33cc502e..3579a0c7b2 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -373,6 +373,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase): return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}") +class SpandrelImageToImageConfig(ModelConfigBase): + """Model config for Spandrel Image to Image models.""" + + type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage + format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}") + + def get_model_discriminator_value(v: Any) -> str: """ Computes the discriminator value for a model config. @@ -409,6 +420,7 @@ AnyModelConfig = Annotated[ Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()], Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()], Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], + Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()], Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], ], Discriminator(get_model_discriminator_value), diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 8ba63f0db5..53da5fc152 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -243,10 +243,14 @@ class ModelProbe(object): # Check if the model can be loaded as a SpandrelImageToImageModel. try: - _ = SpandrelImageToImageModel.load_from_state_dict(ckpt) + # TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected. + # _ = SpandrelImageToImageModel.load_from_state_dict(ckpt) + _ = SpandrelImageToImageModel.load_from_file(model_path) return ModelType.SpandrelImageToImage - except Exception: + except Exception as e: # TODO(ryand): Catch a more specific exception type here if we can. + # TODO(ryand): Delete this print statement. + print(e) pass raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") @@ -579,9 +583,9 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase): raise NotImplementedError() -class SpandrelImageToImageModelProbe(CheckpointProbeBase): +class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase): def get_base_type(self) -> BaseModelType: - raise NotImplementedError() + return BaseModelType.Any ######################################################## @@ -791,6 +795,11 @@ class CLIPVisionFolderProbe(FolderProbeBase): return BaseModelType.Any +class SpandrelImageToImageFolderProbe(FolderProbeBase): + def get_base_type(self) -> BaseModelType: + raise NotImplementedError() + + class T2IAdapterFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" @@ -820,6 +829,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) @@ -829,5 +839,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 99937ceec4..b67439eb70 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -32,6 +32,8 @@ import { isSDXLMainModelFieldInputTemplate, isSDXLRefinerModelFieldInputInstance, isSDXLRefinerModelFieldInputTemplate, + isSpandrelImageToImageModelFieldInputInstance, + isSpandrelImageToImageModelFieldInputTemplate, isStringFieldInputInstance, isStringFieldInputTemplate, isT2IAdapterModelFieldInputInstance, @@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent'; import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent'; import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent'; import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent'; +import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent'; import StringFieldInputComponent from './inputs/StringFieldInputComponent'; import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent'; import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent'; @@ -125,6 +128,11 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) { return ; } + + if (isSpandrelImageToImageModelFieldInputInstance(fieldInstance) && isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate)) { + return ; + } + if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx new file mode 100644 index 0000000000..fbb23caa90 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx @@ -0,0 +1,56 @@ +import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldSpandrelImageToImageModelValueChanged, } from 'features/nodes/store/nodesSlice'; +import type { + SpandrelImageToImageModelFieldInputInstance, + SpandrelImageToImageModelFieldInputTemplate, +} from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType'; +import type { SpandrelImageToImageModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +const SpandrelImageToImageModelFieldInputComponent = ( + props: FieldComponentProps +) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + + const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels(); + + const _onChange = useCallback( + (value: SpandrelImageToImageModelConfig | null) => { + if (!value) { + return; + } + dispatch( + + fieldSpandrelImageToImageModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + const { options, value, onChange } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + selectedModel: field.value, + isLoading, + }); + + return ( + + + + + + ); +}; + +export default memo(SpandrelImageToImageModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 5ebc5de147..e1a74b947d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -19,6 +19,7 @@ import type { ModelIdentifierFieldValue, SchedulerFieldValue, SDXLRefinerModelFieldValue, + SpandrelImageToImageModelFieldValue, StatefulFieldValue, StringFieldValue, T2IAdapterModelFieldValue, @@ -39,6 +40,7 @@ import { zModelIdentifierFieldValue, zSchedulerFieldValue, zSDXLRefinerModelFieldValue, + zSpandrelImageToImageModelFieldValue, zStatefulFieldValue, zStringFieldValue, zT2IAdapterModelFieldValue, @@ -333,6 +335,9 @@ export const nodesSlice = createSlice({ fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zT2IAdapterModelFieldValue); }, + fieldSpandrelImageToImageModelValueChanged: (state, action: FieldValueAction) => { + fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue); + }, fieldEnumModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zEnumFieldValue); }, @@ -384,6 +389,7 @@ export const { fieldImageValueChanged, fieldIPAdapterModelValueChanged, fieldT2IAdapterModelValueChanged, + fieldSpandrelImageToImageModelValueChanged, fieldLabelChanged, fieldLoRAModelValueChanged, fieldModelIdentifierValueChanged, diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 54e126af3a..2ea8900281 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -66,6 +66,7 @@ const zModelType = z.enum([ 'embedding', 'onnx', 'clip_vision', + 'spandrel_image_to_image', ]); const zSubModelType = z.enum([ 'unet', diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 4ede5cd479..05697c384c 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -38,6 +38,7 @@ export const MODEL_TYPES = [ 'VAEField', 'CLIPField', 'T2IAdapterModelField', + 'SpandrelImageToImageModelField', ]; /** @@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = { MainModelField: 'teal.500', SDXLMainModelField: 'teal.500', SDXLRefinerModelField: 'teal.500', + SpandrelImageToImageModelField: 'teal.500', StringField: 'yellow.500', T2IAdapterField: 'teal.500', T2IAdapterModelField: 'teal.500', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index e2a84e3390..ba9078bec2 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ name: z.literal('T2IAdapterModelField'), originalType: zStatelessFieldType.optional(), }); +const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SpandrelImageToImageModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSchedulerFieldType = zFieldTypeBase.extend({ name: z.literal('SchedulerField'), originalType: zStatelessFieldType.optional(), @@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([ zControlNetModelFieldType, zIPAdapterModelFieldType, zT2IAdapterModelFieldType, + zSpandrelImageToImageModelFieldType, zColorFieldType, zSchedulerFieldType, ]); @@ -581,6 +586,30 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda zT2IAdapterModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region SpandrelModelToModelField + +export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional(); +const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zSpandrelImageToImageModelFieldValue, +}); +const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSpandrelImageToImageModelFieldType, + originalType: zFieldType.optional(), + default: zSpandrelImageToImageModelFieldValue, +}); +const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSpandrelImageToImageModelFieldType, +}); +export type SpandrelImageToImageModelFieldValue = z.infer; +export type SpandrelImageToImageModelFieldInputInstance = z.infer; +export type SpandrelImageToImageModelFieldInputTemplate = z.infer; +export const isSpandrelImageToImageModelFieldInputInstance = (val: unknown): val is SpandrelImageToImageModelFieldInputInstance => + zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success; +export const isSpandrelImageToImageModelFieldInputTemplate = (val: unknown): val is SpandrelImageToImageModelFieldInputTemplate => + zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success; +// #endregion + + // #region SchedulerField export const zSchedulerFieldValue = zSchedulerField.optional(); @@ -667,6 +696,7 @@ export const zStatefulFieldValue = z.union([ zControlNetModelFieldValue, zIPAdapterModelFieldValue, zT2IAdapterModelFieldValue, + zSpandrelImageToImageModelFieldValue, zColorFieldValue, zSchedulerFieldValue, ]); @@ -694,6 +724,7 @@ const zStatefulFieldInputInstance = z.union([ zControlNetModelFieldInputInstance, zIPAdapterModelFieldInputInstance, zT2IAdapterModelFieldInputInstance, + zSpandrelImageToImageModelFieldInputInstance, zColorFieldInputInstance, zSchedulerFieldInputInstance, ]); @@ -722,6 +753,7 @@ const zStatefulFieldInputTemplate = z.union([ zControlNetModelFieldInputTemplate, zIPAdapterModelFieldInputTemplate, zT2IAdapterModelFieldInputTemplate, + zSpandrelImageToImageModelFieldInputTemplate, zColorFieldInputTemplate, zSchedulerFieldInputTemplate, zStatelessFieldInputTemplate, @@ -751,6 +783,7 @@ const zStatefulFieldOutputTemplate = z.union([ zControlNetModelFieldOutputTemplate, zIPAdapterModelFieldOutputTemplate, zT2IAdapterModelFieldOutputTemplate, + zSpandrelImageToImageModelFieldOutputTemplate, zColorFieldOutputTemplate, zSchedulerFieldOutputTemplate, ]); diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index 597779fd61..a5a2d89f03 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = SDXLRefinerModelField: undefined, StringField: '', T2IAdapterModelField: undefined, + SpandrelImageToImageModelField: undefined, VAEModelField: undefined, ControlNetModelField: undefined, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 2b77274526..8478415cd1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -17,6 +17,7 @@ import type { SchedulerFieldInputTemplate, SDXLMainModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate, + SpandrelImageToImageModelFieldInputTemplate, StatefulFieldType, StatelessFieldInputTemplate, StringFieldInputTemplate, @@ -263,6 +264,17 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, fieldType }) => { + const template: SpandrelImageToImageModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; const buildBoardFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -377,6 +389,7 @@ export const TEMPLATE_BUILDER_MAP: Record { + return config.type === 'spandrel_image_to_image'; +} + export const isControlAdapterModelConfig = ( config: AnyModelConfig ): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {