fix debug_image to not crash with non-RGB images.

This commit is contained in:
Kevin Turner 2022-12-04 20:12:47 -08:00
parent 875312080d
commit b2664e807e

View File

@ -1,17 +1,13 @@
import importlib
import torch
import numpy as np
import math
from collections import abc
from einops import rearrange
from functools import partial
import multiprocessing as mp
from threading import Thread
from queue import Queue
from collections import abc
from inspect import isfunction
from queue import Queue
from threading import Thread
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
@ -221,7 +217,7 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t*
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
rand_val = torch.rand(res[0]+1, res[1]+1)
angles = 2*math.pi*rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
@ -249,8 +245,8 @@ def ask_user(question: str, answers: list):
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ):
if not debug_status:
return
image_copy = debug_image.copy()
image_copy = debug_image.copy().convert("RGBA")
ImageDraw.Draw(image_copy).text(
(5, 5),
debug_text,
@ -261,4 +257,4 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
image_copy.show()
if debug_result:
return image_copy
return image_copy