remove globals, args, generate and the legacy CLI

This commit is contained in:
Lincoln Stein
2023-05-03 23:36:51 -04:00
parent 90054ddf0d
commit 15ffb53e59
21 changed files with 76 additions and 4679 deletions

View File

@ -1,7 +1,6 @@
"""
Initialization file for invokeai.backend
"""
from .generate import Generate
from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGenerator,
@ -12,5 +11,3 @@ from .generator import (
)
from .model_management import ModelManager, SDModelComponent
from .safety_checker import SafetyChecker
from .args import Args
from .globals import Globals

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,135 +0,0 @@
"""
invokeai.backend.globals defines a small number of global variables that would
otherwise have to be passed through long and complex call chains.
It defines a Namespace object named "Globals" that contains
the attributes:
- root - the root directory under which "models" and "outputs" can be found
- initfile - path to the initialization file
- try_patchmatch - option to globally disable loading of 'patchmatch' module
- always_use_cpu - force use of CPU even if GPU is available
"""
import os
import os.path as osp
from argparse import Namespace
from pathlib import Path
from typing import Union
from pydantic import BaseSettings
Globals = Namespace()
# Where to look for the initialization file and other key components
Globals.initfile = "invokeai.init"
Globals.models_file = "models.yaml"
Globals.models_dir = "models"
Globals.config_dir = "configs"
Globals.autoscan_dir = "weights"
Globals.converted_ckpts_dir = "converted_ckpts"
# Set the default root directory. This can be overwritten by explicitly
# passing the `--root <directory>` argument on the command line.
# logic is:
# 1) use INVOKEAI_ROOT environment variable (no check for this being a valid directory)
# 2) use VIRTUAL_ENV environment variable, with a check for initfile being there
# 3) use ~/invokeai
if os.environ.get("INVOKEAI_ROOT"):
Globals.root = osp.abspath(os.environ.get("INVOKEAI_ROOT"))
elif (
os.environ.get("VIRTUAL_ENV")
and Path(os.environ.get("VIRTUAL_ENV"), "..", Globals.initfile).exists()
):
Globals.root = osp.abspath(osp.join(os.environ.get("VIRTUAL_ENV"), ".."))
else:
Globals.root = osp.abspath(osp.expanduser("~/invokeai"))
# Try loading patchmatch
Globals.try_patchmatch = True
# Use CPU even if GPU is available (main use case is for debugging MPS issues)
Globals.always_use_cpu = False
# Whether the internet is reachable for dynamic downloads
# The CLI will test connectivity at startup time.
Globals.internet_available = True
# Whether to disable xformers
Globals.disable_xformers = False
# Low-memory tradeoff for guidance calculations.
Globals.sequential_guidance = False
# whether we are forcing full precision
Globals.full_precision = False
# whether we should convert ckpt files into diffusers models on the fly
Globals.ckpt_convert = True
# logging tokenization everywhere
Globals.log_tokenization = False
def global_config_file() -> Path:
return Path(Globals.root, Globals.config_dir, Globals.models_file)
def global_config_dir() -> Path:
return Path(Globals.root, Globals.config_dir)
def global_models_dir() -> Path:
return Path(Globals.root, Globals.models_dir)
def global_autoscan_dir() -> Path:
return Path(Globals.root, Globals.autoscan_dir)
def global_converted_ckpts_dir() -> Path:
return Path(global_models_dir(), Globals.converted_ckpts_dir)
def global_set_root(root_dir: Union[str, Path]):
Globals.root = root_dir
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
"""
Returns Path to the model cache directory. If a subdirectory
is provided, it will be appended to the end of the path, allowing
for Hugging Face-style conventions. Currently, Hugging Face has
moved all models into the "hub" subfolder, so for any pretrained
HF model, use:
global_cache_dir('hub')
The legacy location for transformers used to be global_cache_dir('transformers')
and global_cache_dir('diffusers') for diffusers.
"""
home: str = os.getenv("HF_HOME")
if home is None:
home = os.getenv("XDG_CACHE_HOME")
if home is not None:
# Set `home` to $XDG_CACHE_HOME/huggingface, which is the default location mentioned in Hugging Face Hub Client Library.
# See: https://huggingface.co/docs/huggingface_hub/main/en/package_reference/environment_variables#xdgcachehome
home += os.sep + "huggingface"
if home is not None:
return Path(home, subdir)
else:
return Path(Globals.root, "models", subdir)
def copy_conf_to_globals(conf: Union[dict,BaseSettings]):
'''
Given a dict or dict-like object, copy its keys and
values into the Globals Namespace. This is a transitional
workaround until we remove Globals entirely.
'''
if isinstance(conf,BaseSettings):
conf = conf.dict()
for key in conf.keys():
if key is not None:
setattr(Globals,key,conf[key])

