diff --git a/scripts/preload_models.py b/scripts/preload_models.py index 2ef344f8c3..b23bec11f3 100644 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -120,7 +120,12 @@ try: from models.clipseg import CLIPDensePredT model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, ) model.eval() - model.load_state_dict(torch.load('src/clipseg/weights/rd64-uni-refined.pth'), strict=False) + model.load_state_dict( + torch.load('src/clipseg/weights/rd64-uni-refined.pth'), + model.load_state_dict(torch.load('src/clipseg/weights/rd64-uni-refined.pth'), + map_location=torch.device('cpu'), + strict=False, + ) except Exception: print('Error installing clipseg model:') print(traceback.format_exc())