This commit is contained in:
Brandon Rising 2024-08-15 10:27:42 -04:00
parent 46d5107ff1
commit 46b6314482
4 changed files with 5 additions and 17 deletions

View File

@ -1,14 +1,9 @@
from pathlib import Path
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from optimum.quanto import qfloat8
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -40,7 +35,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
# compatible with other ConditioningOutputs.
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
t5_embeddings, clip_embeddings = self._encode_prompt(context)
conditioning_data = ConditioningFieldData(
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
@ -48,7 +42,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name)
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
# TODO: Determine the T5 max sequence length based on the model.
# if self.model == "flux-schnell":

View File

@ -1,13 +1,9 @@
from pathlib import Path
from typing import Literal
from pydantic import Field
import accelerate
import torch
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from invokeai.app.invocations.model import TransformerField, VAEField
from optimum.quanto import qfloat8
from PIL import Image
from safetensors.torch import load_file
from transformers.models.auto import AutoModelForTextEncoding
@ -20,8 +16,8 @@ from invokeai.app.invocations.fields import (
InputField,
WithBoard,
WithMetadata,
UIType,
)
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@ -75,7 +71,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Load the conditioning data.
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
assert len(cond_data.conditionings) == 1

View File

@ -3,6 +3,7 @@ import os
from typing import Union
from diffusers.models.model_loading_utils import load_state_dict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.utils import (
CONFIG_NAME,
SAFE_WEIGHTS_INDEX_NAME,
@ -12,7 +13,6 @@ from diffusers.utils import (
)
from optimum.quanto.models import QuantizedDiffusersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from invokeai.backend.requantize import requantize

View File

@ -1,14 +1,13 @@
import json
import os
import torch
from typing import Union
from optimum.quanto.models import QuantizedTransformersModel
from optimum.quanto.models.shared_dict import ShardedStateDict
from transformers import AutoConfig
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
from transformers.models.auto import AutoModelForTextEncoding
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
from invokeai.backend.requantize import requantize