[util] Add generic torch device class (#6174)

* introduce new abstraction layer for GPU devices

* add unit test for device abstraction

* fix ruff

* convert TorchDeviceSelect into a stateless class

* move logic to select context-specific execution device into context API

* add mock hardware environments to pytest

* remove dangling mocker fixture

* fix unit test for running on non-CUDA systems

* remove unimplemented get_execution_device() call

* remove autocast precision

* Multiple changes:

1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to
   context.models.get_execution_device().
2. Rename TorchDeviceSelect to TorchDevice
3. Added back the legacy public API defined in `invocation_api`, including
   choose_precision().
4. Added a config file migration script to accommodate removal of precision=autocast.

* add deprecation warnings to choose_torch_device() and choose_precision()

* fix test crash

* remove app_config argument from choose_torch_device() and choose_torch_dtype()

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein
2024-04-15 09:12:49 -04:00
committed by GitHub
parent 5a8489bbfc
commit e93f4d632d
20 changed files with 327 additions and 176 deletions

View File

@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import get_torch_device_name
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
torch_device_name = get_torch_device_name()
torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")