2022-08-23 22:26:28 +00:00
import os
import numpy as np
import PIL
from PIL import Image
from torch . utils . data import Dataset
from torchvision import transforms
import random
imagenet_templates_smallest = [
' a photo of a {} ' ,
]
imagenet_templates_small = [
' a photo of a {} ' ,
' a rendering of a {} ' ,
' a cropped photo of the {} ' ,
' the photo of a {} ' ,
' a photo of a clean {} ' ,
' a photo of a dirty {} ' ,
' a dark photo of the {} ' ,
' a photo of my {} ' ,
' a photo of the cool {} ' ,
' a close-up photo of a {} ' ,
' a bright photo of the {} ' ,
' a cropped photo of a {} ' ,
' a photo of the {} ' ,
' a good photo of the {} ' ,
' a photo of one {} ' ,
' a close-up photo of the {} ' ,
' a rendition of the {} ' ,
' a photo of the clean {} ' ,
' a rendition of a {} ' ,
' a photo of a nice {} ' ,
' a good photo of a {} ' ,
' a photo of the nice {} ' ,
' a photo of the small {} ' ,
' a photo of the weird {} ' ,
' a photo of the large {} ' ,
' a photo of a cool {} ' ,
' a photo of a small {} ' ,
]
imagenet_dual_templates_small = [
' a photo of a {} with {} ' ,
' a rendering of a {} with {} ' ,
' a cropped photo of the {} with {} ' ,
' the photo of a {} with {} ' ,
' a photo of a clean {} with {} ' ,
' a photo of a dirty {} with {} ' ,
' a dark photo of the {} with {} ' ,
' a photo of my {} with {} ' ,
' a photo of the cool {} with {} ' ,
' a close-up photo of a {} with {} ' ,
' a bright photo of the {} with {} ' ,
' a cropped photo of a {} with {} ' ,
' a photo of the {} with {} ' ,
' a good photo of the {} with {} ' ,
' a photo of one {} with {} ' ,
' a close-up photo of the {} with {} ' ,
' a rendition of the {} with {} ' ,
' a photo of the clean {} with {} ' ,
' a rendition of a {} with {} ' ,
' a photo of a nice {} with {} ' ,
' a good photo of a {} with {} ' ,
' a photo of the nice {} with {} ' ,
' a photo of the small {} with {} ' ,
' a photo of the weird {} with {} ' ,
' a photo of the large {} with {} ' ,
' a photo of a cool {} with {} ' ,
' a photo of a small {} with {} ' ,
]
per_img_token_list = [
2022-08-26 07:15:42 +00:00
' א ' ,
' ב ' ,
' ג ' ,
' ד ' ,
' ה ' ,
' ו ' ,
' ז ' ,
' ח ' ,
' ט ' ,
' י ' ,
' כ ' ,
' ל ' ,
' מ ' ,
' נ ' ,
' ס ' ,
' ע ' ,
' פ ' ,
' צ ' ,
' ק ' ,
' ר ' ,
' ש ' ,
' ת ' ,
2022-08-23 22:26:28 +00:00
]
2022-08-26 07:15:42 +00:00
2022-08-23 22:26:28 +00:00
class PersonalizedBase ( Dataset ) :
2022-08-26 07:15:42 +00:00
def __init__ (
self ,
data_root ,
size = None ,
repeats = 100 ,
interpolation = ' bicubic ' ,
flip_p = 0.5 ,
set = ' train ' ,
placeholder_token = ' * ' ,
per_image_tokens = False ,
center_crop = False ,
mixing_prob = 0.25 ,
coarse_class_text = None ,
) :
2022-08-23 22:26:28 +00:00
self . data_root = data_root
2022-08-26 07:15:42 +00:00
self . image_paths = [
os . path . join ( self . data_root , file_path )
for file_path in os . listdir ( self . data_root )
]
2022-08-23 22:26:28 +00:00
# self._length = len(self.image_paths)
self . num_images = len ( self . image_paths )
2022-08-26 07:15:42 +00:00
self . _length = self . num_images
2022-08-23 22:26:28 +00:00
self . placeholder_token = placeholder_token
self . per_image_tokens = per_image_tokens
self . center_crop = center_crop
self . mixing_prob = mixing_prob
self . coarse_class_text = coarse_class_text
if per_image_tokens :
2022-08-26 07:15:42 +00:00
assert self . num_images < len (
per_img_token_list
) , f " Can ' t use per-image tokens when the training set contains more than { len ( per_img_token_list ) } tokens. To enable larger sets, add more tokens to ' per_img_token_list ' . "
2022-08-23 22:26:28 +00:00
2022-08-26 07:15:42 +00:00
if set == ' train ' :
2022-08-23 22:26:28 +00:00
self . _length = self . num_images * repeats
self . size = size
2022-08-26 07:15:42 +00:00
self . interpolation = {
' linear ' : PIL . Image . LINEAR ,
' bilinear ' : PIL . Image . BILINEAR ,
' bicubic ' : PIL . Image . BICUBIC ,
' lanczos ' : PIL . Image . LANCZOS ,
} [ interpolation ]
2022-08-23 22:26:28 +00:00
self . flip = transforms . RandomHorizontalFlip ( p = flip_p )
def __len__ ( self ) :
return self . _length
def __getitem__ ( self , i ) :
example = { }
image = Image . open ( self . image_paths [ i % self . num_images ] )
2022-08-26 07:15:42 +00:00
if not image . mode == ' RGB ' :
image = image . convert ( ' RGB ' )
2022-08-23 22:26:28 +00:00
placeholder_string = self . placeholder_token
if self . coarse_class_text :
2022-08-26 07:15:42 +00:00
placeholder_string = (
f ' { self . coarse_class_text } { placeholder_string } '
)
2022-08-23 22:26:28 +00:00
if self . per_image_tokens and np . random . uniform ( ) < self . mixing_prob :
2022-08-26 07:15:42 +00:00
text = random . choice ( imagenet_dual_templates_small ) . format (
placeholder_string , per_img_token_list [ i % self . num_images ]
)
2022-08-23 22:26:28 +00:00
else :
2022-08-26 07:15:42 +00:00
text = random . choice ( imagenet_templates_small ) . format (
placeholder_string
)
example [ ' caption ' ] = text
2022-08-23 22:26:28 +00:00
# default to score-sde preprocessing
img = np . array ( image ) . astype ( np . uint8 )
2022-08-26 07:15:42 +00:00
2022-08-23 22:26:28 +00:00
if self . center_crop :
crop = min ( img . shape [ 0 ] , img . shape [ 1 ] )
2022-08-26 07:15:42 +00:00
h , w , = (
img . shape [ 0 ] ,
img . shape [ 1 ] ,
)
img = img [
( h - crop ) / / 2 : ( h + crop ) / / 2 ,
( w - crop ) / / 2 : ( w + crop ) / / 2 ,
]
2022-08-23 22:26:28 +00:00
image = Image . fromarray ( img )
if self . size is not None :
2022-08-26 07:15:42 +00:00
image = image . resize (
( self . size , self . size ) , resample = self . interpolation
)
2022-08-23 22:26:28 +00:00
image = self . flip ( image )
image = np . array ( image ) . astype ( np . uint8 )
2022-08-26 07:15:42 +00:00
example [ ' image ' ] = ( image / 127.5 - 1.0 ) . astype ( np . float32 )
return example