mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
simplified instructions to preload Bert and kornia prerequisites; fixed --grid and --batch handling; added timing information after image generation
This commit is contained in:
parent
fab1ae8685
commit
a7532b386a
65
README.md
65
README.md
@ -57,57 +57,40 @@ weights (512x512) and the older (256x256) latent diffusion weights
|
|||||||
(laion400m). Within the script, the switches are (mostly) identical to
|
(laion400m). Within the script, the switches are (mostly) identical to
|
||||||
those used in the Discord bot, except you don't need to type "!dream".
|
those used in the Discord bot, except you don't need to type "!dream".
|
||||||
|
|
||||||
## No need for internet connectivity when loading the model
|
## Workaround for machines with limited internet connectivity
|
||||||
|
|
||||||
My development machine is a GPU node in a high-performance compute
|
My development machine is a GPU node in a high-performance compute
|
||||||
cluster which has no connection to the internet. During model
|
cluster which has no connection to the internet. During model
|
||||||
initialization, stable-diffusion tries to download the Bert tokenizer
|
initialization, stable-diffusion tries to download the Bert tokenizer
|
||||||
model from huggingface.co. This obviously didn't work for me.
|
and a file needed by the kornia library. This obviously didn't work
|
||||||
|
for me.
|
||||||
|
|
||||||
Rather than set up a hugging face local hub, I found the most
|
To work around this, I have modified ldm/modules/encoders/modules.py
|
||||||
expedient thing to do was to download the Bert tokenizer in advance
|
to look for locally cached Bert files rather than attempting to
|
||||||
from a machine that had internet access (in this case, the head node
|
download them. For this to work, you must run
|
||||||
of the cluster), and patch stable-diffusion to read it from the local
|
"scripts/preload_models.py" once from an internet-connected machine
|
||||||
disk. After you have completed the conda environment creation and
|
prior to running the code on an isolated one. This assumes that both
|
||||||
activation steps,the steps to preload the Bert model are:
|
machines share a common network-mounted filesystem with a common
|
||||||
|
.cache directory.
|
||||||
~~~~
|
|
||||||
(ldm) ~/stable-diffusion$ mkdir ./models/bert
|
|
||||||
(ldm) ~/stable-diffusion$ python3
|
|
||||||
>>> from transformers import BertTokenizerFast
|
|
||||||
>>> model = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
|
||||||
>>> model.save_pretrained("./models/bert")
|
|
||||||
~~~~
|
|
||||||
|
|
||||||
(Make sure you are in the stable-diffusion directory when you do
|
|
||||||
this!)
|
|
||||||
|
|
||||||
If you don't like this change, just copy over the file
|
|
||||||
ldm/modules/encoders/modules.py from the CompVis/stable-diffusion
|
|
||||||
repository.
|
|
||||||
|
|
||||||
In addition, I have found that the Kornia library needs to do a
|
|
||||||
one-time download of its own. On a non-internet connected system, you
|
|
||||||
may see an error message like this one when running dream.py for the
|
|
||||||
first time
|
|
||||||
|
|
||||||
~~~~
|
~~~~
|
||||||
|
(ldm) ~/stable-diffusion$ python3 ./scripts/preload_models.py
|
||||||
|
preloading bert tokenizer...
|
||||||
|
Downloading: 100%|██████████████████████████████████| 28.0/28.0 [00:00<00:00, 49.3kB/s]
|
||||||
|
Downloading: 100%|██████████████████████████████████| 226k/226k [00:00<00:00, 2.79MB/s]
|
||||||
|
Downloading: 100%|██████████████████████████████████| 455k/455k [00:00<00:00, 4.36MB/s]
|
||||||
|
Downloading: 100%|██████████████████████████████████| 570/570 [00:00<00:00, 477kB/s]
|
||||||
|
...success
|
||||||
|
preloading kornia requirements...
|
||||||
Downloading: "https://github.com/DagnyT/hardnet/raw/master/pretrained/train_liberty_with_aug/checkpoint_liberty_with_aug.pth" to /u/lstein/.cache/torch/hub/checkpoints/checkpoint_liberty_with_aug.pth
|
Downloading: "https://github.com/DagnyT/hardnet/raw/master/pretrained/train_liberty_with_aug/checkpoint_liberty_with_aug.pth" to /u/lstein/.cache/torch/hub/checkpoints/checkpoint_liberty_with_aug.pth
|
||||||
Traceback (most recent call last):
|
100%|███████████████████████████████████████████████| 5.10M/5.10M [00:00<00:00, 101MB/s]
|
||||||
File "/u/lstein/.conda/envs/ldm/lib/python3.8/urllib/request.py", line 1350, in do_open
|
...success
|
||||||
h.request(req.get_method(), req.selector, req.data, headers,
|
|
||||||
File "/u/lstein/.conda/envs/ldm/lib/python3.8/http/client.py", line 1255, in request
|
|
||||||
...
|
|
||||||
~~~~
|
~~~~
|
||||||
|
|
||||||
The fix is to log into an internet-connected machine and manually
|
If you don't need this change and want to download the files just in
|
||||||
download the file into the required location. On my system, the incantation was:
|
time, copy over the file ldm/modules/encoders/modules.py from the
|
||||||
|
CompVis/stable-diffusion repository. Or you can run preload_models.py
|
||||||
~~~~
|
on the target machine.
|
||||||
(ldm) ~/stable-diffusion$ mkdir -p /u/lstein/.cache/torch/hub/checkpoints/
|
|
||||||
(ldm) ~/stable-diffusion$ wget https://github.com/DagnyT/hardnet/raw/master/pretrained/train_liberty_with_aug/checkpoint_liberty_with_aug.pth \
|
|
||||||
-O /u/lstein/.cache/torch/hub/checkpoints/checkpoint_liberty_with_aug.pth
|
|
||||||
~~~~
|
|
||||||
|
|
||||||
## Minor fixes
|
## Minor fixes
|
||||||
|
|
||||||
|
@ -55,10 +55,12 @@ class BERTTokenizer(AbstractEncoder):
|
|||||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||||
fn = 'models/bert'
|
# Modified to allow to run on non-internet connected compute nodes.
|
||||||
print(f'Loading Bert tokenizer from "{fn}"')
|
# Model needs to be loaded into cache from an internet-connected machine
|
||||||
# self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
# by running:
|
||||||
self.tokenizer = BertTokenizerFast.from_pretrained(fn,local_files_only=True)
|
# from transformers import BertTokenizerFast
|
||||||
|
# BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
|
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.vq_interface = vq_interface
|
self.vq_interface = vq_interface
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
@ -235,4 +237,3 @@ if __name__ == "__main__":
|
|||||||
from ldm.util import count_params
|
from ldm.util import count_params
|
||||||
model = FrozenCLIPEmbedder()
|
model = FrozenCLIPEmbedder()
|
||||||
count_params(model, verbose=True)
|
count_params(model, verbose=True)
|
||||||
|
|
||||||
|
@ -52,8 +52,8 @@ from torchvision.utils import make_grid
|
|||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from torch import autocast
|
from torch import autocast
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from time import time
|
import time
|
||||||
from math import sqrt
|
import math
|
||||||
|
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
@ -72,6 +72,7 @@ class T2I:
|
|||||||
seed
|
seed
|
||||||
sampler
|
sampler
|
||||||
grid
|
grid
|
||||||
|
individual
|
||||||
width
|
width
|
||||||
height
|
height
|
||||||
cfg_scale
|
cfg_scale
|
||||||
@ -84,9 +85,10 @@ class T2I:
|
|||||||
outdir="outputs/txt2img-samples",
|
outdir="outputs/txt2img-samples",
|
||||||
batch=1,
|
batch=1,
|
||||||
iterations = 1,
|
iterations = 1,
|
||||||
width=256, # change to 512 for stable diffusion
|
width=512,
|
||||||
height=256, # change to 512 for stable diffusion
|
height=512,
|
||||||
grid=False,
|
grid=False,
|
||||||
|
individual=None, # redundant
|
||||||
steps=50,
|
steps=50,
|
||||||
seed=None,
|
seed=None,
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
@ -122,7 +124,7 @@ class T2I:
|
|||||||
else:
|
else:
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
|
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
|
||||||
steps=None,seed=None,grid=None,width=None,height=None,
|
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||||
cfg_scale=None,ddim_eta=None):
|
cfg_scale=None,ddim_eta=None):
|
||||||
""" generate an image from the prompt, writing iteration images into the outdir """
|
""" generate an image from the prompt, writing iteration images into the outdir """
|
||||||
outdir = outdir or self.outdir
|
outdir = outdir or self.outdir
|
||||||
@ -134,13 +136,16 @@ class T2I:
|
|||||||
ddim_eta = ddim_eta or self.ddim_eta
|
ddim_eta = ddim_eta or self.ddim_eta
|
||||||
batch = batch or self.batch
|
batch = batch or self.batch
|
||||||
iterations = iterations or self.iterations
|
iterations = iterations or self.iterations
|
||||||
if batch > 1:
|
|
||||||
iterations = 1
|
|
||||||
|
|
||||||
model = self.load_model() # will instantiate the model or return it from cache
|
model = self.load_model() # will instantiate the model or return it from cache
|
||||||
|
|
||||||
|
# grid and individual are mutually exclusive, with individual taking priority.
|
||||||
|
# not necessary, but needed for compatability with dream bot
|
||||||
if (grid is None):
|
if (grid is None):
|
||||||
grid = self.grid
|
grid = self.grid
|
||||||
|
if individual:
|
||||||
|
grid = False
|
||||||
|
|
||||||
data = [batch * [prompt]]
|
data = [batch * [prompt]]
|
||||||
|
|
||||||
# make directories and establish names for the output files
|
# make directories and establish names for the output files
|
||||||
@ -159,6 +164,8 @@ class T2I:
|
|||||||
sampler = self.sampler
|
sampler = self.sampler
|
||||||
images = list()
|
images = list()
|
||||||
seeds = list()
|
seeds = list()
|
||||||
|
|
||||||
|
tic = time.time()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
@ -171,7 +178,7 @@ class T2I:
|
|||||||
if cfg_scale != 1.0:
|
if cfg_scale != 1.0:
|
||||||
uc = model.get_learned_conditioning(batch * [""])
|
uc = model.get_learned_conditioning(batch * [""])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
c = model.get_learned_conditioning(prompts)
|
c = model.get_learned_conditioning(prompts)
|
||||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||||
samples_ddim, _ = sampler.sample(S=steps,
|
samples_ddim, _ = sampler.sample(S=steps,
|
||||||
@ -187,20 +194,21 @@ class T2I:
|
|||||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
|
||||||
for x_sample in x_samples_ddim:
|
if not grid:
|
||||||
if grid:
|
for x_sample in x_samples_ddim:
|
||||||
all_samples.append(x_samples_ddim)
|
|
||||||
seeds.append(seed)
|
|
||||||
else:
|
|
||||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||||
filename = os.path.join(outdir, f"{base_count:05}.png")
|
filename = os.path.join(outdir, f"{base_count:05}.png")
|
||||||
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
||||||
images.append([filename,seed])
|
images.append([filename,seed])
|
||||||
base_count += 1
|
base_count += 1
|
||||||
seed = self._new_seed()
|
else:
|
||||||
|
all_samples.append(x_samples_ddim)
|
||||||
|
seeds.append(seed)
|
||||||
|
|
||||||
|
seed = self._new_seed()
|
||||||
|
|
||||||
if grid:
|
if grid:
|
||||||
n_rows = int(sqrt(batch * iterations))
|
n_rows = batch if batch>1 else int(math.sqrt(batch * iterations))
|
||||||
# save as grid
|
# save as grid
|
||||||
grid = torch.stack(all_samples, 0)
|
grid = torch.stack(all_samples, 0)
|
||||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||||
@ -213,6 +221,9 @@ class T2I:
|
|||||||
for s in seeds:
|
for s in seeds:
|
||||||
images.append([filename,s])
|
images.append([filename,s])
|
||||||
|
|
||||||
|
toc = time.time()
|
||||||
|
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import atexit
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
''' Initialize command-line parsers and the diffusion model '''
|
||||||
arg_parser = create_argv_parser()
|
arg_parser = create_argv_parser()
|
||||||
opt = arg_parser.parse_args()
|
opt = arg_parser.parse_args()
|
||||||
if opt.laion400m:
|
if opt.laion400m:
|
||||||
@ -59,6 +60,7 @@ def main():
|
|||||||
log.close()
|
log.close()
|
||||||
|
|
||||||
def main_loop(t2i,parser,log):
|
def main_loop(t2i,parser,log):
|
||||||
|
''' prompt/read/execute loop '''
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
command = input("dream> ")
|
command = input("dream> ")
|
||||||
@ -86,13 +88,35 @@ def main_loop(t2i,parser,log):
|
|||||||
pass
|
pass
|
||||||
results = t2i.txt2img(**vars(opt))
|
results = t2i.txt2img(**vars(opt))
|
||||||
print("Outputs:")
|
print("Outputs:")
|
||||||
for r in results:
|
write_log_message(opt,switches,results,log)
|
||||||
log_message = " ".join([' ',str(r[0])+':',
|
|
||||||
f'"{switches[0]}"',
|
def write_log_message(opt,switches,results,logfile):
|
||||||
*switches[1:],f'-S {r[1]}'])
|
''' logs the name of the output image, its prompt and seed to both the terminal and the log file '''
|
||||||
print(log_message)
|
if opt.grid:
|
||||||
log.write(log_message+"\n")
|
_output_for_grid(switches,results,logfile)
|
||||||
log.flush()
|
else:
|
||||||
|
_output_for_individual(switches,results,logfile)
|
||||||
|
|
||||||
|
def _output_for_individual(switches,results,logfile):
|
||||||
|
for r in results:
|
||||||
|
log_message = " ".join([' ',str(r[0])+':',
|
||||||
|
f'"{switches[0]}"',
|
||||||
|
*switches[1:],f'-S {r[1]}'])
|
||||||
|
print(log_message)
|
||||||
|
logfile.write(log_message+"\n")
|
||||||
|
logfile.flush()
|
||||||
|
|
||||||
|
def _output_for_grid(switches,results,logfile):
|
||||||
|
first_seed = results[0][1]
|
||||||
|
log_message = " ".join([' ',str(results[0][0])+':',
|
||||||
|
f'"{switches[0]}"',
|
||||||
|
*switches[1:],f'-S {results[0][1]}'])
|
||||||
|
print(log_message)
|
||||||
|
logfile.write(log_message+"\n")
|
||||||
|
all_seeds = [row[1] for row in results]
|
||||||
|
log_message = f' seeds for individual rows: {all_seeds}'
|
||||||
|
print(log_message)
|
||||||
|
logfile.write(log_message+"\n")
|
||||||
|
|
||||||
def create_argv_parser():
|
def create_argv_parser():
|
||||||
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
||||||
@ -133,6 +157,7 @@ def create_cmd_parser():
|
|||||||
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
||||||
parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)")
|
parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)")
|
||||||
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
|
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
|
||||||
|
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def load_history():
|
def load_history():
|
||||||
|
17
scripts/preload_models.py
Normal file
17
scripts/preload_models.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Before running stable-diffusion on an internet-isolated machine,
|
||||||
|
# run this script from one with internet connectivity. The
|
||||||
|
# two machines must share a common .cache directory.
|
||||||
|
|
||||||
|
# this will preload the Bert tokenizer fles
|
||||||
|
print("preloading bert tokenizer...",end='')
|
||||||
|
from transformers import BertTokenizerFast
|
||||||
|
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||||
|
print("...success")
|
||||||
|
|
||||||
|
# this will download requirements for Kornia
|
||||||
|
print("preloading Kornia requirements...",end='')
|
||||||
|
import kornia
|
||||||
|
print("...success")
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user