From ed72ff3268313500bc00d26faba328ea0ba6303c Mon Sep 17 00:00:00 2001 From: Samuel Husso Date: Fri, 26 Aug 2022 14:43:18 +0300 Subject: [PATCH] Move torch.device selection to it's own function --- ldm/simplet2i.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 8aad3557af..070f5ea8e2 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -470,17 +470,21 @@ class T2I: self.seed = random.randrange(0, np.iinfo(np.uint32).max) return self.seed + def _get_device(self): + if torch.cuda.is_available(): + return torch.device('cuda') + elif torch.backends.mps.is_available(): + return torch.device('mps') + else: + return torch.device('cpu') + def load_model(self): """Load and initialize the model from configuration variables passed at object creation time""" if self.model is None: seed_everything(self.seed) try: config = OmegaConf.load(self.config) - self.device = ( - torch.device(self.device) - if torch.cuda.is_available() - else torch.device('cpu') - ) + self.device = self._get_device() model = self._load_model_from_config(config, self.weights) if self.embedding_path is not None: model.embedding_manager.load(self.embedding_path)