mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
simplified instructions to preload Bert and kornia prerequisites; fixed --grid and --batch handling; added timing information after image generation
This commit is contained in:
parent
fab1ae8685
commit
a7532b386a
65
README.md
65
README.md
@ -57,57 +57,40 @@ weights (512x512) and the older (256x256) latent diffusion weights
|
||||
(laion400m). Within the script, the switches are (mostly) identical to
|
||||
those used in the Discord bot, except you don't need to type "!dream".
|
||||
|
||||
## No need for internet connectivity when loading the model
|
||||
## Workaround for machines with limited internet connectivity
|
||||
|
||||
My development machine is a GPU node in a high-performance compute
|
||||
cluster which has no connection to the internet. During model
|
||||
initialization, stable-diffusion tries to download the Bert tokenizer
|
||||
model from huggingface.co. This obviously didn't work for me.
|
||||
and a file needed by the kornia library. This obviously didn't work
|
||||
for me.
|
||||
|
||||
Rather than set up a hugging face local hub, I found the most
|
||||
expedient thing to do was to download the Bert tokenizer in advance
|
||||
from a machine that had internet access (in this case, the head node
|
||||
of the cluster), and patch stable-diffusion to read it from the local
|
||||
disk. After you have completed the conda environment creation and
|
||||
activation steps,the steps to preload the Bert model are:
|
||||
|
||||
~~~~
|
||||
(ldm) ~/stable-diffusion$ mkdir ./models/bert
|
||||
(ldm) ~/stable-diffusion$ python3
|
||||
>>> from transformers import BertTokenizerFast
|
||||
>>> model = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
>>> model.save_pretrained("./models/bert")
|
||||
~~~~
|
||||
|
||||
(Make sure you are in the stable-diffusion directory when you do
|
||||
this!)
|
||||
|
||||
If you don't like this change, just copy over the file
|
||||
ldm/modules/encoders/modules.py from the CompVis/stable-diffusion
|
||||
repository.
|
||||
|
||||
In addition, I have found that the Kornia library needs to do a
|
||||
one-time download of its own. On a non-internet connected system, you
|
||||
may see an error message like this one when running dream.py for the
|
||||
first time
|
||||
To work around this, I have modified ldm/modules/encoders/modules.py
|
||||
to look for locally cached Bert files rather than attempting to
|
||||
download them. For this to work, you must run
|
||||
"scripts/preload_models.py" once from an internet-connected machine
|
||||
prior to running the code on an isolated one. This assumes that both
|
||||
machines share a common network-mounted filesystem with a common
|
||||
.cache directory.
|
||||
|
||||
~~~~
|
||||
(ldm) ~/stable-diffusion$ python3 ./scripts/preload_models.py
|
||||
preloading bert tokenizer...
|
||||
Downloading: 100%|██████████████████████████████████| 28.0/28.0 [00:00<00:00, 49.3kB/s]
|
||||
Downloading: 100%|██████████████████████████████████| 226k/226k [00:00<00:00, 2.79MB/s]
|
||||
Downloading: 100%|██████████████████████████████████| 455k/455k [00:00<00:00, 4.36MB/s]
|
||||
Downloading: 100%|██████████████████████████████████| 570/570 [00:00<00:00, 477kB/s]
|
||||
...success
|
||||
preloading kornia requirements...
|
||||
Downloading: "https://github.com/DagnyT/hardnet/raw/master/pretrained/train_liberty_with_aug/checkpoint_liberty_with_aug.pth" to /u/lstein/.cache/torch/hub/checkpoints/checkpoint_liberty_with_aug.pth
|
||||
Traceback (most recent call last):
|
||||
File "/u/lstein/.conda/envs/ldm/lib/python3.8/urllib/request.py", line 1350, in do_open
|
||||
h.request(req.get_method(), req.selector, req.data, headers,
|
||||
File "/u/lstein/.conda/envs/ldm/lib/python3.8/http/client.py", line 1255, in request
|
||||
...
|
||||
100%|███████████████████████████████████████████████| 5.10M/5.10M [00:00<00:00, 101MB/s]
|
||||
...success
|
||||
~~~~
|
||||
|
||||
The fix is to log into an internet-connected machine and manually
|
||||
download the file into the required location. On my system, the incantation was:
|
||||
|
||||
~~~~
|
||||
(ldm) ~/stable-diffusion$ mkdir -p /u/lstein/.cache/torch/hub/checkpoints/
|
||||
(ldm) ~/stable-diffusion$ wget https://github.com/DagnyT/hardnet/raw/master/pretrained/train_liberty_with_aug/checkpoint_liberty_with_aug.pth \
|
||||
-O /u/lstein/.cache/torch/hub/checkpoints/checkpoint_liberty_with_aug.pth
|
||||
~~~~
|
||||
If you don't need this change and want to download the files just in
|
||||
time, copy over the file ldm/modules/encoders/modules.py from the
|
||||
CompVis/stable-diffusion repository. Or you can run preload_models.py
|
||||
on the target machine.
|
||||
|
||||
## Minor fixes
|
||||
|
||||
|
@ -55,10 +55,12 @@ class BERTTokenizer(AbstractEncoder):
|
||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||
super().__init__()
|
||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||
fn = 'models/bert'
|
||||
print(f'Loading Bert tokenizer from "{fn}"')
|
||||
# self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained(fn,local_files_only=True)
|
||||
# Modified to allow to run on non-internet connected compute nodes.
|
||||
# Model needs to be loaded into cache from an internet-connected machine
|
||||
# by running:
|
||||
# from transformers import BertTokenizerFast
|
||||
# BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased",local_files_only=True)
|
||||
self.device = device
|
||||
self.vq_interface = vq_interface
|
||||
self.max_length = max_length
|
||||
@ -235,4 +237,3 @@ if __name__ == "__main__":
|
||||
from ldm.util import count_params
|
||||
model = FrozenCLIPEmbedder()
|
||||
count_params(model, verbose=True)
|
||||
|
||||
|
@ -52,8 +52,8 @@ from torchvision.utils import make_grid
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from time import time
|
||||
from math import sqrt
|
||||
import time
|
||||
import math
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
@ -72,6 +72,7 @@ class T2I:
|
||||
seed
|
||||
sampler
|
||||
grid
|
||||
individual
|
||||
width
|
||||
height
|
||||
cfg_scale
|
||||
@ -84,9 +85,10 @@ class T2I:
|
||||
outdir="outputs/txt2img-samples",
|
||||
batch=1,
|
||||
iterations = 1,
|
||||
width=256, # change to 512 for stable diffusion
|
||||
height=256, # change to 512 for stable diffusion
|
||||
width=512,
|
||||
height=512,
|
||||
grid=False,
|
||||
individual=None, # redundant
|
||||
steps=50,
|
||||
seed=None,
|
||||
cfg_scale=7.5,
|
||||
@ -122,7 +124,7 @@ class T2I:
|
||||
else:
|
||||
self.seed = seed
|
||||
def txt2img(self,prompt,outdir=None,batch=None,iterations=None,
|
||||
steps=None,seed=None,grid=None,width=None,height=None,
|
||||
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
|
||||
cfg_scale=None,ddim_eta=None):
|
||||
""" generate an image from the prompt, writing iteration images into the outdir """
|
||||
outdir = outdir or self.outdir
|
||||
@ -134,13 +136,16 @@ class T2I:
|
||||
ddim_eta = ddim_eta or self.ddim_eta
|
||||
batch = batch or self.batch
|
||||
iterations = iterations or self.iterations
|
||||
if batch > 1:
|
||||
iterations = 1
|
||||
|
||||
model = self.load_model() # will instantiate the model or return it from cache
|
||||
|
||||
# grid and individual are mutually exclusive, with individual taking priority.
|
||||
# not necessary, but needed for compatability with dream bot
|
||||
if (grid is None):
|
||||
grid = self.grid
|
||||
if individual:
|
||||
grid = False
|
||||
|
||||
data = [batch * [prompt]]
|
||||
|
||||
# make directories and establish names for the output files
|
||||
@ -159,6 +164,8 @@ class T2I:
|
||||
sampler = self.sampler
|
||||
images = list()
|
||||
seeds = list()
|
||||
|
||||
tic = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
@ -171,7 +178,7 @@ class T2I:
|
||||
if cfg_scale != 1.0:
|
||||
uc = model.get_learned_conditioning(batch * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
prompts = list(prompts)
|
||||
c = model.get_learned_conditioning(prompts)
|
||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||
samples_ddim, _ = sampler.sample(S=steps,
|
||||
@ -187,20 +194,21 @@ class T2I:
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
for x_sample in x_samples_ddim:
|
||||
if grid:
|
||||
all_samples.append(x_samples_ddim)
|
||||
seeds.append(seed)
|
||||
else:
|
||||
if not grid:
|
||||
for x_sample in x_samples_ddim:
|
||||
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
|
||||
filename = os.path.join(outdir, f"{base_count:05}.png")
|
||||
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
|
||||
images.append([filename,seed])
|
||||
base_count += 1
|
||||
seed = self._new_seed()
|
||||
else:
|
||||
all_samples.append(x_samples_ddim)
|
||||
seeds.append(seed)
|
||||
|
||||
seed = self._new_seed()
|
||||
|
||||
if grid:
|
||||
n_rows = int(sqrt(batch * iterations))
|
||||
n_rows = batch if batch>1 else int(math.sqrt(batch * iterations))
|
||||
# save as grid
|
||||
grid = torch.stack(all_samples, 0)
|
||||
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
|
||||
@ -213,6 +221,9 @@ class T2I:
|
||||
for s in seeds:
|
||||
images.append([filename,s])
|
||||
|
||||
toc = time.time()
|
||||
print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic))
|
||||
|
||||
return images
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@ import atexit
|
||||
import os
|
||||
|
||||
def main():
|
||||
''' Initialize command-line parsers and the diffusion model '''
|
||||
arg_parser = create_argv_parser()
|
||||
opt = arg_parser.parse_args()
|
||||
if opt.laion400m:
|
||||
@ -59,6 +60,7 @@ def main():
|
||||
log.close()
|
||||
|
||||
def main_loop(t2i,parser,log):
|
||||
''' prompt/read/execute loop '''
|
||||
while True:
|
||||
try:
|
||||
command = input("dream> ")
|
||||
@ -86,13 +88,35 @@ def main_loop(t2i,parser,log):
|
||||
pass
|
||||
results = t2i.txt2img(**vars(opt))
|
||||
print("Outputs:")
|
||||
for r in results:
|
||||
log_message = " ".join([' ',str(r[0])+':',
|
||||
f'"{switches[0]}"',
|
||||
*switches[1:],f'-S {r[1]}'])
|
||||
print(log_message)
|
||||
log.write(log_message+"\n")
|
||||
log.flush()
|
||||
write_log_message(opt,switches,results,log)
|
||||
|
||||
def write_log_message(opt,switches,results,logfile):
|
||||
''' logs the name of the output image, its prompt and seed to both the terminal and the log file '''
|
||||
if opt.grid:
|
||||
_output_for_grid(switches,results,logfile)
|
||||
else:
|
||||
_output_for_individual(switches,results,logfile)
|
||||
|
||||
def _output_for_individual(switches,results,logfile):
|
||||
for r in results:
|
||||
log_message = " ".join([' ',str(r[0])+':',
|
||||
f'"{switches[0]}"',
|
||||
*switches[1:],f'-S {r[1]}'])
|
||||
print(log_message)
|
||||
logfile.write(log_message+"\n")
|
||||
logfile.flush()
|
||||
|
||||
def _output_for_grid(switches,results,logfile):
|
||||
first_seed = results[0][1]
|
||||
log_message = " ".join([' ',str(results[0][0])+':',
|
||||
f'"{switches[0]}"',
|
||||
*switches[1:],f'-S {results[0][1]}'])
|
||||
print(log_message)
|
||||
logfile.write(log_message+"\n")
|
||||
all_seeds = [row[1] for row in results]
|
||||
log_message = f' seeds for individual rows: {all_seeds}'
|
||||
print(log_message)
|
||||
logfile.write(log_message+"\n")
|
||||
|
||||
def create_argv_parser():
|
||||
parser = argparse.ArgumentParser(description="Parse script's command line args")
|
||||
@ -133,6 +157,7 @@ def create_cmd_parser():
|
||||
parser.add_argument('-H','--height',type=int,help="image height, multiple of 64")
|
||||
parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)")
|
||||
parser.add_argument('-g','--grid',action='store_true',help="generate a grid")
|
||||
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
|
||||
return parser
|
||||
|
||||
def load_history():
|
||||
|
17
scripts/preload_models.py
Normal file
17
scripts/preload_models.py
Normal file
@ -0,0 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Before running stable-diffusion on an internet-isolated machine,
|
||||
# run this script from one with internet connectivity. The
|
||||
# two machines must share a common .cache directory.
|
||||
|
||||
# this will preload the Bert tokenizer fles
|
||||
print("preloading bert tokenizer...",end='')
|
||||
from transformers import BertTokenizerFast
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
print("...success")
|
||||
|
||||
# this will download requirements for Kornia
|
||||
print("preloading Kornia requirements...",end='')
|
||||
import kornia
|
||||
print("...success")
|
||||
|
Loading…
Reference in New Issue
Block a user