mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'Any-Winter-4079-Textual_Inversion' of github.com:Any-Winter-4079/InvokeAI into development
This commit is contained in:
commit
0a217b5f15
@ -30,9 +30,9 @@ model:
|
|||||||
target: ldm.modules.embedding_manager.EmbeddingManager
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
params:
|
params:
|
||||||
placeholder_strings: ["*"]
|
placeholder_strings: ["*"]
|
||||||
initializer_words: ["sculpture"]
|
initializer_words: ['face', 'man', 'photo', 'africanmale']
|
||||||
per_image_tokens: false
|
per_image_tokens: false
|
||||||
num_vectors_per_token: 1
|
num_vectors_per_token: 6
|
||||||
progressive_words: False
|
progressive_words: False
|
||||||
|
|
||||||
unet_config:
|
unet_config:
|
||||||
|
110
configs/stable-diffusion/v1-m1-finetune.yaml
Normal file
110
configs/stable-diffusion/v1-m1-finetune.yaml
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-03
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
embedding_reg_weight: 0.0
|
||||||
|
|
||||||
|
personalization_config:
|
||||||
|
target: ldm.modules.embedding_manager.EmbeddingManager
|
||||||
|
params:
|
||||||
|
placeholder_strings: ["*"]
|
||||||
|
initializer_words: ['face', 'man', 'photo', 'africanmale']
|
||||||
|
per_image_tokens: false
|
||||||
|
num_vectors_per_token: 6
|
||||||
|
progressive_words: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 2
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
set: train
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 100
|
||||||
|
validation:
|
||||||
|
target: ldm.data.personalized.PersonalizedBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
set: val
|
||||||
|
per_image_tokens: false
|
||||||
|
repeats: 10
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 500
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 500
|
||||||
|
max_images: 5
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: False
|
||||||
|
max_steps: 6200
|
||||||
|
# max_steps: 4000
|
||||||
|
|
@ -33,13 +33,13 @@ dependencies:
|
|||||||
- openh264==2.3.0
|
- openh264==2.3.0
|
||||||
- onnx==1.12.0
|
- onnx==1.12.0
|
||||||
- onnxruntime==1.12.1
|
- onnxruntime==1.12.1
|
||||||
- protobuf==3.20.1
|
- protobuf==3.19.4
|
||||||
- pudb==2022.1
|
- pudb==2022.1
|
||||||
- pytorch-lightning==1.6.5
|
- pytorch-lightning==1.7.5
|
||||||
- scipy==1.9.1
|
- scipy==1.9.1
|
||||||
- streamlit==1.12.2
|
- streamlit==1.12.2
|
||||||
- sympy==1.10.1
|
- sympy==1.10.1
|
||||||
- tensorboard==2.9.0
|
- tensorboard==2.10.0
|
||||||
- torchmetrics==0.9.3
|
- torchmetrics==0.9.3
|
||||||
- pip:
|
- pip:
|
||||||
- flask==2.1.3
|
- flask==2.1.3
|
||||||
|
@ -117,7 +117,7 @@ class PersonalizedBase(Dataset):
|
|||||||
|
|
||||||
self.image_paths = [
|
self.image_paths = [
|
||||||
os.path.join(self.data_root, file_path)
|
os.path.join(self.data_root, file_path)
|
||||||
for file_path in os.listdir(self.data_root)
|
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
|
||||||
]
|
]
|
||||||
|
|
||||||
# self._length = len(self.image_paths)
|
# self._length = len(self.image_paths)
|
||||||
|
@ -93,7 +93,7 @@ class PersonalizedBase(Dataset):
|
|||||||
|
|
||||||
self.image_paths = [
|
self.image_paths = [
|
||||||
os.path.join(self.data_root, file_path)
|
os.path.join(self.data_root, file_path)
|
||||||
for file_path in os.listdir(self.data_root)
|
for file_path in os.listdir(self.data_root) if file_path != ".DS_Store"
|
||||||
]
|
]
|
||||||
|
|
||||||
# self._length = len(self.image_paths)
|
# self._length = len(self.image_paths)
|
||||||
|
@ -701,7 +701,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=None):
|
||||||
# only for very first batch
|
# only for very first batch
|
||||||
if (
|
if (
|
||||||
self.scale_by_std
|
self.scale_by_std
|
||||||
@ -1890,7 +1890,7 @@ class LatentDiffusion(DDPM):
|
|||||||
N=8,
|
N=8,
|
||||||
n_row=4,
|
n_row=4,
|
||||||
sample=True,
|
sample=True,
|
||||||
ddim_steps=200,
|
ddim_steps=50,
|
||||||
ddim_eta=1.0,
|
ddim_eta=1.0,
|
||||||
return_keys=None,
|
return_keys=None,
|
||||||
quantize_denoised=True,
|
quantize_denoised=True,
|
||||||
|
@ -169,9 +169,14 @@ class EmbeddingManager(nn.Module):
|
|||||||
placeholder_embedding.shape[0], max_step_tokens
|
placeholder_embedding.shape[0], max_step_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
placeholder_rows, placeholder_cols = torch.where(
|
placeholder_rows, placeholder_cols = torch.where(
|
||||||
tokenized_text == placeholder_token.to(device)
|
tokenized_text == placeholder_token.to(device)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
placeholder_rows, placeholder_cols = torch.where(
|
||||||
|
tokenized_text == placeholder_token
|
||||||
|
)
|
||||||
|
|
||||||
if placeholder_rows.nelement() == 0:
|
if placeholder_rows.nelement() == 0:
|
||||||
continue
|
continue
|
||||||
|
45
main.py
45
main.py
@ -25,6 +25,23 @@ from pytorch_lightning.utilities import rank_zero_info
|
|||||||
from ldm.data.base import Txt2ImgIterableBaseDataset
|
from ldm.data.base import Txt2ImgIterableBaseDataset
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
def fix_func(orig):
|
||||||
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
|
def new_func(*args, **kw):
|
||||||
|
device = kw.get("device", "mps")
|
||||||
|
kw["device"]="cpu"
|
||||||
|
return orig(*args, **kw).to(device)
|
||||||
|
return new_func
|
||||||
|
return orig
|
||||||
|
|
||||||
|
torch.rand = fix_func(torch.rand)
|
||||||
|
torch.rand_like = fix_func(torch.rand_like)
|
||||||
|
torch.randn = fix_func(torch.randn)
|
||||||
|
torch.randn_like = fix_func(torch.randn_like)
|
||||||
|
torch.randint = fix_func(torch.randint)
|
||||||
|
torch.randint_like = fix_func(torch.randint_like)
|
||||||
|
torch.bernoulli = fix_func(torch.bernoulli)
|
||||||
|
torch.multinomial = fix_func(torch.multinomial)
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=False):
|
def load_model_from_config(config, ckpt, verbose=False):
|
||||||
print(f'Loading model from {ckpt}')
|
print(f'Loading model from {ckpt}')
|
||||||
@ -422,9 +439,7 @@ class ImageLogger(Callback):
|
|||||||
self.rescale = rescale
|
self.rescale = rescale
|
||||||
self.batch_freq = batch_frequency
|
self.batch_freq = batch_frequency
|
||||||
self.max_images = max_images
|
self.max_images = max_images
|
||||||
self.logger_log_images = {
|
self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { }
|
||||||
pl.loggers.TestTubeLogger: self._testtube,
|
|
||||||
}
|
|
||||||
self.log_steps = [
|
self.log_steps = [
|
||||||
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
|
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
|
||||||
]
|
]
|
||||||
@ -527,7 +542,7 @@ class ImageLogger(Callback):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def on_train_batch_end(
|
def on_train_batch_end(
|
||||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
|
||||||
):
|
):
|
||||||
if not self.disabled and (
|
if not self.disabled and (
|
||||||
pl_module.global_step > 0 or self.log_first_step
|
pl_module.global_step > 0 or self.log_first_step
|
||||||
@ -535,7 +550,7 @@ class ImageLogger(Callback):
|
|||||||
self.log_img(pl_module, batch, batch_idx, split='train')
|
self.log_img(pl_module, batch, batch_idx, split='train')
|
||||||
|
|
||||||
def on_validation_batch_end(
|
def on_validation_batch_end(
|
||||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
|
||||||
):
|
):
|
||||||
if not self.disabled and pl_module.global_step > 0:
|
if not self.disabled and pl_module.global_step > 0:
|
||||||
self.log_img(pl_module, batch, batch_idx, split='val')
|
self.log_img(pl_module, batch, batch_idx, split='val')
|
||||||
@ -555,7 +570,7 @@ class CUDACallback(Callback):
|
|||||||
torch.cuda.synchronize(trainer.root_gpu)
|
torch.cuda.synchronize(trainer.root_gpu)
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
|
||||||
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
def on_train_epoch_end(self, trainer, pl_module, outputs=None):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.synchronize(trainer.root_gpu)
|
torch.cuda.synchronize(trainer.root_gpu)
|
||||||
epoch_time = time.time() - self.start_time
|
epoch_time = time.time() - self.start_time
|
||||||
@ -736,6 +751,12 @@ if __name__ == '__main__':
|
|||||||
trainer_kwargs = dict()
|
trainer_kwargs = dict()
|
||||||
|
|
||||||
# default logger configs
|
# default logger configs
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
def_logger = 'testtube'
|
||||||
|
def_logger_target = 'TestTubeLogger'
|
||||||
|
else:
|
||||||
|
def_logger = 'csv'
|
||||||
|
def_logger_target = 'CSVLogger'
|
||||||
default_logger_cfgs = {
|
default_logger_cfgs = {
|
||||||
'wandb': {
|
'wandb': {
|
||||||
'target': 'pytorch_lightning.loggers.WandbLogger',
|
'target': 'pytorch_lightning.loggers.WandbLogger',
|
||||||
@ -746,15 +767,15 @@ if __name__ == '__main__':
|
|||||||
'id': nowname,
|
'id': nowname,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
'testtube': {
|
def_logger: {
|
||||||
'target': 'pytorch_lightning.loggers.TestTubeLogger',
|
'target': 'pytorch_lightning.loggers.' + def_logger_target,
|
||||||
'params': {
|
'params': {
|
||||||
'name': 'testtube',
|
'name': def_logger,
|
||||||
'save_dir': logdir,
|
'save_dir': logdir,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
default_logger_cfg = default_logger_cfgs['testtube']
|
default_logger_cfg = default_logger_cfgs[def_logger]
|
||||||
if 'logger' in lightning_config:
|
if 'logger' in lightning_config:
|
||||||
logger_cfg = lightning_config.logger
|
logger_cfg = lightning_config.logger
|
||||||
else:
|
else:
|
||||||
@ -868,6 +889,10 @@ if __name__ == '__main__':
|
|||||||
]
|
]
|
||||||
trainer_kwargs['max_steps'] = trainer_opt.max_steps
|
trainer_kwargs['max_steps'] = trainer_opt.max_steps
|
||||||
|
|
||||||
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
|
trainer_opt.accelerator = 'mps'
|
||||||
|
trainer_opt.detect_anomaly = False
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||||
trainer.logdir = logdir ###
|
trainer.logdir = logdir ###
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user