2022-08-23 22:26:28 +00:00
import os
2023-03-03 06:02:00 +00:00
import random
2022-08-23 22:26:28 +00:00
import numpy as np
import PIL
from PIL import Image
from torch . utils . data import Dataset
from torchvision import transforms
imagenet_templates_smallest = [
2023-03-03 06:02:00 +00:00
" a photo of a {} " ,
2022-08-23 22:26:28 +00:00
]
imagenet_templates_small = [
2023-03-03 06:02:00 +00:00
" a photo of a {} " ,
" a rendering of a {} " ,
" a cropped photo of the {} " ,
" the photo of a {} " ,
" a photo of a clean {} " ,
" a photo of a dirty {} " ,
" a dark photo of the {} " ,
" a photo of my {} " ,
" a photo of the cool {} " ,
" a close-up photo of a {} " ,
" a bright photo of the {} " ,
" a cropped photo of a {} " ,
" a photo of the {} " ,
" a good photo of the {} " ,
" a photo of one {} " ,
" a close-up photo of the {} " ,
" a rendition of the {} " ,
" a photo of the clean {} " ,
" a rendition of a {} " ,
" a photo of a nice {} " ,
" a good photo of a {} " ,
" a photo of the nice {} " ,
" a photo of the small {} " ,
" a photo of the weird {} " ,
" a photo of the large {} " ,
" a photo of a cool {} " ,
" a photo of a small {} " ,
2022-08-23 22:26:28 +00:00
]
imagenet_dual_templates_small = [
2023-03-03 06:02:00 +00:00
" a photo of a {} with {} " ,
" a rendering of a {} with {} " ,
" a cropped photo of the {} with {} " ,
" the photo of a {} with {} " ,
" a photo of a clean {} with {} " ,
" a photo of a dirty {} with {} " ,
" a dark photo of the {} with {} " ,
" a photo of my {} with {} " ,
" a photo of the cool {} with {} " ,
" a close-up photo of a {} with {} " ,
" a bright photo of the {} with {} " ,
" a cropped photo of a {} with {} " ,
" a photo of the {} with {} " ,
" a good photo of the {} with {} " ,
" a photo of one {} with {} " ,
" a close-up photo of the {} with {} " ,
" a rendition of the {} with {} " ,
" a photo of the clean {} with {} " ,
" a rendition of a {} with {} " ,
" a photo of a nice {} with {} " ,
" a good photo of a {} with {} " ,
" a photo of the nice {} with {} " ,
" a photo of the small {} with {} " ,
" a photo of the weird {} with {} " ,
" a photo of the large {} with {} " ,
" a photo of a cool {} with {} " ,
" a photo of a small {} with {} " ,
2022-08-23 22:26:28 +00:00
]
per_img_token_list = [
2023-03-03 06:02:00 +00:00
" א " ,
" ב " ,
" ג " ,
" ד " ,
" ה " ,
" ו " ,
" ז " ,
" ח " ,
" ט " ,
" י " ,
" כ " ,
" ל " ,
" מ " ,
" נ " ,
" ס " ,
" ע " ,
" פ " ,
" צ " ,
" ק " ,
" ר " ,
" ש " ,
" ת " ,
2022-08-23 22:26:28 +00:00
]
2022-08-26 07:15:42 +00:00
2022-08-23 22:26:28 +00:00
class PersonalizedBase ( Dataset ) :
2022-08-26 07:15:42 +00:00
def __init__ (
self ,
data_root ,
size = None ,
repeats = 100 ,
2023-03-03 06:02:00 +00:00
interpolation = " bicubic " ,
2022-08-26 07:15:42 +00:00
flip_p = 0.5 ,
2023-03-03 06:02:00 +00:00
set = " train " ,
placeholder_token = " * " ,
2022-08-26 07:15:42 +00:00
per_image_tokens = False ,
center_crop = False ,
mixing_prob = 0.25 ,
coarse_class_text = None ,
) :
2022-08-23 22:26:28 +00:00
self . data_root = data_root
2022-08-26 07:15:42 +00:00
self . image_paths = [
os . path . join ( self . data_root , file_path )
2023-03-03 06:02:00 +00:00
for file_path in os . listdir ( self . data_root )
if file_path != " .DS_Store "
2022-08-26 07:15:42 +00:00
]
2022-08-23 22:26:28 +00:00
# self._length = len(self.image_paths)
self . num_images = len ( self . image_paths )
2022-08-26 07:15:42 +00:00
self . _length = self . num_images
2022-08-23 22:26:28 +00:00
self . placeholder_token = placeholder_token
self . per_image_tokens = per_image_tokens
self . center_crop = center_crop
self . mixing_prob = mixing_prob
self . coarse_class_text = coarse_class_text
if per_image_tokens :
2022-08-26 07:15:42 +00:00
assert self . num_images < len (
per_img_token_list
) , f " Can ' t use per-image tokens when the training set contains more than { len ( per_img_token_list ) } tokens. To enable larger sets, add more tokens to ' per_img_token_list ' . "
2022-08-23 22:26:28 +00:00
2023-03-03 06:02:00 +00:00
if set == " train " :
2022-08-23 22:26:28 +00:00
self . _length = self . num_images * repeats
self . size = size
2022-08-26 07:15:42 +00:00
self . interpolation = {
2023-03-03 06:02:00 +00:00
" linear " : PIL . Image . LINEAR ,
" bilinear " : PIL . Image . BILINEAR ,
" bicubic " : PIL . Image . BICUBIC ,
" lanczos " : PIL . Image . LANCZOS ,
2022-08-26 07:15:42 +00:00
} [ interpolation ]
2022-08-23 22:26:28 +00:00
self . flip = transforms . RandomHorizontalFlip ( p = flip_p )
def __len__ ( self ) :
return self . _length
def __getitem__ ( self , i ) :
example = { }
image = Image . open ( self . image_paths [ i % self . num_images ] )
2023-03-03 06:02:00 +00:00
if not image . mode == " RGB " :
image = image . convert ( " RGB " )
2022-08-23 22:26:28 +00:00
placeholder_string = self . placeholder_token
if self . coarse_class_text :
2023-03-03 06:02:00 +00:00
placeholder_string = f " { self . coarse_class_text } { placeholder_string } "
2022-08-23 22:26:28 +00:00
if self . per_image_tokens and np . random . uniform ( ) < self . mixing_prob :
2022-08-26 07:15:42 +00:00
text = random . choice ( imagenet_dual_templates_small ) . format (
placeholder_string , per_img_token_list [ i % self . num_images ]
)
2022-08-23 22:26:28 +00:00
else :
2023-03-03 06:02:00 +00:00
text = random . choice ( imagenet_templates_small ) . format ( placeholder_string )
2022-08-26 07:15:42 +00:00
2023-03-03 06:02:00 +00:00
example [ " caption " ] = text
2022-08-23 22:26:28 +00:00
# default to score-sde preprocessing
img = np . array ( image ) . astype ( np . uint8 )
2022-08-26 07:15:42 +00:00
2022-08-23 22:26:28 +00:00
if self . center_crop :
crop = min ( img . shape [ 0 ] , img . shape [ 1 ] )
2023-03-03 06:02:00 +00:00
(
h ,
w ,
) = (
2022-08-26 07:15:42 +00:00
img . shape [ 0 ] ,
img . shape [ 1 ] ,
)
img = img [
( h - crop ) / / 2 : ( h + crop ) / / 2 ,
( w - crop ) / / 2 : ( w + crop ) / / 2 ,
]
2022-08-23 22:26:28 +00:00
image = Image . fromarray ( img )
if self . size is not None :
2023-03-03 06:02:00 +00:00
image = image . resize ( ( self . size , self . size ) , resample = self . interpolation )
2022-08-23 22:26:28 +00:00
image = self . flip ( image )
image = np . array ( image ) . astype ( np . uint8 )
2023-03-03 06:02:00 +00:00
example [ " image " ] = ( image / 127.5 - 1.0 ) . astype ( np . float32 )
2022-08-26 07:15:42 +00:00
return example