diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 8aad3557af..070f5ea8e2 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -470,17 +470,21 @@ class T2I: self.seed = random.randrange(0, np.iinfo(np.uint32).max) return self.seed + def _get_device(self): + if torch.cuda.is_available(): + return torch.device('cuda') + elif torch.backends.mps.is_available(): + return torch.device('mps') + else: + return torch.device('cpu') + def load_model(self): """Load and initialize the model from configuration variables passed at object creation time""" if self.model is None: seed_everything(self.seed) try: config = OmegaConf.load(self.config) - self.device = ( - torch.device(self.device) - if torch.cuda.is_available() - else torch.device('cpu') - ) + self.device = self._get_device() model = self._load_model_from_config(config, self.weights) if self.embedding_path is not None: model.embedding_manager.load(self.embedding_path)