changes to get dream.py working on M1

- move all device init logic to T2I.__init__
- handle m1 specific edge case with autocast device type
- check torch.cuda.is_available before using cuda
This commit is contained in:
Jason Toffaletti 2022-08-31 00:32:07 -07:00
parent 70119602a0
commit bf50ab9dd6
2 changed files with 11 additions and 9 deletions

View File

@ -28,7 +28,7 @@ dependencies:
- kornia==0.6.0
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
- -e .
variables:
PYTORCH_ENABLE_MPS_FALLBACK: 1

View File

@ -154,7 +154,10 @@ class T2I:
self.model = None # empty for now
self.sampler = None
self.latent_diffusion_weights = latent_diffusion_weights
self.device = device
if device == 'cuda' and not torch.cuda.is_available():
device = choose_torch_device()
print("cuda not available, using device", device)
self.device = torch.device(device)
# for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self.device == 'cuda' else None
@ -279,7 +282,8 @@ class T2I:
self._set_sampler()
tic = time.time()
torch.cuda.torch.cuda.reset_peak_memory_stats()
if torch.cuda.is_available():
torch.cuda.torch.cuda.reset_peak_memory_stats()
results = list()
try:
@ -311,7 +315,10 @@ class T2I:
callback=step_callback,
)
with scope(self.device.type), self.model.ema_scope():
device_type = self.device.type # this returns 'mps' on M1
if device_type != 'cuda' or device_type != 'cpu':
device_type = 'cpu'
with scope(device_type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'):
seed_everything(seed)
image = next(images_iterator)
@ -523,17 +530,12 @@ class T2I:
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed
def _get_device(self):
device_type = choose_torch_device()
return torch.device(device_type)
def load_model(self):
"""Load and initialize the model from configuration variables passed at object creation time"""
if self.model is None:
seed_everything(self.seed)
try:
config = OmegaConf.load(self.config)
self.device = self._get_device()
model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None:
model.embedding_manager.load(