Merge branch 'v2.3' into enhance/curated-2.3.1-models

This commit is contained in:
Lincoln Stein 2023-02-24 10:30:42 -05:00 committed by GitHub
commit 4e446130d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 11 deletions

View File

@ -650,6 +650,8 @@ class Generate:
def clear_cuda_cache(self):
if self._has_cuda():
self.gather_cuda_stats()
# Run garbage collection prior to emptying the CUDA cache
gc.collect()
torch.cuda.empty_cache()
def clear_cuda_stats(self):

View File

@ -137,17 +137,9 @@ class Generator:
Given samples returned from a sampler, converts
it into a PIL Image
"""
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if len(x_samples) != 1:
raise Exception(
f'>> expected to get a single image, but got {len(x_samples)}')
x_sample = 255.0 * rearrange(
x_samples[0].cpu().numpy(), 'c h w -> h w c'
)
return Image.fromarray(x_sample.astype(np.uint8))
# write an approximate RGB image from latent samples for a single step to PNG
with torch.inference_mode():
image = self.model.decode_latents(samples)
return self.model.numpy_to_pil(image)[0]
def repaste_and_color_correct(self, result: Image.Image, init_image: Image.Image, init_mask: Image.Image, mask_blur_radius: int = 8) -> Image.Image:
if init_image is None or init_mask is None: