2022-08-23 22:26:28 +00:00
import os
import numpy as np
import PIL
from PIL import Image
from torch . utils . data import Dataset
from torchvision import transforms
import random
imagenet_templates_small = [
' a painting in the style of {} ' ,
' a rendering in the style of {} ' ,
' a cropped painting in the style of {} ' ,
' the painting in the style of {} ' ,
' a clean painting in the style of {} ' ,
' a dirty painting in the style of {} ' ,
' a dark painting in the style of {} ' ,
' a picture in the style of {} ' ,
' a cool painting in the style of {} ' ,
' a close-up painting in the style of {} ' ,
' a bright painting in the style of {} ' ,
' a cropped painting in the style of {} ' ,
' a good painting in the style of {} ' ,
' a close-up painting in the style of {} ' ,
' a rendition in the style of {} ' ,
' a nice painting in the style of {} ' ,
' a small painting in the style of {} ' ,
' a weird painting in the style of {} ' ,
' a large painting in the style of {} ' ,
]
imagenet_dual_templates_small = [
' a painting in the style of {} with {} ' ,
' a rendering in the style of {} with {} ' ,
' a cropped painting in the style of {} with {} ' ,
' the painting in the style of {} with {} ' ,
' a clean painting in the style of {} with {} ' ,
' a dirty painting in the style of {} with {} ' ,
' a dark painting in the style of {} with {} ' ,
' a cool painting in the style of {} with {} ' ,
' a close-up painting in the style of {} with {} ' ,
' a bright painting in the style of {} with {} ' ,
' a cropped painting in the style of {} with {} ' ,
' a good painting in the style of {} with {} ' ,
' a painting of one {} in the style of {} ' ,
' a nice painting in the style of {} with {} ' ,
' a small painting in the style of {} with {} ' ,
' a weird painting in the style of {} with {} ' ,
' a large painting in the style of {} with {} ' ,
]
per_img_token_list = [
2022-08-26 07:15:42 +00:00
' א ' ,
' ב ' ,
' ג ' ,
' ד ' ,
' ה ' ,
' ו ' ,
' ז ' ,
' ח ' ,
' ט ' ,
' י ' ,
' כ ' ,
' ל ' ,
' מ ' ,
' נ ' ,
' ס ' ,
' ע ' ,
' פ ' ,
' צ ' ,
' ק ' ,
' ר ' ,
' ש ' ,
' ת ' ,
2022-08-23 22:26:28 +00:00
]
2022-08-26 07:15:42 +00:00
2022-08-23 22:26:28 +00:00
class PersonalizedBase ( Dataset ) :
2022-08-26 07:15:42 +00:00
def __init__ (
self ,
data_root ,
size = None ,
repeats = 100 ,
interpolation = ' bicubic ' ,
flip_p = 0.5 ,
set = ' train ' ,
placeholder_token = ' * ' ,
per_image_tokens = False ,
center_crop = False ,
) :
2022-08-23 22:26:28 +00:00
self . data_root = data_root
2022-08-26 07:15:42 +00:00
self . image_paths = [
os . path . join ( self . data_root , file_path )
2022-09-25 17:12:11 +00:00
for file_path in os . listdir ( self . data_root ) if file_path != " .DS_Store "
2022-08-26 07:15:42 +00:00
]
2022-08-23 22:26:28 +00:00
# self._length = len(self.image_paths)
self . num_images = len ( self . image_paths )
2022-08-26 07:15:42 +00:00
self . _length = self . num_images
2022-08-23 22:26:28 +00:00
self . placeholder_token = placeholder_token
self . per_image_tokens = per_image_tokens
self . center_crop = center_crop
if per_image_tokens :
2022-08-26 07:15:42 +00:00
assert self . num_images < len (
per_img_token_list
) , f " Can ' t use per-image tokens when the training set contains more than { len ( per_img_token_list ) } tokens. To enable larger sets, add more tokens to ' per_img_token_list ' . "
2022-08-23 22:26:28 +00:00
2022-08-26 07:15:42 +00:00
if set == ' train ' :
2022-08-23 22:26:28 +00:00
self . _length = self . num_images * repeats
self . size = size
2022-08-26 07:15:42 +00:00
self . interpolation = {
' linear ' : PIL . Image . LINEAR ,
' bilinear ' : PIL . Image . BILINEAR ,
' bicubic ' : PIL . Image . BICUBIC ,
' lanczos ' : PIL . Image . LANCZOS ,
} [ interpolation ]
2022-08-23 22:26:28 +00:00
self . flip = transforms . RandomHorizontalFlip ( p = flip_p )
def __len__ ( self ) :
return self . _length
def __getitem__ ( self , i ) :
example = { }
image = Image . open ( self . image_paths [ i % self . num_images ] )
2022-08-26 07:15:42 +00:00
if not image . mode == ' RGB ' :
image = image . convert ( ' RGB ' )
2022-08-23 22:26:28 +00:00
if self . per_image_tokens and np . random . uniform ( ) < 0.25 :
2022-08-26 07:15:42 +00:00
text = random . choice ( imagenet_dual_templates_small ) . format (
self . placeholder_token , per_img_token_list [ i % self . num_images ]
)
2022-08-23 22:26:28 +00:00
else :
2022-08-26 07:15:42 +00:00
text = random . choice ( imagenet_templates_small ) . format (
self . placeholder_token
)
example [ ' caption ' ] = text
2022-08-23 22:26:28 +00:00
# default to score-sde preprocessing
img = np . array ( image ) . astype ( np . uint8 )
2022-08-26 07:15:42 +00:00
2022-08-23 22:26:28 +00:00
if self . center_crop :
crop = min ( img . shape [ 0 ] , img . shape [ 1 ] )
2022-08-26 07:15:42 +00:00
h , w , = (
img . shape [ 0 ] ,
img . shape [ 1 ] ,
)
img = img [
( h - crop ) / / 2 : ( h + crop ) / / 2 ,
( w - crop ) / / 2 : ( w + crop ) / / 2 ,
]
2022-08-23 22:26:28 +00:00
image = Image . fromarray ( img )
if self . size is not None :
2022-08-26 07:15:42 +00:00
image = image . resize (
( self . size , self . size ) , resample = self . interpolation
)
2022-08-23 22:26:28 +00:00
image = self . flip ( image )
image = np . array ( image ) . astype ( np . uint8 )
2022-08-26 07:15:42 +00:00
example [ ' image ' ] = ( image / 127.5 - 1.0 ) . astype ( np . float32 )
return example