Merge pull request #105 from shusso/select-device

Move torch.device selection to it's own function
This commit is contained in:
Lincoln Stein 2022-08-26 12:23:21 -04:00 committed by GitHub
commit 800132970e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -475,17 +475,21 @@ class T2I:
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed
def _get_device(self):
if torch.cuda.is_available():
return torch.device('cuda')
elif torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
def load_model(self):
"""Load and initialize the model from configuration variables passed at object creation time"""
if self.model is None:
seed_everything(self.seed)
try:
config = OmegaConf.load(self.config)
self.device = (
torch.device(self.device)
if torch.cuda.is_available()
else torch.device('cpu')
)
self.device = self._get_device()
model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None:
model.embedding_manager.load(self.embedding_path)