resolved conflicts with main

This commit is contained in:
Lincoln Stein
2023-07-27 15:11:25 -04:00
275 changed files with 11706 additions and 8208 deletions

View File

@ -4,6 +4,6 @@
import warnings
from invokeai.frontend.install import invokeai_configure as configure
if __name__ == '__main__':
if __name__ == "__main__":
warnings.warn("configure_invokeai.py is deprecated, running 'invokeai-configure'...", DeprecationWarning)
configure()

View File

@ -28,11 +28,12 @@ canny_image.show()
print("loading base model stable-diffusion-1.5")
model_config_path = os.getcwd() + "/../configs/models.yaml"
model_manager = ModelManager(model_config_path)
model = model_manager.get_model('stable-diffusion-1.5')
model = model_manager.get_model("stable-diffusion-1.5")
print("loading control model lllyasviel/sd-controlnet-canny")
canny_controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny",
torch_dtype=torch.float16).to("cuda")
canny_controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to(
"cuda"
)
print("testing Txt2Img() constructor with control_model arg")
txt2img_canny = Txt2Img(model, control_model=canny_controlnet)
@ -49,6 +50,3 @@ outputs = txt2img_canny.generate(
generate_output = next(outputs)
out_image = generate_output.image
out_image.show()

View File

@ -3,8 +3,9 @@
import warnings
from invokeai.frontend.CLI import invokeai_command_line_interface as main
warnings.warn("dream.py is being deprecated, please run invoke.py for the "
"new UI/API or legacy_api.py for the old API",
DeprecationWarning)
main()
warnings.warn(
"dream.py is being deprecated, please run invoke.py for the " "new UI/API or legacy_api.py for the old API",
DeprecationWarning,
)
main()

View File

@ -1,12 +1,14 @@
#!/usr/bin/env python
'''This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py'''
"""This script reads the "Invoke" Stable Diffusion prompt embedded in files generated by invoke.py"""
import sys
from PIL import Image,PngImagePlugin
from PIL import Image, PngImagePlugin
if len(sys.argv) < 2:
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
print("This script opens up the indicated invoke.py-generated PNG file(s) and prints out the prompt used to generate them.")
print(
"This script opens up the indicated invoke.py-generated PNG file(s) and prints out the prompt used to generate them."
)
exit(-1)
filenames = sys.argv[1:]
@ -14,17 +16,13 @@ for f in filenames:
try:
im = Image.open(f)
try:
prompt = im.text['Dream']
prompt = im.text["Dream"]
except KeyError:
prompt = ''
print(f'{f}: {prompt}')
prompt = ""
print(f"{f}: {prompt}")
except FileNotFoundError:
sys.stderr.write(f'{f} not found\n')
sys.stderr.write(f"{f} not found\n")
continue
except PermissionError:
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
sys.stderr.write(f"{f} could not be opened due to inadequate permissions\n")
continue

View File

@ -3,18 +3,22 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
import os
import sys
def main():
# Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# TODO: Parse some top-level args here.
from invokeai.app.cli_app import invoke_cli
invoke_cli()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,3 +1,3 @@
from invokeai.frontend.install.model_install import main
main()
main()

View File

@ -3,18 +3,21 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
logging.getLogger("xformers").addFilter(lambda record: "A matching Triton is not available" not in record.getMessage())
import os
import sys
def main():
# Change working directory to the repo root
os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from invokeai.app.api_app import invoke_api
invoke_api()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,23 +1,24 @@
#!/usr/bin/env python
'''
"""
This script is used at release time to generate a markdown table describing the
starter models. This text is then manually copied into 050_INSTALL_MODELS.md.
'''
"""
from omegaconf import OmegaConf
from pathlib import Path
def main():
initial_models_file = Path(__file__).parent / '../invokeai/configs/INITIAL_MODELS.yaml'
initial_models_file = Path(__file__).parent / "../invokeai/configs/INITIAL_MODELS.yaml"
models = OmegaConf.load(initial_models_file)
print('|Model Name | HuggingFace Repo ID | Description | URL |')
print('|---------- | ---------- | ----------- | --- |')
print("|Model Name | HuggingFace Repo ID | Description | URL |")
print("|---------- | ---------- | ----------- | --- |")
for model in models:
repo_id = models[model].repo_id
url = f'https://huggingface.co/{repo_id}'
print(f'|{model}|{repo_id}|{models[model].description}|{url} |')
url = f"https://huggingface.co/{repo_id}"
print(f"|{model}|{repo_id}|{models[model].description}|{url} |")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -18,7 +18,7 @@ from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.invoke.devices import choose_torch_device
from ldm.invoke.devices import choose_torch_device
def chunk(it, size):
@ -55,7 +55,7 @@ def load_img(path):
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.
return 2.0 * image - 1.0
def main():
@ -66,33 +66,24 @@ def main():
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
help="the prompt to render",
)
parser.add_argument(
"--init-img",
type=str,
nargs="?",
help="path to the input image"
)
parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/img2img-samples"
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action='store_true',
action="store_true",
help="do not save indiviual samples. For speed measurements.",
)
@ -105,12 +96,12 @@ def main():
parser.add_argument(
"--plms",
action='store_true',
action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--fixed_code",
action='store_true',
action="store_true",
help="if enabled, uses the same starting code across all samples ",
)
@ -187,11 +178,7 @@ def main():
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
)
opt = parser.parse_args()
@ -232,18 +219,18 @@ def main():
assert os.path.isfile(opt.init_img)
init_image = load_img(opt.init_img).to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
if device.type in ["mps", "cpu"]:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
@ -259,37 +246,42 @@ def main():
c = model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,)
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save:
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png"))
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
all_samples.append(x_samples)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
if __name__ == "__main__":

View File

@ -8,25 +8,26 @@ from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.invoke.devices import choose_torch_device
def make_batch(image, mask, device):
image = np.array(Image.open(image).convert("RGB"))
image = image.astype(np.float32)/255.0
image = image[None].transpose(0,3,1,2)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
mask = np.array(Image.open(mask).convert("L"))
mask = mask.astype(np.float32)/255.0
mask = mask[None,None]
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = (1-mask)*image
masked_image = (1 - mask) * image
batch = {"image": image, "mask": mask, "masked_image": masked_image}
for k in batch:
batch[k] = batch[k].to(device=device)
batch[k] = batch[k]*2.0-1.0
batch[k] = batch[k] * 2.0 - 1.0
return batch
@ -58,11 +59,10 @@ if __name__ == "__main__":
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
strict=False)
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False)
device = choose_torch_device()
model = model.to(device)
device = choose_torch_device()
model = model.to(device)
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
@ -74,25 +74,19 @@ if __name__ == "__main__":
# encode masked image and concat downsampled mask
c = model.cond_stage_model.encode(batch["masked_image"])
cc = torch.nn.functional.interpolate(batch["mask"],
size=c.shape[-2:])
cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:])
c = torch.cat((c, cc), dim=1)
shape = (c.shape[1]-1,)+c.shape[2:]
samples_ddim, _ = sampler.sample(S=opt.steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False)
shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim, _ = sampler.sample(
S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
image = torch.clamp((batch["image"]+1.0)/2.0,
min=0.0, max=1.0)
mask = torch.clamp((batch["mask"]+1.0)/2.0,
min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
min=0.0, max=1.0)
image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0)
mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
inpainted = (1-mask)*image+mask*predicted_image
inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
inpainted = (1 - mask) * image + mask * predicted_image
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)

View File

@ -59,29 +59,24 @@ def load_model_from_config(config, ckpt, verbose=False):
class Searcher(object):
def __init__(self, database, retriever_version='ViT-L/14'):
def __init__(self, database, retriever_version="ViT-L/14"):
assert database in DATABASES
# self.database = self.load_database(database)
self.database_name = database
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
self.searcher_savedir = f"data/rdm/searchers/{self.database_name}"
self.database_path = f"data/rdm/retrieval_databases/{self.database_name}"
self.retriever = self.load_retriever(version=retriever_version)
self.database = {'embedding': [],
'img_id': [],
'patch_coords': []}
self.database = {"embedding": [], "img_id": [], "patch_coords": []}
self.load_database()
self.load_searcher()
def train_searcher(self, k,
metric='dot_product',
searcher_savedir=None):
print('Start training searcher')
searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
k, metric)
def train_searcher(self, k, metric="dot_product", searcher_savedir=None):
print("Start training searcher")
searcher = scann.scann_ops_pybind.builder(
self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric
)
self.searcher = searcher.score_brute_force().build()
print('Finish training searcher')
print("Finish training searcher")
if searcher_savedir is not None:
print(f'Save trained searcher under "{searcher_savedir}"')
@ -91,36 +86,40 @@ class Searcher(object):
def load_single_file(self, saved_embeddings):
compressed = np.load(saved_embeddings)
self.database = {key: compressed[key] for key in compressed.files}
print('Finished loading of clip embeddings.')
print("Finished loading of clip embeddings.")
def load_multi_files(self, data_archive):
out_data = {key: [] for key in self.database}
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
for key in d.files:
out_data[key].append(d[key])
return out_data
def load_database(self):
print(f'Load saved patch embedding from "{self.database_path}"')
file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
file_content = glob.glob(os.path.join(self.database_path, "*.npz"))
if len(file_content) == 1:
self.load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
prefetched_data = parallel_data_prefetch(
self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
)
self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
self.database}
self.database = {
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database
}
else:
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
def load_retriever(self, version='ViT-L/14', ):
def load_retriever(
self,
version="ViT-L/14",
):
model = FrozenClipImageEmbedder(model=version)
if torch.cuda.is_available():
model.cuda()
@ -128,14 +127,14 @@ class Searcher(object):
return model
def load_searcher(self):
print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
print(f"load searcher for database {self.database_name} from {self.searcher_savedir}")
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
print('Finished loading searcher.')
print("Finished loading searcher.")
def search(self, x, k):
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
if self.searcher is None and self.database["embedding"].shape[0] < 2e4:
self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, "Cannot search with uninitialized searcher"
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
if len(x.shape) == 3:
@ -146,17 +145,19 @@ class Searcher(object):
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
end = time.time()
out_embeddings = self.database['embedding'][nns]
out_img_ids = self.database['img_id'][nns]
out_pc = self.database['patch_coords'][nns]
out_embeddings = self.database["embedding"][nns]
out_img_ids = self.database["img_id"][nns]
out_pc = self.database["patch_coords"][nns]
out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
'img_ids': out_img_ids,
'patch_coords': out_pc,
'queries': x,
'exec_time': end - start,
'nns': nns,
'q_embeddings': query_embeddings}
out = {
"nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
"img_ids": out_img_ids,
"patch_coords": out_pc,
"queries": x,
"exec_time": end - start,
"nns": nns,
"q_embeddings": query_embeddings,
}
return out
@ -173,20 +174,16 @@ if __name__ == "__main__":
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
help="the prompt to render",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
@ -206,7 +203,7 @@ if __name__ == "__main__":
parser.add_argument(
"--plms",
action='store_true',
action="store_true",
help="use plms sampling",
)
@ -287,14 +284,14 @@ if __name__ == "__main__":
parser.add_argument(
"--database",
type=str,
default='artbench-surrealism',
default="artbench-surrealism",
choices=DATABASES,
help="The database used for the search, only applied when --use_neighbors=True",
)
parser.add_argument(
"--use_neighbors",
default=False,
action='store_true',
action="store_true",
help="Include neighbors in addition to text prompt for conditioning",
)
parser.add_argument(
@ -358,41 +355,43 @@ if __name__ == "__main__":
uc = None
if searcher is not None:
nn_dict = searcher(c, opt.knn)
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1)
if opt.scale != 1.0:
uc = torch.zeros_like(c)
if isinstance(prompts, tuple):
prompts = list(prompts)
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
)
samples_ddim, _ = sampler.sample(
S=opt.ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png"))
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")

View File

@ -25,15 +25,19 @@ from pytorch_lightning.utilities import rank_zero_info
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
kw["device"] = "cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
@ -43,18 +47,19 @@ torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def load_model_from_config(config, ckpt, verbose=False):
print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location='cpu')
sd = pl_sd['state_dict']
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
config.model.params.ckpt_path = ckpt
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print('missing keys:')
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print('unexpected keys:')
print("unexpected keys:")
print(u)
if torch.cuda.is_available():
@ -66,132 +71,130 @@ def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
'-n',
'--name',
"-n",
"--name",
type=str,
const=True,
default='',
nargs='?',
help='postfix for logdir',
default="",
nargs="?",
help="postfix for logdir",
)
parser.add_argument(
'-r',
'--resume',
"-r",
"--resume",
type=str,
const=True,
default='',
nargs='?',
help='resume from logdir or checkpoint in logdir',
default="",
nargs="?",
help="resume from logdir or checkpoint in logdir",
)
parser.add_argument(
'-b',
'--base',
nargs='*',
metavar='base_config.yaml',
help='paths to base configs. Loaded from left-to-right. '
'Parameters can be overwritten or added with command-line options of the form `--key value`.',
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
'-t',
'--train',
"-t",
"--train",
type=str2bool,
const=True,
default=False,
nargs='?',
help='train',
nargs="?",
help="train",
)
parser.add_argument(
'--no-test',
"--no-test",
type=str2bool,
const=True,
default=False,
nargs='?',
help='disable test',
nargs="?",
help="disable test",
)
parser.add_argument("-p", "--project", help="name of new or path to existing project")
parser.add_argument(
'-p', '--project', help='name of new or path to existing project'
)
parser.add_argument(
'-d',
'--debug',
"-d",
"--debug",
type=str2bool,
nargs='?',
nargs="?",
const=True,
default=False,
help='enable post-mortem debugging',
help="enable post-mortem debugging",
)
parser.add_argument(
'-s',
'--seed',
"-s",
"--seed",
type=int,
default=23,
help='seed for seed_everything',
help="seed for seed_everything",
)
parser.add_argument(
'-f',
'--postfix',
"-f",
"--postfix",
type=str,
default='',
help='post-postfix for default name',
default="",
help="post-postfix for default name",
)
parser.add_argument(
'-l',
'--logdir',
"-l",
"--logdir",
type=str,
default='logs',
help='directory for logging dat shit',
default="logs",
help="directory for logging dat shit",
)
parser.add_argument(
'--scale_lr',
"--scale_lr",
type=str2bool,
nargs='?',
nargs="?",
const=True,
default=True,
help='scale base-lr by ngpu * batch_size * n_accumulate',
help="scale base-lr by ngpu * batch_size * n_accumulate",
)
parser.add_argument(
'--datadir_in_name',
"--datadir_in_name",
type=str2bool,
nargs='?',
nargs="?",
const=True,
default=True,
help='Prepend the final directory in the data_root to the output directory name',
help="Prepend the final directory in the data_root to the output directory name",
)
parser.add_argument(
'--actual_resume',
"--actual_resume",
type=str,
default='',
help='Path to model to actually resume from',
default="",
help="Path to model to actually resume from",
)
parser.add_argument(
'--data_root',
"--data_root",
type=str,
required=True,
help='Path to directory with training images',
help="Path to directory with training images",
)
parser.add_argument(
'--embedding_manager_ckpt',
"--embedding_manager_ckpt",
type=str,
default='',
help='Initialize embedding manager from a checkpoint',
default="",
help="Initialize embedding manager from a checkpoint",
)
parser.add_argument(
'--init_word',
"--init_word",
type=str,
help='Word to use as source for initial token embedding.',
help="Word to use as source for initial token embedding.",
)
return parser
@ -226,9 +229,7 @@ def worker_init_fn(_):
if isinstance(dataset, Txt2ImgIterableBaseDataset):
split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
dataset.sample_ids = dataset.valid_ids[
worker_id * split_size : (worker_id + 1) * split_size
]
dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
else:
@ -252,25 +253,19 @@ class DataModuleFromConfig(pl.LightningDataModule):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = (
num_workers if num_workers is not None else batch_size * 2
)
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs['train'] = train
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs['validation'] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
if test is not None:
self.dataset_configs['test'] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
self.dataset_configs["test"] = test
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
if predict is not None:
self.dataset_configs['predict'] = predict
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
@ -279,24 +274,19 @@ class DataModuleFromConfig(pl.LightningDataModule):
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict(
(k, instantiate_from_config(self.dataset_configs[k]))
for k in self.dataset_configs
)
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
is_iterable_dataset = isinstance(
self.datasets['train'], Txt2ImgIterableBaseDataset
)
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets['train'],
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False if is_iterable_dataset else True,
@ -304,15 +294,12 @@ class DataModuleFromConfig(pl.LightningDataModule):
)
def _val_dataloader(self, shuffle=False):
if (
isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset)
or self.use_worker_init_fn
):
if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets['validation'],
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
@ -320,9 +307,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(
self.datasets['train'], Txt2ImgIterableBaseDataset
)
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
@ -332,7 +317,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
shuffle = shuffle and (not is_iterable_dataset)
return DataLoader(
self.datasets['test'],
self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
@ -340,15 +325,12 @@ class DataModuleFromConfig(pl.LightningDataModule):
)
def _predict_dataloader(self, shuffle=False):
if (
isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset)
or self.use_worker_init_fn
):
if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets['predict'],
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
@ -356,9 +338,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
class SetupCallback(Callback):
def __init__(
self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config
):
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
super().__init__()
self.resume = resume
self.now = now
@ -370,8 +350,8 @@ class SetupCallback(Callback):
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
print('Summoning checkpoint.')
ckpt_path = os.path.join(self.ckptdir, 'last.ckpt')
print("Summoning checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def on_pretrain_routine_start(self, trainer, pl_module):
@ -381,36 +361,31 @@ class SetupCallback(Callback):
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
if 'callbacks' in self.lightning_config:
if (
'metrics_over_trainsteps_checkpoint'
in self.lightning_config['callbacks']
):
if "callbacks" in self.lightning_config:
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
os.makedirs(
os.path.join(self.ckptdir, 'trainstep_checkpoints'),
os.path.join(self.ckptdir, "trainstep_checkpoints"),
exist_ok=True,
)
print('Project config')
print("Project config")
print(OmegaConf.to_yaml(self.config))
OmegaConf.save(
self.config,
os.path.join(self.cfgdir, '{}-project.yaml'.format(self.now)),
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
)
print('Lightning config')
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))
OmegaConf.save(
OmegaConf.create({'lightning': self.lightning_config}),
os.path.join(
self.cfgdir, '{}-lightning.yaml'.format(self.now)
),
OmegaConf.create({"lightning": self.lightning_config}),
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, 'child_runs', name)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
os.rename(self.logdir, dst)
@ -435,10 +410,8 @@ class ImageLogger(Callback):
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = { }
self.log_steps = [
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
]
self.logger_log_images = {}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
@ -448,10 +421,8 @@ class ImageLogger(Callback):
self.log_first_step = log_first_step
@rank_zero_only
def log_local(
self, save_dir, split, images, global_step, current_epoch, batch_idx
):
root = os.path.join(save_dir, 'images', split)
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
@ -459,22 +430,16 @@ class ImageLogger(Callback):
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = '{}_gs-{:06}_e-{:06}_b-{:06}.png'.format(
k, global_step, current_epoch, batch_idx
)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split='train'):
check_idx = (
batch_idx if self.log_on_batch_idx else pl_module.global_step
)
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if (
self.check_frequency(check_idx)
and hasattr( # batch_idx % self.batch_freq == 0
pl_module, 'log_images'
)
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
and callable(pl_module.log_images)
and self.max_images > 0
):
@ -485,9 +450,7 @@ class ImageLogger(Callback):
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(
batch, split=split, **self.log_images_kwargs
)
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
@ -506,18 +469,16 @@ class ImageLogger(Callback):
batch_idx,
)
logger_log_images = self.logger_log_images.get(
logger, lambda *args, **kwargs: None
)
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if (
(check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)
) and (check_idx > 0 or self.log_first_step):
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
check_idx > 0 or self.log_first_step
):
try:
self.log_steps.pop(0)
except IndexError as e:
@ -526,23 +487,15 @@ class ImageLogger(Callback):
return True
return False
def on_train_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
if not self.disabled and (
pl_module.global_step > 0 or self.log_first_step
):
self.log_img(pl_module, batch, batch_idx, split='train')
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split='val')
if hasattr(pl_module, 'calibrate_grad_norm'):
if (
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
) and batch_idx > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, "calibrate_grad_norm"):
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
@ -562,19 +515,17 @@ class CUDACallback(Callback):
try:
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds')
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
if torch.cuda.is_available():
max_memory = (
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
)
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
max_memory = trainer.training_type_plugin.reduce(max_memory)
rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB')
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass
class ModeSwapCallback(Callback):
class ModeSwapCallback(Callback):
def __init__(self, swap_step=2000):
super().__init__()
self.is_frozen = False
@ -589,7 +540,8 @@ class ModeSwapCallback(Callback):
self.is_frozen = False
trainer.optimizers = [pl_module.configure_opt_model()]
if __name__ == '__main__':
if __name__ == "__main__":
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
@ -631,7 +583,7 @@ if __name__ == '__main__':
# params:
# key: value
now = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
@ -644,50 +596,47 @@ if __name__ == '__main__':
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError(
'-n/--name and -r/--resume cannot be specified both.'
'If you want to resume training in a new log folder, '
'use -n/--name in combination with --resume_from_checkpoint'
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError('Cannot find {}'.format(opt.resume))
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split('/')
paths = opt.resume.split("/")
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir = '/'.join(paths[:-2])
logdir = "/".join(paths[:-2])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip('/')
ckpt = os.path.join(logdir, 'checkpoints', 'last.ckpt')
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(
glob.glob(os.path.join(logdir, 'configs/*.yaml'))
)
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split('/')
_tmp = logdir.split("/")
nowname = _tmp[-1]
else:
if opt.name:
name = '_' + opt.name
name = "_" + opt.name
elif opt.base:
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = '_' + cfg_name
name = "_" + cfg_name
else:
name = ''
name = ""
if opt.datadir_in_name:
now = os.path.basename(os.path.normpath(opt.data_root)) + now
nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname)
ckptdir = os.path.join(logdir, 'checkpoints')
cfgdir = os.path.join(logdir, 'configs')
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
try:
@ -695,19 +644,19 @@ if __name__ == '__main__':
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop('lightning', OmegaConf.create())
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get('trainer', OmegaConf.create())
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# default to ddp
trainer_config['accelerator'] = 'auto'
trainer_config["accelerator"] = "auto"
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if not 'gpus' in trainer_config:
del trainer_config['accelerator']
if not "gpus" in trainer_config:
del trainer_config["accelerator"]
cpu = True
else:
gpuinfo = trainer_config['gpus']
print(f'Running on GPUs {gpuinfo}')
gpuinfo = trainer_config["gpus"]
print(f"Running on GPUs {gpuinfo}")
cpu = False
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
@ -715,9 +664,7 @@ if __name__ == '__main__':
# model
# config.model.params.personalization_config.params.init_word = opt.init_word
config.model.params.personalization_config.params.embedding_manager_ckpt = (
opt.embedding_manager_ckpt
)
config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt
if opt.init_word:
config.model.params.personalization_config.params.initializer_words = [opt.init_word]
@ -731,142 +678,128 @@ if __name__ == '__main__':
trainer_kwargs = dict()
# default logger configs
def_logger = 'csv'
def_logger_target = 'CSVLogger'
def_logger = "csv"
def_logger_target = "CSVLogger"
default_logger_cfgs = {
'wandb': {
'target': 'pytorch_lightning.loggers.WandbLogger',
'params': {
'name': nowname,
'save_dir': logdir,
'offline': opt.debug,
'id': nowname,
"wandb": {
"target": "pytorch_lightning.loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
"offline": opt.debug,
"id": nowname,
},
},
def_logger: {
'target': 'pytorch_lightning.loggers.' + def_logger_target,
'params': {
'name': def_logger,
'save_dir': logdir,
"target": "pytorch_lightning.loggers." + def_logger_target,
"params": {
"name": def_logger,
"save_dir": logdir,
},
},
}
default_logger_cfg = default_logger_cfgs[def_logger]
if 'logger' in lightning_config:
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs['logger'] = instantiate_from_config(logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
'target': 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
'dirpath': ckptdir,
'filename': '{epoch:06}',
'verbose': True,
'save_last': True,
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
},
}
if hasattr(model, 'monitor'):
print(f'Monitoring {model.monitor} as checkpoint metric.')
default_modelckpt_cfg['params']['monitor'] = model.monitor
default_modelckpt_cfg['params']['save_top_k'] = 1
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 1
if 'modelcheckpoint' in lightning_config:
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f'Merged modelckpt-cfg: \n{modelckpt_cfg}')
if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs['checkpoint_callback'] = instantiate_from_config(
modelckpt_cfg
)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse("1.4.0"):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
# add callback which sets up log directory
default_callbacks_cfg = {
'setup_callback': {
'target': 'main.SetupCallback',
'params': {
'resume': opt.resume,
'now': now,
'logdir': logdir,
'ckptdir': ckptdir,
'cfgdir': cfgdir,
'config': config,
'lightning_config': lightning_config,
"setup_callback": {
"target": "main.SetupCallback",
"params": {
"resume": opt.resume,
"now": now,
"logdir": logdir,
"ckptdir": ckptdir,
"cfgdir": cfgdir,
"config": config,
"lightning_config": lightning_config,
},
},
'image_logger': {
'target': 'main.ImageLogger',
'params': {
'batch_frequency': 750,
'max_images': 4,
'clamp': True,
"image_logger": {
"target": "main.ImageLogger",
"params": {
"batch_frequency": 750,
"max_images": 4,
"clamp": True,
},
},
'learning_rate_logger': {
'target': 'main.LearningRateMonitor',
'params': {
'logging_interval': 'step',
"learning_rate_logger": {
"target": "main.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
'cuda_callback': {'target': 'main.CUDACallback'},
"cuda_callback": {"target": "main.CUDACallback"},
}
if version.parse(pl.__version__) >= version.parse('1.4.0'):
default_callbacks_cfg.update(
{'checkpoint_callback': modelckpt_cfg}
)
if version.parse(pl.__version__) >= version.parse("1.4.0"):
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
if 'callbacks' in lightning_config:
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
print(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
"Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
)
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {
'target': 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
'dirpath': os.path.join(
ckptdir, 'trainstep_checkpoints'
),
'filename': '{epoch:06}-{step:09}',
'verbose': True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True,
"metrics_over_trainsteps_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
"save_top_k": -1,
"every_n_train_steps": 10000,
"save_weights_only": True,
},
}
}
default_callbacks_cfg.update(
default_metrics_over_trainsteps_ckpt_dict
)
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
if 'ignore_keys_callback' in callbacks_cfg and hasattr(
trainer_opt, 'resume_from_checkpoint'
):
callbacks_cfg.ignore_keys_callback.params[
'ckpt_path'
] = trainer_opt.resume_from_checkpoint
elif 'ignore_keys_callback' in callbacks_cfg:
del callbacks_cfg['ignore_keys_callback']
if "ignore_keys_callback" in callbacks_cfg and hasattr(trainer_opt, "resume_from_checkpoint"):
callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = trainer_opt.resume_from_checkpoint
elif "ignore_keys_callback" in callbacks_cfg:
del callbacks_cfg["ignore_keys_callback"]
trainer_kwargs['callbacks'] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
trainer_kwargs['max_steps'] = trainer_opt.max_steps
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer_kwargs["max_steps"] = trainer_opt.max_steps
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
trainer_opt.accelerator = 'mps'
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
trainer_opt.accelerator = "mps"
trainer_opt.detect_anomaly = False
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
@ -882,11 +815,9 @@ if __name__ == '__main__':
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
print('#### Data #####')
print("#### Data #####")
for k in data.datasets:
print(
f'{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}'
)
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
# configure learning rate
bs, base_lr = (
@ -894,24 +825,20 @@ if __name__ == '__main__':
config.model.base_learning_rate,
)
if not cpu:
gpus = str(lightning_config.trainer.gpus).strip(', ').split(',')
gpus = str(lightning_config.trainer.gpus).strip(", ").split(",")
ngpu = len(gpus)
else:
ngpu = 1
if 'accumulate_grad_batches' in lightning_config.trainer:
accumulate_grad_batches = (
lightning_config.trainer.accumulate_grad_batches
)
if "accumulate_grad_batches" in lightning_config.trainer:
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
print(f'accumulate_grad_batches = {accumulate_grad_batches}')
lightning_config.trainer.accumulate_grad_batches = (
accumulate_grad_batches
)
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
'Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)'.format(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
model.learning_rate,
accumulate_grad_batches,
ngpu,
@ -921,15 +848,15 @@ if __name__ == '__main__':
)
else:
model.learning_rate = base_lr
print('++++ NOT USING LR SCALING ++++')
print(f'Setting learning rate to {model.learning_rate:.2e}')
print("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print('Summoning checkpoint.')
ckpt_path = os.path.join(ckptdir, 'last.ckpt')
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
@ -964,7 +891,7 @@ if __name__ == '__main__':
# move newly created debug project to debug_runs
if opt.debug and not opt.resume and trainer.global_rank == 0:
dst, name = os.path.split(logdir)
dst = os.path.join(dst, 'debug_runs', name)
dst = os.path.join(dst, "debug_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)
# if trainer.global_rank == 0:

View File

@ -7,21 +7,30 @@ from functools import partial
import torch
def get_placeholder_loop(placeholder_string, embedder, use_bert):
new_placeholder = None
def get_placeholder_loop(placeholder_string, embedder, use_bert):
new_placeholder = None
while True:
if new_placeholder is None:
new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ")
new_placeholder = input(
f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: "
)
else:
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
new_placeholder = input(
f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: "
)
token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
token = (
get_bert_token_for_string(embedder.tknz_fn, new_placeholder)
if use_bert
else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
)
if token is not None:
return new_placeholder, token
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(
string,
@ -30,7 +39,7 @@ def get_clip_token_for_string(tokenizer, string):
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt"
return_tensors="pt",
)
tokens = batch_encoding["input_ids"]
@ -40,6 +49,7 @@ def get_clip_token_for_string(tokenizer, string):
return None
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
if torch.count_nonzero(token) == 3:
@ -49,22 +59,17 @@ def get_bert_token_for_string(tokenizer, string):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--root_dir",
type=str,
default='.',
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'."
default=".",
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'.",
)
parser.add_argument(
"--manager_ckpts",
type=str,
nargs="+",
required=True,
help="Paths to a set of embedding managers to be merged."
"--manager_ckpts", type=str, nargs="+", required=True, help="Paths to a set of embedding managers to be merged."
)
parser.add_argument(
@ -75,13 +80,14 @@ if __name__ == "__main__":
)
parser.add_argument(
"-sd", "--use_bert",
"-sd",
"--use_bert",
action="store_true",
help="Flag to denote that we are not merging stable diffusion embeddings"
help="Flag to denote that we are not merging stable diffusion embeddings",
)
args = parser.parse_args()
Globals.root=args.root_dir
Globals.root = args.root_dir
if args.use_bert:
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()

View File

@ -10,12 +10,13 @@ from PIL import Image
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
rescale = lambda x: (x + 1.) / 2.
rescale = lambda x: (x + 1.0) / 2.0
def custom_to_pil(x):
x = x.detach().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.) / 2.
x = torch.clamp(x, -1.0, 1.0)
x = (x + 1.0) / 2.0
x = x.permute(1, 2, 0).numpy()
x = (255 * x).astype(np.uint8)
x = Image.fromarray(x)
@ -51,49 +52,51 @@ def logs2pil(logs, keys=["sample"]):
@torch.no_grad()
def convsample(model, shape, return_intermediates=True,
verbose=True,
make_prog_row=False):
def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False):
if not make_prog_row:
return model.p_sample_loop(None, shape,
return_intermediates=return_intermediates, verbose=verbose)
return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose)
else:
return model.progressive_denoising(
None, shape, verbose=True
)
return model.progressive_denoising(None, shape, verbose=True)
@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0
):
def convsample_ddim(model, steps, shape, eta=1.0):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
samples, intermediates = ddim.sample(
steps,
batch_size=bs,
shape=shape,
eta=eta,
verbose=False,
)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
def make_convolutional_sample(
model,
batch_size,
vanilla=False,
custom_steps=None,
eta=1.0,
):
log = dict()
shape = [batch_size,
model.model.diffusion_model.in_channels,
model.model.diffusion_model.image_size,
model.model.diffusion_model.image_size]
shape = [
batch_size,
model.model.diffusion_model.in_channels,
model.model.diffusion_model.image_size,
model.model.diffusion_model.image_size,
]
with model.ema_scope("Plotting"):
t0 = time.time()
if vanilla:
sample, progrow = convsample(model, shape,
make_prog_row=True)
sample, progrow = convsample(model, shape, make_prog_row=True)
else:
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
eta=eta)
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta)
t1 = time.time()
@ -101,32 +104,32 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non
log["sample"] = x_sample
log["time"] = t1 - t0
log['throughput'] = sample.shape[0] / (t1 - t0)
log["throughput"] = sample.shape[0] / (t1 - t0)
print(f'Throughput for this batch: {log["throughput"]}')
return log
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
if vanilla:
print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.")
else:
print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}")
tstart = time.time()
n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1
# path = logdir
if model.cond_stage_model is None:
all_images = []
print(f"Running unconditional sampling for {n_samples} samples")
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
logs = make_convolutional_sample(model, batch_size=batch_size,
vanilla=vanilla, custom_steps=custom_steps,
eta=eta)
logs = make_convolutional_sample(
model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta
)
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
all_images.extend([custom_to_np(logs["sample"])])
if n_saved >= n_samples:
print(f'Finish after generating {n_saved} samples')
print(f"Finish after generating {n_saved} samples")
break
all_img = np.concatenate(all_images, axis=0)
all_img = all_img[:n_samples]
@ -135,7 +138,7 @@ def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None
np.savez(nppath, all_img)
else:
raise NotImplementedError('Currently only sampling for unconditional models supported.')
raise NotImplementedError("Currently only sampling for unconditional models supported.")
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
@ -168,58 +171,33 @@ def get_parser():
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-n",
"--n_samples",
type=int,
nargs="?",
help="number of samples to draw",
default=50000
)
parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000)
parser.add_argument(
"-e",
"--eta",
type=float,
nargs="?",
help="eta for ddim sampling (0.0 yields deterministic sampling)",
default=1.0
default=1.0,
)
parser.add_argument(
"-v",
"--vanilla_sample",
default=False,
action='store_true',
action="store_true",
help="vanilla sampling (default option is DDIM sampling)?",
)
parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none")
parser.add_argument(
"-l",
"--logdir",
type=str,
nargs="?",
help="extra logdir",
default="none"
)
parser.add_argument(
"-c",
"--custom_steps",
type=int,
nargs="?",
help="number of steps for ddim and fastdpm sampling",
default=50
)
parser.add_argument(
"--batch_size",
type=int,
nargs="?",
help="the bs",
default=10
"-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50
)
parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10)
return parser
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd,strict=False)
model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model
@ -233,8 +211,7 @@ def load_model(config, ckpt, gpu, eval_mode):
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"])
model = load_model_from_config(config.model, pl_sd["state_dict"])
return model, global_step
@ -253,9 +230,9 @@ if __name__ == "__main__":
if os.path.isfile(opt.resume):
# paths = opt.resume.split("/")
try:
logdir = '/'.join(opt.resume.split('/')[:-1])
logdir = "/".join(opt.resume.split("/")[:-1])
# idx = len(paths)-paths[::-1].index("logs")+1
print(f'Logdir is {logdir}')
print(f"Logdir is {logdir}")
except ValueError:
paths = opt.resume.split("/")
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
@ -278,7 +255,8 @@ if __name__ == "__main__":
if opt.logdir != "none":
locallog = logdir.split(os.sep)[-1]
if locallog == "": locallog = logdir.split(os.sep)[-2]
if locallog == "":
locallog = logdir.split(os.sep)[-2]
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
logdir = os.path.join(opt.logdir, locallog)
@ -301,13 +279,19 @@ if __name__ == "__main__":
sampling_file = os.path.join(logdir, "sampling_config.yaml")
sampling_conf = vars(opt)
with open(sampling_file, 'w') as f:
with open(sampling_file, "w") as f:
yaml.dump(sampling_conf, f, default_flow_style=False)
print(sampling_conf)
run(model, imglogdir, eta=opt.eta,
vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
batch_size=opt.batch_size, nplog=numpylogdir)
run(
model,
imglogdir,
eta=opt.eta,
vanilla=opt.vanilla_sample,
n_samples=opt.n_samples,
custom_steps=opt.custom_steps,
batch_size=opt.batch_size,
nplog=numpylogdir,
)
print("done.")

View File

@ -13,21 +13,26 @@ def search_bruteforce(searcher):
return searcher.score_brute_force().build()
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search):
return searcher.tree(num_leaves=num_leaves,
num_leaves_to_search=num_leaves_to_search,
training_sample_size=partioning_trainsize). \
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
def search_partioned_ah(
searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
):
return (
searcher.tree(
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize
)
.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold)
.reorder(reorder_k)
.build()
)
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
reorder_k).build()
return (
searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
)
def load_datapool(dpath):
def load_single_file(saved_embeddings):
compressed = np.load(saved_embeddings)
database = {key: compressed[key] for key in compressed.files}
@ -35,23 +40,26 @@ def load_datapool(dpath):
def load_multi_files(data_archive):
database = {key: [] for key in data_archive[0].files}
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
for key in d.files:
database[key].append(d[key])
return database
print(f'Load saved patch embedding from "{dpath}"')
file_content = glob.glob(os.path.join(dpath, '*.npz'))
file_content = glob.glob(os.path.join(dpath, "*.npz"))
if len(file_content) == 1:
data_pool = load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
prefetched_data = parallel_data_prefetch(
load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
)
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
data_pool = {
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()
}
else:
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
@ -59,16 +67,17 @@ def load_datapool(dpath):
return data_pool
def train_searcher(opt,
metric='dot_product',
partioning_trainsize=None,
reorder_k=None,
# todo tune
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
num_leaves_to_search=None,):
def train_searcher(
opt,
metric="dot_product",
partioning_trainsize=None,
reorder_k=None,
# todo tune
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
num_leaves_to_search=None,
):
data_pool = load_datapool(opt.database)
k = opt.knn
@ -77,71 +86,83 @@ def train_searcher(opt,
# normalize
# embeddings =
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
pool_size = data_pool['embedding'].shape[0]
searcher = scann.scann_ops_pybind.builder(
data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric
)
pool_size = data_pool["embedding"].shape[0]
print(*(['#'] * 100))
print('Initializing scaNN searcher with the following values:')
print(f'k: {k}')
print(f'metric: {metric}')
print(f'reorder_k: {reorder_k}')
print(f'anisotropic_quantization_threshold: {aiq_thld}')
print(f'dims_per_block: {dims_per_block}')
print(*(['#'] * 100))
print('Start training searcher....')
print(f'N samples in pool is {pool_size}')
print(*(["#"] * 100))
print("Initializing scaNN searcher with the following values:")
print(f"k: {k}")
print(f"metric: {metric}")
print(f"reorder_k: {reorder_k}")
print(f"anisotropic_quantization_threshold: {aiq_thld}")
print(f"dims_per_block: {dims_per_block}")
print(*(["#"] * 100))
print("Start training searcher....")
print(f"N samples in pool is {pool_size}")
# this reflects the recommended design choices proposed at
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
if pool_size < 2e4:
print('Using brute force search.')
print("Using brute force search.")
searcher = search_bruteforce(searcher)
elif 2e4 <= pool_size and pool_size < 1e5:
print('Using asymmetric hashing search and reordering.')
print("Using asymmetric hashing search and reordering.")
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
else:
print('Using using partioning, asymmetric hashing search and reordering.')
print("Using using partioning, asymmetric hashing search and reordering.")
if not partioning_trainsize:
partioning_trainsize = data_pool['embedding'].shape[0] // 10
partioning_trainsize = data_pool["embedding"].shape[0] // 10
if not num_leaves:
num_leaves = int(np.sqrt(pool_size))
if not num_leaves_to_search:
num_leaves_to_search = max(num_leaves // 20, 1)
print('Partitioning params:')
print(f'num_leaves: {num_leaves}')
print(f'num_leaves_to_search: {num_leaves_to_search}')
print("Partitioning params:")
print(f"num_leaves: {num_leaves}")
print(f"num_leaves_to_search: {num_leaves_to_search}")
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search)
searcher = search_partioned_ah(
searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
)
print('Finish training searcher')
print("Finish training searcher")
searcher_savedir = opt.target_path
os.makedirs(searcher_savedir, exist_ok=True)
searcher.serialize(searcher_savedir)
print(f'Saved trained searcher under "{searcher_savedir}"')
if __name__ == '__main__':
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = argparse.ArgumentParser()
parser.add_argument('--database',
'-d',
default='data/rdm/retrieval_databases/openimages',
type=str,
help='path to folder containing the clip feature of the database')
parser.add_argument('--target_path',
'-t',
default='data/rdm/searchers/openimages',
type=str,
help='path to the target folder where the searcher shall be stored.')
parser.add_argument('--knn',
'-k',
default=20,
type=int,
help='number of nearest neighbors, for which the searcher shall be optimized')
parser.add_argument(
"--database",
"-d",
default="data/rdm/retrieval_databases/openimages",
type=str,
help="path to folder containing the clip feature of the database",
)
parser.add_argument(
"--target_path",
"-t",
default="data/rdm/searchers/openimages",
type=str,
help="path to the target folder where the searcher shall be stored.",
)
parser.add_argument(
"--knn",
"-k",
default=20,
type=int,
help="number of nearest neighbors, for which the searcher shall be optimized",
)
opt, _ = parser.parse_known_args()
opt, _ = parser.parse_known_args()
train_searcher(opt,)
train_searcher(
opt,
)

View File

@ -15,10 +15,11 @@ from contextlib import contextmanager, nullcontext
import k_diffusion as K
import torch.nn as nn
from ldm.util import instantiate_from_config
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.invoke.devices import choose_torch_device
from ldm.invoke.devices import choose_torch_device
def chunk(it, size):
it = iter(it)
@ -53,23 +54,19 @@ def main():
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
help="the prompt to render",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action='store_true',
action="store_true",
help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
@ -80,22 +77,22 @@ def main():
)
parser.add_argument(
"--plms",
action='store_true',
action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--klms",
action='store_true',
action="store_true",
help="use klms sampling",
)
parser.add_argument(
"--laion400m",
action='store_true',
action="store_true",
help="uses the LAION400M model",
)
parser.add_argument(
"--fixed_code",
action='store_true',
action="store_true",
help="if enabled, uses the same starting code across samples ",
)
parser.add_argument(
@ -176,11 +173,7 @@ def main():
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
)
opt = parser.parse_args()
@ -190,17 +183,17 @@ def main():
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
seed_everything(opt.seed)
device = torch.device(choose_torch_device())
model = model.to(device)
model = model.to(device)
#for klms
# for klms
model_wrap = K.external.CompVisDenoiser(model)
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
@ -232,10 +225,10 @@ def main():
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
if (len(data) >= batch_size):
if len(data) >= batch_size:
data = list(chunk(data, batch_size))
else:
while (len(data) < batch_size):
while len(data) < batch_size:
data.append(data[-1])
data = [data]
@ -247,14 +240,14 @@ def main():
start_code = None
if opt.fixed_code:
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
if device.type == 'mps':
start_code = torch.randn(shape, device='cpu').to(device)
if device.type == "mps":
start_code = torch.randn(shape, device="cpu").to(device)
else:
torch.randn(shape, device=device)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type in ["mps", "cpu"]:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
@ -271,23 +264,25 @@ def main():
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
if not opt.klms:
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
samples_ddim, _ = sampler.sample(
S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code,
)
else:
sigmas = model_wrap.get_sigmas(opt.ddim_steps)
if start_code:
x = start_code
else:
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
extra_args = {"cond": c, "uncond": uc, "cond_scale": opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
x_samples_ddim = model.decode_first_stage(samples_ddim)
@ -295,9 +290,10 @@ def main():
if not opt.skip_save:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png"))
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
if not opt.skip_grid:
@ -306,18 +302,17 @@ def main():
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
if __name__ == "__main__":

View File

@ -5,7 +5,8 @@ import requests
from invokeai.version import __version__
local_version = str(__version__).replace("-", "")
package_name = 'InvokeAI'
package_name = "InvokeAI"
def get_pypi_versions(package_name=package_name) -> list[str]:
"""Get the versions of the package from PyPI"""

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python
'''
"""
Scan the models directory and print out a new models.yaml
'''
"""
import os
import sys
@ -11,49 +11,51 @@ import argparse
from pathlib import Path
from omegaconf import OmegaConf
def main():
parser = argparse.ArgumentParser(description="Model directory scanner")
parser.add_argument('models_directory')
parser.add_argument('--all-models',
default=False,
action='store_true',
help='If true, then generates stanzas for all models; otherwise just diffusers'
)
parser.add_argument("models_directory")
parser.add_argument(
"--all-models",
default=False,
action="store_true",
help="If true, then generates stanzas for all models; otherwise just diffusers",
)
args = parser.parse_args()
directory = args.models_directory
conf = OmegaConf.create()
conf['_version'] = '3.0.0'
conf["_version"] = "3.0.0"
for root, dirs, files in os.walk(directory):
parents = root.split('/')
subpaths = parents[parents.index('models')+1:]
parents = root.split("/")
subpaths = parents[parents.index("models") + 1 :]
if len(subpaths) < 2:
continue
base, model_type, *_ = subpaths
if args.all_models or model_type=='diffusers':
if args.all_models or model_type == "diffusers":
for d in dirs:
conf[f'{base}/{model_type}/{d}'] = dict(
path = os.path.join(root,d),
description = f'{model_type} model {d}',
format = 'folder',
base = base,
conf[f"{base}/{model_type}/{d}"] = dict(
path=os.path.join(root, d),
description=f"{model_type} model {d}",
format="folder",
base=base,
)
for f in files:
basename = Path(f).stem
format = Path(f).suffix[1:]
conf[f'{base}/{model_type}/{basename}'] = dict(
path = os.path.join(root,f),
description = f'{model_type} model {basename}',
format = format,
base = base,
conf[f"{base}/{model_type}/{basename}"] = dict(
path=os.path.join(root, f),
description=f"{model_type} model {basename}",
format=format,
base=base,
)
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)
if __name__ == '__main__':
OmegaConf.save(config=dict(sorted(conf.items())), f=sys.stdout)
if __name__ == "__main__":
main()

View File

@ -13,10 +13,10 @@ filenames = sys.argv[1:]
for f in filenames:
try:
metadata = retrieve_metadata(f)
print(f'{f}:\n',json.dumps(metadata['sd-metadata'], indent=4))
print(f"{f}:\n", json.dumps(metadata["sd-metadata"], indent=4))
except FileNotFoundError:
sys.stderr.write(f'{f} not found\n')
sys.stderr.write(f"{f} not found\n")
continue
except PermissionError:
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
sys.stderr.write(f"{f} could not be opened due to inadequate permissions\n")
continue