mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add differentiated sdxl and sdxl_refiner model loaders
This commit is contained in:
@ -46,7 +46,6 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
|
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
import copy
|
|
||||||
import torch
|
import torch
|
||||||
import inspect
|
import inspect
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
from pydantic import Field, validator
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType, ModelType, SubModelType
|
from ...backend.model_management import ModelType, SubModelType
|
||||||
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
InvocationConfig, InvocationContext)
|
InvocationConfig, InvocationContext)
|
||||||
|
|
||||||
@ -14,6 +13,136 @@ from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
|
|||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
|
||||||
|
|
||||||
|
class SDXLModelLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""SDXL base model loader output"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"
|
||||||
|
|
||||||
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
|
||||||
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
class SDXLRefinerModelLoaderOutput(SDXLModelLoaderOutput):
|
||||||
|
"""SDXL refiner model loader output"""
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
|
type: Literal["sdxl_model_loader"] = "sdxl_main_model_loader"
|
||||||
|
|
||||||
|
model: MainModelField = Field(description="The model to load")
|
||||||
|
# TODO: precision?
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "SDXL Model Loader",
|
||||||
|
"tags": ["model", "loader", "sdxl"],
|
||||||
|
"type_hints": {"model": "model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _output_class(cls):
|
||||||
|
return SDXLModelLoaderOutput
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||||
|
base_model = self.model.base_model
|
||||||
|
model_name = self.model.model_name
|
||||||
|
model_type = ModelType.Main
|
||||||
|
|
||||||
|
# TODO: not found exceptions
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||||
|
|
||||||
|
return self._output_class(
|
||||||
|
unet=UNetField(
|
||||||
|
unet=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Tokenizer,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.TextEncoder,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
skipped_layers=0,
|
||||||
|
),
|
||||||
|
clip2=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Tokenizer2,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.TextEncoder2,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
skipped_layers=0,
|
||||||
|
),
|
||||||
|
vae=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Vae,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
|
||||||
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
|
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"title": "SDXL Refiner Model Loader",
|
||||||
|
"tags": ["model", "loader", "sdxl_refiner"],
|
||||||
|
"type_hints": {"model": "model"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _output_class(cls):
|
||||||
|
return SDXLRefinerModelLoaderOutput
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
|
Reference in New Issue
Block a user