mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
changes to get dream.py working on M1
- move all device init logic to T2I.__init__ - handle m1 specific edge case with autocast device type - check torch.cuda.is_available before using cuda
This commit is contained in:
parent
70119602a0
commit
bf50ab9dd6
@ -28,7 +28,7 @@ dependencies:
|
||||
- kornia==0.6.0
|
||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
|
||||
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
||||
- -e .
|
||||
variables:
|
||||
PYTORCH_ENABLE_MPS_FALLBACK: 1
|
||||
|
@ -154,7 +154,10 @@ class T2I:
|
||||
self.model = None # empty for now
|
||||
self.sampler = None
|
||||
self.latent_diffusion_weights = latent_diffusion_weights
|
||||
self.device = device
|
||||
if device == 'cuda' and not torch.cuda.is_available():
|
||||
device = choose_torch_device()
|
||||
print("cuda not available, using device", device)
|
||||
self.device = torch.device(device)
|
||||
|
||||
# for VRAM usage statistics
|
||||
self.session_peakmem = torch.cuda.max_memory_allocated() if self.device == 'cuda' else None
|
||||
@ -279,7 +282,8 @@ class T2I:
|
||||
self._set_sampler()
|
||||
|
||||
tic = time.time()
|
||||
torch.cuda.torch.cuda.reset_peak_memory_stats()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.torch.cuda.reset_peak_memory_stats()
|
||||
results = list()
|
||||
|
||||
try:
|
||||
@ -311,7 +315,10 @@ class T2I:
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
with scope(self.device.type), self.model.ema_scope():
|
||||
device_type = self.device.type # this returns 'mps' on M1
|
||||
if device_type != 'cuda' or device_type != 'cpu':
|
||||
device_type = 'cpu'
|
||||
with scope(device_type), self.model.ema_scope():
|
||||
for n in trange(iterations, desc='Generating'):
|
||||
seed_everything(seed)
|
||||
image = next(images_iterator)
|
||||
@ -523,17 +530,12 @@ class T2I:
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
|
||||
def _get_device(self):
|
||||
device_type = choose_torch_device()
|
||||
return torch.device(device_type)
|
||||
|
||||
def load_model(self):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if self.model is None:
|
||||
seed_everything(self.seed)
|
||||
try:
|
||||
config = OmegaConf.load(self.config)
|
||||
self.device = self._get_device()
|
||||
model = self._load_model_from_config(config, self.weights)
|
||||
if self.embedding_path is not None:
|
||||
model.embedding_manager.load(
|
||||
|
Loading…
x
Reference in New Issue
Block a user