mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move torch.device selection to it's own function
This commit is contained in:
parent
4f02b72c9c
commit
ed72ff3268
@ -470,17 +470,21 @@ class T2I:
|
|||||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
return self.seed
|
return self.seed
|
||||||
|
|
||||||
|
def _get_device(self):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device('cuda')
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
return torch.device('mps')
|
||||||
|
else:
|
||||||
|
return torch.device('cpu')
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
seed_everything(self.seed)
|
seed_everything(self.seed)
|
||||||
try:
|
try:
|
||||||
config = OmegaConf.load(self.config)
|
config = OmegaConf.load(self.config)
|
||||||
self.device = (
|
self.device = self._get_device()
|
||||||
torch.device(self.device)
|
|
||||||
if torch.cuda.is_available()
|
|
||||||
else torch.device('cpu')
|
|
||||||
)
|
|
||||||
model = self._load_model_from_config(config, self.weights)
|
model = self._load_model_from_config(config, self.weights)
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
model.embedding_manager.load(self.embedding_path)
|
model.embedding_manager.load(self.embedding_path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user