mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[util] Add generic torch device class (#6174)
* introduce new abstraction layer for GPU devices * add unit test for device abstraction * fix ruff * convert TorchDeviceSelect into a stateless class * move logic to select context-specific execution device into context API * add mock hardware environments to pytest * remove dangling mocker fixture * fix unit test for running on non-CUDA systems * remove unimplemented get_execution_device() call * remove autocast precision * Multiple changes: 1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to context.models.get_execution_device(). 2. Rename TorchDeviceSelect to TorchDevice 3. Added back the legacy public API defined in `invocation_api`, including choose_precision(). 4. Added a config file migration script to accommodate removal of precision=autocast. * add deprecation warnings to choose_torch_device() and choose_precision() * fix test crash * remove app_config argument from choose_torch_device() and choose_torch_dtype() --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
parent
5a8489bbfc
commit
e93f4d632d
@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
|||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.backend.util.devices import get_torch_device_name
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from ..backend.util.logging import InvokeAILogger
|
from ..backend.util.logging import InvokeAILogger
|
||||||
from .api.dependencies import ApiDependencies
|
from .api.dependencies import ApiDependencies
|
||||||
@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
|||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type("text/css", ".css")
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
|
||||||
torch_device_name = get_torch_device_name()
|
torch_device_name = TorchDevice.get_torch_device_name()
|
||||||
logger.info(f"Using torch device: {torch_device_name}")
|
logger.info(f"Using torch device: {torch_device_name}")
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
ConditioningFieldData,
|
ConditioningFieldData,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.devices import torch_dtype
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from .model import CLIPField
|
from .model import CLIPField
|
||||||
@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||||
truncate_long_prompts=False,
|
truncate_long_prompts=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -193,7 +193,7 @@ class SDXLPromptInvocationBase:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=torch_dtype,
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||||
truncate_long_prompts=False, # TODO:
|
truncate_long_prompts=False, # TODO:
|
||||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
requires_pooled=get_pooled,
|
requires_pooled=get_pooled,
|
||||||
|
@ -72,15 +72,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import TorchDevice
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
from .controlnet_image_processors import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelIdentifierField, UNetField, VAEField
|
from .model import ModelIdentifierField, UNetField, VAEField
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("scheduler_output")
|
@invocation_output("scheduler_output")
|
||||||
@ -959,9 +956,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=result_latents)
|
name = context.tensors.save(tensor=result_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||||
@ -1028,9 +1023,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
# clear memory as vae decode can request a lot
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# copied from diffusers pipeline
|
# copied from diffusers pipeline
|
||||||
@ -1042,9 +1035,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=image)
|
image_dto = context.images.save(image=image)
|
||||||
|
|
||||||
@ -1083,9 +1074,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
device = TorchDevice.choose_torch_device()
|
||||||
# TODO:
|
|
||||||
device = choose_torch_device()
|
|
||||||
|
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
latents.to(device),
|
latents.to(device),
|
||||||
@ -1096,9 +1085,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
resized_latents = resized_latents.to("cpu")
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if device == torch.device("mps"):
|
TorchDevice.empty_cache()
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=resized_latents)
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
@ -1125,8 +1113,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
latents = context.tensors.load(self.latents.latents_name)
|
latents = context.tensors.load(self.latents.latents_name)
|
||||||
|
|
||||||
# TODO:
|
device = TorchDevice.choose_torch_device()
|
||||||
device = choose_torch_device()
|
|
||||||
|
|
||||||
# resizing
|
# resizing
|
||||||
resized_latents = torch.nn.functional.interpolate(
|
resized_latents = torch.nn.functional.interpolate(
|
||||||
@ -1138,9 +1125,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
resized_latents = resized_latents.to("cpu")
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if device == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=resized_latents)
|
name = context.tensors.save(tensor=resized_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||||
@ -1272,8 +1257,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
if latents_a.shape != latents_b.shape:
|
if latents_a.shape != latents_b.shape:
|
||||||
raise Exception("Latents to blend must be the same size.")
|
raise Exception("Latents to blend must be the same size.")
|
||||||
|
|
||||||
# TODO:
|
device = TorchDevice.choose_torch_device()
|
||||||
device = choose_torch_device()
|
|
||||||
|
|
||||||
def slerp(
|
def slerp(
|
||||||
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
|
||||||
@ -1326,9 +1310,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
blended_latents = blended_latents.to("cpu")
|
blended_latents = blended_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
|
||||||
if device == torch.device("mps"):
|
TorchDevice.empty_cache()
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
name = context.tensors.save(tensor=blended_latents)
|
name = context.tensors.save(tensor=blended_latents)
|
||||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||||
|
@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.misc import SEED_MAX
|
from invokeai.app.util.misc import SEED_MAX
|
||||||
|
|
||||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
from ...backend.util.devices import TorchDevice
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
@ -46,7 +46,7 @@ def get_noise(
|
|||||||
height // downsampling_factor,
|
height // downsampling_factor,
|
||||||
width // downsampling_factor,
|
width // downsampling_factor,
|
||||||
],
|
],
|
||||||
dtype=torch_dtype(device),
|
dtype=TorchDevice.choose_torch_dtype(device=device),
|
||||||
device=noise_device_type,
|
device=noise_device_type,
|
||||||
generator=generator,
|
generator=generator,
|
||||||
).to("cpu")
|
).to("cpu")
|
||||||
@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
|
|||||||
|
|
||||||
@field_validator("seed", mode="before")
|
@field_validator("seed", mode="before")
|
||||||
def modulo_seed(cls, v):
|
def modulo_seed(cls, v):
|
||||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||||
return v % (SEED_MAX + 1)
|
return v % (SEED_MAX + 1)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||||
noise = get_noise(
|
noise = get_noise(
|
||||||
width=self.width,
|
width=self.width,
|
||||||
height=self.height,
|
height=self.height,
|
||||||
device=choose_torch_device(),
|
device=TorchDevice.choose_torch_device(),
|
||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
use_cpu=self.use_cpu,
|
use_cpu=self.use_cpu,
|
||||||
)
|
)
|
||||||
|
@ -4,7 +4,6 @@ from typing import Literal
|
|||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
@ -14,7 +13,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, invocation
|
from .baseinvocation import BaseInvocation, invocation
|
||||||
from .fields import InputField, WithBoard, WithMetadata
|
from .fields import InputField, WithBoard, WithMetadata
|
||||||
@ -35,9 +34,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
|
|||||||
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
}
|
}
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
|
||||||
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
@ -120,9 +116,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
upscaled_image = upscaler.upscale(cv2_image)
|
upscaled_image = upscaler.upscale(cv2_image)
|
||||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
|
|
||||||
|
@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
|
|||||||
DEFAULT_VRAM_CACHE = 0.25
|
DEFAULT_VRAM_CACHE = 0.25
|
||||||
DEFAULT_CONVERT_CACHE = 20.0
|
DEFAULT_CONVERT_CACHE = 20.0
|
||||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||||
CONFIG_SCHEMA_VERSION = "4.0.0"
|
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||||
|
|
||||||
|
|
||||||
def get_default_ram_cache_size() -> float:
|
def get_default_ram_cache_size() -> float:
|
||||||
@ -105,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
lazy_offload: Keep models in VRAM until their space is needed.
|
lazy_offload: Keep models in VRAM until their space is needed.
|
||||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||||
@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
|||||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||||
parsed_config_dict["vram"] = v
|
parsed_config_dict["vram"] = v
|
||||||
|
# autocast was removed in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
if k == "conf_path":
|
if k == "conf_path":
|
||||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||||
if k == "legacy_conf_dir":
|
if k == "legacy_conf_dir":
|
||||||
@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||||
|
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||||
|
"""
|
||||||
|
parsed_config_dict: dict[str, Any] = {}
|
||||||
|
for k, v in config_dict.items():
|
||||||
|
# autocast was removed from precision in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
|
else:
|
||||||
|
parsed_config_dict[k] = v
|
||||||
|
if k == "schema_version":
|
||||||
|
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||||
|
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||||
"""Load and migrate a config file to the latest version.
|
"""Load and migrate a config file to the latest version.
|
||||||
|
|
||||||
@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||||
migrated_config.write_file(config_path)
|
migrated_config.write_file(config_path)
|
||||||
return migrated_config
|
return migrated_config
|
||||||
else:
|
|
||||||
# Attempt to load as a v4 config file
|
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||||
try:
|
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||||
# Meta is not included in the model fields, so we need to validate it separately
|
loaded_config_dict.write_file(config_path)
|
||||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
|
||||||
assert (
|
# Attempt to load as a v4 config file
|
||||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
try:
|
||||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
# Meta is not included in the model fields, so we need to validate it separately
|
||||||
return config
|
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||||
except Exception as e:
|
assert (
|
||||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||||
|
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
|
@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree
|
|||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
from huggingface_hub import HfFolder
|
from huggingface_hub import HfFolder
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
|
|||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import InvokeAILogger
|
from invokeai.backend.util import InvokeAILogger
|
||||||
from invokeai.backend.util.devices import choose_precision, choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .model_install_base import (
|
from .model_install_base import (
|
||||||
MODEL_SOURCE_TO_TYPE_MAP,
|
MODEL_SOURCE_TO_TYPE_MAP,
|
||||||
@ -634,11 +635,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._next_job_id += 1
|
self._next_job_id += 1
|
||||||
return id
|
return id
|
||||||
|
|
||||||
@staticmethod
|
def _guess_variant(self) -> Optional[ModelRepoVariant]:
|
||||||
def _guess_variant() -> Optional[ModelRepoVariant]:
|
|
||||||
"""Guess the best HuggingFace variant type to download."""
|
"""Guess the best HuggingFace variant type to download."""
|
||||||
precision = choose_precision(choose_torch_device())
|
precision = TorchDevice.choose_torch_dtype()
|
||||||
return ModelRepoVariant.FP16 if precision == "float16" else None
|
return ModelRepoVariant.FP16 if precision == torch.float16 else None
|
||||||
|
|
||||||
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
|
||||||
return ModelInstallJob(
|
return ModelInstallJob(
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
model_record_service: ModelRecordServiceBase,
|
model_record_service: ModelRecordServiceBase,
|
||||||
download_queue: DownloadQueueServiceBase,
|
download_queue: DownloadQueueServiceBase,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
execution_device: torch.device = choose_torch_device(),
|
execution_device: Optional[torch.device] = None,
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""
|
"""
|
||||||
Construct the model manager service instance.
|
Construct the model manager service instance.
|
||||||
@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
max_vram_cache_size=app_config.vram,
|
max_vram_cache_size=app_config.vram,
|
||||||
lazy_offloading=app_config.lazy_offload,
|
lazy_offloading=app_config.lazy_offload,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
execution_device=execution_device,
|
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||||
loader = ModelLoadService(
|
loader = ModelLoadService(
|
||||||
|
@ -13,7 +13,7 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
@ -56,7 +56,7 @@ class DepthAnythingDetector:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.model = None
|
self.model = None
|
||||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||||
self.device = choose_torch_device()
|
self.device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||||
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]
|
||||||
@ -81,7 +81,7 @@ class DepthAnythingDetector:
|
|||||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
self.model.to(choose_torch_device())
|
self.model.to(self.device)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||||
@ -94,7 +94,7 @@ class DepthAnythingDetector:
|
|||||||
|
|
||||||
image_height, image_width = np_image.shape[:2]
|
image_height, image_width = np_image.shape[:2]
|
||||||
np_image = transform({"image": np_image})["image"]
|
np_image = transform({"image": np_image})["image"]
|
||||||
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
|
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
depth = self.model(tensor_image)
|
depth = self.model(tensor_image)
|
||||||
|
@ -7,7 +7,7 @@ import onnxruntime as ort
|
|||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .onnxdet import inference_detector
|
from .onnxdet import inference_detector
|
||||||
from .onnxpose import inference_pose
|
from .onnxpose import inference_pose
|
||||||
@ -28,9 +28,9 @@ config = get_config()
|
|||||||
|
|
||||||
class Wholebody:
|
class Wholebody:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
device = choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
|
||||||
|
|
||||||
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"]
|
||||||
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||||
|
@ -8,7 +8,7 @@ from PIL import Image
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
def norm_img(np_img):
|
def norm_img(np_img):
|
||||||
@ -29,7 +29,7 @@ def load_jit_model(url_or_path, device):
|
|||||||
|
|
||||||
class LaMA:
|
class LaMA:
|
||||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||||
device = choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||||
|
|
||||||
if not model_location.exists():
|
if not model_location.exists():
|
||||||
|
@ -11,7 +11,7 @@ from cv2.typing import MatLike
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
||||||
@ -65,7 +65,7 @@ class RealESRGAN:
|
|||||||
self.pre_pad = pre_pad
|
self.pre_pad = pre_pad
|
||||||
self.mod_scale: Optional[int] = None
|
self.mod_scale: Optional[int] = None
|
||||||
self.half = half
|
self.half = half
|
||||||
self.device = choose_torch_device()
|
self.device = TorchDevice.choose_torch_device()
|
||||||
|
|
||||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
|
|||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||||
@ -51,7 +51,7 @@ class SafetyChecker:
|
|||||||
cls._load_safety_checker()
|
cls._load_safety_checker()
|
||||||
if cls.safety_checker is None or cls.feature_extractor is None:
|
if cls.safety_checker is None or cls.feature_extractor is None:
|
||||||
return False
|
return False
|
||||||
device = choose_torch_device()
|
device = TorchDevice.choose_torch_device()
|
||||||
features = cls.feature_extractor([image], return_tensors="pt")
|
features = cls.feature_extractor([image], return_tensors="pt")
|
||||||
features.to(device)
|
features.to(device)
|
||||||
cls.safety_checker.to(device)
|
cls.safety_checker.to(device)
|
||||||
|
@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
|
|||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
|
||||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
# TO DO: The loader is not thread safe!
|
# TO DO: The loader is not thread safe!
|
||||||
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
|
|||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._ram_cache = ram_cache
|
self._ram_cache = ram_cache
|
||||||
self._convert_cache = convert_cache
|
self._convert_cache = convert_cache
|
||||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
|
@ -30,15 +30,12 @@ import torch
|
|||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||||
from .model_locker import ModelLocker
|
from .model_locker import ModelLocker
|
||||||
|
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
from torch import mps
|
|
||||||
|
|
||||||
# Maximum size of the cache, in gigs
|
# Maximum size of the cache, in gigs
|
||||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||||
@ -244,9 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||||
"""Move model into the indicated device.
|
"""Move model into the indicated device.
|
||||||
@ -416,10 +411,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
self.stats.cleared = models_cleared
|
self.stats.cleared = models_cleared
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
|
||||||
mps.empty_cache()
|
|
||||||
|
|
||||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||||
|
@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
|
|||||||
|
|
||||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -43,6 +43,7 @@ class ModelMerger(object):
|
|||||||
Initialize a ModelMerger object with the model installer.
|
Initialize a ModelMerger object with the model installer.
|
||||||
"""
|
"""
|
||||||
self._installer = installer
|
self._installer = installer
|
||||||
|
self._dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
|
||||||
def merge_diffusion_models(
|
def merge_diffusion_models(
|
||||||
self,
|
self,
|
||||||
@ -68,7 +69,7 @@ class ModelMerger(object):
|
|||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
verbosity = dlogging.get_verbosity()
|
verbosity = dlogging.get_verbosity()
|
||||||
dlogging.set_verbosity_error()
|
dlogging.set_verbosity_error()
|
||||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||||
|
|
||||||
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
|
||||||
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
|
||||||
@ -151,7 +152,7 @@ class ModelMerger(object):
|
|||||||
dump_path.mkdir(parents=True, exist_ok=True)
|
dump_path.mkdir(parents=True, exist_ok=True)
|
||||||
dump_path = dump_path / merged_model_name
|
dump_path = dump_path / merged_model_name
|
||||||
|
|
||||||
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
|
dtype = torch.float16 if variant == "fp16" else self._dtype
|
||||||
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
|
||||||
|
|
||||||
# register model and get its unique key
|
# register model and get its unique key
|
||||||
|
@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||||
from invokeai.backend.util.devices import normalize_device
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -258,7 +258,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||||
mem_free = psutil.virtual_memory().free
|
mem_free = psutil.virtual_memory().free
|
||||||
elif self.unet.device.type == "cuda":
|
elif self.unet.device.type == "cuda":
|
||||||
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
|
mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unrecognized device {self.unet.device}")
|
raise ValueError(f"unrecognized device {self.unet.device}")
|
||||||
# input tensor of [1, 4, h/8, w/8]
|
# input tensor of [1, 4, h/8, w/8]
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
Initialization file for invokeai.backend.util
|
Initialization file for invokeai.backend.util
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .devices import choose_precision, choose_torch_device
|
|
||||||
from .logging import InvokeAILogger
|
from .logging import InvokeAILogger
|
||||||
from .util import GIG, Chdir, directory_size
|
from .util import GIG, Chdir, directory_size
|
||||||
|
|
||||||
@ -11,6 +10,4 @@ __all__ = [
|
|||||||
"directory_size",
|
"directory_size",
|
||||||
"Chdir",
|
"Chdir",
|
||||||
"InvokeAILogger",
|
"InvokeAILogger",
|
||||||
"choose_precision",
|
|
||||||
"choose_torch_device",
|
|
||||||
]
|
]
|
||||||
|
@ -1,89 +1,110 @@
|
|||||||
from __future__ import annotations
|
from typing import Dict, Literal, Optional, Union
|
||||||
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import autocast
|
from deprecated import deprecated
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import PRECISION, get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
|
||||||
|
# legacy APIs
|
||||||
|
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
MPS_DEVICE = torch.device("mps")
|
MPS_DEVICE = torch.device("mps")
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||||
|
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
||||||
|
"""Return the string representation of the recommended torch device."""
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype(device)
|
||||||
|
return PRECISION_TO_NAME[torch_dtype]
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
|
||||||
def choose_torch_device() -> torch.device:
|
def choose_torch_device() -> torch.device:
|
||||||
"""Convenience routine for guessing which GPU device to run model on"""
|
"""Return the torch.device to use for accelerated inference."""
|
||||||
config = get_config()
|
return TorchDevice.choose_torch_device()
|
||||||
if config.device == "auto":
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return torch.device("cuda")
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
def torch_dtype(device: torch.device) -> torch.dtype:
|
||||||
return torch.device("mps")
|
"""Return the torch precision for the recommended torch device."""
|
||||||
|
return TorchDevice.choose_torch_dtype(device)
|
||||||
|
|
||||||
|
|
||||||
|
NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class TorchDevice:
|
||||||
|
"""Abstraction layer for torch devices."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def choose_torch_device(cls) -> torch.device:
|
||||||
|
"""Return the torch.device to use for accelerated inference."""
|
||||||
|
app_config = get_config()
|
||||||
|
if app_config.device != "auto":
|
||||||
|
device = torch.device(app_config.device)
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
device = CUDA_DEVICE
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = MPS_DEVICE
|
||||||
else:
|
else:
|
||||||
return CPU_DEVICE
|
device = CPU_DEVICE
|
||||||
else:
|
return cls.normalize(device)
|
||||||
return torch.device(config.device)
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||||
|
"""Return the precision to use for accelerated inference."""
|
||||||
|
device = device or cls.choose_torch_device()
|
||||||
|
config = get_config()
|
||||||
|
if device.type == "cuda" and torch.cuda.is_available():
|
||||||
|
device_name = torch.cuda.get_device_name(device)
|
||||||
|
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||||
|
# These GPUs have limited support for float16
|
||||||
|
return cls._to_dtype("float32")
|
||||||
|
elif config.precision == "auto":
|
||||||
|
# Default to float16 for CUDA devices
|
||||||
|
return cls._to_dtype("float16")
|
||||||
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return cls._to_dtype(config.precision)
|
||||||
|
|
||||||
def get_torch_device_name() -> str:
|
elif device.type == "mps" and torch.backends.mps.is_available():
|
||||||
device = choose_torch_device()
|
if config.precision == "auto":
|
||||||
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
# Default to float16 for MPS devices
|
||||||
|
return cls._to_dtype("float16")
|
||||||
|
else:
|
||||||
|
# Use the user-defined precision
|
||||||
|
return cls._to_dtype(config.precision)
|
||||||
|
# CPU / safe fallback
|
||||||
|
return cls._to_dtype("float32")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_torch_device_name(cls) -> str:
|
||||||
|
"""Return the device name for the current torch device."""
|
||||||
|
device = cls.choose_torch_device()
|
||||||
|
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
|
||||||
|
|
||||||
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
|
@classmethod
|
||||||
"""Return an appropriate precision for the given torch device."""
|
def normalize(cls, device: Union[str, torch.device]) -> torch.device:
|
||||||
app_config = get_config()
|
"""Add the device index to CUDA devices."""
|
||||||
if device.type == "cuda":
|
device = torch.device(device)
|
||||||
device_name = torch.cuda.get_device_name(device)
|
if device.index is None and device.type == "cuda" and torch.cuda.is_available():
|
||||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
|
||||||
# These GPUs have limited support for float16
|
|
||||||
return "float32"
|
|
||||||
elif app_config.precision == "auto" or app_config.precision == "autocast":
|
|
||||||
# Default to float16 for CUDA devices
|
|
||||||
return "float16"
|
|
||||||
else:
|
|
||||||
# Use the user-defined precision
|
|
||||||
return app_config.precision
|
|
||||||
elif device.type == "mps":
|
|
||||||
if app_config.precision == "auto" or app_config.precision == "autocast":
|
|
||||||
# Default to float16 for MPS devices
|
|
||||||
return "float16"
|
|
||||||
else:
|
|
||||||
# Use the user-defined precision
|
|
||||||
return app_config.precision
|
|
||||||
# CPU / safe fallback
|
|
||||||
return "float32"
|
|
||||||
|
|
||||||
|
|
||||||
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
|
|
||||||
device = device or choose_torch_device()
|
|
||||||
precision = choose_precision(device)
|
|
||||||
if precision == "float16":
|
|
||||||
return torch.float16
|
|
||||||
if precision == "bfloat16":
|
|
||||||
return torch.bfloat16
|
|
||||||
else:
|
|
||||||
# "auto", "autocast", "float32"
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
|
|
||||||
def choose_autocast(precision: PRECISION):
|
|
||||||
"""Returns an autocast context or nullcontext for the given precision string"""
|
|
||||||
# float16 currently requires autocast to avoid errors like:
|
|
||||||
# 'expected scalar type Half but found Float'
|
|
||||||
if precision == "autocast" or precision == "float16":
|
|
||||||
return autocast
|
|
||||||
return nullcontext
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_device(device: Union[str, torch.device]) -> torch.device:
|
|
||||||
"""Ensure device has a device index defined, if appropriate."""
|
|
||||||
device = torch.device(device)
|
|
||||||
if device.index is None:
|
|
||||||
# cuda might be the only torch backend that currently uses the device index?
|
|
||||||
# I don't see anything like `current_device` for cpu or mps.
|
|
||||||
if device.type == "cuda":
|
|
||||||
device = torch.device(device.type, torch.cuda.current_device())
|
device = torch.device(device.type, torch.cuda.current_device())
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty_cache(cls) -> None:
|
||||||
|
"""Clear the GPU device cache."""
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||||
|
return NAME_TO_PRECISION[precision_name]
|
||||||
|
132
tests/backend/util/test_devices.py
Normal file
132
tests/backend/util/test_devices.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
Test abstract device class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.services.config import get_config
|
||||||
|
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||||
|
|
||||||
|
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||||
|
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
||||||
|
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
|
||||||
|
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_name", devices)
|
||||||
|
def test_device_choice(device_name):
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_device = TorchDevice.choose_torch_device()
|
||||||
|
assert torch_device == torch.device(device_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||||
|
def test_device_dtype_cpu(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||||
|
def test_device_dtype_cuda(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_mps)
|
||||||
|
def test_device_dtype_mps(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=True),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||||
|
def test_device_dtype_override(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
config.precision = "float32"
|
||||||
|
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||||
|
assert torch_dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize():
|
||||||
|
assert (
|
||||||
|
TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda")
|
||||||
|
)
|
||||||
|
assert TorchDevice.normalize("mps") == torch.device("mps")
|
||||||
|
assert TorchDevice.normalize("cpu") == torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_name", devices)
|
||||||
|
def test_legacy_device_choice(device_name):
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
torch_device = choose_torch_device()
|
||||||
|
assert torch_device == torch.device(device_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||||
|
def test_legacy_device_dtype_cpu(device_dtype_pair):
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.is_available", return_value=False),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=False),
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||||
|
):
|
||||||
|
device_name, dtype = device_dtype_pair
|
||||||
|
config = get_config()
|
||||||
|
config.device = device_name
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
torch_device = choose_torch_device()
|
||||||
|
returned_dtype = torch_dtype(torch_device)
|
||||||
|
assert returned_dtype == dtype
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_precision_name():
|
||||||
|
config = get_config()
|
||||||
|
config.precision = "auto"
|
||||||
|
with (
|
||||||
|
pytest.deprecated_call(),
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
patch("torch.backends.mps.is_available", return_value=True),
|
||||||
|
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||||
|
):
|
||||||
|
assert "float16" == choose_precision(torch.device("cuda"))
|
||||||
|
assert "float16" == choose_precision(torch.device("mps"))
|
||||||
|
assert "float32" == choose_precision(torch.device("cpu"))
|
Loading…
Reference in New Issue
Block a user