Textual Inversion for M1

Update main.py

Update ddpm.py

Update personalized.py

Update personalized_style.py

Update v1-finetune.yaml

Update environment-mac.yaml

Rename v1-finetune.yaml to v1-m1-finetune.yaml

Create v1-finetune.yaml

Update main.py

Update main.py

Update environment-mac.yaml

Update v1-inference.yaml
This commit is contained in:
Any-Winter-4079 2022-09-25 19:12:11 +02:00
parent d2b57029c8
commit e19aab4a9b
9 changed files with 163 additions and 23 deletions

View File

@ -107,4 +107,4 @@ lightning:
benchmark: True
max_steps: 4000000
# max_steps: 4000

View File

@ -30,9 +30,9 @@ model:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
initializer_words: ['face', 'man', 'photo', 'africanmale']
per_image_tokens: false
num_vectors_per_token: 1
num_vectors_per_token: 6
progressive_words: False
unet_config:

View File

@ -0,0 +1,110 @@
model:
base_learning_rate: 5.0e-03
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ['face', 'man', 'photo', 'africanmale']
per_image_tokens: false
num_vectors_per_token: 6
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
validation:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: val
per_image_tokens: false
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 5
increase_log_steps: False
trainer:
benchmark: False
max_steps: 6200
# max_steps: 4000

View File

@ -32,13 +32,13 @@ dependencies:
- omegaconf==2.1.1
- onnx==1.12.0
- onnxruntime==1.12.1
- protobuf==3.20.1
- protobuf==3.19.5
- pudb==2022.1
- pytorch-lightning==1.6.5
- pytorch-lightning==1.7.5
- scipy==1.9.1
- streamlit==1.12.2
- sympy==1.10.1
- tensorboard==2.9.0
- tensorboard==2.10.0
- torchmetrics==0.9.3
- pip:
- flask==2.1.3

View File

@ -117,7 +117,7 @@ class PersonalizedBase(Dataset):
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
]
# self._length = len(self.image_paths)

View File

@ -93,7 +93,7 @@ class PersonalizedBase(Dataset):
self.image_paths = [
os.path.join(self.data_root, file_path)
for file_path in os.listdir(self.data_root)
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
]
# self._length = len(self.image_paths)

View File

@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
# only for very first batch
if (
self.scale_by_std
@ -1890,7 +1890,7 @@ class LatentDiffusion(DDPM):
N=8,
n_row=4,
sample=True,
ddim_steps=200,
ddim_steps=50,
ddim_eta=1.0,
return_keys=None,
quantize_denoised=True,

View File

@ -169,9 +169,14 @@ class EmbeddingManager(nn.Module):
placeholder_embedding.shape[0], max_step_tokens
)
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
if torch.cuda.is_available():
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token.to(device)
)
else:
placeholder_rows, placeholder_cols = torch.where(
tokenized_text == placeholder_token
)
if placeholder_rows.nelement() == 0:
continue

45
main.py
View File

@ -25,6 +25,23 @@ from pytorch_lightning.utilities import rank_zero_info
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def load_model_from_config(config, ckpt, verbose=False):
print(f'Loading model from {ckpt}')
@ -422,9 +439,7 @@ class ImageLogger(Callback):
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {
pl.loggers.TestTubeLogger: self._testtube,
}
self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { }
self.log_steps = [
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
]
@ -527,7 +542,7 @@ class ImageLogger(Callback):
return False
def on_train_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
if not self.disabled and (
pl_module.global_step > 0 or self.log_first_step
@ -535,7 +550,7 @@ class ImageLogger(Callback):
self.log_img(pl_module, batch, batch_idx, split='train')
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split='val')
@ -555,7 +570,7 @@ class CUDACallback(Callback):
torch.cuda.synchronize(trainer.root_gpu)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
if torch.cuda.is_available():
torch.cuda.synchronize(trainer.root_gpu)
epoch_time = time.time() - self.start_time
@ -736,6 +751,12 @@ if __name__ == '__main__':
trainer_kwargs = dict()
# default logger configs
if torch.cuda.is_available():
def_logger = 'testtube'
def_logger_target = 'TestTubeLogger'
else:
def_logger = 'csv'
def_logger_target = 'CSVLogger'
default_logger_cfgs = {
'wandb': {
'target': 'pytorch_lightning.loggers.WandbLogger',
@ -746,15 +767,15 @@ if __name__ == '__main__':
'id': nowname,
},
},
'testtube': {
'target': 'pytorch_lightning.loggers.TestTubeLogger',
def_logger: {
'target': 'pytorch_lightning.loggers.' + def_logger_target,
'params': {
'name': 'testtube',
'name': def_logger,
'save_dir': logdir,
},
},
}
default_logger_cfg = default_logger_cfgs['testtube']
default_logger_cfg = default_logger_cfgs[def_logger]
if 'logger' in lightning_config:
logger_cfg = lightning_config.logger
else:
@ -868,6 +889,10 @@ if __name__ == '__main__':
]
trainer_kwargs['max_steps'] = trainer_opt.max_steps
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
trainer_opt.accelerator = 'mps'
trainer_opt.detect_anomaly = False
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
trainer.logdir = logdir ###