mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix invokeai_configure script to work with new mm; rename CLIs
This commit is contained in:
committed by
psychedelicious
parent
dfcf38be91
commit
d959276217
@ -37,7 +37,7 @@ from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
# name of the starter models file
|
||||
INITIAL_MODELS = "INITIAL_MODELS2.yaml"
|
||||
INITIAL_MODELS = "INITIAL_MODELS.yaml"
|
||||
|
||||
|
||||
def initialize_record_store(app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
|
@ -18,31 +18,30 @@ from argparse import Namespace
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import get_terminal_size
|
||||
from typing import Any, get_args, get_type_hints
|
||||
from typing import Any, Optional, Set, Tuple, Type, get_args, get_type_hints
|
||||
from urllib import request
|
||||
|
||||
import npyscreen
|
||||
import omegaconf
|
||||
import psutil
|
||||
import torch
|
||||
import transformers
|
||||
import yaml
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import AutoencoderKL, ModelMixin
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from huggingface_hub import HfFolder
|
||||
from huggingface_hub import login as hf_hub_login
|
||||
from omegaconf import OmegaConf
|
||||
from pydantic import ValidationError
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.configs as configs
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.install.install_helper import InstallHelper, InstallSelections
|
||||
from invokeai.backend.install.legacy_arg_parsing import legacy_parser
|
||||
from invokeai.backend.install.model_install_backend import InstallSelections, ModelInstall, hf_download_from_pretrained
|
||||
from invokeai.backend.model_management.model_probe import BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
from invokeai.backend.util import choose_precision, choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.frontend.install.model_install import addModelsForm, process_and_execute
|
||||
from invokeai.frontend.install.model_install import addModelsForm
|
||||
|
||||
# TO DO - Move all the frontend code into invokeai.frontend.install
|
||||
from invokeai.frontend.install.widgets import (
|
||||
@ -61,7 +60,7 @@ warnings.filterwarnings("ignore")
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
def get_literal_fields(field) -> list[Any]:
|
||||
def get_literal_fields(field: str) -> Tuple[Any]:
|
||||
return get_args(get_type_hints(InvokeAIAppConfig).get(field))
|
||||
|
||||
|
||||
@ -80,8 +79,7 @@ ATTENTION_SLICE_CHOICES = get_literal_fields("attention_slice_size")
|
||||
GENERATION_OPT_CHOICES = ["sequential_guidance", "force_tiled_decode", "lazy_offload"]
|
||||
GB = 1073741824 # GB in bytes
|
||||
HAS_CUDA = torch.cuda.is_available()
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0, 0)
|
||||
|
||||
_, MAX_VRAM = torch.cuda.mem_get_info() if HAS_CUDA else (0.0, 0.0)
|
||||
|
||||
MAX_VRAM /= GB
|
||||
MAX_RAM = psutil.virtual_memory().total / GB
|
||||
@ -96,13 +94,15 @@ logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class DummyWidgetValue(Enum):
|
||||
"""Dummy widget values."""
|
||||
|
||||
zero = 0
|
||||
true = True
|
||||
false = False
|
||||
|
||||
|
||||
# --------------------------------------------
|
||||
def postscript(errors: None):
|
||||
def postscript(errors: Set[str]) -> None:
|
||||
if not any(errors):
|
||||
message = f"""
|
||||
** INVOKEAI INSTALLATION SUCCESSFUL **
|
||||
@ -143,7 +143,7 @@ def yes_or_no(prompt: str, default_yes=True):
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def HfLogin(access_token) -> str:
|
||||
def HfLogin(access_token) -> None:
|
||||
"""
|
||||
Helper for logging in to Huggingface
|
||||
The stdout capture is needed to hide the irrelevant "git credential helper" warning
|
||||
@ -162,7 +162,7 @@ def HfLogin(access_token) -> str:
|
||||
|
||||
# -------------------------------------
|
||||
class ProgressBar:
|
||||
def __init__(self, model_name="file"):
|
||||
def __init__(self, model_name: str = "file"):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
@ -179,6 +179,22 @@ class ProgressBar:
|
||||
self.pbar.update(block_size)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_from_pretrained(model_class: Type[ModelMixin], model_name: str, destination: Path, **kwargs: Any):
|
||||
filter = lambda x: "fp16 is not a valid" not in x.getMessage() # noqa E731
|
||||
logger.addFilter(filter)
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model.save_pretrained(destination, safe_serialization=True)
|
||||
finally:
|
||||
logger.removeFilter(filter)
|
||||
return destination
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||
try:
|
||||
@ -249,6 +265,7 @@ def download_conversion_models():
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# TO DO: use the download queue here.
|
||||
def download_realesrgan():
|
||||
logger.info("Installing ESRGAN Upscaling models...")
|
||||
URLs = [
|
||||
@ -288,18 +305,19 @@ def download_lama():
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_support_models():
|
||||
def download_support_models() -> None:
|
||||
download_realesrgan()
|
||||
download_lama()
|
||||
download_conversion_models()
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
def get_root(root: Optional[str] = None) -> str:
|
||||
if root:
|
||||
return root
|
||||
elif os.environ.get("INVOKEAI_ROOT"):
|
||||
return os.environ.get("INVOKEAI_ROOT")
|
||||
elif root := os.environ.get("INVOKEAI_ROOT"):
|
||||
assert root is not None
|
||||
return root
|
||||
else:
|
||||
return str(config.root_path)
|
||||
|
||||
@ -455,6 +473,25 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
max_width=110,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model disk conversion cache size (GB). This is used to cache safetensors files that need to be converted to diffusers..",
|
||||
begin_entry_at=0,
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely -= 1
|
||||
self.disk = self.add_widget_intelligent(
|
||||
npyscreen.Slider,
|
||||
value=clip(old_opts.convert_cache, range=(0, 100), step=0.5),
|
||||
out_of=100,
|
||||
lowest=0.0,
|
||||
step=0.5,
|
||||
relx=8,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.TitleFixedText,
|
||||
name="Model RAM cache size (GB). Make this at least large enough to hold a single full model (2GB for SD-1, 6GB for SDXL).",
|
||||
@ -495,6 +532,14 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
)
|
||||
else:
|
||||
self.vram = DummyWidgetValue.zero
|
||||
|
||||
self.nextrely += 1
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Location of the database used to store model path and configuration information:",
|
||||
editable=False,
|
||||
color="CONTROL",
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.outdir = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
@ -506,19 +551,21 @@ Use cursor arrows to make a checkbox selection, and space to toggle.
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=40,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.autoimport_dirs = {}
|
||||
self.autoimport_dirs["autoimport_dir"] = self.add_widget_intelligent(
|
||||
FileBox,
|
||||
name="Folder to recursively scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir),
|
||||
name="Optional folder to scan for new checkpoints, ControlNets, LoRAs and TI models",
|
||||
value=str(config.root_path / config.autoimport_dir) if config.autoimport_dir else "",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
use_two_lines=False,
|
||||
labelColor="GOOD",
|
||||
begin_entry_at=32,
|
||||
max_height=3,
|
||||
max_width=127,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
@ -555,6 +602,10 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
self.attention_slice_label.hidden = not show
|
||||
self.attention_slice_size.hidden = not show
|
||||
|
||||
def show_hide_model_conf_override(self, value):
|
||||
self.model_conf_override.hidden = value
|
||||
self.model_conf_override.display()
|
||||
|
||||
def on_ok(self):
|
||||
options = self.marshall_arguments()
|
||||
if self.validate_field_values(options):
|
||||
@ -584,18 +635,21 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
else:
|
||||
return True
|
||||
|
||||
def marshall_arguments(self):
|
||||
def marshall_arguments(self) -> Namespace:
|
||||
new_opts = Namespace()
|
||||
|
||||
for attr in [
|
||||
"ram",
|
||||
"vram",
|
||||
"convert_cache",
|
||||
"outdir",
|
||||
]:
|
||||
if hasattr(self, attr):
|
||||
setattr(new_opts, attr, getattr(self, attr).value)
|
||||
|
||||
for attr in self.autoimport_dirs:
|
||||
if not self.autoimport_dirs[attr].value:
|
||||
continue
|
||||
directory = Path(self.autoimport_dirs[attr].value)
|
||||
if directory.is_relative_to(config.root_path):
|
||||
directory = directory.relative_to(config.root_path)
|
||||
@ -615,13 +669,14 @@ https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENS
|
||||
|
||||
|
||||
class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: Namespace):
|
||||
def __init__(self, program_opts: Namespace, invokeai_opts: InvokeAIAppConfig, install_helper: InstallHelper):
|
||||
super().__init__()
|
||||
self.program_opts = program_opts
|
||||
self.invokeai_opts = invokeai_opts
|
||||
self.user_cancelled = False
|
||||
self.autoload_pending = True
|
||||
self.install_selections = default_user_selections(program_opts)
|
||||
self.install_helper = install_helper
|
||||
self.install_selections = default_user_selections(program_opts, install_helper)
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
@ -640,16 +695,10 @@ class EditOptApplication(npyscreen.NPSAppManaged):
|
||||
cycle_widgets=False,
|
||||
)
|
||||
|
||||
def new_opts(self):
|
||||
def new_opts(self) -> Namespace:
|
||||
return self.options.marshall_arguments()
|
||||
|
||||
|
||||
def edit_opts(program_opts: Namespace, invokeai_opts: Namespace) -> argparse.Namespace:
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp.run()
|
||||
return editApp.new_opts()
|
||||
|
||||
|
||||
def default_ramcache() -> float:
|
||||
"""Run a heuristic for the default RAM cache based on installed RAM."""
|
||||
|
||||
@ -660,27 +709,18 @@ def default_ramcache() -> float:
|
||||
) # 2.1 is just large enough for sd 1.5 ;-)
|
||||
|
||||
|
||||
def default_startup_options(init_file: Path) -> Namespace:
|
||||
def default_startup_options(init_file: Path) -> InvokeAIAppConfig:
|
||||
opts = InvokeAIAppConfig.get_config()
|
||||
opts.ram = opts.ram or default_ramcache()
|
||||
opts.ram = default_ramcache()
|
||||
return opts
|
||||
|
||||
|
||||
def default_user_selections(program_opts: Namespace) -> InstallSelections:
|
||||
try:
|
||||
installer = ModelInstall(config)
|
||||
except omegaconf.errors.ConfigKeyError:
|
||||
logger.warning("Your models.yaml file is corrupt or out of date. Reinitializing")
|
||||
initialize_rootdir(config.root_path, True)
|
||||
installer = ModelInstall(config)
|
||||
|
||||
models = installer.all_models()
|
||||
def default_user_selections(program_opts: Namespace, install_helper: InstallHelper) -> InstallSelections:
|
||||
default_model = install_helper.default_model()
|
||||
assert default_model is not None
|
||||
default_models = [default_model] if program_opts.default_only else install_helper.recommended_models()
|
||||
return InstallSelections(
|
||||
install_models=[models[installer.default_model()].path or models[installer.default_model()].repo_id]
|
||||
if program_opts.default_only
|
||||
else [models[x].path or models[x].repo_id for x in installer.recommended_models()]
|
||||
if program_opts.yes_to_all
|
||||
else [],
|
||||
install_models=default_models if program_opts.yes_to_all else [],
|
||||
)
|
||||
|
||||
|
||||
@ -716,21 +756,10 @@ def initialize_rootdir(root: Path, yes_to_all: bool = False):
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def maybe_create_models_yaml(root: Path):
|
||||
models_yaml = root / "configs" / "models.yaml"
|
||||
if models_yaml.exists():
|
||||
if OmegaConf.load(models_yaml).get("__metadata__"): # up to date
|
||||
return
|
||||
else:
|
||||
logger.info("Creating new models.yaml, original saved as models.yaml.orig")
|
||||
models_yaml.rename(models_yaml.parent / "models.yaml.orig")
|
||||
|
||||
with open(models_yaml, "w") as yaml_file:
|
||||
yaml_file.write(yaml.dump({"__metadata__": {"version": "3.0.0"}}))
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace, Namespace):
|
||||
def run_console_ui(
|
||||
program_opts: Namespace, initfile: Path, install_helper: InstallHelper
|
||||
) -> Tuple[Optional[Namespace], Optional[InstallSelections]]:
|
||||
invokeai_opts = default_startup_options(initfile)
|
||||
invokeai_opts.root = program_opts.root
|
||||
|
||||
@ -739,22 +768,16 @@ def run_console_ui(program_opts: Namespace, initfile: Path = None) -> (Namespace
|
||||
"Could not increase terminal size. Try running again with a larger window or smaller font size."
|
||||
)
|
||||
|
||||
# the install-models application spawns a subprocess to install
|
||||
# models, and will crash unless this is set before running.
|
||||
import torch
|
||||
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts)
|
||||
editApp = EditOptApplication(program_opts, invokeai_opts, install_helper)
|
||||
editApp.run()
|
||||
if editApp.user_cancelled:
|
||||
return (None, None)
|
||||
else:
|
||||
return (editApp.new_opts, editApp.install_selections)
|
||||
return (editApp.new_opts(), editApp.install_selections)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_opts(opts: Namespace, init_file: Path):
|
||||
def write_opts(opts: InvokeAIAppConfig, init_file: Path) -> None:
|
||||
"""
|
||||
Update the invokeai.yaml file with values from current settings.
|
||||
"""
|
||||
@ -762,7 +785,7 @@ def write_opts(opts: Namespace, init_file: Path):
|
||||
new_config = InvokeAIAppConfig.get_config()
|
||||
new_config.root = config.root
|
||||
|
||||
for key, value in opts.__dict__.items():
|
||||
for key, value in opts.model_dump().items():
|
||||
if hasattr(new_config, key):
|
||||
setattr(new_config, key, value)
|
||||
|
||||
@ -779,7 +802,7 @@ def default_output_dir() -> Path:
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
def write_default_options(program_opts: Namespace, initfile: Path) -> None:
|
||||
opt = default_startup_options(initfile)
|
||||
write_opts(opt, initfile)
|
||||
|
||||
@ -789,16 +812,11 @@ def write_default_options(program_opts: Namespace, initfile: Path):
|
||||
# the legacy Args object in order to parse
|
||||
# the old init file and write out the new
|
||||
# yaml format.
|
||||
def migrate_init_file(legacy_format: Path):
|
||||
def migrate_init_file(legacy_format: Path) -> None:
|
||||
old = legacy_parser.parse_args([f"@{str(legacy_format)}"])
|
||||
new = InvokeAIAppConfig.get_config()
|
||||
|
||||
fields = [
|
||||
x
|
||||
for x, y in InvokeAIAppConfig.model_fields.items()
|
||||
if (y.json_schema_extra.get("category", None) if y.json_schema_extra else None) != "DEPRECATED"
|
||||
]
|
||||
for attr in fields:
|
||||
for attr in InvokeAIAppConfig.model_fields.keys():
|
||||
if hasattr(old, attr):
|
||||
try:
|
||||
setattr(new, attr, getattr(old, attr))
|
||||
@ -819,7 +837,7 @@ def migrate_init_file(legacy_format: Path):
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def migrate_models(root: Path):
|
||||
def migrate_models(root: Path) -> None:
|
||||
from invokeai.backend.install.migrate_to_3 import do_migrate
|
||||
|
||||
do_migrate(root, root)
|
||||
@ -838,7 +856,9 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
||||
):
|
||||
logger.info("** Migrating invokeai.init to invokeai.yaml")
|
||||
migrate_init_file(old_init_file)
|
||||
config.parse_args(argv=[], conf=OmegaConf.load(new_init_file))
|
||||
omegaconf = OmegaConf.load(new_init_file)
|
||||
assert isinstance(omegaconf, DictConfig)
|
||||
config.parse_args(argv=[], conf=omegaconf)
|
||||
|
||||
if old_hub.exists():
|
||||
migrate_models(config.root_path)
|
||||
@ -849,7 +869,7 @@ def migrate_if_needed(opt: Namespace, root: Path) -> bool:
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def main() -> None:
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--skip-sd-weights",
|
||||
@ -908,6 +928,7 @@ def main() -> None:
|
||||
if opt.full_precision:
|
||||
invoke_args.extend(["--precision", "float32"])
|
||||
config.parse_args(invoke_args)
|
||||
config.precision = "float32" if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
|
||||
logger = InvokeAILogger().get_logger(config=config)
|
||||
|
||||
errors = set()
|
||||
@ -921,14 +942,18 @@ def main() -> None:
|
||||
# run this unconditionally in case new directories need to be added
|
||||
initialize_rootdir(config.root_path, opt.yes_to_all)
|
||||
|
||||
models_to_download = default_user_selections(opt)
|
||||
# this will initialize the models.yaml file if not present
|
||||
install_helper = InstallHelper(config, logger)
|
||||
|
||||
models_to_download = default_user_selections(opt, install_helper)
|
||||
new_init_file = config.root_path / "invokeai.yaml"
|
||||
|
||||
if opt.yes_to_all:
|
||||
write_default_options(opt, new_init_file)
|
||||
init_options = Namespace(precision="float32" if opt.full_precision else "float16")
|
||||
|
||||
else:
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file)
|
||||
init_options, models_to_download = run_console_ui(opt, new_init_file, install_helper)
|
||||
if init_options:
|
||||
write_opts(init_options, new_init_file)
|
||||
else:
|
||||
@ -943,10 +968,12 @@ def main() -> None:
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
logger.warning("Skipping diffusion weights download per user request")
|
||||
|
||||
elif models_to_download:
|
||||
process_and_execute(opt, models_to_download)
|
||||
install_helper.add_or_delete(models_to_download)
|
||||
|
||||
postscript(errors=errors)
|
||||
|
||||
if not opt.yes_to_all:
|
||||
input("Press any key to continue...")
|
||||
except WindowTooSmallException as e:
|
||||
|
@ -19,7 +19,7 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats, ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional, Union
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
@ -31,7 +31,9 @@ def choose_torch_device() -> torch.device:
|
||||
|
||||
# We are in transition here from using a single global AppConfig to allowing multiple
|
||||
# configurations. It is strongly recommended to pass the app_config to this function.
|
||||
def choose_precision(device: torch.device, app_config: Optional[InvokeAIAppConfig] = None) -> str:
|
||||
def choose_precision(
|
||||
device: torch.device, app_config: Optional[InvokeAIAppConfig] = None
|
||||
) -> Literal["float32", "float16", "bfloat16"]:
|
||||
"""Return an appropriate precision for the given torch device."""
|
||||
app_config = app_config or config
|
||||
if device.type == "cuda":
|
||||
|
Reference in New Issue
Block a user