mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): use same pattern for vae converter as others
Add `dump_path` arg to the converter function & save the model to disk inside the conversion function. This is the same pattern as in the other conversion functions.
This commit is contained in:
parent
13f410478a
commit
59b4a23479
@ -3,10 +3,10 @@
|
|||||||
"""Conversion script for the Stable Diffusion checkpoints."""
|
"""Conversion script for the Stable Diffusion checkpoints."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import AutoencoderKL
|
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||||
convert_ldm_vae_checkpoint,
|
convert_ldm_vae_checkpoint,
|
||||||
create_vae_diffusers_config,
|
create_vae_diffusers_config,
|
||||||
@ -19,9 +19,10 @@ from . import AnyModel
|
|||||||
|
|
||||||
|
|
||||||
def convert_ldm_vae_to_diffusers(
|
def convert_ldm_vae_to_diffusers(
|
||||||
checkpoint: Dict[str, torch.Tensor],
|
checkpoint: torch.Tensor | dict[str, torch.Tensor],
|
||||||
vae_config: DictConfig,
|
vae_config: DictConfig,
|
||||||
image_size: int,
|
image_size: int,
|
||||||
|
dump_path: Optional[Path] = None,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
) -> AutoencoderKL:
|
) -> AutoencoderKL:
|
||||||
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
|
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
|
||||||
@ -30,7 +31,12 @@ def convert_ldm_vae_to_diffusers(
|
|||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
vae = AutoencoderKL(**vae_config)
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
return vae.to(precision)
|
vae.to(precision)
|
||||||
|
|
||||||
|
if dump_path:
|
||||||
|
vae.save_pretrained(dump_path, safe_serialization=True)
|
||||||
|
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
def convert_ckpt_to_diffusers(
|
def convert_ckpt_to_diffusers(
|
||||||
|
@ -64,6 +64,6 @@ class VAELoader(GenericDiffusersLoader):
|
|||||||
vae_config=ckpt_config,
|
vae_config=ckpt_config,
|
||||||
image_size=512,
|
image_size=512,
|
||||||
precision=self._torch_dtype,
|
precision=self._torch_dtype,
|
||||||
|
dump_path=output_path,
|
||||||
)
|
)
|
||||||
vae_model.save_pretrained(output_path, safe_serialization=True)
|
|
||||||
return vae_model
|
return vae_model
|
||||||
|
Loading…
Reference in New Issue
Block a user