add vae lodaer

This commit is contained in:
Lincoln Stein 2023-06-30 18:15:04 -04:00 committed by psychedelicious
parent 630f3c8b0b
commit fa8a5838d3
2 changed files with 169 additions and 59 deletions

View File

@ -1,10 +1,9 @@
import copy
from typing import List, Literal, Optional, Union
from typing import List, Literal, Optional
from pydantic import BaseModel, Field
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from ...backend.util.devices import choose_torch_device, torch_dtype
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
@ -32,7 +31,6 @@ class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
class ModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
@ -223,3 +221,53 @@ class LoraLoaderInvocation(BaseInvocation):
return output
class VAEModelField(BaseModel):
"""Vae model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
class VaeLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["vae_loader_output"] = "vae_loader_output"
vae: VaeField = Field(default=None, description="Vae model")
#fmt: on
class VaeLoaderInvocation(BaseInvocation):
"""Loads a VAE model, outputting a VaeLoaderOutput"""
type: Literal["vae_loader"] = "vae_loader"
vae_model: VAEModelField = Field(description="The VAE to load")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"vae_model": "vae_model"
}
},
}
def invoke(self, context: InvocationContext) -> VaeLoaderOutput:
base_model = self.vae.base_model
model_name = self.vae.model_name
model_type = ModelType.vae
if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=model_name,
model_type=model_type,
):
raise Exception(f"Unkown vae name: {model_name}!")
return VaeLoaderOutput(
vae=VaeField(
model_name = model_name,
base_model = base_model,
model_type = model_type,
)
)

View File

@ -1068,12 +1068,6 @@ export type components = {
nodes?: {
[key: string]:
| (
| components['schemas']['RangeInvocation']
| components['schemas']['RangeOfSizeInvocation']
| components['schemas']['RandomRangeInvocation']
| components['schemas']['MainModelLoaderInvocation']
| components['schemas']['LoraLoaderInvocation']
| components['schemas']['CompelInvocation']
| components['schemas']['LoadImageInvocation']
| components['schemas']['ShowImageInvocation']
| components['schemas']['ImageCropInvocation']
@ -1087,31 +1081,38 @@ export type components = {
| components['schemas']['ImageScaleInvocation']
| components['schemas']['ImageLerpInvocation']
| components['schemas']['ImageInverseLerpInvocation']
| components['schemas']['CvInpaintInvocation']
| components['schemas']['RestoreFaceInvocation']
| components['schemas']['MainModelLoaderInvocation']
| components['schemas']['LoraLoaderInvocation']
| components['schemas']['VaeLoaderInvocation']
| components['schemas']['CompelInvocation']
| components['schemas']['ControlNetInvocation']
| components['schemas']['ImageProcessorInvocation']
| components['schemas']['CvInpaintInvocation']
| components['schemas']['TextToLatentsInvocation']
| components['schemas']['LatentsToImageInvocation']
| components['schemas']['ResizeLatentsInvocation']
| components['schemas']['ScaleLatentsInvocation']
| components['schemas']['ImageToLatentsInvocation']
| components['schemas']['InpaintInvocation']
| components['schemas']['InfillColorInvocation']
| components['schemas']['InfillTileInvocation']
| components['schemas']['InfillPatchMatchInvocation']
| components['schemas']['AddInvocation']
| components['schemas']['SubtractInvocation']
| components['schemas']['MultiplyInvocation']
| components['schemas']['DivideInvocation']
| components['schemas']['RandomIntInvocation']
| components['schemas']['NoiseInvocation']
| components['schemas']['ParamIntInvocation']
| components['schemas']['ParamFloatInvocation']
| components['schemas']['UpscaleInvocation']
| components['schemas']['RangeInvocation']
| components['schemas']['RangeOfSizeInvocation']
| components['schemas']['RandomRangeInvocation']
| components['schemas']['DynamicPromptInvocation']
| components['schemas']['InfillColorInvocation']
| components['schemas']['InfillTileInvocation']
| components['schemas']['InfillPatchMatchInvocation']
| components['schemas']['NoiseInvocation']
| components['schemas']['FloatLinearRangeInvocation']
| components['schemas']['StepParamEasingInvocation']
| components['schemas']['DynamicPromptInvocation']
| components['schemas']['RestoreFaceInvocation']
| components['schemas']['UpscaleInvocation']
| components['schemas']['GraphInvocation']
| components['schemas']['IterateInvocation']
| components['schemas']['CollectInvocation']
@ -1177,20 +1178,21 @@ export type components = {
results: {
[key: string]:
| (
| components['schemas']['IntCollectionOutput']
| components['schemas']['FloatCollectionOutput']
| components['schemas']['ModelLoaderOutput']
| components['schemas']['LoraLoaderOutput']
| components['schemas']['CompelOutput']
| components['schemas']['ImageOutput']
| components['schemas']['MaskOutput']
| components['schemas']['ModelLoaderOutput']
| components['schemas']['LoraLoaderOutput']
| components['schemas']['VaeLoaderOutput']
| components['schemas']['CompelOutput']
| components['schemas']['ControlOutput']
| components['schemas']['LatentsOutput']
| components['schemas']['IntOutput']
| components['schemas']['FloatOutput']
| components['schemas']['NoiseOutput']
| components['schemas']['IntCollectionOutput']
| components['schemas']['FloatCollectionOutput']
| components['schemas']['PromptOutput']
| components['schemas']['PromptCollectionOutput']
| components['schemas']['NoiseOutput']
| components['schemas']['GraphInvocationOutput']
| components['schemas']['IterateInvocationOutput']
| components['schemas']['CollectInvocationOutput']
@ -3267,14 +3269,14 @@ export type components = {
ModelsList: {
/** Models */
models: (
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['StableDiffusion1ModelDiffusersConfig']
| components['schemas']['StableDiffusion1ModelCheckpointConfig']
| components['schemas']['VaeModelConfig']
| components['schemas']['LoRAModelConfig']
| components['schemas']['ControlNetModelConfig']
| components['schemas']['TextualInversionModelConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig']
| components['schemas']['StableDiffusion2ModelCheckpointConfig']
| components['schemas']['StableDiffusion2ModelDiffusersConfig']
)[];
};
/**
@ -4539,6 +4541,19 @@ export type components = {
*/
level?: 2 | 4;
};
/**
* VAEModelField
* @description Vae model field
*/
VAEModelField: {
/**
* Model Name
* @description Name of the model
*/
model_name: string;
/** @description Base model */
base_model: components['schemas']['BaseModelType'];
};
/** VaeField */
VaeField: {
/**
@ -4547,6 +4562,51 @@ export type components = {
*/
vae: components['schemas']['ModelInfo'];
};
/**
* VaeLoaderInvocation
* @description Loads a VAE model, outputting a VaeLoaderOutput
*/
VaeLoaderInvocation: {
/**
* Id
* @description The id of this node. Must be unique among all nodes.
*/
id: string;
/**
* Is Intermediate
* @description Whether or not this node is an intermediate node.
* @default false
*/
is_intermediate?: boolean;
/**
* Type
* @default vae_loader
* @enum {string}
*/
type?: 'vae_loader';
/**
* Vae Model
* @description The VAE to load
*/
vae_model: components['schemas']['VAEModelField'];
};
/**
* VaeLoaderOutput
* @description Model loader output
*/
VaeLoaderOutput: {
/**
* Type
* @default vae_loader_output
* @enum {string}
*/
type?: 'vae_loader_output';
/**
* Vae
* @description Vae model
*/
vae?: components['schemas']['VaeField'];
};
/** VaeModelConfig */
VaeModelConfig: {
/** Name */
@ -4625,18 +4685,18 @@ export type components = {
*/
image?: components['schemas']['ImageField'];
};
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: 'checkpoint' | 'diffusers';
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: 'checkpoint' | 'diffusers';
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: 'checkpoint' | 'diffusers';
};
responses: never;
parameters: never;
@ -4747,12 +4807,6 @@ export type operations = {
requestBody: {
content: {
'application/json':
| components['schemas']['RangeInvocation']
| components['schemas']['RangeOfSizeInvocation']
| components['schemas']['RandomRangeInvocation']
| components['schemas']['MainModelLoaderInvocation']
| components['schemas']['LoraLoaderInvocation']
| components['schemas']['CompelInvocation']
| components['schemas']['LoadImageInvocation']
| components['schemas']['ShowImageInvocation']
| components['schemas']['ImageCropInvocation']
@ -4766,31 +4820,38 @@ export type operations = {
| components['schemas']['ImageScaleInvocation']
| components['schemas']['ImageLerpInvocation']
| components['schemas']['ImageInverseLerpInvocation']
| components['schemas']['CvInpaintInvocation']
| components['schemas']['RestoreFaceInvocation']
| components['schemas']['MainModelLoaderInvocation']
| components['schemas']['LoraLoaderInvocation']
| components['schemas']['VaeLoaderInvocation']
| components['schemas']['CompelInvocation']
| components['schemas']['ControlNetInvocation']
| components['schemas']['ImageProcessorInvocation']
| components['schemas']['CvInpaintInvocation']
| components['schemas']['TextToLatentsInvocation']
| components['schemas']['LatentsToImageInvocation']
| components['schemas']['ResizeLatentsInvocation']
| components['schemas']['ScaleLatentsInvocation']
| components['schemas']['ImageToLatentsInvocation']
| components['schemas']['InpaintInvocation']
| components['schemas']['InfillColorInvocation']
| components['schemas']['InfillTileInvocation']
| components['schemas']['InfillPatchMatchInvocation']
| components['schemas']['AddInvocation']
| components['schemas']['SubtractInvocation']
| components['schemas']['MultiplyInvocation']
| components['schemas']['DivideInvocation']
| components['schemas']['RandomIntInvocation']
| components['schemas']['NoiseInvocation']
| components['schemas']['ParamIntInvocation']
| components['schemas']['ParamFloatInvocation']
| components['schemas']['UpscaleInvocation']
| components['schemas']['RangeInvocation']
| components['schemas']['RangeOfSizeInvocation']
| components['schemas']['RandomRangeInvocation']
| components['schemas']['DynamicPromptInvocation']
| components['schemas']['InfillColorInvocation']
| components['schemas']['InfillTileInvocation']
| components['schemas']['InfillPatchMatchInvocation']
| components['schemas']['NoiseInvocation']
| components['schemas']['FloatLinearRangeInvocation']
| components['schemas']['StepParamEasingInvocation']
| components['schemas']['DynamicPromptInvocation']
| components['schemas']['RestoreFaceInvocation']
| components['schemas']['UpscaleInvocation']
| components['schemas']['GraphInvocation']
| components['schemas']['IterateInvocation']
| components['schemas']['CollectInvocation']
@ -4847,12 +4908,6 @@ export type operations = {
requestBody: {
content: {
'application/json':
| components['schemas']['RangeInvocation']
| components['schemas']['RangeOfSizeInvocation']
| components['schemas']['RandomRangeInvocation']
| components['schemas']['MainModelLoaderInvocation']
| components['schemas']['LoraLoaderInvocation']
| components['schemas']['CompelInvocation']
| components['schemas']['LoadImageInvocation']
| components['schemas']['ShowImageInvocation']
| components['schemas']['ImageCropInvocation']
@ -4866,31 +4921,38 @@ export type operations = {
| components['schemas']['ImageScaleInvocation']
| components['schemas']['ImageLerpInvocation']
| components['schemas']['ImageInverseLerpInvocation']
| components['schemas']['CvInpaintInvocation']
| components['schemas']['RestoreFaceInvocation']
| components['schemas']['MainModelLoaderInvocation']
| components['schemas']['LoraLoaderInvocation']
| components['schemas']['VaeLoaderInvocation']
| components['schemas']['CompelInvocation']
| components['schemas']['ControlNetInvocation']
| components['schemas']['ImageProcessorInvocation']
| components['schemas']['CvInpaintInvocation']
| components['schemas']['TextToLatentsInvocation']
| components['schemas']['LatentsToImageInvocation']
| components['schemas']['ResizeLatentsInvocation']
| components['schemas']['ScaleLatentsInvocation']
| components['schemas']['ImageToLatentsInvocation']
| components['schemas']['InpaintInvocation']
| components['schemas']['InfillColorInvocation']
| components['schemas']['InfillTileInvocation']
| components['schemas']['InfillPatchMatchInvocation']
| components['schemas']['AddInvocation']
| components['schemas']['SubtractInvocation']
| components['schemas']['MultiplyInvocation']
| components['schemas']['DivideInvocation']
| components['schemas']['RandomIntInvocation']
| components['schemas']['NoiseInvocation']
| components['schemas']['ParamIntInvocation']
| components['schemas']['ParamFloatInvocation']
| components['schemas']['UpscaleInvocation']
| components['schemas']['RangeInvocation']
| components['schemas']['RangeOfSizeInvocation']
| components['schemas']['RandomRangeInvocation']
| components['schemas']['DynamicPromptInvocation']
| components['schemas']['InfillColorInvocation']
| components['schemas']['InfillTileInvocation']
| components['schemas']['InfillPatchMatchInvocation']
| components['schemas']['NoiseInvocation']
| components['schemas']['FloatLinearRangeInvocation']
| components['schemas']['StepParamEasingInvocation']
| components['schemas']['DynamicPromptInvocation']
| components['schemas']['RestoreFaceInvocation']
| components['schemas']['UpscaleInvocation']
| components['schemas']['GraphInvocation']
| components['schemas']['IterateInvocation']
| components['schemas']['CollectInvocation']