mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Address minor review comments.
This commit is contained in:
parent
253b2b1dc6
commit
e680cf76f6
@ -135,7 +135,7 @@ class FieldDescriptions:
|
|||||||
vae_model = "VAE model to load"
|
vae_model = "VAE model to load"
|
||||||
lora_model = "LoRA model to load"
|
lora_model = "LoRA model to load"
|
||||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||||
flux_model = "Flux model (Transformer, VAE, CLIP) to load"
|
flux_model = "Flux model (Transformer) to load"
|
||||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||||
|
@ -15,8 +15,8 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
|
|||||||
@invocation(
|
@invocation(
|
||||||
"flux_text_encoder",
|
"flux_text_encoder",
|
||||||
title="FLUX Text Encoding",
|
title="FLUX Text Encoding",
|
||||||
tags=["image", "flux"],
|
tags=["prompt", "conditioning", "flux"],
|
||||||
category="image",
|
category="conditioning",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class FluxTextEncoderInvocation(BaseInvocation):
|
class FluxTextEncoderInvocation(BaseInvocation):
|
||||||
@ -32,7 +32,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.t5_encoder,
|
description=FieldDescriptions.t5_encoder,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
max_seq_len: Literal[256, 512] = InputField(description="Max sequence length for the desired flux model")
|
t5_max_seq_len: Literal[256, 512] = InputField(
|
||||||
|
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
|
||||||
|
)
|
||||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||||
|
|
||||||
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
|
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
|
||||||
@ -48,8 +50,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput.build(conditioning_name)
|
||||||
|
|
||||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
max_seq_len = self.max_seq_len
|
|
||||||
|
|
||||||
# Load CLIP.
|
# Load CLIP.
|
||||||
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||||
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||||
@ -70,7 +70,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
assert isinstance(t5_tokenizer, T5Tokenizer)
|
assert isinstance(t5_tokenizer, T5Tokenizer)
|
||||||
|
|
||||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, max_seq_len)
|
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||||
|
|
||||||
prompt = [self.positive_prompt]
|
prompt = [self.positive_prompt]
|
||||||
prompt_embeds = t5_encoder(prompt)
|
prompt_embeds = t5_encoder(prompt)
|
||||||
|
@ -33,7 +33,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
"""Text-to-image generation using a FLUX model."""
|
"""Text-to-image generation using a FLUX model."""
|
||||||
|
|
||||||
transformer: TransformerField = InputField(
|
transformer: TransformerField = InputField(
|
||||||
description=FieldDescriptions.unet,
|
description=FieldDescriptions.flux_model,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
title="Transformer",
|
title="Transformer",
|
||||||
)
|
)
|
||||||
@ -46,10 +46,12 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
)
|
)
|
||||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||||
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
|
||||||
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
|
num_steps: int = InputField(
|
||||||
|
default=4, description="Number of diffusion steps. Recommend values are schnell: 4, dev: 50."
|
||||||
|
)
|
||||||
guidance: float = InputField(
|
guidance: float = InputField(
|
||||||
default=4.0,
|
default=4.0,
|
||||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images.",
|
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||||
)
|
)
|
||||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user