feat(mm): use same pattern for vae converter as others

Add `dump_path` arg to the converter function & save the model to disk inside the conversion function. This is the same pattern as in the other conversion functions.
This commit is contained in:
psychedelicious 2024-04-01 12:24:58 +11:00
parent 13f410478a
commit 59b4a23479
2 changed files with 11 additions and 5 deletions

View File

@ -3,10 +3,10 @@
"""Conversion script for the Stable Diffusion checkpoints."""
from pathlib import Path
from typing import Dict, Optional
from typing import Optional
import torch
from diffusers import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_vae_checkpoint,
create_vae_diffusers_config,
@ -19,9 +19,10 @@ from . import AnyModel
def convert_ldm_vae_to_diffusers(
checkpoint: Dict[str, torch.Tensor],
checkpoint: torch.Tensor | dict[str, torch.Tensor],
vae_config: DictConfig,
image_size: int,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16,
) -> AutoencoderKL:
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
@ -30,7 +31,12 @@ def convert_ldm_vae_to_diffusers(
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae.to(precision)
vae.to(precision)
if dump_path:
vae.save_pretrained(dump_path, safe_serialization=True)
return vae
def convert_ckpt_to_diffusers(

View File

@ -64,6 +64,6 @@ class VAELoader(GenericDiffusersLoader):
vae_config=ckpt_config,
image_size=512,
precision=self._torch_dtype,
dump_path=output_path,
)
vae_model.save_pretrained(output_path, safe_serialization=True)
return vae_model