diff --git a/invokeai/app/invocations/flux_text_to_image.py b/invokeai/app/invocations/flux_text_to_image.py index 2f78713b0c..2efa76b4ec 100644 --- a/invokeai/app/invocations/flux_text_to_image.py +++ b/invokeai/app/invocations/flux_text_to_image.py @@ -33,7 +33,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.") use_8bit: bool = InputField( - default=False, description="Whether to quantize the T5 model and transformer model to 8-bit precision." + default=False, description="Whether to quantize the transformer model to 8-bit precision." ) positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.") width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") @@ -56,7 +56,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): return ImageOutput.build(image_dto) def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]: - # Determine the T5 max sequence lenght based on the model. + # Determine the T5 max sequence length based on the model. if self.model == "flux-schnell": max_seq_len = 256 # elif self.model == "flux-dev": @@ -118,7 +118,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ): scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True) - # HACK(ryand): Manually empty the cache. + # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from + # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems + # if the cache is not empty. context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) transformer_path = flux_model_dir / "transformer" @@ -137,7 +139,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): transformer=transformer, ) - return flux_pipeline_with_transformer( + latents = flux_pipeline_with_transformer( height=self.height, width=self.width, num_inference_steps=self.num_steps, @@ -149,6 +151,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): return_dict=False, )[0] + assert isinstance(latents, torch.Tensor) + return latents + def _run_vae_decoding( self, context: InvocationContext, @@ -201,9 +206,14 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): model_8bit_map_path = model_8bit_path / "quantization_map.json" if model_8bit_path.exists(): # The quantized model exists, load it. - with torch.device("meta"): - model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True) - assert isinstance(model, FluxTransformer2DModel) + # TODO(ryand): Make loading from quantized model work properly. + # Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210 + model = FluxTransformer2DModel.from_pretrained( + path, + local_files_only=True, + ) + assert isinstance(model, FluxTransformer2DModel) + model = model.to(device=torch.device("meta")) state_dict = load_file(model_8bit_weights_path) with open(model_8bit_map_path, "r") as f: @@ -211,6 +221,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): requantize(model=model, state_dict=state_dict, quantization_map=quant_map) else: # The quantized model does not exist yet, quantize and save it. + # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on + # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it + # here. model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) assert isinstance(model, FluxTransformer2DModel) @@ -222,9 +235,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard): with open(model_8bit_map_path, "w") as f: json.dump(quantization_map(model), f) else: - model = FluxTransformer2DModel.from_pretrained( - path, local_files_only=True, torch_dtype=TorchDevice.choose_torch_dtype() - ) + model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) assert isinstance(model, FluxTransformer2DModel) return model