mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove dependency on flux config files
This commit is contained in:
parent
56b9906e2e
commit
70c278c810
@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.flux.util import max_seq_lengths
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.shared.models import FreeUConfig
|
from invokeai.app.shared.models import FreeUConfig
|
||||||
@ -188,17 +189,13 @@ class FluxModelLoaderInvocation(BaseInvocation):
|
|||||||
vae = self._get_model(context, SubModelType.VAE)
|
vae = self._get_model(context, SubModelType.VAE)
|
||||||
transformer_config = context.models.get_config(transformer)
|
transformer_config = context.models.get_config(transformer)
|
||||||
assert isinstance(transformer_config, CheckpointConfigBase)
|
assert isinstance(transformer_config, CheckpointConfigBase)
|
||||||
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
|
|
||||||
config_path = legacy_config_path.as_posix()
|
|
||||||
with open(config_path, "r") as stream:
|
|
||||||
flux_conf = yaml.safe_load(stream)
|
|
||||||
|
|
||||||
return FluxModelLoaderOutput(
|
return FluxModelLoaderOutput(
|
||||||
transformer=TransformerField(transformer=transformer),
|
transformer=TransformerField(transformer=transformer),
|
||||||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
|
||||||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
|
||||||
vae=VAEField(vae=vae),
|
vae=VAEField(vae=vae),
|
||||||
max_seq_len=flux_conf["max_seq_len"],
|
max_seq_len=max_seq_lengths[transformer_config.config_path],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Literal
|
||||||
|
|
||||||
from invokeai.backend.flux.model import FluxParams
|
from invokeai.backend.flux.model import FluxParams
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
|
||||||
@ -18,6 +19,25 @@ class ModelSpec:
|
|||||||
repo_ae: str | None
|
repo_ae: str | None
|
||||||
|
|
||||||
|
|
||||||
|
max_seq_lengths: Dict[str, Literal[256, 512]] = {
|
||||||
|
"flux-dev": 512,
|
||||||
|
"flux-schnell": 256,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ae_params=AutoEncoderParams(
|
||||||
|
resolution=256,
|
||||||
|
in_channels=3,
|
||||||
|
ch=128,
|
||||||
|
out_ch=3,
|
||||||
|
ch_mult=[1, 2, 4, 4],
|
||||||
|
num_res_blocks=2,
|
||||||
|
z_channels=16,
|
||||||
|
scale_factor=0.3611,
|
||||||
|
shift_factor=0.1159,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
configs = {
|
configs = {
|
||||||
"flux-dev": ModelSpec(
|
"flux-dev": ModelSpec(
|
||||||
repo_id="black-forest-labs/FLUX.1-dev",
|
repo_id="black-forest-labs/FLUX.1-dev",
|
||||||
|
@ -1,19 +1,18 @@
|
|||||||
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
|
||||||
"""Class for Flux model loading in InvokeAI."""
|
"""Class for Flux model loading in InvokeAI."""
|
||||||
|
|
||||||
from dataclasses import fields
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import accelerate
|
import accelerate
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.flux.model import Flux, FluxParams
|
from invokeai.backend.flux.model import Flux
|
||||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
from invokeai.backend.flux.util import configs, ae_params
|
||||||
|
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||||
from invokeai.backend.model_manager import (
|
from invokeai.backend.model_manager import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -58,17 +57,9 @@ class FluxVAELoader(ModelLoader):
|
|||||||
if not isinstance(config, VAECheckpointConfig):
|
if not isinstance(config, VAECheckpointConfig):
|
||||||
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
|
||||||
config_path = legacy_config_path.as_posix()
|
|
||||||
with open(config_path, "r") as stream:
|
|
||||||
flux_conf = yaml.safe_load(stream)
|
|
||||||
|
|
||||||
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
|
|
||||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
|
||||||
params = AutoEncoderParams(**filtered_data)
|
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = AutoEncoder(params)
|
model = AutoEncoder(ae_params)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
model.to(dtype=self._torch_dtype)
|
model.to(dtype=self._torch_dtype)
|
||||||
@ -182,14 +173,10 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, CheckpointConfigBase):
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
|
||||||
config_path = legacy_config_path.as_posix()
|
|
||||||
with open(config_path, "r") as stream:
|
|
||||||
flux_conf = yaml.safe_load(stream)
|
|
||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self._load_from_singlefile(config, flux_conf)
|
return self._load_from_singlefile(config)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||||
@ -198,16 +185,12 @@ class FluxCheckpointModel(ModelLoader):
|
|||||||
def _load_from_singlefile(
|
def _load_from_singlefile(
|
||||||
self,
|
self,
|
||||||
config: AnyModelConfig,
|
config: AnyModelConfig,
|
||||||
flux_conf: Any,
|
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
assert isinstance(config, MainCheckpointConfig)
|
assert isinstance(config, MainCheckpointConfig)
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
|
||||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
|
||||||
params = FluxParams(**filtered_data)
|
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
model = Flux(params)
|
model = Flux(configs[config.config_path].params)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
return model
|
return model
|
||||||
@ -224,14 +207,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
if not isinstance(config, CheckpointConfigBase):
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
|
||||||
legacy_config_path = app_config.legacy_conf_path / config.config_path
|
|
||||||
config_path = legacy_config_path.as_posix()
|
|
||||||
with open(config_path, "r") as stream:
|
|
||||||
flux_conf = yaml.safe_load(stream)
|
|
||||||
|
|
||||||
match submodel_type:
|
match submodel_type:
|
||||||
case SubModelType.Transformer:
|
case SubModelType.Transformer:
|
||||||
return self._load_from_singlefile(config, flux_conf)
|
return self._load_from_singlefile(config)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
|
||||||
@ -240,7 +219,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
def _load_from_singlefile(
|
def _load_from_singlefile(
|
||||||
self,
|
self,
|
||||||
config: AnyModelConfig,
|
config: AnyModelConfig,
|
||||||
flux_conf: Any,
|
|
||||||
) -> AnyModel:
|
) -> AnyModel:
|
||||||
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
|
||||||
if not bnb_available:
|
if not bnb_available:
|
||||||
@ -248,13 +226,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
|||||||
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
|
||||||
)
|
)
|
||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
dataclass_fields = {f.name for f in fields(FluxParams)}
|
|
||||||
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
|
|
||||||
params = FluxParams(**filtered_data)
|
|
||||||
|
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
with accelerate.init_empty_weights():
|
with accelerate.init_empty_weights():
|
||||||
model = Flux(params)
|
model = Flux(configs[config.config_path].params)
|
||||||
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
|
||||||
sd = load_file(model_path)
|
sd = load_file(model_path)
|
||||||
model.load_state_dict(sd, assign=True)
|
model.load_state_dict(sd, assign=True)
|
||||||
|
@ -329,9 +329,9 @@ class ModelProbe(object):
|
|||||||
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
|
||||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||||
if "guidance_in.out_layer.weight" in state_dict:
|
if "guidance_in.out_layer.weight" in state_dict:
|
||||||
config_file = "flux/flux1-dev.yaml"
|
config_file = "flux-dev"
|
||||||
else:
|
else:
|
||||||
config_file = "flux/flux1-schnell.yaml"
|
config_file = "flux-schnell"
|
||||||
else:
|
else:
|
||||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
|
@ -1,19 +0,0 @@
|
|||||||
repo_id: "black-forest-labs/FLUX.1-dev"
|
|
||||||
repo_ae: "ae.safetensors"
|
|
||||||
max_seq_len: 512
|
|
||||||
params:
|
|
||||||
in_channels: 64
|
|
||||||
vec_in_dim: 768
|
|
||||||
context_in_dim: 4096
|
|
||||||
hidden_size: 3072
|
|
||||||
mlp_ratio: 4.0
|
|
||||||
num_heads: 24
|
|
||||||
depth: 19
|
|
||||||
depth_single_blocks: 38
|
|
||||||
axes_dim:
|
|
||||||
- 16
|
|
||||||
- 56
|
|
||||||
- 56
|
|
||||||
theta: 10_000
|
|
||||||
qkv_bias: True
|
|
||||||
guidance_embed: True
|
|
@ -1,19 +0,0 @@
|
|||||||
repo_id: "black-forest-labs/FLUX.1-schnell"
|
|
||||||
repo_ae: "ae.safetensors"
|
|
||||||
max_seq_len: 256
|
|
||||||
params:
|
|
||||||
in_channels: 64
|
|
||||||
vec_in_dim: 768
|
|
||||||
context_in_dim: 4096
|
|
||||||
hidden_size: 3072
|
|
||||||
mlp_ratio: 4.0
|
|
||||||
num_heads: 24
|
|
||||||
depth: 19
|
|
||||||
depth_single_blocks: 38
|
|
||||||
axes_dim:
|
|
||||||
- 16
|
|
||||||
- 56
|
|
||||||
- 56
|
|
||||||
theta: 10_000
|
|
||||||
qkv_bias: True
|
|
||||||
guidance_embed: False
|
|
@ -1,16 +0,0 @@
|
|||||||
repo_id: "black-forest-labs/FLUX.1-schnell"
|
|
||||||
repo_path: "ae.safetensors"
|
|
||||||
params:
|
|
||||||
resolution: 256
|
|
||||||
in_channels: 3
|
|
||||||
ch: 128
|
|
||||||
out_ch: 3
|
|
||||||
ch_mult:
|
|
||||||
- 1
|
|
||||||
- 2
|
|
||||||
- 4
|
|
||||||
- 4
|
|
||||||
num_res_blocks: 2
|
|
||||||
z_channels: 16
|
|
||||||
scale_factor: 0.3611
|
|
||||||
shift_factor: 0.1159
|
|
Loading…
Reference in New Issue
Block a user