Setup flux model loading in the UI

This commit is contained in:
Brandon Rising
2024-08-12 14:04:23 -04:00
committed by Brandon
parent 1fa6bddc89
commit 5f59a828f9
22 changed files with 463 additions and 18 deletions

View File

@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
# Kandinsky2_1 = "kandinsky-2.1"
@ -74,6 +75,7 @@ class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"

View File

@ -95,6 +95,7 @@ class ModelProbe(object):
}
CLASS2TYPE = {
"FluxPipeline": ModelType.Main,
"StableDiffusionPipeline": ModelType.Main,
"StableDiffusionInpaintPipeline": ModelType.Main,
"StableDiffusionXLPipeline": ModelType.Main,
@ -626,6 +627,10 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(f"{self.model_path}/model_index.json", "r") as file:
conf = json.load(file)
if "_class_name" in conf and conf.get("_class_name") == "FluxPipeline":
return BaseModelType.Flux
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768: