mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): use same pattern for vae converter as others
Add `dump_path` arg to the converter function & save the model to disk inside the conversion function. This is the same pattern as in the other conversion functions.
This commit is contained in:
parent
13f410478a
commit
59b4a23479
@ -3,10 +3,10 @@
|
||||
"""Conversion script for the Stable Diffusion checkpoints."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
convert_ldm_vae_checkpoint,
|
||||
create_vae_diffusers_config,
|
||||
@ -19,9 +19,10 @@ from . import AnyModel
|
||||
|
||||
|
||||
def convert_ldm_vae_to_diffusers(
|
||||
checkpoint: Dict[str, torch.Tensor],
|
||||
checkpoint: torch.Tensor | dict[str, torch.Tensor],
|
||||
vae_config: DictConfig,
|
||||
image_size: int,
|
||||
dump_path: Optional[Path] = None,
|
||||
precision: torch.dtype = torch.float16,
|
||||
) -> AutoencoderKL:
|
||||
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
|
||||
@ -30,7 +31,12 @@ def convert_ldm_vae_to_diffusers(
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae.to(precision)
|
||||
vae.to(precision)
|
||||
|
||||
if dump_path:
|
||||
vae.save_pretrained(dump_path, safe_serialization=True)
|
||||
|
||||
return vae
|
||||
|
||||
|
||||
def convert_ckpt_to_diffusers(
|
||||
|
@ -64,6 +64,6 @@ class VAELoader(GenericDiffusersLoader):
|
||||
vae_config=ckpt_config,
|
||||
image_size=512,
|
||||
precision=self._torch_dtype,
|
||||
dump_path=output_path,
|
||||
)
|
||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
||||
return vae_model
|
||||
|
Loading…
Reference in New Issue
Block a user