proposed fix to work on mps systems

This commit is contained in:
Lincoln Stein
2022-10-12 11:08:27 -04:00
parent b537e92789
commit aa6aa68753
2 changed files with 6 additions and 3 deletions

View File

@ -195,15 +195,17 @@ class ModelCache(object):
torch.cuda.empty_cache()
def _model_to_cpu(self,model):
if self._has_cuda():
if self.device != 'cpu':
model.cond_stage_model.device = 'cpu'
model.first_stage_model.to('cpu')
model.cond_stage_model.to('cpu')
model.model.to('cpu')
return model.to('cpu')
return model.to('cpu')
else:
return model
def _model_from_cpu(self,model):
if self._has_cuda():
if self.device != 'cpu':
model.to(self.device)
model.first_stage_model.to(self.device)
model.cond_stage_model.to(self.device)