diff --git a/environment-mac.yaml b/environment-mac.yaml index be63a05540..42d2d5eaaf 100644 --- a/environment-mac.yaml +++ b/environment-mac.yaml @@ -28,7 +28,7 @@ dependencies: - kornia==0.6.0 - -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - - -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion + - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion - -e . variables: PYTORCH_ENABLE_MPS_FALLBACK: 1 diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 230aa74c28..cdc72f2fba 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -154,7 +154,10 @@ class T2I: self.model = None # empty for now self.sampler = None self.latent_diffusion_weights = latent_diffusion_weights - self.device = device + if device == 'cuda' and not torch.cuda.is_available(): + device = choose_torch_device() + print("cuda not available, using device", device) + self.device = torch.device(device) # for VRAM usage statistics self.session_peakmem = torch.cuda.max_memory_allocated() if self.device == 'cuda' else None @@ -279,7 +282,8 @@ class T2I: self._set_sampler() tic = time.time() - torch.cuda.torch.cuda.reset_peak_memory_stats() + if torch.cuda.is_available(): + torch.cuda.torch.cuda.reset_peak_memory_stats() results = list() try: @@ -311,7 +315,10 @@ class T2I: callback=step_callback, ) - with scope(self.device.type), self.model.ema_scope(): + device_type = self.device.type # this returns 'mps' on M1 + if device_type != 'cuda' or device_type != 'cpu': + device_type = 'cpu' + with scope(device_type), self.model.ema_scope(): for n in trange(iterations, desc='Generating'): seed_everything(seed) image = next(images_iterator) @@ -523,17 +530,12 @@ class T2I: self.seed = random.randrange(0, np.iinfo(np.uint32).max) return self.seed - def _get_device(self): - device_type = choose_torch_device() - return torch.device(device_type) - def load_model(self): """Load and initialize the model from configuration variables passed at object creation time""" if self.model is None: seed_everything(self.seed) try: config = OmegaConf.load(self.config) - self.device = self._get_device() model = self._load_model_from_config(config, self.weights) if self.embedding_path is not None: model.embedding_manager.load(