Support for Flux and SDXL

This commit is contained in:
Billy
2025-06-23 13:51:16 +10:00
parent 4ee54eac1d
commit e1157f343b
3 changed files with 6 additions and 13 deletions

View File

@ -356,15 +356,11 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
base_str = base_str.lower()
if "stable-diffusion-v1" in base_str:
base = BaseModelType.StableDiffusion1
elif "stable-diffusion-v3" in base_str:
base = BaseModelType.StableDiffusion3
elif base_str == "stable-diffusion-xl-v1-base":
base = BaseModelType.StableDiffusionXL
elif "flux" in base_str:
base = BaseModelType.Flux
else:
raise InvalidModelConfigException(f"Unrecognised base architecture for OMI LoRA: {base_str}")
raise InvalidModelConfigException(f"Unrecognised/unsupported base architecture for OMI LoRA: {base_str}")
return {"base": base}

View File

@ -41,8 +41,6 @@ from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
@ -78,7 +76,8 @@ class LoRALoader(ModelLoader):
else:
state_dict = torch.load(model_path, map_location="cpu")
if config.format == ModelFormat.OMI:
# At the time of writing, we support the OMI standard for base models Flux and SDXL
if config.format == ModelFormat.OMI and self._model_base in [BaseModelType.StableDiffusionXL, BaseModelType.Flux]:
state_dict = convert_from_omi(state_dict, config.base) # type: ignore
# Apply state_dict key conversions, if necessary.

View File

@ -12,9 +12,7 @@ def convert_from_omi(weights_sd: StateDict, base: BaseModelType):
keyset = {
BaseModelType.Flux: convert_flux_lora_key_sets(),
BaseModelType.StableDiffusionXL: convert_sdxl_lora_key_sets(),
BaseModelType.StableDiffusion1: convert_sd_lora_key_sets(),
BaseModelType.StableDiffusion3: convert_sd3_lora_key_sets(),
}[base]
target = "diffusers" # alternatively, "legacy_diffusers"
return lora_util.__convert(weights_sd, keyset, "omi", target) # type: ignore
source = "omi"
target = "legacy_diffusers"
return lora_util.__convert(weights_sd, keyset, source, target) # type: ignore