diff --git a/scripts/orig_scripts/img2img.py b/scripts/orig_scripts/img2img.py index 4bbafcad01..fcd0b8cdfa 100644 --- a/scripts/orig_scripts/img2img.py +++ b/scripts/orig_scripts/img2img.py @@ -200,7 +200,7 @@ def main(): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") - device = choose_torch_device() + device = torch.device(choose_torch_device()) model = model.to(device) if opt.plms: