diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 12b8c7cdd6..73a640e04b 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,10 +1,9 @@ import copy -from typing import List, Literal, Optional, Union +from typing import List, Literal, Optional from pydantic import BaseModel, Field from ...backend.model_management import BaseModelType, ModelType, SubModelType -from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) @@ -32,7 +31,6 @@ class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") - class ModelLoaderOutput(BaseInvocationOutput): """Model loader output""" @@ -223,3 +221,53 @@ class LoraLoaderInvocation(BaseInvocation): return output +class VAEModelField(BaseModel): + """Vae model field""" + + model_name: str = Field(description="Name of the model") + base_model: BaseModelType = Field(description="Base model") + +class VaeLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + #fmt: off + type: Literal["vae_loader_output"] = "vae_loader_output" + + vae: VaeField = Field(default=None, description="Vae model") + #fmt: on + +class VaeLoaderInvocation(BaseInvocation): + """Loads a VAE model, outputting a VaeLoaderOutput""" + type: Literal["vae_loader"] = "vae_loader" + + vae_model: VAEModelField = Field(description="The VAE to load") + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["model", "loader"], + "type_hints": { + "vae_model": "vae_model" + } + }, + } + + def invoke(self, context: InvocationContext) -> VaeLoaderOutput: + base_model = self.vae.base_model + model_name = self.vae.model_name + model_type = ModelType.vae + + if not context.services.model_manager.model_exists( + base_model=base_model, + model_name=model_name, + model_type=model_type, + ): + raise Exception(f"Unkown vae name: {model_name}!") + return VaeLoaderOutput( + vae=VaeField( + model_name = model_name, + base_model = base_model, + model_type = model_type, + ) + ) diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 9d8ac90535..41058e3cd3 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -1068,12 +1068,6 @@ export type components = { nodes?: { [key: string]: | ( - | components['schemas']['RangeInvocation'] - | components['schemas']['RangeOfSizeInvocation'] - | components['schemas']['RandomRangeInvocation'] - | components['schemas']['MainModelLoaderInvocation'] - | components['schemas']['LoraLoaderInvocation'] - | components['schemas']['CompelInvocation'] | components['schemas']['LoadImageInvocation'] | components['schemas']['ShowImageInvocation'] | components['schemas']['ImageCropInvocation'] @@ -1087,31 +1081,38 @@ export type components = { | components['schemas']['ImageScaleInvocation'] | components['schemas']['ImageLerpInvocation'] | components['schemas']['ImageInverseLerpInvocation'] + | components['schemas']['CvInpaintInvocation'] + | components['schemas']['RestoreFaceInvocation'] + | components['schemas']['MainModelLoaderInvocation'] + | components['schemas']['LoraLoaderInvocation'] + | components['schemas']['VaeLoaderInvocation'] + | components['schemas']['CompelInvocation'] | components['schemas']['ControlNetInvocation'] | components['schemas']['ImageProcessorInvocation'] - | components['schemas']['CvInpaintInvocation'] | components['schemas']['TextToLatentsInvocation'] | components['schemas']['LatentsToImageInvocation'] | components['schemas']['ResizeLatentsInvocation'] | components['schemas']['ScaleLatentsInvocation'] | components['schemas']['ImageToLatentsInvocation'] | components['schemas']['InpaintInvocation'] - | components['schemas']['InfillColorInvocation'] - | components['schemas']['InfillTileInvocation'] - | components['schemas']['InfillPatchMatchInvocation'] | components['schemas']['AddInvocation'] | components['schemas']['SubtractInvocation'] | components['schemas']['MultiplyInvocation'] | components['schemas']['DivideInvocation'] | components['schemas']['RandomIntInvocation'] - | components['schemas']['NoiseInvocation'] | components['schemas']['ParamIntInvocation'] | components['schemas']['ParamFloatInvocation'] + | components['schemas']['UpscaleInvocation'] + | components['schemas']['RangeInvocation'] + | components['schemas']['RangeOfSizeInvocation'] + | components['schemas']['RandomRangeInvocation'] + | components['schemas']['DynamicPromptInvocation'] + | components['schemas']['InfillColorInvocation'] + | components['schemas']['InfillTileInvocation'] + | components['schemas']['InfillPatchMatchInvocation'] + | components['schemas']['NoiseInvocation'] | components['schemas']['FloatLinearRangeInvocation'] | components['schemas']['StepParamEasingInvocation'] - | components['schemas']['DynamicPromptInvocation'] - | components['schemas']['RestoreFaceInvocation'] - | components['schemas']['UpscaleInvocation'] | components['schemas']['GraphInvocation'] | components['schemas']['IterateInvocation'] | components['schemas']['CollectInvocation'] @@ -1177,20 +1178,21 @@ export type components = { results: { [key: string]: | ( - | components['schemas']['IntCollectionOutput'] - | components['schemas']['FloatCollectionOutput'] - | components['schemas']['ModelLoaderOutput'] - | components['schemas']['LoraLoaderOutput'] - | components['schemas']['CompelOutput'] | components['schemas']['ImageOutput'] | components['schemas']['MaskOutput'] + | components['schemas']['ModelLoaderOutput'] + | components['schemas']['LoraLoaderOutput'] + | components['schemas']['VaeLoaderOutput'] + | components['schemas']['CompelOutput'] | components['schemas']['ControlOutput'] | components['schemas']['LatentsOutput'] | components['schemas']['IntOutput'] | components['schemas']['FloatOutput'] - | components['schemas']['NoiseOutput'] + | components['schemas']['IntCollectionOutput'] + | components['schemas']['FloatCollectionOutput'] | components['schemas']['PromptOutput'] | components['schemas']['PromptCollectionOutput'] + | components['schemas']['NoiseOutput'] | components['schemas']['GraphInvocationOutput'] | components['schemas']['IterateInvocationOutput'] | components['schemas']['CollectInvocationOutput'] @@ -3267,14 +3269,14 @@ export type components = { ModelsList: { /** Models */ models: ( - | components['schemas']['StableDiffusion1ModelCheckpointConfig'] | components['schemas']['StableDiffusion1ModelDiffusersConfig'] + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] | components['schemas']['VaeModelConfig'] | components['schemas']['LoRAModelConfig'] | components['schemas']['ControlNetModelConfig'] | components['schemas']['TextualInversionModelConfig'] - | components['schemas']['StableDiffusion2ModelDiffusersConfig'] | components['schemas']['StableDiffusion2ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelDiffusersConfig'] )[]; }; /** @@ -4539,6 +4541,19 @@ export type components = { */ level?: 2 | 4; }; + /** + * VAEModelField + * @description Vae model field + */ + VAEModelField: { + /** + * Model Name + * @description Name of the model + */ + model_name: string; + /** @description Base model */ + base_model: components['schemas']['BaseModelType']; + }; /** VaeField */ VaeField: { /** @@ -4547,6 +4562,51 @@ export type components = { */ vae: components['schemas']['ModelInfo']; }; + /** + * VaeLoaderInvocation + * @description Loads a VAE model, outputting a VaeLoaderOutput + */ + VaeLoaderInvocation: { + /** + * Id + * @description The id of this node. Must be unique among all nodes. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this node is an intermediate node. + * @default false + */ + is_intermediate?: boolean; + /** + * Type + * @default vae_loader + * @enum {string} + */ + type?: 'vae_loader'; + /** + * Vae Model + * @description The VAE to load + */ + vae_model: components['schemas']['VAEModelField']; + }; + /** + * VaeLoaderOutput + * @description Model loader output + */ + VaeLoaderOutput: { + /** + * Type + * @default vae_loader_output + * @enum {string} + */ + type?: 'vae_loader_output'; + /** + * Vae + * @description Vae model + */ + vae?: components['schemas']['VaeField']; + }; /** VaeModelConfig */ VaeModelConfig: { /** Name */ @@ -4625,18 +4685,18 @@ export type components = { */ image?: components['schemas']['ImageField']; }; - /** - * StableDiffusion1ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion1ModelFormat: 'checkpoint' | 'diffusers'; /** * StableDiffusion2ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion2ModelFormat: 'checkpoint' | 'diffusers'; + /** + * StableDiffusion1ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion1ModelFormat: 'checkpoint' | 'diffusers'; }; responses: never; parameters: never; @@ -4747,12 +4807,6 @@ export type operations = { requestBody: { content: { 'application/json': - | components['schemas']['RangeInvocation'] - | components['schemas']['RangeOfSizeInvocation'] - | components['schemas']['RandomRangeInvocation'] - | components['schemas']['MainModelLoaderInvocation'] - | components['schemas']['LoraLoaderInvocation'] - | components['schemas']['CompelInvocation'] | components['schemas']['LoadImageInvocation'] | components['schemas']['ShowImageInvocation'] | components['schemas']['ImageCropInvocation'] @@ -4766,31 +4820,38 @@ export type operations = { | components['schemas']['ImageScaleInvocation'] | components['schemas']['ImageLerpInvocation'] | components['schemas']['ImageInverseLerpInvocation'] + | components['schemas']['CvInpaintInvocation'] + | components['schemas']['RestoreFaceInvocation'] + | components['schemas']['MainModelLoaderInvocation'] + | components['schemas']['LoraLoaderInvocation'] + | components['schemas']['VaeLoaderInvocation'] + | components['schemas']['CompelInvocation'] | components['schemas']['ControlNetInvocation'] | components['schemas']['ImageProcessorInvocation'] - | components['schemas']['CvInpaintInvocation'] | components['schemas']['TextToLatentsInvocation'] | components['schemas']['LatentsToImageInvocation'] | components['schemas']['ResizeLatentsInvocation'] | components['schemas']['ScaleLatentsInvocation'] | components['schemas']['ImageToLatentsInvocation'] | components['schemas']['InpaintInvocation'] - | components['schemas']['InfillColorInvocation'] - | components['schemas']['InfillTileInvocation'] - | components['schemas']['InfillPatchMatchInvocation'] | components['schemas']['AddInvocation'] | components['schemas']['SubtractInvocation'] | components['schemas']['MultiplyInvocation'] | components['schemas']['DivideInvocation'] | components['schemas']['RandomIntInvocation'] - | components['schemas']['NoiseInvocation'] | components['schemas']['ParamIntInvocation'] | components['schemas']['ParamFloatInvocation'] + | components['schemas']['UpscaleInvocation'] + | components['schemas']['RangeInvocation'] + | components['schemas']['RangeOfSizeInvocation'] + | components['schemas']['RandomRangeInvocation'] + | components['schemas']['DynamicPromptInvocation'] + | components['schemas']['InfillColorInvocation'] + | components['schemas']['InfillTileInvocation'] + | components['schemas']['InfillPatchMatchInvocation'] + | components['schemas']['NoiseInvocation'] | components['schemas']['FloatLinearRangeInvocation'] | components['schemas']['StepParamEasingInvocation'] - | components['schemas']['DynamicPromptInvocation'] - | components['schemas']['RestoreFaceInvocation'] - | components['schemas']['UpscaleInvocation'] | components['schemas']['GraphInvocation'] | components['schemas']['IterateInvocation'] | components['schemas']['CollectInvocation'] @@ -4847,12 +4908,6 @@ export type operations = { requestBody: { content: { 'application/json': - | components['schemas']['RangeInvocation'] - | components['schemas']['RangeOfSizeInvocation'] - | components['schemas']['RandomRangeInvocation'] - | components['schemas']['MainModelLoaderInvocation'] - | components['schemas']['LoraLoaderInvocation'] - | components['schemas']['CompelInvocation'] | components['schemas']['LoadImageInvocation'] | components['schemas']['ShowImageInvocation'] | components['schemas']['ImageCropInvocation'] @@ -4866,31 +4921,38 @@ export type operations = { | components['schemas']['ImageScaleInvocation'] | components['schemas']['ImageLerpInvocation'] | components['schemas']['ImageInverseLerpInvocation'] + | components['schemas']['CvInpaintInvocation'] + | components['schemas']['RestoreFaceInvocation'] + | components['schemas']['MainModelLoaderInvocation'] + | components['schemas']['LoraLoaderInvocation'] + | components['schemas']['VaeLoaderInvocation'] + | components['schemas']['CompelInvocation'] | components['schemas']['ControlNetInvocation'] | components['schemas']['ImageProcessorInvocation'] - | components['schemas']['CvInpaintInvocation'] | components['schemas']['TextToLatentsInvocation'] | components['schemas']['LatentsToImageInvocation'] | components['schemas']['ResizeLatentsInvocation'] | components['schemas']['ScaleLatentsInvocation'] | components['schemas']['ImageToLatentsInvocation'] | components['schemas']['InpaintInvocation'] - | components['schemas']['InfillColorInvocation'] - | components['schemas']['InfillTileInvocation'] - | components['schemas']['InfillPatchMatchInvocation'] | components['schemas']['AddInvocation'] | components['schemas']['SubtractInvocation'] | components['schemas']['MultiplyInvocation'] | components['schemas']['DivideInvocation'] | components['schemas']['RandomIntInvocation'] - | components['schemas']['NoiseInvocation'] | components['schemas']['ParamIntInvocation'] | components['schemas']['ParamFloatInvocation'] + | components['schemas']['UpscaleInvocation'] + | components['schemas']['RangeInvocation'] + | components['schemas']['RangeOfSizeInvocation'] + | components['schemas']['RandomRangeInvocation'] + | components['schemas']['DynamicPromptInvocation'] + | components['schemas']['InfillColorInvocation'] + | components['schemas']['InfillTileInvocation'] + | components['schemas']['InfillPatchMatchInvocation'] + | components['schemas']['NoiseInvocation'] | components['schemas']['FloatLinearRangeInvocation'] | components['schemas']['StepParamEasingInvocation'] - | components['schemas']['DynamicPromptInvocation'] - | components['schemas']['RestoreFaceInvocation'] - | components['schemas']['UpscaleInvocation'] | components['schemas']['GraphInvocation'] | components['schemas']['IterateInvocation'] | components['schemas']['CollectInvocation']