diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 1981b4eacb..f188d9b23d 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -48,6 +48,10 @@ class Img2Img(Generator): torch.tensor([t_enc]).to(self.model.device), noise=x_T ) + + if self.free_gpu_mem and self.model.model.device != self.model.device: + self.model.model.to(self.model.device) + # decode it samples = sampler.decode( z_enc, @@ -61,6 +65,9 @@ class Img2Img(Generator): all_timesteps_count = steps ) + if self.free_gpu_mem: + self.model.model.to("cpu") + return self.sample_to_image(samples) return make_image @@ -87,4 +94,4 @@ class Img2Img(Generator): image = torch.from_numpy(image) if normalize: image = 2.0 * image - 1.0 - return image.to(self.model.device) + return image.to(self.model.device)