mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Textual Inversion for M1
Update main.py Update ddpm.py Update personalized.py Update personalized_style.py Update v1-finetune.yaml Update environment-mac.yaml Rename v1-finetune.yaml to v1-m1-finetune.yaml Create v1-finetune.yaml Update main.py Update main.py Update environment-mac.yaml Update v1-inference.yaml
This commit is contained in:
parent
d2b57029c8
commit
e19aab4a9b
@ -107,4 +107,4 @@ lightning:
|
||||
benchmark: True
|
||||
max_steps: 4000000
|
||||
# max_steps: 4000
|
||||
|
||||
|
||||
|
@ -30,9 +30,9 @@ model:
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ["sculpture"]
|
||||
initializer_words: ['face', 'man', 'photo', 'africanmale']
|
||||
per_image_tokens: false
|
||||
num_vectors_per_token: 1
|
||||
num_vectors_per_token: 6
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
|
110
configs/stable-diffusion/v1-m1-finetune.yaml
Normal file
110
configs/stable-diffusion/v1-m1-finetune.yaml
Normal file
@ -0,0 +1,110 @@
|
||||
model:
|
||||
base_learning_rate: 5.0e-03
|
||||
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: image
|
||||
cond_stage_key: caption
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: true # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
embedding_reg_weight: 0.0
|
||||
|
||||
personalization_config:
|
||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ['face', 'man', 'photo', 'africanmale']
|
||||
per_image_tokens: false
|
||||
num_vectors_per_token: 6
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: ldm.models.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
|
||||
data:
|
||||
target: main.DataModuleFromConfig
|
||||
params:
|
||||
batch_size: 1
|
||||
num_workers: 2
|
||||
wrap: false
|
||||
train:
|
||||
target: ldm.data.personalized.PersonalizedBase
|
||||
params:
|
||||
size: 512
|
||||
set: train
|
||||
per_image_tokens: false
|
||||
repeats: 100
|
||||
validation:
|
||||
target: ldm.data.personalized.PersonalizedBase
|
||||
params:
|
||||
size: 512
|
||||
set: val
|
||||
per_image_tokens: false
|
||||
repeats: 10
|
||||
|
||||
lightning:
|
||||
modelcheckpoint:
|
||||
params:
|
||||
every_n_train_steps: 500
|
||||
callbacks:
|
||||
image_logger:
|
||||
target: main.ImageLogger
|
||||
params:
|
||||
batch_frequency: 500
|
||||
max_images: 5
|
||||
increase_log_steps: False
|
||||
|
||||
trainer:
|
||||
benchmark: False
|
||||
max_steps: 6200
|
||||
# max_steps: 4000
|
||||
|
@ -32,13 +32,13 @@ dependencies:
|
||||
- omegaconf==2.1.1
|
||||
- onnx==1.12.0
|
||||
- onnxruntime==1.12.1
|
||||
- protobuf==3.20.1
|
||||
- protobuf==3.19.5
|
||||
- pudb==2022.1
|
||||
- pytorch-lightning==1.6.5
|
||||
- pytorch-lightning==1.7.5
|
||||
- scipy==1.9.1
|
||||
- streamlit==1.12.2
|
||||
- sympy==1.10.1
|
||||
- tensorboard==2.9.0
|
||||
- tensorboard==2.10.0
|
||||
- torchmetrics==0.9.3
|
||||
- pip:
|
||||
- flask==2.1.3
|
||||
|
@ -117,7 +117,7 @@ class PersonalizedBase(Dataset):
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
|
@ -93,7 +93,7 @@ class PersonalizedBase(Dataset):
|
||||
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
|
||||
]
|
||||
|
||||
# self._length = len(self.image_paths)
|
||||
|
@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
@rank_zero_only
|
||||
@torch.no_grad()
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
|
||||
# only for very first batch
|
||||
if (
|
||||
self.scale_by_std
|
||||
@ -1890,7 +1890,7 @@ class LatentDiffusion(DDPM):
|
||||
N=8,
|
||||
n_row=4,
|
||||
sample=True,
|
||||
ddim_steps=200,
|
||||
ddim_steps=50,
|
||||
ddim_eta=1.0,
|
||||
return_keys=None,
|
||||
quantize_denoised=True,
|
||||
|
@ -169,9 +169,14 @@ class EmbeddingManager(nn.Module):
|
||||
placeholder_embedding.shape[0], max_step_tokens
|
||||
)
|
||||
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token.to(device)
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token.to(device)
|
||||
)
|
||||
else:
|
||||
placeholder_rows, placeholder_cols = torch.where(
|
||||
tokenized_text == placeholder_token
|
||||
)
|
||||
|
||||
if placeholder_rows.nelement() == 0:
|
||||
continue
|
||||
|
45
main.py
45
main.py
@ -25,6 +25,23 @@ from pytorch_lightning.utilities import rank_zero_info
|
||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
def new_func(*args, **kw):
|
||||
device = kw.get("device", "mps")
|
||||
kw["device"]="cpu"
|
||||
return orig(*args, **kw).to(device)
|
||||
return new_func
|
||||
return orig
|
||||
|
||||
torch.rand = fix_func(torch.rand)
|
||||
torch.rand_like = fix_func(torch.rand_like)
|
||||
torch.randn = fix_func(torch.randn)
|
||||
torch.randn_like = fix_func(torch.randn_like)
|
||||
torch.randint = fix_func(torch.randint)
|
||||
torch.randint_like = fix_func(torch.randint_like)
|
||||
torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=False):
|
||||
print(f'Loading model from {ckpt}')
|
||||
@ -422,9 +439,7 @@ class ImageLogger(Callback):
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.logger_log_images = {
|
||||
pl.loggers.TestTubeLogger: self._testtube,
|
||||
}
|
||||
self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { }
|
||||
self.log_steps = [
|
||||
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
|
||||
]
|
||||
@ -527,7 +542,7 @@ class ImageLogger(Callback):
|
||||
return False
|
||||
|
||||
def on_train_batch_end(
|
||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
|
||||
):
|
||||
if not self.disabled and (
|
||||
pl_module.global_step > 0 or self.log_first_step
|
||||
@ -535,7 +550,7 @@ class ImageLogger(Callback):
|
||||
self.log_img(pl_module, batch, batch_idx, split='train')
|
||||
|
||||
def on_validation_batch_end(
|
||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
|
||||
):
|
||||
if not self.disabled and pl_module.global_step > 0:
|
||||
self.log_img(pl_module, batch, batch_idx, split='val')
|
||||
@ -555,7 +570,7 @@ class CUDACallback(Callback):
|
||||
torch.cuda.synchronize(trainer.root_gpu)
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
||||
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize(trainer.root_gpu)
|
||||
epoch_time = time.time() - self.start_time
|
||||
@ -736,6 +751,12 @@ if __name__ == '__main__':
|
||||
trainer_kwargs = dict()
|
||||
|
||||
# default logger configs
|
||||
if torch.cuda.is_available():
|
||||
def_logger = 'testtube'
|
||||
def_logger_target = 'TestTubeLogger'
|
||||
else:
|
||||
def_logger = 'csv'
|
||||
def_logger_target = 'CSVLogger'
|
||||
default_logger_cfgs = {
|
||||
'wandb': {
|
||||
'target': 'pytorch_lightning.loggers.WandbLogger',
|
||||
@ -746,15 +767,15 @@ if __name__ == '__main__':
|
||||
'id': nowname,
|
||||
},
|
||||
},
|
||||
'testtube': {
|
||||
'target': 'pytorch_lightning.loggers.TestTubeLogger',
|
||||
def_logger: {
|
||||
'target': 'pytorch_lightning.loggers.' + def_logger_target,
|
||||
'params': {
|
||||
'name': 'testtube',
|
||||
'name': def_logger,
|
||||
'save_dir': logdir,
|
||||
},
|
||||
},
|
||||
}
|
||||
default_logger_cfg = default_logger_cfgs['testtube']
|
||||
default_logger_cfg = default_logger_cfgs[def_logger]
|
||||
if 'logger' in lightning_config:
|
||||
logger_cfg = lightning_config.logger
|
||||
else:
|
||||
@ -868,6 +889,10 @@ if __name__ == '__main__':
|
||||
]
|
||||
trainer_kwargs['max_steps'] = trainer_opt.max_steps
|
||||
|
||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||
trainer_opt.accelerator = 'mps'
|
||||
trainer_opt.detect_anomaly = False
|
||||
|
||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||
trainer.logdir = logdir ###
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user