mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run Ruff
This commit is contained in:
parent
9ed53af520
commit
2d9042fb93
@ -1,14 +1,9 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||||
from optimum.quanto import qfloat8
|
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||||
from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
|
|
||||||
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
|
|
||||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
@ -40,7 +35,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
# compatible with other ConditioningOutputs.
|
# compatible with other ConditioningOutputs.
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
|
|
||||||
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
t5_embeddings, clip_embeddings = self._encode_prompt(context)
|
||||||
conditioning_data = ConditioningFieldData(
|
conditioning_data = ConditioningFieldData(
|
||||||
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
|
||||||
@ -48,7 +42,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput.build(conditioning_name)
|
||||||
|
|
||||||
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# TODO: Determine the T5 max sequence length based on the model.
|
# TODO: Determine the T5 max sequence length based on the model.
|
||||||
# if self.model == "flux-schnell":
|
# if self.model == "flux-schnell":
|
||||||
|
@ -1,13 +1,9 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
|
||||||
from optimum.quanto import qfloat8
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers.models.auto import AutoModelForTextEncoding
|
from transformers.models.auto import AutoModelForTextEncoding
|
||||||
@ -20,8 +16,8 @@ from invokeai.app.invocations.fields import (
|
|||||||
InputField,
|
InputField,
|
||||||
WithBoard,
|
WithBoard,
|
||||||
WithMetadata,
|
WithMetadata,
|
||||||
UIType,
|
|
||||||
)
|
)
|
||||||
|
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||||
from invokeai.app.invocations.primitives import ImageOutput
|
from invokeai.app.invocations.primitives import ImageOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||||
@ -75,7 +71,6 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
|
||||||
# Load the conditioning data.
|
# Load the conditioning data.
|
||||||
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
||||||
assert len(cond_data.conditionings) == 1
|
assert len(cond_data.conditionings) == 1
|
||||||
|
@ -3,6 +3,7 @@ import os
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from diffusers.models.model_loading_utils import load_state_dict
|
from diffusers.models.model_loading_utils import load_state_dict
|
||||||
|
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||||
from diffusers.utils import (
|
from diffusers.utils import (
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
SAFE_WEIGHTS_INDEX_NAME,
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
@ -12,7 +13,6 @@ from diffusers.utils import (
|
|||||||
)
|
)
|
||||||
from optimum.quanto.models import QuantizedDiffusersModel
|
from optimum.quanto.models import QuantizedDiffusersModel
|
||||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
|
||||||
|
|
||||||
from invokeai.backend.requantize import requantize
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from optimum.quanto.models import QuantizedTransformersModel
|
from optimum.quanto.models import QuantizedTransformersModel
|
||||||
from optimum.quanto.models.shared_dict import ShardedStateDict
|
from optimum.quanto.models.shared_dict import ShardedStateDict
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
|
||||||
from transformers.models.auto import AutoModelForTextEncoding
|
from transformers.models.auto import AutoModelForTextEncoding
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
|
||||||
|
|
||||||
from invokeai.backend.requantize import requantize
|
from invokeai.backend.requantize import requantize
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user