Remove dependency on flux config files

This commit is contained in:
Brandon Rising 2024-08-25 02:41:13 -04:00 committed by Brandon
parent 56b9906e2e
commit 70c278c810
7 changed files with 33 additions and 95 deletions

View File

@ -11,6 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
@ -188,17 +189,13 @@ class FluxModelLoaderInvocation(BaseInvocation):
vae = self._get_model(context, SubModelType.VAE)
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=flux_conf["max_seq_len"],
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:

View File

@ -2,6 +2,7 @@
import os
from dataclasses import dataclass
from typing import Dict, Literal
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams
@ -18,6 +19,25 @@ class ModelSpec:
repo_ae: str | None
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-schnell": 256,
}
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",

View File

@ -1,19 +1,18 @@
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
"""Class for Flux model loading in InvokeAI."""
from dataclasses import fields
from pathlib import Path
from typing import Any, Optional
from typing import Optional
import accelerate
import torch
import yaml
from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.model import Flux, FluxParams
from invokeai.backend.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import configs, ae_params
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -58,17 +57,9 @@ class FluxVAELoader(ModelLoader):
if not isinstance(config, VAECheckpointConfig):
raise ValueError("Only VAECheckpointConfig models are currently supported here.")
model_path = Path(config.path)
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data)
with SilenceWarnings():
model = AutoEncoder(params)
model = AutoEncoder(ae_params)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
@ -182,14 +173,10 @@ class FluxCheckpointModel(ModelLoader):
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
@ -198,16 +185,12 @@ class FluxCheckpointModel(ModelLoader):
def _load_from_singlefile(
self,
config: AnyModelConfig,
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)
with SilenceWarnings():
model = Flux(params)
model = Flux(configs[config.config_path].params)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
return model
@ -224,14 +207,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
) -> AnyModel:
if not isinstance(config, CheckpointConfigBase):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
legacy_config_path = app_config.legacy_conf_path / config.config_path
config_path = legacy_config_path.as_posix()
with open(config_path, "r") as stream:
flux_conf = yaml.safe_load(stream)
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config, flux_conf)
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
@ -240,7 +219,6 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
def _load_from_singlefile(
self,
config: AnyModelConfig,
flux_conf: Any,
) -> AnyModel:
assert isinstance(config, MainBnbQuantized4bCheckpointConfig)
if not bnb_available:
@ -248,13 +226,10 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
model_path = Path(config.path)
dataclass_fields = {f.name for f in fields(FluxParams)}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = FluxParams(**filtered_data)
with SilenceWarnings():
with accelerate.init_empty_weights():
model = Flux(params)
model = Flux(configs[config.config_path].params)
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)

View File

@ -329,9 +329,9 @@ class ModelProbe(object):
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if "guidance_in.out_layer.weight" in state_dict:
config_file = "flux/flux1-dev.yaml"
config_file = "flux-dev"
else:
config_file = "flux/flux1-schnell.yaml"
config_file = "flux-schnell"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models

View File

@ -1,19 +0,0 @@
repo_id: "black-forest-labs/FLUX.1-dev"
repo_ae: "ae.safetensors"
max_seq_len: 512
params:
in_channels: 64
vec_in_dim: 768
context_in_dim: 4096
hidden_size: 3072
mlp_ratio: 4.0
num_heads: 24
depth: 19
depth_single_blocks: 38
axes_dim:
- 16
- 56
- 56
theta: 10_000
qkv_bias: True
guidance_embed: True

View File

@ -1,19 +0,0 @@
repo_id: "black-forest-labs/FLUX.1-schnell"
repo_ae: "ae.safetensors"
max_seq_len: 256
params:
in_channels: 64
vec_in_dim: 768
context_in_dim: 4096
hidden_size: 3072
mlp_ratio: 4.0
num_heads: 24
depth: 19
depth_single_blocks: 38
axes_dim:
- 16
- 56
- 56
theta: 10_000
qkv_bias: True
guidance_embed: False

View File

@ -1,16 +0,0 @@
repo_id: "black-forest-labs/FLUX.1-schnell"
repo_path: "ae.safetensors"
params:
resolution: 256
in_channels: 3
ch: 128
out_ch: 3
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
z_channels: 16
scale_factor: 0.3611
shift_factor: 0.1159