feat(mm): use same pattern for vae converter as others

Add `dump_path` arg to the converter function & save the model to disk inside the conversion function. This is the same pattern as in the other conversion functions.
This commit is contained in:
psychedelicious 2024-04-01 12:24:58 +11:00
parent 13f410478a
commit 59b4a23479
2 changed files with 11 additions and 5 deletions

View File

@ -3,10 +3,10 @@
"""Conversion script for the Stable Diffusion checkpoints.""" """Conversion script for the Stable Diffusion checkpoints."""
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Optional
import torch import torch
from diffusers import AutoencoderKL from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_vae_checkpoint, convert_ldm_vae_checkpoint,
create_vae_diffusers_config, create_vae_diffusers_config,
@ -19,9 +19,10 @@ from . import AnyModel
def convert_ldm_vae_to_diffusers( def convert_ldm_vae_to_diffusers(
checkpoint: Dict[str, torch.Tensor], checkpoint: torch.Tensor | dict[str, torch.Tensor],
vae_config: DictConfig, vae_config: DictConfig,
image_size: int, image_size: int,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float16,
) -> AutoencoderKL: ) -> AutoencoderKL:
"""Convert a checkpoint-style VAE into a Diffusers VAE""" """Convert a checkpoint-style VAE into a Diffusers VAE"""
@ -30,7 +31,12 @@ def convert_ldm_vae_to_diffusers(
vae = AutoencoderKL(**vae_config) vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint) vae.load_state_dict(converted_vae_checkpoint)
return vae.to(precision) vae.to(precision)
if dump_path:
vae.save_pretrained(dump_path, safe_serialization=True)
return vae
def convert_ckpt_to_diffusers( def convert_ckpt_to_diffusers(

View File

@ -64,6 +64,6 @@ class VAELoader(GenericDiffusersLoader):
vae_config=ckpt_config, vae_config=ckpt_config,
image_size=512, image_size=512,
precision=self._torch_dtype, precision=self._torch_dtype,
dump_path=output_path,
) )
vae_model.save_pretrained(output_path, safe_serialization=True)
return vae_model return vae_model