mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model loading and conversion implemented for vaes
This commit is contained in:
committed by
psychedelicious
parent
b8e875bb73
commit
8ba5360269
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Union
|
||||
from typing import Union, Optional
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
@ -43,7 +43,8 @@ def choose_precision(device: torch.device) -> str:
|
||||
return "float32"
|
||||
|
||||
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
||||
device = device or choose_torch_device()
|
||||
precision = choose_precision(device)
|
||||
if precision == "float16":
|
||||
return torch.float16
|
||||
|
Reference in New Issue
Block a user