Fixes torch.load() for MPS/CPU

This commit is contained in:
psychedelicious 2022-10-20 12:47:46 +08:00 committed by Lincoln Stein
parent ed9307f469
commit bfa65560eb

View File

@ -121,11 +121,11 @@ try:
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, ) model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval() model.eval()
model.load_state_dict( model.load_state_dict(
torch.load('src/clipseg/weights/rd64-uni-refined.pth'), torch.load(
model.load_state_dict(torch.load('src/clipseg/weights/rd64-uni-refined.pth'), 'src/clipseg/weights/rd64-uni-refined.pth',
map_location=torch.device('cpu'), map_location=torch.device('cpu')
strict=False, ),
) strict=False,
) )
except Exception: except Exception:
print('Error installing clipseg model:') print('Error installing clipseg model:')