fix(mm): typing issues in vae loader

This commit is contained in:
psychedelicious 2024-04-01 12:21:48 +11:00
parent 25ff0bf80f
commit 13f410478a

View File

@ -2,6 +2,7 @@
"""Class for VAE model loading in InvokeAI.""" """Class for VAE model loading in InvokeAI."""
from pathlib import Path from pathlib import Path
from typing import Optional
import torch import torch
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
@ -13,7 +14,7 @@ from invokeai.backend.model_manager import (
ModelFormat, ModelFormat,
ModelType, ModelType,
) )
from invokeai.backend.model_manager.config import CheckpointConfigBase from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from .. import ModelLoaderRegistry from .. import ModelLoaderRegistry
@ -38,7 +39,7 @@ class VAELoader(GenericDiffusersLoader):
else: else:
return True return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path: def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
# TODO(MM2): check whether sdxl VAE models convert. # TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"VAE conversion not supported for model type: {config.base}") raise Exception(f"VAE conversion not supported for model type: {config.base}")