from contextlib import ExitStack from typing import cast import torch from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline from git import Optional from pydantic import field_validator from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, Input, invocation, invocation_output, ) from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.denoise_latents import get_scheduler from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField, UIType from invokeai.app.invocations.model import ModelIdentifierField, SD3CLIPField, TransformerField, VAEField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX from invokeai.backend.model_manager.config import SubModelType from invokeai.backend.model_manager.load.load_base import LoadedModel sd3_pipeline: Optional[StableDiffusion3Pipeline] = None transformer_info: Optional[LoadedModel] = None tokenizer_1_info: Optional[LoadedModel] = None tokenizer_2_info: Optional[LoadedModel] = None tokenizer_3_info: Optional[LoadedModel] = None text_encoder_1_info: Optional[LoadedModel] = None text_encoder_2_info: Optional[LoadedModel] = None text_encoder_3_info: Optional[LoadedModel] = None class FakeVae: class FakeVaeConfig: def __init__(self) -> None: self.block_out_channels = [0] def __init__(self) -> None: self.config = FakeVae.FakeVaeConfig() @invocation_output("sd3_model_loader_output") class SD3ModelLoaderOutput(BaseInvocationOutput): """Stable Diffuion 3 base model loader output""" transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") clip: SD3CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP") vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") @invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0") class SD3ModelLoaderInvocation(BaseInvocation): """Loads an SD3 base model, outputting its submodels.""" model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel) def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput: model_key = self.model.key if not context.models.exists(model_key): raise Exception(f"Unknown model: {model_key}") transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler}) tokenizer_1 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) text_encoder_1 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) tokenizer_2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) text_encoder_2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) tokenizer_3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3}) text_encoder_3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3}) vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) return SD3ModelLoaderOutput( transformer=TransformerField(transformer=transformer, scheduler=scheduler), clip=SD3CLIPField( tokenizer_1=tokenizer_1, text_encoder_1=text_encoder_1, tokenizer_2=tokenizer_2, text_encoder_2=text_encoder_2, tokenizer_3=tokenizer_3, text_encoder_3=text_encoder_3, ), vae=VAEField(vae=vae), ) @invocation( "sd3_image_generator", title="Stable Diffusion 3", tags=["latent", "sd3"], category="latents", version="1.0.0" ) class StableDiffusion3Invocation(BaseInvocation): """Generates an image using Stable Diffusion 3.""" transformer: TransformerField = InputField( description=FieldDescriptions.transformer, input=Input.Connection, title="Transformer", ui_order=0, ) clip: SD3CLIPField = InputField( description=FieldDescriptions.clip, input=Input.Connection, title="CLIP", ui_order=1, ) noise: Optional[LatentsField] = InputField( default=None, description=FieldDescriptions.noise, input=Input.Connection, ui_order=2, ) scheduler: SCHEDULER_NAME_VALUES = InputField( default="euler_f", description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler, ) positive_prompt: str = InputField(default="", title="Positive Prompt") negative_prompt: str = InputField(default="", title="Negative Prompt") steps: int = InputField(default=20, gt=0, description=FieldDescriptions.steps) guidance_scale: float = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") seed: int = InputField( default=0, ge=0, le=SEED_MAX, description=FieldDescriptions.seed, ) width: int = InputField( default=1024, multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.width, ) height: int = InputField( default=1024, multiple_of=LATENT_SCALE_FACTOR, gt=0, description=FieldDescriptions.height, ) @field_validator("seed", mode="before") def modulo_seed(cls, v: int): """Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) def invoke(self, context: InvocationContext) -> LatentsOutput: app_config = context.config.get() load_te3 = app_config.load_sd3_encoder_3 transformer_info = context.models.load(self.transformer.transformer) tokenizer_1_info = context.models.load(self.clip.tokenizer_1) tokenizer_2_info = context.models.load(self.clip.tokenizer_2) text_encoder_1_info = context.models.load(self.clip.text_encoder_1) text_encoder_2_info = context.models.load(self.clip.text_encoder_2) with ExitStack() as stack: tokenizer_1 = stack.enter_context(tokenizer_1_info) tokenizer_2 = stack.enter_context(tokenizer_2_info) text_encoder_1 = stack.enter_context(text_encoder_1_info) text_encoder_2 = stack.enter_context(text_encoder_2_info) transformer = stack.enter_context(transformer_info) assert isinstance(transformer, SD3Transformer2DModel) assert isinstance(text_encoder_1, CLIPTextModelWithProjection) assert isinstance(text_encoder_2, CLIPTextModelWithProjection) assert isinstance(tokenizer_1, CLIPTokenizer) assert isinstance(tokenizer_2, CLIPTokenizer) if load_te3: tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3)) text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3)) assert isinstance(text_encoder_3, T5EncoderModel) assert isinstance(tokenizer_3, T5TokenizerFast) else: tokenizer_3 = None text_encoder_3 = None scheduler = get_scheduler( context=context, scheduler_info=self.transformer.scheduler, scheduler_name=self.scheduler, seed=self.seed, ) sd3_pipeline = StableDiffusion3Pipeline( transformer=transformer, vae=FakeVae(), text_encoder=text_encoder_1, text_encoder_2=text_encoder_2, text_encoder_3=text_encoder_3, tokenizer=tokenizer_1, tokenizer_2=tokenizer_2, tokenizer_3=tokenizer_3, scheduler=scheduler, ) results = sd3_pipeline( self.positive_prompt, negative_prompt=self.negative_prompt, num_inference_steps=self.steps, guidance_scale=self.guidance_scale, output_type="latent", width=self.width, height=self.height, ) latents = cast(torch.Tensor, results.images[0]) latents = latents.unsqueeze(0) latents_name = context.tensors.save(latents) return LatentsOutput.build(latents_name, latents=latents, seed=self.seed)