Move torch.device selection to it's own function

This commit is contained in:
Samuel Husso 2022-08-26 14:43:18 +03:00
parent 4f02b72c9c
commit ed72ff3268

View File

@ -470,17 +470,21 @@ class T2I:
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed
def _get_device(self):
if torch.cuda.is_available():
return torch.device('cuda')
elif torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
def load_model(self):
"""Load and initialize the model from configuration variables passed at object creation time"""
if self.model is None:
seed_everything(self.seed)
try:
config = OmegaConf.load(self.config)
self.device = (
torch.device(self.device)
if torch.cuda.is_available()
else torch.device('cpu')
)
self.device = self._get_device()
model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None:
model.embedding_manager.load(self.embedding_path)