From 8e42502dfdee44da62c8d6e133fcc589899797c5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 10 Jul 2023 20:18:30 -0400 Subject: [PATCH] partial implementation of SDXL model loader --- invokeai/app/api/routers/models.py | 9 +- invokeai/app/invocations/model.py | 104 +++++++++++++++++- .../system/components/ModelSelect.tsx | 2 + 3 files changed, 111 insertions(+), 4 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 8dbeaa3d05..2f222bfeef 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -32,11 +32,16 @@ class ModelsList(BaseModel): responses={200: {"model": ModelsList }}, ) async def list_models( - base_model: Optional[BaseModelType] = Query(default=None, description="Base model"), + base_models: Optional[List[Union[BaseModelType,None]]] = Query(default=None, description="Base models to include"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Gets a list of models""" - models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) + if base_models and len(base_models)>0: + models_raw = list() + for base_model in base_models: + models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)) + else: + models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type) models = parse_obj_as(ModelsList, { "models": models_raw }) return models diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 34836eabd2..76cd5b81e3 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -33,7 +33,6 @@ class ClipField(BaseModel): skipped_layers: int = Field(description="Number of skipped layers in text_encoder") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") - class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") @@ -50,6 +49,18 @@ class ModelLoaderOutput(BaseInvocationOutput): vae: VaeField = Field(default=None, description="Vae submodel") # fmt: on +class SDXLModelLoaderOutput(BaseInvocationOutput): + """SDXL model loader output""" + + # fmt: off + type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output" + + unet: UNetField = Field(default=None, description="UNet submodel") + clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") + clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (2d set)") + vae: VaeField = Field(default=None, description="Vae submodel") + # fmt: on + class MainModelField(BaseModel): """Main model field""" @@ -64,7 +75,6 @@ class LoRAModelField(BaseModel): model_name: str = Field(description="Name of the LoRA model") base_model: BaseModelType = Field(description="Base model") - class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -167,7 +177,97 @@ class MainModelLoaderInvocation(BaseInvocation): ), ) +class SDXLMainModelLoaderInvocation(BaseInvocation): + """Loads an SDXL main model, outputting its submodels.""" + type: Literal["sdxl_main_model_loader"] = "sdxl_main_model_loader" + + model: MainModelField = Field(description="The SDXL model to load") + # TODO: precision? + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "title": "SDXL Model Loader", + "tags": ["model", "loader", "sdxl"], + "type_hints": {"model": "model"}, + }, + } + + def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: + base_model = self.model.base_model + model_name = self.model.model_name + model_type = ModelType.Main + + # TODO: not found exceptions + if not context.services.model_manager.model_exists( + model_name=model_name, + base_model=base_model, + model_type=model_type, + ): + raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") + + return SDXLModelLoaderOutput( + unet=UNetField( + unet=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.UNet, + ), + scheduler=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.Scheduler, + ), + loras=[], + ), + clip=ClipField( + tokenizer=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.Tokenizer, + ), + text_encoder=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.TextEncoder, + ), + loras=[], + skipped_layers=0, + ), + clip2=ClipField( + tokenizer=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.Tokenizer2, + ), + text_encoder=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.TextEncoder2, + ), + loras=[], + skipped_layers=0, + ), + vae=VaeField( + vae=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, + submodel=SubModelType.Vae, + ), + ), + ) + + + class LoraLoaderOutput(BaseInvocationOutput): """Model loader output""" diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 6b5aa830d9..40a6a1203b 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -13,6 +13,8 @@ import { useGetMainModelsQuery } from 'services/api/endpoints/models'; export const MODEL_TYPE_MAP = { 'sd-1': 'Stable Diffusion 1.x', 'sd-2': 'Stable Diffusion 2.x', + sdxl: 'Stable Diffusion XL', + 'sdxl-refiner': 'Stable Diffusion XL Refiner', }; const ModelSelect = () => {