fix debug_image to not crash with non-RGB images.

This commit is contained in:
Kevin Turner 2022-12-04 20:12:47 -08:00
parent 875312080d
commit b2664e807e

View File

@ -1,17 +1,13 @@
import importlib import importlib
import torch
import numpy as np
import math import math
from collections import abc
from einops import rearrange
from functools import partial
import multiprocessing as mp import multiprocessing as mp
from threading import Thread from collections import abc
from queue import Queue
from inspect import isfunction from inspect import isfunction
from queue import Queue
from threading import Thread
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
@ -221,7 +217,7 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t*
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1 grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
rand_val = torch.rand(res[0]+1, res[1]+1) rand_val = torch.rand(res[0]+1, res[1]+1)
angles = 2*math.pi*rand_val angles = 2*math.pi*rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device) gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
@ -249,8 +245,8 @@ def ask_user(question: str, answers: list):
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ): def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ):
if not debug_status: if not debug_status:
return return
image_copy = debug_image.copy() image_copy = debug_image.copy().convert("RGBA")
ImageDraw.Draw(image_copy).text( ImageDraw.Draw(image_copy).text(
(5, 5), (5, 5),
debug_text, debug_text,
@ -261,4 +257,4 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
image_copy.show() image_copy.show()
if debug_result: if debug_result:
return image_copy return image_copy