Select dev/schnell based on state dict, use correct max seq len based on dev/schnell, and shift in inference, separate vae flux params into separate config

This commit is contained in:
Brandon Rising
2024-08-19 14:41:28 -04:00
committed by Brandon
parent 4bd7fda694
commit a63f842a13
9 changed files with 170 additions and 66 deletions

View File

@ -32,7 +32,6 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@ -60,7 +59,7 @@ class FluxVAELoader(GenericDiffusersLoader):
raise
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields}
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
params = AutoEncoderParams(**filtered_data)
with SilenceWarnings():

View File

@ -324,7 +324,12 @@ class ModelProbe(object):
if model_type is ModelType.Main:
if base_type == BaseModelType.Flux:
# TODO: Decide between dev/schnell
config_file = "flux/flux1-schnell.yaml"
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
state_dict = checkpoint.get("state_dict") or checkpoint
if 'guidance_in.out_layer.weight' in state_dict:
config_file = "flux/flux1-dev.yaml"
else:
config_file = "flux/flux1-schnell.yaml"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
@ -338,7 +343,7 @@ class ModelProbe(object):
)
elif model_type is ModelType.VAE:
config_file = (
"flux/flux1-schnell.yaml"
"flux/flux1-vae.yaml"
if base_type is BaseModelType.Flux
else "stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1