mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge pull request #105 from shusso/select-device
Move torch.device selection to it's own function
This commit is contained in:
commit
800132970e
@ -475,17 +475,21 @@ class T2I:
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
|
||||
def _get_device(self):
|
||||
if torch.cuda.is_available():
|
||||
return torch.device('cuda')
|
||||
elif torch.backends.mps.is_available():
|
||||
return torch.device('mps')
|
||||
else:
|
||||
return torch.device('cpu')
|
||||
|
||||
def load_model(self):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if self.model is None:
|
||||
seed_everything(self.seed)
|
||||
try:
|
||||
config = OmegaConf.load(self.config)
|
||||
self.device = (
|
||||
torch.device(self.device)
|
||||
if torch.cuda.is_available()
|
||||
else torch.device('cpu')
|
||||
)
|
||||
self.device = self._get_device()
|
||||
model = self._load_model_from_config(config, self.weights)
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(self.embedding_path)
|
||||
|
Loading…
Reference in New Issue
Block a user