mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[util] Add generic torch device class (#6174)
* introduce new abstraction layer for GPU devices * add unit test for device abstraction * fix ruff * convert TorchDeviceSelect into a stateless class * move logic to select context-specific execution device into context API * add mock hardware environments to pytest * remove dangling mocker fixture * fix unit test for running on non-CUDA systems * remove unimplemented get_execution_device() call * remove autocast precision * Multiple changes: 1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to context.models.get_execution_device(). 2. Rename TorchDeviceSelect to TorchDevice 3. Added back the legacy public API defined in `invocation_api`, including choose_precision(). 4. Added a config file migration script to accommodate removal of precision=autocast. * add deprecation warnings to choose_torch_device() and choose_precision() * fix test crash * remove app_config argument from choose_torch_device() and choose_torch_dtype() --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from ...backend.util.devices import TorchDevice
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@ -46,7 +46,7 @@ def get_noise(
|
||||
height // downsampling_factor,
|
||||
width // downsampling_factor,
|
||||
],
|
||||
dtype=torch_dtype(device),
|
||||
dtype=TorchDevice.choose_torch_dtype(device=device),
|
||||
device=noise_device_type,
|
||||
generator=generator,
|
||||
).to("cpu")
|
||||
@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
|
||||
|
||||
@field_validator("seed", mode="before")
|
||||
def modulo_seed(cls, v):
|
||||
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> NoiseOutput:
|
||||
noise = get_noise(
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
device=choose_torch_device(),
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
seed=self.seed,
|
||||
use_cpu=self.use_cpu,
|
||||
)
|
||||
|
Reference in New Issue
Block a user