mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add config variable to suppress loading of sd3 text_encoder_3 T5 model
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
@ -23,7 +24,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
sd3_pipeline: Optional[StableDiffusion3Pipeline] = None
|
||||
transformer_info: Optional[LoadedModel] = None
|
||||
@ -148,39 +148,35 @@ class StableDiffusion3Invocation(BaseInvocation):
|
||||
return v % (SEED_MAX + 1)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
global sd3_pipeline, transformer_info, tokenizer_1_info, tokenizer_2_info, tokenizer_3_info, text_encoder_1_info, text_encoder_2_info, text_encoder_3_info
|
||||
app_config = context.config.get()
|
||||
load_te3 = app_config.load_sd3_encoder_3
|
||||
|
||||
if not transformer_info:
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
if not tokenizer_1_info:
|
||||
tokenizer_1_info = context.models.load(self.clip.tokenizer_1)
|
||||
if not tokenizer_2_info:
|
||||
tokenizer_2_info = context.models.load(self.clip.tokenizer_2)
|
||||
if not tokenizer_3_info:
|
||||
tokenizer_3_info = context.models.load(self.clip.tokenizer_3)
|
||||
if not text_encoder_1_info:
|
||||
text_encoder_1_info = context.models.load(self.clip.text_encoder_1)
|
||||
if not text_encoder_2_info:
|
||||
text_encoder_2_info = context.models.load(self.clip.text_encoder_2)
|
||||
if not text_encoder_3_info:
|
||||
text_encoder_3_info = context.models.load(self.clip.text_encoder_3)
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
tokenizer_1_info = context.models.load(self.clip.tokenizer_1)
|
||||
tokenizer_2_info = context.models.load(self.clip.tokenizer_2)
|
||||
text_encoder_1_info = context.models.load(self.clip.text_encoder_1)
|
||||
text_encoder_2_info = context.models.load(self.clip.text_encoder_2)
|
||||
|
||||
with (
|
||||
tokenizer_1_info as tokenizer_1,
|
||||
tokenizer_2_info as tokenizer_2,
|
||||
tokenizer_3_info as tokenizer_3,
|
||||
text_encoder_1_info as text_encoder_1,
|
||||
text_encoder_2_info as text_encoder_2,
|
||||
text_encoder_3_info as text_encoder_3,
|
||||
transformer_info as transformer,
|
||||
):
|
||||
with ExitStack() as stack:
|
||||
tokenizer_1 = stack.enter_context(tokenizer_1_info)
|
||||
tokenizer_2 = stack.enter_context(tokenizer_2_info)
|
||||
text_encoder_1 = stack.enter_context(text_encoder_1_info)
|
||||
text_encoder_2 = stack.enter_context(text_encoder_2_info)
|
||||
transformer = stack.enter_context(transformer_info)
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
|
||||
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
|
||||
assert isinstance(text_encoder_3, T5EncoderModel)
|
||||
assert isinstance(tokenizer_1, CLIPTokenizer)
|
||||
assert isinstance(tokenizer_2, CLIPTokenizer)
|
||||
assert isinstance(tokenizer_3, T5TokenizerFast)
|
||||
|
||||
if load_te3:
|
||||
tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3))
|
||||
text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3))
|
||||
assert isinstance(text_encoder_3, T5EncoderModel)
|
||||
assert isinstance(tokenizer_3, T5TokenizerFast)
|
||||
else:
|
||||
tokenizer_3 = None
|
||||
text_encoder_3 = None
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
@ -189,21 +185,17 @@ class StableDiffusion3Invocation(BaseInvocation):
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
if not isinstance(sd3_pipeline, StableDiffusion3Pipeline):
|
||||
sd3_pipeline = StableDiffusion3Pipeline(
|
||||
transformer=transformer,
|
||||
vae=FakeVae(),
|
||||
text_encoder=text_encoder_1,
|
||||
text_encoder_2=text_encoder_2,
|
||||
text_encoder_3=text_encoder_3,
|
||||
tokenizer=tokenizer_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
sd3_pipeline.components["scheduler"] = scheduler
|
||||
sd3_pipeline.to(TorchDevice.choose_torch_device().type)
|
||||
sd3_pipeline = StableDiffusion3Pipeline(
|
||||
transformer=transformer,
|
||||
vae=FakeVae(),
|
||||
text_encoder=text_encoder_1,
|
||||
text_encoder_2=text_encoder_2,
|
||||
text_encoder_3=text_encoder_3,
|
||||
tokenizer=tokenizer_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
results = sd3_pipeline(
|
||||
self.positive_prompt,
|
||||
|
Reference in New Issue
Block a user