add config variable to suppress loading of sd3 text_encoder_3 T5 model

This commit is contained in:
Lincoln Stein
2024-06-16 16:28:39 -04:00
parent f65d50a4dd
commit 423057a2e8
4 changed files with 56 additions and 58 deletions

View File

@ -1,3 +1,4 @@
from contextlib import ExitStack
from typing import cast
import torch
@ -23,7 +24,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.model_manager.config import SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.util.devices import TorchDevice
sd3_pipeline: Optional[StableDiffusion3Pipeline] = None
transformer_info: Optional[LoadedModel] = None
@ -148,39 +148,35 @@ class StableDiffusion3Invocation(BaseInvocation):
return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> LatentsOutput:
global sd3_pipeline, transformer_info, tokenizer_1_info, tokenizer_2_info, tokenizer_3_info, text_encoder_1_info, text_encoder_2_info, text_encoder_3_info
app_config = context.config.get()
load_te3 = app_config.load_sd3_encoder_3
if not transformer_info:
transformer_info = context.models.load(self.transformer.transformer)
if not tokenizer_1_info:
tokenizer_1_info = context.models.load(self.clip.tokenizer_1)
if not tokenizer_2_info:
tokenizer_2_info = context.models.load(self.clip.tokenizer_2)
if not tokenizer_3_info:
tokenizer_3_info = context.models.load(self.clip.tokenizer_3)
if not text_encoder_1_info:
text_encoder_1_info = context.models.load(self.clip.text_encoder_1)
if not text_encoder_2_info:
text_encoder_2_info = context.models.load(self.clip.text_encoder_2)
if not text_encoder_3_info:
text_encoder_3_info = context.models.load(self.clip.text_encoder_3)
transformer_info = context.models.load(self.transformer.transformer)
tokenizer_1_info = context.models.load(self.clip.tokenizer_1)
tokenizer_2_info = context.models.load(self.clip.tokenizer_2)
text_encoder_1_info = context.models.load(self.clip.text_encoder_1)
text_encoder_2_info = context.models.load(self.clip.text_encoder_2)
with (
tokenizer_1_info as tokenizer_1,
tokenizer_2_info as tokenizer_2,
tokenizer_3_info as tokenizer_3,
text_encoder_1_info as text_encoder_1,
text_encoder_2_info as text_encoder_2,
text_encoder_3_info as text_encoder_3,
transformer_info as transformer,
):
with ExitStack() as stack:
tokenizer_1 = stack.enter_context(tokenizer_1_info)
tokenizer_2 = stack.enter_context(tokenizer_2_info)
text_encoder_1 = stack.enter_context(text_encoder_1_info)
text_encoder_2 = stack.enter_context(text_encoder_2_info)
transformer = stack.enter_context(transformer_info)
assert isinstance(transformer, SD3Transformer2DModel)
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
assert isinstance(text_encoder_3, T5EncoderModel)
assert isinstance(tokenizer_1, CLIPTokenizer)
assert isinstance(tokenizer_2, CLIPTokenizer)
assert isinstance(tokenizer_3, T5TokenizerFast)
if load_te3:
tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3))
text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3))
assert isinstance(text_encoder_3, T5EncoderModel)
assert isinstance(tokenizer_3, T5TokenizerFast)
else:
tokenizer_3 = None
text_encoder_3 = None
scheduler = get_scheduler(
context=context,
@ -189,21 +185,17 @@ class StableDiffusion3Invocation(BaseInvocation):
seed=self.seed,
)
if not isinstance(sd3_pipeline, StableDiffusion3Pipeline):
sd3_pipeline = StableDiffusion3Pipeline(
transformer=transformer,
vae=FakeVae(),
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
scheduler=scheduler,
)
sd3_pipeline.components["scheduler"] = scheduler
sd3_pipeline.to(TorchDevice.choose_torch_device().type)
sd3_pipeline = StableDiffusion3Pipeline(
transformer=transformer,
vae=FakeVae(),
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
scheduler=scheduler,
)
results = sd3_pipeline(
self.positive_prompt,