mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into ip-adapter-style-comp
This commit is contained in:
@ -13,7 +13,7 @@ from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
@ -56,7 +56,7 @@ class DepthAnythingDetector:
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||
self.device = choose_torch_device()
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
||||
@ -81,7 +81,7 @@ class DepthAnythingDetector:
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
|
||||
self.model.to(choose_torch_device())
|
||||
self.model.to(self.device)
|
||||
return self.model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
@ -94,7 +94,7 @@ class DepthAnythingDetector:
|
||||
|
||||
image_height, image_width = np_image.shape[:2]
|
||||
np_image = transform({"image": np_image})["image"]
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
|
||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
depth = self.model(tensor_image)
|
||||
|
@ -7,7 +7,7 @@ import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
@ -28,9 +28,9 @@ config = get_config()
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self):
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
||||
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||
|
@ -8,7 +8,7 @@ from PIL import Image
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
def norm_img(np_img):
|
||||
@ -29,7 +29,7 @@ def load_jit_model(url_or_path, device):
|
||||
|
||||
class LaMA:
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
|
@ -11,7 +11,7 @@ from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
"""
|
||||
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
||||
@ -65,7 +65,7 @@ class RealESRGAN:
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale: Optional[int] = None
|
||||
self.half = half
|
||||
self.device = choose_torch_device()
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
|
||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||
|
||||
|
@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||
@ -51,7 +51,7 @@ class SafetyChecker:
|
||||
cls._load_safety_checker()
|
||||
if cls.safety_checker is None or cls.feature_extractor is None:
|
||||
return False
|
||||
device = choose_torch_device()
|
||||
device = TorchDevice.choose_torch_device()
|
||||
features = cls.feature_extractor([image], return_tensors="pt")
|
||||
features.to(device)
|
||||
cls.safety_checker.to(device)
|
||||
|
@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
# TO DO: The loader is not thread safe!
|
||||
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
|
@ -30,15 +30,12 @@ import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||
from .model_locker import ModelLocker
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
@ -244,9 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
@ -416,10 +411,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
|
@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from . import (
|
||||
AnyModelConfig,
|
||||
@ -43,6 +43,7 @@ class ModelMerger(object):
|
||||
Initialize a ModelMerger object with the model installer.
|
||||
"""
|
||||
self._installer = installer
|
||||
self._dtype = TorchDevice.choose_torch_dtype()
|
||||
|
||||
def merge_diffusion_models(
|
||||
self,
|
||||
@ -68,7 +69,7 @@ class ModelMerger(object):
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||
|
||||
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
||||
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
||||
@ -151,7 +152,7 @@ class ModelMerger(object):
|
||||
dump_path.mkdir(parents=True, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
|
||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
||||
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
||||
|
||||
# register model and get its unique key
|
||||
|
@ -25,7 +25,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdap
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import normalize_device
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -255,7 +255,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif self.unet.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
||||
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||
# input tensor of [1, 4, h/8, w/8]
|
||||
|
@ -2,7 +2,6 @@
|
||||
Initialization file for invokeai.backend.util
|
||||
"""
|
||||
|
||||
from .devices import choose_precision, choose_torch_device
|
||||
from .logging import InvokeAILogger
|
||||
from .util import GIG, Chdir, directory_size
|
||||
|
||||
@ -11,6 +10,4 @@ __all__ = [
|
||||
"directory_size",
|
||||
"Chdir",
|
||||
"InvokeAILogger",
|
||||
"choose_precision",
|
||||
"choose_torch_device",
|
||||
]
|
||||
|
@ -1,89 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import nullcontext
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import autocast
|
||||
from deprecated import deprecated
|
||||
|
||||
from invokeai.app.services.config.config_default import PRECISION, get_config
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
# legacy APIs
|
||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
CUDA_DEVICE = torch.device("cuda")
|
||||
MPS_DEVICE = torch.device("mps")
|
||||
|
||||
|
||||
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
||||
"""Return the string representation of the recommended torch device."""
|
||||
torch_dtype = TorchDevice.choose_torch_dtype(device)
|
||||
return PRECISION_TO_NAME[torch_dtype]
|
||||
|
||||
|
||||
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
|
||||
def choose_torch_device() -> torch.device:
|
||||
"""Convenience routine for guessing which GPU device to run model on"""
|
||||
config = get_config()
|
||||
if config.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
return TorchDevice.choose_torch_device()
|
||||
|
||||
|
||||
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||
"""Return the torch precision for the recommended torch device."""
|
||||
return TorchDevice.choose_torch_dtype(device)
|
||||
|
||||
|
||||
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
|
||||
|
||||
|
||||
class TorchDevice:
|
||||
"""Abstraction layer for torch devices."""
|
||||
|
||||
@classmethod
|
||||
def choose_torch_device(cls) -> torch.device:
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
app_config = get_config()
|
||||
if app_config.device != "auto":
|
||||
device = torch.device(app_config.device)
|
||||
elif torch.cuda.is_available():
|
||||
device = CUDA_DEVICE
|
||||
elif torch.backends.mps.is_available():
|
||||
device = MPS_DEVICE
|
||||
else:
|
||||
return CPU_DEVICE
|
||||
else:
|
||||
return torch.device(config.device)
|
||||
device = CPU_DEVICE
|
||||
return cls.normalize(device)
|
||||
|
||||
@classmethod
|
||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""Return the precision to use for accelerated inference."""
|
||||
device = device or cls.choose_torch_device()
|
||||
config = get_config()
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||
# These GPUs have limited support for float16
|
||||
return cls._to_dtype("float32")
|
||||
elif config.precision == "auto":
|
||||
# Default to float16 for CUDA devices
|
||||
return cls._to_dtype("float16")
|
||||
else:
|
||||
# Use the user-defined precision
|
||||
return cls._to_dtype(config.precision)
|
||||
|
||||
def get_torch_device_name() -> str:
|
||||
device = choose_torch_device()
|
||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||
elif device.type == "mps" and torch.backends.mps.is_available():
|
||||
if config.precision == "auto":
|
||||
# Default to float16 for MPS devices
|
||||
return cls._to_dtype("float16")
|
||||
else:
|
||||
# Use the user-defined precision
|
||||
return cls._to_dtype(config.precision)
|
||||
# CPU / safe fallback
|
||||
return cls._to_dtype("float32")
|
||||
|
||||
@classmethod
|
||||
def get_torch_device_name(cls) -> str:
|
||||
"""Return the device name for the current torch device."""
|
||||
device = cls.choose_torch_device()
|
||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||
|
||||
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
|
||||
"""Return an appropriate precision for the given torch device."""
|
||||
app_config = get_config()
|
||||
if device.type == "cuda":
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||
# These GPUs have limited support for float16
|
||||
return "float32"
|
||||
elif app_config.precision == "auto" or app_config.precision == "autocast":
|
||||
# Default to float16 for CUDA devices
|
||||
return "float16"
|
||||
else:
|
||||
# Use the user-defined precision
|
||||
return app_config.precision
|
||||
elif device.type == "mps":
|
||||
if app_config.precision == "auto" or app_config.precision == "autocast":
|
||||
# Default to float16 for MPS devices
|
||||
return "float16"
|
||||
else:
|
||||
# Use the user-defined precision
|
||||
return app_config.precision
|
||||
# CPU / safe fallback
|
||||
return "float32"
|
||||
|
||||
|
||||
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
||||
device = device or choose_torch_device()
|
||||
precision = choose_precision(device)
|
||||
if precision == "float16":
|
||||
return torch.float16
|
||||
if precision == "bfloat16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
# "auto", "autocast", "float32"
|
||||
return torch.float32
|
||||
|
||||
|
||||
def choose_autocast(precision: PRECISION):
|
||||
"""Returns an autocast context or nullcontext for the given precision string"""
|
||||
# float16 currently requires autocast to avoid errors like:
|
||||
# 'expected scalar type Half but found Float'
|
||||
if precision == "autocast" or precision == "float16":
|
||||
return autocast
|
||||
return nullcontext
|
||||
|
||||
|
||||
def normalize_device(device: Union[str, torch.device]) -> torch.device:
|
||||
"""Ensure device has a device index defined, if appropriate."""
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
# cuda might be the only torch backend that currently uses the device index?
|
||||
# I don't see anything like `current_device` for cpu or mps.
|
||||
if device.type == "cuda":
|
||||
@classmethod
|
||||
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
|
||||
"""Add the device index to CUDA devices."""
|
||||
device = torch.device(device)
|
||||
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
|
||||
device = torch.device(device.type, torch.cuda.current_device())
|
||||
return device
|
||||
return device
|
||||
|
||||
@classmethod
|
||||
def empty_cache(cls) -> None:
|
||||
"""Clear the GPU device cache."""
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
Reference in New Issue
Block a user