Fixes torch.load() for MPS/CPU

This commit is contained in:
psychedelicious 2022-10-20 12:47:46 +08:00 committed by Lincoln Stein
parent ed9307f469
commit bfa65560eb

View File

@ -121,11 +121,11 @@ try:
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval()
model.load_state_dict(
torch.load('src/clipseg/weights/rd64-uni-refined.pth'),
model.load_state_dict(torch.load('src/clipseg/weights/rd64-uni-refined.pth'),
map_location=torch.device('cpu'),
strict=False,
)
torch.load(
'src/clipseg/weights/rd64-uni-refined.pth',
map_location=torch.device('cpu')
),
strict=False,
)
except Exception:
print('Error installing clipseg model:')