model loading and conversion implemented for vaes

This commit is contained in:
Lincoln Stein
2024-02-03 22:55:09 -05:00
committed by psychedelicious
parent b8e875bb73
commit 8ba5360269
29 changed files with 2382 additions and 237 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import nullcontext
from typing import Union
from typing import Union, Optional
import torch
from torch import autocast
@ -43,7 +43,8 @@ def choose_precision(device: torch.device) -> str:
return "float32"
def torch_dtype(device: torch.device) -> torch.dtype:
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
device = device or choose_torch_device()
precision = choose_precision(device)
if precision == "float16":
return torch.float16