Merge branch 'Any-Winter-4079-Textual_Inversion' of github.com:Any-Winter-4079/InvokeAI into development

This commit is contained in:
Lincoln Stein 2022-10-03 19:43:46 -04:00
commit 0a217b5f15
9 changed files with 163 additions and 23 deletions

View File

@ -107,4 +107,4 @@ lightning:
benchmark: True benchmark: True
max_steps: 4000000 max_steps: 4000000
# max_steps: 4000 # max_steps: 4000

View File

@ -30,9 +30,9 @@ model:
target: ldm.modules.embedding_manager.EmbeddingManager target: ldm.modules.embedding_manager.EmbeddingManager
params: params:
placeholder_strings: ["*"] placeholder_strings: ["*"]
initializer_words: ["sculpture"] initializer_words: ['face', 'man', 'photo', 'africanmale']
per_image_tokens: false per_image_tokens: false
num_vectors_per_token: 1 num_vectors_per_token: 6
progressive_words: False progressive_words: False
unet_config: unet_config:

View File

@ -0,0 +1,110 @@
model:
base_learning_rate: 5.0e-03
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['face', 'man', 'photo', 'africanmale']
per_image_tokens: false
num_vectors_per_token: 6
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
validation:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: val
per_image_tokens: false
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 5
increase_log_steps: False
trainer:
benchmark: False
max_steps: 6200
# max_steps: 4000

View File

@ -33,13 +33,13 @@ dependencies:
- openh264==2.3.0 - openh264==2.3.0
- onnx==1.12.0 - onnx==1.12.0
- onnxruntime==1.12.1 - onnxruntime==1.12.1
- protobuf==3.20.1 - protobuf==3.19.4
- pudb==2022.1 - pudb==2022.1
- pytorch-lightning==1.6.5 - pytorch-lightning==1.7.5
- scipy==1.9.1 - scipy==1.9.1
- streamlit==1.12.2 - streamlit==1.12.2
- sympy==1.10.1 - sympy==1.10.1
- tensorboard==2.9.0 - tensorboard==2.10.0
- torchmetrics==0.9.3 - torchmetrics==0.9.3
- pip: - pip:
- flask==2.1.3 - flask==2.1.3

View File

@ -117,7 +117,7 @@ class PersonalizedBase(Dataset):
self.image_paths = [ self.image_paths = [
os.path.join(self.data_root, file_path) os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root) for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
] ]
# self._length = len(self.image_paths) # self._length = len(self.image_paths)

View File

@ -93,7 +93,7 @@ class PersonalizedBase(Dataset):
self.image_paths = [ self.image_paths = [
os.path.join(self.data_root, file_path) os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root) for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
] ]
# self._length = len(self.image_paths) # self._length = len(self.image_paths)

View File

@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
@rank_zero_only @rank_zero_only
@torch.no_grad() @torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
# only for very first batch # only for very first batch
if ( if (
self.scale_by_std self.scale_by_std
@ -1890,7 +1890,7 @@ class LatentDiffusion(DDPM):
N=8, N=8,
n_row=4, n_row=4,
sample=True, sample=True,
ddim_steps=200, ddim_steps=50,
ddim_eta=1.0, ddim_eta=1.0,
return_keys=None, return_keys=None,
quantize_denoised=True, quantize_denoised=True,

View File

@ -169,9 +169,14 @@ class EmbeddingManager(nn.Module):
placeholder_embedding.shape[0], max_step_tokens placeholder_embedding.shape[0], max_step_tokens
) )
placeholder_rows, placeholder_cols = torch.where( if torch.cuda.is_available():
tokenized_text == placeholder_token.to(device) placeholder_rows, placeholder_cols = torch.where(
) tokenized_text == placeholder_token.to(device)
)
else:
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token
)
if placeholder_rows.nelement() == 0: if placeholder_rows.nelement() == 0:
continue continue

45
main.py
View File

@ -25,6 +25,23 @@ from pytorch_lightning.utilities import rank_zero_info
from ldm.data.base import Txt2ImgIterableBaseDataset from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def load_model_from_config(config, ckpt, verbose=False): def load_model_from_config(config, ckpt, verbose=False):
print(f'Loading model from {ckpt}') print(f'Loading model from {ckpt}')
@ -422,9 +439,7 @@ class ImageLogger(Callback):
self.rescale = rescale self.rescale = rescale
self.batch_freq = batch_frequency self.batch_freq = batch_frequency
self.max_images = max_images self.max_images = max_images
self.logger_log_images = { self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { }
pl.loggers.TestTubeLogger: self._testtube,
}
self.log_steps = [ self.log_steps = [
2**n for n in range(int(np.log2(self.batch_freq)) + 1) 2**n for n in range(int(np.log2(self.batch_freq)) + 1)
] ]
@ -527,7 +542,7 @@ class ImageLogger(Callback):
return False return False
def on_train_batch_end( def on_train_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
): ):
if not self.disabled and ( if not self.disabled and (
pl_module.global_step > 0 or self.log_first_step pl_module.global_step > 0 or self.log_first_step
@ -535,7 +550,7 @@ class ImageLogger(Callback):
self.log_img(pl_module, batch, batch_idx, split='train') self.log_img(pl_module, batch, batch_idx, split='train')
def on_validation_batch_end( def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
): ):
if not self.disabled and pl_module.global_step > 0: if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split='val') self.log_img(pl_module, batch, batch_idx, split='val')
@ -555,7 +570,7 @@ class CUDACallback(Callback):
torch.cuda.synchronize(trainer.root_gpu) torch.cuda.synchronize(trainer.root_gpu)
self.start_time = time.time() self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module, outputs): def on_train_epoch_end(self, trainer, pl_module, outputs=None):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize(trainer.root_gpu) torch.cuda.synchronize(trainer.root_gpu)
epoch_time = time.time() - self.start_time epoch_time = time.time() - self.start_time
@ -736,6 +751,12 @@ if __name__ == '__main__':
trainer_kwargs = dict() trainer_kwargs = dict()
# default logger configs # default logger configs
if torch.cuda.is_available():
def_logger = 'testtube'
def_logger_target = 'TestTubeLogger'
else:
def_logger = 'csv'
def_logger_target = 'CSVLogger'
default_logger_cfgs = { default_logger_cfgs = {
'wandb': { 'wandb': {
'target': 'pytorch_lightning.loggers.WandbLogger', 'target': 'pytorch_lightning.loggers.WandbLogger',
@ -746,15 +767,15 @@ if __name__ == '__main__':
'id': nowname, 'id': nowname,
}, },
}, },
'testtube': { def_logger: {
'target': 'pytorch_lightning.loggers.TestTubeLogger', 'target': 'pytorch_lightning.loggers.' + def_logger_target,
'params': { 'params': {
'name': 'testtube', 'name': def_logger,
'save_dir': logdir, 'save_dir': logdir,
}, },
}, },
} }
default_logger_cfg = default_logger_cfgs['testtube'] default_logger_cfg = default_logger_cfgs[def_logger]
if 'logger' in lightning_config: if 'logger' in lightning_config:
logger_cfg = lightning_config.logger logger_cfg = lightning_config.logger
else: else:
@ -868,6 +889,10 @@ if __name__ == '__main__':
] ]
trainer_kwargs['max_steps'] = trainer_opt.max_steps trainer_kwargs['max_steps'] = trainer_opt.max_steps
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
trainer_opt.accelerator = 'mps'
trainer_opt.detect_anomaly = False
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ### trainer.logdir = logdir ###