model loading and conversion implemented for vaes

This commit is contained in:
Lincoln Stein
2024-02-03 22:55:09 -05:00
committed by psychedelicious
parent 5c2884569e
commit 60aa3d4893
29 changed files with 2382 additions and 237 deletions

View File

@ -12,6 +12,14 @@ from .devices import ( # noqa: F401
torch_dtype,
)
from .logging import InvokeAILogger
from .util import Chdir, ask_user, download_with_resume, instantiate_from_config, url_attachment_name # noqa: F401
from .util import ( # TO DO: Clean this up; remove the unused symbols
GIG,
Chdir,
ask_user, # noqa
directory_size,
download_with_resume,
instantiate_from_config, # noqa
url_attachment_name, # noqa
)
__all__ = ["Chdir", "InvokeAILogger", "choose_precision", "choose_torch_device"]
__all__ = ["GIG", "directory_size","Chdir", "download_with_resume", "InvokeAILogger", "choose_precision", "choose_torch_device"]

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from contextlib import nullcontext
from typing import Union
from typing import Union, Optional
import torch
from torch import autocast
@ -43,7 +43,8 @@ def choose_precision(device: torch.device) -> str:
return "float32"
def torch_dtype(device: torch.device) -> torch.dtype:
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
device = device or choose_torch_device()
precision = choose_precision(device)
if precision == "float16":
return torch.float16

View File

@ -24,6 +24,20 @@ import invokeai.backend.util.logging as logger
from .devices import torch_dtype
# actual size of a gig
GIG = 1073741824
def directory_size(directory: Path) -> int:
"""
Return the aggregate size of all files in a directory (bytes).
"""
sum = 0
for root, dirs, files in os.walk(directory):
for f in files:
sum += Path(root, f).stat().st_size
for d in dirs:
sum += Path(root, d).stat().st_size
return sum
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)