mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Minor improvements to FLUX workflow.
This commit is contained in:
parent
8b9bf55bba
commit
4833746698
@ -33,7 +33,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
|
||||||
use_8bit: bool = InputField(
|
use_8bit: bool = InputField(
|
||||||
default=False, description="Whether to quantize the T5 model and transformer model to 8-bit precision."
|
default=False, description="Whether to quantize the transformer model to 8-bit precision."
|
||||||
)
|
)
|
||||||
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
|
||||||
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
|
||||||
@ -56,7 +56,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
|
||||||
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Determine the T5 max sequence lenght based on the model.
|
# Determine the T5 max sequence length based on the model.
|
||||||
if self.model == "flux-schnell":
|
if self.model == "flux-schnell":
|
||||||
max_seq_len = 256
|
max_seq_len = 256
|
||||||
# elif self.model == "flux-dev":
|
# elif self.model == "flux-dev":
|
||||||
@ -118,7 +118,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
):
|
):
|
||||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
|
||||||
|
|
||||||
# HACK(ryand): Manually empty the cache.
|
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||||
|
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||||
|
# if the cache is not empty.
|
||||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||||
|
|
||||||
transformer_path = flux_model_dir / "transformer"
|
transformer_path = flux_model_dir / "transformer"
|
||||||
@ -137,7 +139,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
)
|
)
|
||||||
|
|
||||||
return flux_pipeline_with_transformer(
|
latents = flux_pipeline_with_transformer(
|
||||||
height=self.height,
|
height=self.height,
|
||||||
width=self.width,
|
width=self.width,
|
||||||
num_inference_steps=self.num_steps,
|
num_inference_steps=self.num_steps,
|
||||||
@ -149,6 +151,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
|
assert isinstance(latents, torch.Tensor)
|
||||||
|
return latents
|
||||||
|
|
||||||
def _run_vae_decoding(
|
def _run_vae_decoding(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -201,9 +206,14 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
model_8bit_map_path = model_8bit_path / "quantization_map.json"
|
model_8bit_map_path = model_8bit_path / "quantization_map.json"
|
||||||
if model_8bit_path.exists():
|
if model_8bit_path.exists():
|
||||||
# The quantized model exists, load it.
|
# The quantized model exists, load it.
|
||||||
with torch.device("meta"):
|
# TODO(ryand): Make loading from quantized model work properly.
|
||||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True)
|
# Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210
|
||||||
|
model = FluxTransformer2DModel.from_pretrained(
|
||||||
|
path,
|
||||||
|
local_files_only=True,
|
||||||
|
)
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
model = model.to(device=torch.device("meta"))
|
||||||
|
|
||||||
state_dict = load_file(model_8bit_weights_path)
|
state_dict = load_file(model_8bit_weights_path)
|
||||||
with open(model_8bit_map_path, "r") as f:
|
with open(model_8bit_map_path, "r") as f:
|
||||||
@ -211,6 +221,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
|
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
|
||||||
else:
|
else:
|
||||||
# The quantized model does not exist yet, quantize and save it.
|
# The quantized model does not exist yet, quantize and save it.
|
||||||
|
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
|
||||||
|
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
|
||||||
|
# here.
|
||||||
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
|
|
||||||
@ -222,9 +235,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
with open(model_8bit_map_path, "w") as f:
|
with open(model_8bit_map_path, "w") as f:
|
||||||
json.dump(quantization_map(model), f)
|
json.dump(quantization_map(model), f)
|
||||||
else:
|
else:
|
||||||
model = FluxTransformer2DModel.from_pretrained(
|
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
|
||||||
path, local_files_only=True, torch_dtype=TorchDevice.choose_torch_dtype()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(model, FluxTransformer2DModel)
|
assert isinstance(model, FluxTransformer2DModel)
|
||||||
return model
|
return model
|
||||||
|
Loading…
Reference in New Issue
Block a user