import torch from typing import Literal from pydantic import Field from ...backend.model_management import ModelType, SubModelType from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo class SDXLModelLoaderOutput(BaseInvocationOutput): """SDXL base model loader output""" # fmt: off type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output" unet: UNetField = Field(default=None, description="UNet submodel") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") vae: VaeField = Field(default=None, description="Vae submodel") # fmt: on class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): """SDXL refiner model loader output""" # fmt: off type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output" unet: UNetField = Field(default=None, description="UNet submodel") clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") vae: VaeField = Field(default=None, description="Vae submodel") # fmt: on # fmt: on class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" type: Literal["sdxl_model_loader"] = "sdxl_model_loader" model: MainModelField = Field(description="The model to load") # TODO: precision? # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { "title": "SDXL Model Loader", "tags": ["model", "loader", "sdxl"], "type_hints": {"model": "model"}, }, } def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions if not context.services.model_manager.model_exists( model_name=model_name, base_model=base_model, model_type=model_type, ): raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") return SDXLModelLoaderOutput( unet=UNetField( unet=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.UNet, ), scheduler=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.Scheduler, ), loras=[], ), clip=ClipField( tokenizer=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.Tokenizer, ), text_encoder=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.TextEncoder, ), loras=[], skipped_layers=0, ), clip2=ClipField( tokenizer=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, ), vae=VaeField( vae=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.Vae, ), ), ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader" model: MainModelField = Field(description="The model to load") # TODO: precision? # Schema customisation class Config(InvocationConfig): schema_extra = { "ui": { "title": "SDXL Refiner Model Loader", "tags": ["model", "loader", "sdxl_refiner"], "type_hints": {"model": "refiner_model"}, }, } def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main # TODO: not found exceptions if not context.services.model_manager.model_exists( model_name=model_name, base_model=base_model, model_type=model_type, ): raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") return SDXLRefinerModelLoaderOutput( unet=UNetField( unet=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.UNet, ), scheduler=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.Scheduler, ), loras=[], ), clip2=ClipField( tokenizer=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.Tokenizer2, ), text_encoder=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.TextEncoder2, ), loras=[], skipped_layers=0, ), vae=VaeField( vae=ModelInfo( model_name=model_name, base_model=base_model, model_type=model_type, submodel=SubModelType.Vae, ), ), )