mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run ruff, setup initial text to image node
This commit is contained in:
parent
436f18ff55
commit
1bd90e0fd4
@ -1,8 +1,4 @@
|
||||
import torch
|
||||
|
||||
|
||||
from einops import repeat
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
@ -10,9 +6,8 @@ from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
|
||||
from invokeai.app.invocations.model import CLIPField, T5EncoderField
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.flux.modules.conditioner import HFEncoder
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
|
||||
|
||||
|
||||
@invocation(
|
||||
|
@ -1,12 +1,6 @@
|
||||
from typing import Literal
|
||||
|
||||
import accelerate
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
||||
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from transformers.models.auto import AutoModelForTextEncoding
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
@ -20,23 +14,12 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
||||
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
||||
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
TFluxModelKeys = Literal["flux-schnell"]
|
||||
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
||||
|
||||
|
||||
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
|
||||
base_class = FluxTransformer2DModel
|
||||
|
||||
|
||||
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
|
||||
auto_class = AutoModelForTextEncoding
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_text_to_image",
|
||||
@ -78,7 +61,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
||||
|
||||
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
|
||||
image = self._run_vae_decoding(context, latents)
|
||||
image = self._run_vae_decoding(context, flux_ae_path, latents)
|
||||
image_dto = context.images.save(image=image)
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@ -89,14 +72,40 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
t5_embeddings: torch.Tensor,
|
||||
):
|
||||
transformer_info = context.models.load(self.transformer.transformer)
|
||||
inference_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
# Prepare input noise.
|
||||
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
|
||||
# CPU RNG?
|
||||
x = get_noise(
|
||||
num_samples=1,
|
||||
height=self.height,
|
||||
width=self.width,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
dtype=inference_dtype,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
img, img_ids = self._prepare_latent_img_patches(x)
|
||||
|
||||
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
|
||||
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
|
||||
timesteps = get_schedule(
|
||||
num_steps=self.num_steps,
|
||||
image_seq_len=img.shape[1],
|
||||
shift=not is_schnell,
|
||||
)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
|
||||
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
||||
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
||||
# if the cache is not empty.
|
||||
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
||||
|
||||
with transformer_info as transformer:
|
||||
assert isinstance(transformer, FluxTransformer2DModel)
|
||||
assert isinstance(transformer, Flux)
|
||||
|
||||
x = denoise(
|
||||
model=transformer,
|
||||
@ -144,21 +153,13 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
) -> Image.Image:
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
# TODO(ryand): Test that this works with both float16 and bfloat16.
|
||||
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
|
||||
img = vae.decode(latents)
|
||||
|
||||
img.clamp(-1, 1)
|
||||
img = rearrange(img[0], "c h w -> h w c")
|
||||
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
|
||||
|
||||
latents = flux_pipeline_with_vae._unpack_latents(
|
||||
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
|
||||
)
|
||||
latents = (
|
||||
latents / flux_pipeline_with_vae.vae.config.scaling_factor
|
||||
) + flux_pipeline_with_vae.vae.config.shift_factor
|
||||
latents = latents.to(dtype=vae.dtype)
|
||||
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
|
||||
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
|
||||
|
||||
assert isinstance(image, Image.Image)
|
||||
return image
|
||||
return img_pil
|
||||
|
@ -1,6 +1,6 @@
|
||||
import copy
|
||||
from time import sleep
|
||||
from typing import List, Optional, Literal, Dict
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -12,10 +12,10 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.model_records import ModelRecordChanges
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.app.services.model_records import ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
@ -132,31 +132,22 @@ class ModelIdentifierInvocation(BaseInvocation):
|
||||
|
||||
return ModelIdentifierOutput(model=self.model)
|
||||
|
||||
T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]
|
||||
|
||||
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
|
||||
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
|
||||
"base": {
|
||||
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2",
|
||||
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2",
|
||||
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
|
||||
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
|
||||
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
|
||||
"name": "t5_base_encoder",
|
||||
"format": ModelFormat.T5Encoder,
|
||||
},
|
||||
"8b_quantized": {
|
||||
"text_encoder_repo": "hf_repo1",
|
||||
"tokenizer_repo": "hf_repo1",
|
||||
"text_encoder_name": "hf_repo1",
|
||||
"tokenizer_name": "hf_repo1",
|
||||
"format": ModelFormat.T5Encoder8b,
|
||||
},
|
||||
"4b_quantized": {
|
||||
"text_encoder_repo": "hf_repo2",
|
||||
"tokenizer_repo": "hf_repo2",
|
||||
"text_encoder_name": "hf_repo2",
|
||||
"tokenizer_name": "hf_repo2",
|
||||
"format": ModelFormat.T5Encoder8b,
|
||||
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
|
||||
"name": "t5_8b_quantized_encoder",
|
||||
"format": ModelFormat.T5Encoder,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@invocation_output("flux_model_loader_output")
|
||||
class FluxModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Flux base model loader output"""
|
||||
@ -189,7 +180,15 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
|
||||
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
|
||||
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
|
||||
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux)
|
||||
vae = self._install_model(
|
||||
context,
|
||||
SubModelType.VAE,
|
||||
"FLUX.1-schnell_ae",
|
||||
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
|
||||
ModelFormat.Checkpoint,
|
||||
ModelType.VAE,
|
||||
BaseModelType.Flux,
|
||||
)
|
||||
|
||||
return FluxModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer),
|
||||
@ -199,20 +198,43 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||
match(submodel):
|
||||
match submodel:
|
||||
case SubModelType.Transformer:
|
||||
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
|
||||
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any)
|
||||
case SubModelType.TextEncoder2:
|
||||
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
|
||||
case SubModelType.Tokenizer2:
|
||||
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
|
||||
return self._install_model(
|
||||
context,
|
||||
submodel,
|
||||
"clip-vit-large-patch14",
|
||||
"openai/clip-vit-large-patch14",
|
||||
ModelFormat.Diffusers,
|
||||
ModelType.CLIPEmbed,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
|
||||
return self._install_model(
|
||||
context,
|
||||
submodel,
|
||||
T5_ENCODER_MAP[self.t5_encoder]["name"],
|
||||
T5_ENCODER_MAP[self.t5_encoder]["repo"],
|
||||
ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]),
|
||||
ModelType.T5Encoder,
|
||||
BaseModelType.Any,
|
||||
)
|
||||
case _:
|
||||
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
|
||||
|
||||
def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType):
|
||||
if (models := context.models.search_by_attrs(name=name, base=base, type=type)):
|
||||
def _install_model(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
submodel: SubModelType,
|
||||
name: str,
|
||||
repo_id: str,
|
||||
format: ModelFormat,
|
||||
type: ModelType,
|
||||
base: BaseModelType,
|
||||
):
|
||||
if models := context.models.search_by_attrs(name=name, base=base, type=type):
|
||||
if len(models) != 1:
|
||||
raise Exception(f"Multiple models detected for selected model with name {name}")
|
||||
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
|
||||
@ -224,7 +246,10 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
||||
sleep(0.01)
|
||||
if not model_install_job.config_out:
|
||||
raise Exception(f"Failed to install {name}")
|
||||
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel})
|
||||
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
|
||||
update={"submodel_type": submodel}
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"main_model_loader",
|
||||
|
@ -301,7 +301,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
for row in result:
|
||||
try:
|
||||
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
|
||||
except pydantic.ValidationError as e:
|
||||
except pydantic.ValidationError:
|
||||
# We catch this error so that the app can still run if there are invalid model configs in the database.
|
||||
# One reason that an invalid model config might be in the database is if someone had to rollback from a
|
||||
# newer version of the app that added a new model type.
|
||||
|
@ -476,7 +476,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
"""
|
||||
if not model_path.exists():
|
||||
raise Exception("Models provided to import_local_model must already exist on disk")
|
||||
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, access_token=access_token, inplace=inplace)
|
||||
return self._services.model_manager.install.heuristic_import(
|
||||
str(model_path), config=config, access_token=access_token, inplace=inplace
|
||||
)
|
||||
|
||||
def load_local_model(
|
||||
self,
|
||||
|
@ -3,9 +3,15 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
@ -35,9 +41,7 @@ class Flux(nn.Module):
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
|
@ -1,5 +1,6 @@
|
||||
from torch import Tensor, nn
|
||||
from transformers import (PreTrainedModel, PreTrainedTokenizer)
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
class HFEncoder(nn.Module):
|
||||
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
|
||||
|
@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
||||
t.device
|
||||
)
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
|
134
invokeai/backend/flux/sampling.py
Normal file
134
invokeai/backend/flux/sampling.py
Normal file
@ -0,0 +1,134 @@
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor
|
||||
|
||||
from .model import Flux
|
||||
from .modules.conditioner import HFEncoder
|
||||
|
||||
|
||||
def get_noise(
|
||||
num_samples: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
):
|
||||
return torch.randn(
|
||||
num_samples,
|
||||
16,
|
||||
# allow for packing
|
||||
2 * math.ceil(height / 16),
|
||||
2 * math.ceil(width / 16),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
generator=torch.Generator(device=device).manual_seed(seed),
|
||||
)
|
||||
|
||||
|
||||
def prepare(t5: HFEncoder, clip: HFEncoder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
|
||||
bs, c, h, w = img.shape
|
||||
if bs == 1 and not isinstance(prompt, str):
|
||||
bs = len(prompt)
|
||||
|
||||
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
if img.shape[0] == 1 and bs > 1:
|
||||
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
txt = t5(prompt)
|
||||
if txt.shape[0] == 1 and bs > 1:
|
||||
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
||||
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
||||
|
||||
vec = clip(prompt)
|
||||
if vec.shape[0] == 1 and bs > 1:
|
||||
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
||||
|
||||
return {
|
||||
"img": img,
|
||||
"img_ids": img_ids.to(img.device),
|
||||
"txt": txt.to(img.device),
|
||||
"txt_ids": txt_ids.to(img.device),
|
||||
"vec": vec.to(img.device),
|
||||
}
|
||||
|
||||
|
||||
def time_shift(mu: float, sigma: float, t: Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15) -> Callable[[float], float]:
|
||||
m = (y2 - y1) / (x2 - x1)
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
|
||||
def get_schedule(
|
||||
num_steps: int,
|
||||
image_seq_len: int,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
shift: bool = True,
|
||||
) -> list[float]:
|
||||
# extra step for zero
|
||||
timesteps = torch.linspace(1, 0, num_steps + 1)
|
||||
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
|
||||
|
||||
def denoise(
|
||||
model: Flux,
|
||||
# model input
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
vec: Tensor,
|
||||
# sampling parameters
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=math.ceil(height / 16),
|
||||
w=math.ceil(width / 16),
|
||||
ph=2,
|
||||
pw=2,
|
||||
)
|
@ -1,14 +1,17 @@
|
||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||
"""Class for Flux model loading in InvokeAI."""
|
||||
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
from dataclasses import fields
|
||||
from safetensors.torch import load_file
|
||||
from typing import Optional, Any
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import yaml
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.model import Flux, FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -19,20 +22,15 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
MainCheckpointConfig,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
MainCheckpointConfig,
|
||||
T5EncoderConfig,
|
||||
VAECheckpointConfig,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.flux.model import Flux, FluxParams
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams, AutoEncoder
|
||||
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
||||
T5Tokenizer)
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
@ -58,7 +56,7 @@ class FluxVAELoader(GenericDiffusersLoader):
|
||||
raise
|
||||
|
||||
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf['params']['ae_params'].items() if k in dataclass_fields}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields}
|
||||
params = AutoEncoderParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
@ -92,6 +90,7 @@ class ClipCheckpointModel(GenericDiffusersLoader):
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder)
|
||||
class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
@ -106,9 +105,9 @@ class T5EncoderCheckpointModel(GenericDiffusersLoader):
|
||||
|
||||
match submodel_type:
|
||||
case SubModelType.Tokenizer2:
|
||||
return T5Tokenizer.from_pretrained(Path(config.path), max_length=512)
|
||||
return T5Tokenizer.from_pretrained(Path(config.path) / "encoder", max_length=512)
|
||||
case SubModelType.TextEncoder2:
|
||||
return T5EncoderModel.from_pretrained(Path(config.path))
|
||||
return T5EncoderModel.from_pretrained(Path(config.path) / "tokenizer")
|
||||
|
||||
raise Exception("Only Checkpoint Flux models are currently supported.")
|
||||
|
||||
@ -148,7 +147,7 @@ class FluxCheckpointModel(GenericDiffusersLoader):
|
||||
params = None
|
||||
model_path = Path(config.path)
|
||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
||||
filtered_data = {k: v for k, v in flux_conf['params'].items() if k in dataclass_fields}
|
||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
||||
params = FluxParams(**filtered_data)
|
||||
|
||||
with SilenceWarnings():
|
||||
|
@ -39,11 +39,15 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
|
@ -9,7 +9,7 @@ from typing import Optional
|
||||
import torch
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast, T5Tokenizer
|
||||
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
|
||||
@ -52,7 +52,10 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
return model.calc_size()
|
||||
elif isinstance(
|
||||
model,
|
||||
(T5TokenizerFast,T5Tokenizer,),
|
||||
(
|
||||
T5TokenizerFast,
|
||||
T5Tokenizer,
|
||||
),
|
||||
):
|
||||
return len(model)
|
||||
else:
|
||||
|
@ -56,7 +56,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -727,6 +727,7 @@ class T5EncoderFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> ModelFormat:
|
||||
return ModelFormat.T5Encoder
|
||||
|
||||
|
||||
class ONNXFolderProbe(PipelineFolderProbe):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# Due to the way the installer is set up, the configuration file for safetensors
|
||||
|
Loading…
Reference in New Issue
Block a user