diff --git a/scripts/orig_scripts/img2img.py b/scripts/orig_scripts/img2img.py
index 4bbafcad01..fcd0b8cdfa 100644
--- a/scripts/orig_scripts/img2img.py
+++ b/scripts/orig_scripts/img2img.py
@@ -200,7 +200,7 @@ def main():
     config = OmegaConf.load(f"{opt.config}")
     model = load_model_from_config(config, f"{opt.ckpt}")
 
-    device = choose_torch_device()
+    device = torch.device(choose_torch_device())
     model = model.to(device)
 
     if opt.plms: