don't use cpu state_dict for model unpatching when executing on cpu (#6631)

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein 2024-07-18 15:34:01 -04:00 committed by GitHub
parent 0583101c1c
commit 97a7f51721
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -167,7 +167,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
size = calc_model_size_by_data(self.logger, model) size = calc_model_size_by_data(self.logger, model)
self.make_room(size) self.make_room(size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None running_on_cpu = self.execution_device == torch.device("cpu")
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)