mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
move autocast device selection to a function
This commit is contained in:
parent
fa98601bfb
commit
09bd9fa47e
@ -8,4 +8,10 @@ def choose_torch_device() -> str:
|
||||
return 'mps'
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def choose_autocast_device(device) -> str:
|
||||
'''Returns an autocast compatible device from a torch device'''
|
||||
device_type = device.type # this returns 'mps' on M1
|
||||
# autocast only supports cuda or cpu
|
||||
if device_type != 'cuda' or device_type != 'cpu':
|
||||
return 'cpu'
|
||||
return device_type
|
||||
|
@ -27,7 +27,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
from ldm.dream.devices import choose_autocast_device, choose_torch_device
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
@ -315,9 +315,7 @@ class T2I:
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
device_type = self.device.type # this returns 'mps' on M1
|
||||
if device_type != 'cuda' or device_type != 'cpu':
|
||||
device_type = 'cpu'
|
||||
device_type = choose_autocast_device(self.device)
|
||||
with scope(device_type), self.model.ema_scope():
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
seed_everything(seed)
|
||||
|
Loading…
Reference in New Issue
Block a user