mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Support for Flux and SDXL
This commit is contained in:
@ -356,15 +356,11 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
base_str = base_str.lower()
|
||||
|
||||
if "stable-diffusion-v1" in base_str:
|
||||
base = BaseModelType.StableDiffusion1
|
||||
elif "stable-diffusion-v3" in base_str:
|
||||
base = BaseModelType.StableDiffusion3
|
||||
elif base_str == "stable-diffusion-xl-v1-base":
|
||||
base = BaseModelType.StableDiffusionXL
|
||||
elif "flux" in base_str:
|
||||
base = BaseModelType.Flux
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unrecognised base architecture for OMI LoRA: {base_str}")
|
||||
raise InvalidModelConfigException(f"Unrecognised/unsupported base architecture for OMI LoRA: {base_str}")
|
||||
|
||||
return {"base": base}
|
||||
|
||||
|
@ -41,8 +41,6 @@ from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||
@ -78,7 +76,8 @@ class LoRALoader(ModelLoader):
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
if config.format == ModelFormat.OMI:
|
||||
# At the time of writing, we support the OMI standard for base models Flux and SDXL
|
||||
if config.format == ModelFormat.OMI and self._model_base in [BaseModelType.StableDiffusionXL, BaseModelType.Flux]:
|
||||
state_dict = convert_from_omi(state_dict, config.base) # type: ignore
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
|
@ -12,9 +12,7 @@ def convert_from_omi(weights_sd: StateDict, base: BaseModelType):
|
||||
keyset = {
|
||||
BaseModelType.Flux: convert_flux_lora_key_sets(),
|
||||
BaseModelType.StableDiffusionXL: convert_sdxl_lora_key_sets(),
|
||||
BaseModelType.StableDiffusion1: convert_sd_lora_key_sets(),
|
||||
BaseModelType.StableDiffusion3: convert_sd3_lora_key_sets(),
|
||||
}[base]
|
||||
|
||||
target = "diffusers" # alternatively, "legacy_diffusers"
|
||||
return lora_util.__convert(weights_sd, keyset, "omi", target) # type: ignore
|
||||
source = "omi"
|
||||
target = "legacy_diffusers"
|
||||
return lora_util.__convert(weights_sd, keyset, source, target) # type: ignore
|
||||
|
Reference in New Issue
Block a user