View File

@ -6,7 +6,9 @@ be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig()
class PatchMatch:
"""
@ -23,7 +25,7 @@ class PatchMatch:
def _load_patch_match(self):
if self.tried_load:
return
if Globals.try_patchmatch:
if config.try_patchmatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:

View File

@ -33,11 +33,11 @@ from PIL import Image, ImageOps
from transformers import AutoProcessor, CLIPSegForImageSegmentation
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import global_cache_dir
from invokeai.app.services.config import InvokeAIAppConfig
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352
config = InvokeAIAppConfig()
class SegmentedGrayscale(object):
def __init__(self, image: Image, heatmap: torch.Tensor):
@ -88,10 +88,10 @@ class Txt2Mask(object):
# BUG: we are not doing anything with the device option at this time
self.device = device
self.processor = AutoProcessor.from_pretrained(
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
CLIPSEG_MODEL, cache_dir=config.cache_dir
)
self.model = CLIPSegForImageSegmentation.from_pretrained(
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
CLIPSEG_MODEL, cache_dir=config.cache_dir
)
@torch.no_grad()

View File

@ -26,7 +26,7 @@ import torch
from safetensors.torch import load_file
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import global_cache_dir, global_config_dir
from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager, SDLegacyType
@ -73,6 +73,7 @@ from transformers import (
from ..stable_diffusion import StableDiffusionGeneratorPipeline
config = InvokeAIAppConfig()
def shave_segments(path, n_shave_prefix_segments=1):
"""
@ -842,7 +843,7 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=global_cache_dir("hub")
"openai/clip-vit-large-patch14", cache_dir=config.cache_dir
)
keys = list(checkpoint.keys())
@ -897,7 +898,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint):
cache_dir = global_cache_dir("hub")
cache_dir = config.cache_dir
config = CLIPVisionConfig.from_pretrained(
"openai/clip-vit-large-patch14", cache_dir=cache_dir
)
@ -969,7 +970,7 @@ def convert_paint_by_example_checkpoint(checkpoint):
def convert_open_clip_checkpoint(checkpoint):
cache_dir = global_cache_dir("hub")
cache_dir = config.cache_dir
text_model = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2", subfolder="text_encoder", cache_dir=cache_dir
)
@ -1105,7 +1106,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
else:
checkpoint = load_file(checkpoint_path)
cache_dir = global_cache_dir("hub")
cache_dir = config.cache_dir
pipeline_class = (
StableDiffusionGeneratorPipeline
if return_generator_pipeline
@ -1129,25 +1130,23 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if model_type == SDLegacyType.V2_v:
original_config_file = (
global_config_dir() / "stable-diffusion" / "v2-inference-v.yaml"
config.legacy_conf_path / "v2-inference-v.yaml"
)
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
elif model_type == SDLegacyType.V2_e:
original_config_file = (
global_config_dir() / "stable-diffusion" / "v2-inference.yaml"
config.legacy_conf_path / "v2-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
original_config_file = (
global_config_dir()
/ "stable-diffusion"
/ "v1-inpainting-inference.yaml"
config.legacy_conf_path / "v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V1:
original_config_file = (
global_config_dir() / "stable-diffusion" / "v1-inference.yaml"
config.legacy_conf_path / "v1-inference.yaml"
)
else:
@ -1297,7 +1296,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker",
cache_dir=global_cache_dir("hub"),
cache_dir=config.cache_dir,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", cache_dir=cache_dir

View File

@ -36,8 +36,6 @@ from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from transformers import (
CLIPTextModel,
CLIPTokenizer,
@ -49,9 +47,9 @@ from diffusers.pipelines.stable_diffusion.safety_checker import (
from ..stable_diffusion import (
StableDiffusionGeneratorPipeline,
)
from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CUDA_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = auto()
V1_INPAINT = auto()
@ -70,6 +68,7 @@ class SDModelComponent(Enum):
feature_extractor="feature_extractor"
DEFAULT_MAX_MODELS = 2
config = InvokeAIAppConfig()
class ModelManager(object):
"""
@ -292,7 +291,7 @@ class ModelManager(object):
"""
# if we are converting legacy files automatically, then
# there are no legacy ckpts!
if Globals.ckpt_convert:
if config.ckpt_convert:
return False
info = self.model_info(model_name)
if "weights" in info and info["weights"].endswith((".ckpt", ".safetensors")):
@ -502,13 +501,13 @@ class ModelManager(object):
# TODO: scan weights maybe?
pipeline_args: dict[str, Any] = dict(
safety_checker=None, local_files_only=not Globals.internet_available
safety_checker=None, local_files_only=not config.internet_available
)
if "vae" in mconfig and mconfig["vae"] is not None:
if vae := self._load_vae(mconfig["vae"]):
pipeline_args.update(vae=vae)
if not isinstance(name_or_path, Path):
pipeline_args.update(cache_dir=global_cache_dir("hub"))
pipeline_args.update(cache_dir=config.cache_dir)
if using_fp16:
pipeline_args.update(torch_dtype=torch.float16)
fp_args_list = [{"revision": "fp16"}, {}]
@ -561,9 +560,9 @@ class ModelManager(object):
height = mconfig.height
if not os.path.isabs(config):
config = os.path.join(Globals.root, config)
config = os.path.join(config.root, config)
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root, weights))
weights = os.path.normpath(os.path.join(config.root, weights))
# Convert to diffusers and return a diffusers pipeline
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
@ -581,7 +580,7 @@ class ModelManager(object):
vae_path = (
vae
if os.path.isabs(vae)
else os.path.normpath(os.path.join(Globals.root, vae))
else os.path.normpath(os.path.join(config.root, vae))
)
if self._has_cuda():
torch.cuda.empty_cache()
@ -616,7 +615,7 @@ class ModelManager(object):
if "path" in mconfig and mconfig["path"] is not None:
path = Path(mconfig["path"])
if not path.is_absolute():
path = Path(Globals.root, path).resolve()
path = Path(config.root, path).resolve()
return path
elif "repo_id" in mconfig:
return mconfig["repo_id"]
@ -864,25 +863,16 @@ class ModelManager(object):
model_type = self.probe_model_type(checkpoint)
if model_type == SDLegacyType.V1:
self.logger.debug("SD-v1 model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
)
model_config_file = config.legacy_conf_path / "v1-inference.yaml"
elif model_type == SDLegacyType.V1_INPAINT:
self.logger.debug("SD-v1 inpainting model detected")
model_config_file = Path(
Globals.root,
"configs/stable-diffusion/v1-inpainting-inference.yaml",
)
model_config_file = config.legacy_conf_path / "v1-inpainting-inference.yaml",
elif model_type == SDLegacyType.V2_v:
self.logger.debug("SD-v2-v model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
)
model_config_file = config.legacy_conf_path / "v2-inference-v.yaml"
elif model_type == SDLegacyType.V2_e:
self.logger.debug("SD-v2-e model detected")
model_config_file = Path(
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
)
model_config_file = config.legacy_conf_path / "v2-inference.yaml"
elif model_type == SDLegacyType.V2:
self.logger.warning(
f"{thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
@ -909,9 +899,7 @@ class ModelManager(object):
self.logger.debug(f"Using VAE file {vae_path.name}")
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
diffuser_path = Path(
Globals.root, "models", Globals.converted_ckpts_dir, model_path.stem
)
diffuser_path = config.root / "models/converted_ckpts" / model_path.stem
model_name = self.convert_and_import(
model_path,
diffusers_path=diffuser_path,
@ -1044,9 +1032,7 @@ class ModelManager(object):
"""
yaml_str = OmegaConf.to_yaml(self.config)
if not os.path.isabs(config_file_path):
config_file_path = os.path.normpath(
os.path.join(Globals.root, config_file_path)
)
config_file_path = config.model_conf_path
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
with open(tmpfile, "w", encoding="utf-8") as outfile:
outfile.write(self.preamble())
@ -1078,7 +1064,7 @@ class ModelManager(object):
"""
# Three transformer models to check: bert, clip and safety checker, and
# the diffusers as well
models_dir = Path(Globals.root, "models")
models_dir = config.root / "models"
legacy_locations = [
Path(
models_dir,
@ -1090,8 +1076,8 @@ class ModelManager(object):
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
),
]
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
legacy_cache_dir = config.cache_dir / "../diffusers"
legacy_locations.extend(list(legacy_cache_dir.glob("*")))
legacy_layout = False
for model in legacy_locations:
legacy_layout = legacy_layout or model.exists()
@ -1113,7 +1099,7 @@ class ModelManager(object):
# transformer files get moved into the hub directory
if cls._is_huggingface_hub_directory_present():
hub = global_cache_dir("hub")
hub = config.cache_dir
else:
hub = models_dir / "hub"
@ -1152,12 +1138,12 @@ class ModelManager(object):
if str(source).startswith(("http:", "https:", "ftp:")):
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = Globals.root / dest_directory
dest_directory = config.root / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)
source = config.root / source
resolved_path = Path(source)
return resolved_path
@ -1208,7 +1194,7 @@ class ModelManager(object):
path = name_or_path
else:
owner, repo = name_or_path.split("/")
path = Path(global_cache_dir("hub") / f"models--{owner}--{repo}")
path = Path(config.cache_dir / f"models--{owner}--{repo}")
if not path.exists():
return None
hashpath = path / "checksum.sha256"
@ -1269,8 +1255,8 @@ class ModelManager(object):
using_fp16 = self.precision == "float16"
vae_args.update(
cache_dir=global_cache_dir("hub"),
local_files_only=not Globals.internet_available,
cache_dir=config.cache_dir,
local_files_only=not config.internet_available,
)
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
@ -1308,7 +1294,7 @@ class ModelManager(object):
@classmethod
def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(global_cache_dir("hub"))
cache_info = scan_cache_dir(config.cache_dir)
# I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible!
@ -1327,7 +1313,7 @@ class ModelManager(object):
def _abs_path(path: str | Path) -> Path:
if path is None or Path(path).is_absolute():
return path
return Path(Globals.root, path).resolve()
return Path(config.root, path).resolve()
@staticmethod
def _is_huggingface_hub_directory_present() -> bool:

View File

@ -19,11 +19,12 @@ from compel.prompt_parser import (
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import InvokeAIAppConfig
from ..stable_diffusion import InvokeAIDiffuserComponent
from ..util import torch_dtype
config = InvokeAIAppConfig()
def get_uc_and_c_and_ec(
prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False
@ -61,7 +62,7 @@ def get_uc_and_c_and_ec(
negative_prompt_string
)
if log_tokens or getattr(Globals, "log_tokenization", False):
if log_tokens or config.log_tokenization:
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)

View File

@ -6,7 +6,8 @@ import numpy as np
import torch
import invokeai.backend.util.logging as logger
from ..globals import Globals
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig()
pretrained_model_url = (
"https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
@ -18,7 +19,7 @@ class CodeFormerRestoration:
self, codeformer_dir="models/codeformer", codeformer_model_path="codeformer.pth"
) -> None:
if not os.path.isabs(codeformer_dir):
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
codeformer_dir = os.path.join(config.root, codeformer_dir)
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
self.codeformer_model_exists = os.path.isfile(self.model_path)
@ -72,7 +73,7 @@ class CodeFormerRestoration:
use_parse=True,
device=device,
model_rootpath=os.path.join(
Globals.root, "models", "gfpgan", "weights"
config.root, "models", "gfpgan", "weights"
),
)
face_helper.clean_all()

View File

@ -7,13 +7,14 @@ import torch
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig()
class GFPGAN:
def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None:
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path = os.path.abspath(
os.path.join(Globals.root, gfpgan_model_path)
os.path.join(config.root, gfpgan_model_path)
)
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path)
@ -33,7 +34,7 @@ class GFPGAN:
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
cwd = os.getcwd()
os.chdir(os.path.join(Globals.root, "models"))
os.chdir(os.path.join(config.root, "models"))
try:
from gfpgan import GFPGANer

View File

@ -15,9 +15,11 @@ from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger
from .globals import global_cache_dir
from invokeai.app.services.config import InvokeAIAppConfig
from .util import CPU_DEVICE
config = InvokeAIAppConfig()
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
@ -29,7 +31,7 @@ class SafetyChecker(object):
try:
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = global_cache_dir("hub")
safety_model_path = config.cache_dir
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
local_files_only=True,

View File

@ -18,15 +18,15 @@ from huggingface_hub import (
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import InvokeAIAppConfig
config = InvokeAIAppConfig()
class HuggingFaceConceptsLibrary(object):
def __init__(self, root=None):
"""
Initialize the Concepts object. May optionally pass a root directory.
"""
self.root = root or Globals.root
self.root = root or config.root
self.hf_api = HfApi()
self.local_concepts = dict()
self.concept_list = None
@ -58,7 +58,7 @@ class HuggingFaceConceptsLibrary(object):
self.concept_list.extend(list(local_concepts_to_add))
return self.concept_list
return self.concept_list
elif Globals.internet_available is True:
elif config.internet_available is True:
try:
models = self.hf_api.list_models(
filter=ModelFilter(model_name="sd-concepts-library/")

View File

@ -33,8 +33,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from typing_extensions import ParamSpec
from invokeai.backend.globals import Globals
from invokeai.app.services.config import InvokeAIAppConfig
from ..util import CPU_DEVICE, normalize_device
from .diffusion import (
AttentionMapSaver,
@ -44,6 +43,7 @@ from .diffusion import (
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
from .textual_inversion_manager import TextualInversionManager
config = InvokeAIAppConfig()
@dataclass
class PipelineIntermediateState:
@ -351,7 +351,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if (
torch.cuda.is_available()
and is_xformers_available()
and not Globals.disable_xformers
and not config.disable_xformers
):
self.enable_xformers_memory_efficient_attention()
else:

View File

@ -9,7 +9,7 @@ from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from invokeai.app.services.config import InvokeAIAppConfig
from .cross_attention_control import (
Arguments,
@ -31,6 +31,7 @@ ModelForwardCallback: TypeAlias = Union[
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
]
config = InvokeAIAppConfig()
@dataclass(frozen=True)
class PostprocessingSettings:
@ -77,7 +78,7 @@ class InvokeAIDiffuserComponent:
self.is_running_diffusers = is_running_diffusers
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
self.sequential_guidance = Globals.sequential_guidance
self.sequential_guidance = config.sequential_guidance
@contextmanager
def custom_attention_context(

View File

@ -4,17 +4,16 @@ from contextlib import nullcontext
import torch
from torch import autocast
from invokeai.backend.globals import Globals
from invokeai.app.services.config import InvokeAIAppConfig
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
config = InvokeAIAppConfig()
def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on"""
if Globals.always_use_cpu:
if config.always_use_cpu:
return CPU_DEVICE
if torch.cuda.is_available():
return torch.device("cuda")
@ -33,7 +32,7 @@ def choose_precision(device: torch.device) -> str:
def torch_dtype(device: torch.device) -> torch.dtype:
if Globals.full_precision:
if config.full_precision:
return torch.float32
if choose_precision(device) == "float16":
return torch.float16