Merge branch 'inpaint-model' of github.com:invoke-ai/InvokeAI into inpaint-model

This commit is contained in:
Lincoln Stein 2022-10-25 13:17:20 -04:00
commit c732fd0740
24 changed files with 902 additions and 796 deletions

BIN
assets/caution.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

View File

@ -10,7 +10,6 @@ stable-diffusion-1.4:
weights: models/ldm/stable-diffusion-v1/model.ckpt weights: models/ldm/stable-diffusion-v1/model.ckpt
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt # vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: Stable Diffusion inference model version 1.4 description: Stable Diffusion inference model version 1.4
default: true
width: 512 width: 512
height: 512 height: 512
inpainting-1.5: inpainting-1.5:
@ -20,6 +19,7 @@ inpainting-1.5:
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt # vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512 width: 512
height: 512 height: 512
default: true
stable-diffusion-1.5: stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt

View File

@ -86,6 +86,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" | | `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. | | `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) | | `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| `--safety-checker` | | False | Activate safety checker for NSFW and other potentially disturbing imagery |
| `--web` | | `False` | Start in web server mode | | `--web` | | `False` | Start in web server mode |
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. | | `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
| `--port <port>` | | `9090` | Which port web server should listen for requests on. | | `--port <port>` | | `9090` | Which port web server should listen for requests on. |
@ -97,7 +98,6 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models | | `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. | | `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. | | `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
| `--device <device>` | `-d<device>` | `torch.cuda.current_device()` | Device to run SD on, e.g. "cuda:0" |
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions | | `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast | | `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |

View File

@ -75,6 +75,23 @@ combination of integers and floating point numbers, and they do not need to add
--- ---
## **Filename Format**
The argument `--fnformat` allows to specify the filename of the
image. Supported wildcards are all arguments what can be set such as
`perlin`, `seed`, `threshold`, `height`, `width`, `gfpgan_strength`,
`sampler_name`, `steps`, `model`, `upscale`, `prompt`, `cfg_scale`,
`prefix`.
The following prompt
```bash
dream> a red car --steps 25 -C 9.8 --perlin 0.1 --fnformat {prompt}_steps.{steps}_cfg.{cfg_scale}_perlin.{perlin}.png
```
generates a file with the name: `outputs/img-samples/a red car_steps.25_cfg.9.8_perlin.0.1.png`
---
## **Thresholding and Perlin Noise Initialization Options** ## **Thresholding and Perlin Noise Initialization Options**
Two new options are the thresholding (`--threshold`) and the perlin noise initialization (`--perlin`) options. Thresholding limits the range of the latent values during optimization, which helps combat oversaturation with higher CFG scale values. Perlin noise initialization starts with a percentage (a value ranging from 0 to 1) of perlin noise mixed into the initial noise. Both features allow for more variations and options in the course of generating images. Two new options are the thresholding (`--threshold`) and the perlin noise initialization (`--perlin`) options. Thresholding limits the range of the latent values during optimization, which helps combat oversaturation with higher CFG scale values. Perlin noise initialization starts with a percentage (a value ranging from 0 to 1) of perlin noise mixed into the initial noise. Both features allow for more variations and options in the course of generating images.

View File

@ -19,6 +19,7 @@ dependencies:
# ``` # ```
- albumentations==1.2.1 - albumentations==1.2.1
- coloredlogs==15.0.1 - coloredlogs==15.0.1
- diffusers==0.6.0
- einops==0.4.1 - einops==0.4.1
- grpcio==1.46.4 - grpcio==1.46.4
- humanfriendly==10.0 - humanfriendly==10.0

View File

@ -26,6 +26,7 @@ dependencies:
- pyreadline3 - pyreadline3
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0
- transformers==4.21.3 - transformers==4.21.3
- diffusers==0.6.0
- torchmetrics==0.7.0 - torchmetrics==0.7.0
- flask==2.1.3 - flask==2.1.3
- flask_socketio==5.3.0 - flask_socketio==5.3.0

483
frontend/dist/assets/index.0a6593a2.js vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -5,9 +5,9 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>InvokeAI - A Stable Diffusion Toolkit</title> <title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="/assets/favicon.0d253ced.ico" /> <link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" />
<script type="module" crossorigin src="/assets/index.2d646c45.js"></script> <script type="module" crossorigin src="./assets/index.0a6593a2.js"></script>
<link rel="stylesheet" href="/assets/index.7749e179.css"> <link rel="stylesheet" href="./assets/index.193aec6f.css">
</head> </head>
<body> <body>

View File

@ -26,6 +26,7 @@ export const socketioMiddleware = () => {
const socketio = io(origin, { const socketio = io(origin, {
timeout: 60000, timeout: 60000,
path: window.location.pathname + 'socket.io',
}); });
let areListenersSet = false; let areListenersSet = false;

View File

@ -68,7 +68,6 @@ const PromptInput = () => {
<div className="prompt-bar"> <div className="prompt-bar">
<FormControl <FormControl
isInvalid={prompt.length === 0 || Boolean(prompt.match(/^[\s\r\n]+$/))} isInvalid={prompt.length === 0 || Boolean(prompt.match(/^[\s\r\n]+$/))}
isDisabled={isProcessing}
> >
<Textarea <Textarea
id="prompt" id="prompt"

View File

@ -5,6 +5,7 @@ import eslint from 'vite-plugin-eslint';
// https://vitejs.dev/config/ // https://vitejs.dev/config/
export default defineConfig(({ mode }) => { export default defineConfig(({ mode }) => {
const common = { const common = {
base: '',
plugins: [react(), eslint()], plugins: [react(), eslint()],
server: { server: {
// Proxy HTTP requests to the flask server // Proxy HTTP requests to the flask server

View File

@ -140,13 +140,14 @@ class Generate:
ddim_eta = 0.0, # deterministic ddim_eta = 0.0, # deterministic
full_precision = False, full_precision = False,
precision = 'auto', precision = 'auto',
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
gfpgan=None, gfpgan=None,
codeformer=None, codeformer=None,
esrgan=None, esrgan=None,
free_gpu_mem=False, free_gpu_mem=False,
safety_checker:bool=False,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
): ):
mconfig = OmegaConf.load(conf) mconfig = OmegaConf.load(conf)
self.height = None self.height = None
@ -177,6 +178,7 @@ class Generate:
self.free_gpu_mem = free_gpu_mem self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None self.txt2mask = None
self.safety_checker = None
# Note that in previous versions, there was an option to pass the # Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so # device to Generate(). However the device was then ignored, so
@ -204,6 +206,19 @@ class Generate:
# gets rid of annoying messages about random seed # gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load safety checker if requested
if safety_checker:
try:
print('>> Initializing safety checker')
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())
def prompt2png(self, prompt, outdir, **kwargs): def prompt2png(self, prompt, outdir, **kwargs):
""" """
Takes a prompt and an output directory, writes out the requested number Takes a prompt and an output directory, writes out the requested number
@ -277,6 +292,7 @@ class Generate:
# Set this True to handle KeyboardInterrupt internally # Set this True to handle KeyboardInterrupt internally
catch_interrupts = False, catch_interrupts = False,
hires_fix = False, hires_fix = False,
use_mps_noise = False,
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
@ -421,6 +437,12 @@ class Generate:
generator.set_variation( generator.set_variation(
self.seed, variation_amount, with_variations self.seed, variation_amount, with_variations
) )
generator.use_mps_noise = use_mps_noise
checker = {
'checker':self.safety_checker,
'extractor':self.safety_feature_extractor
} if self.safety_checker else None
results = generator.generate( results = generator.generate(
prompt, prompt,
@ -444,7 +466,8 @@ class Generate:
embiggen=embiggen, embiggen=embiggen,
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace, inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius mask_blur_radius=mask_blur_radius,
safety_checker=checker
) )
if init_color: if init_color:

View File

@ -217,6 +217,7 @@ class Args(object):
switches.append(f'-W {a["width"]}') switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}') switches.append(f'-H {a["height"]}')
switches.append(f'-C {a["cfg_scale"]}') switches.append(f'-C {a["cfg_scale"]}')
switches.append(f'--fnformat {a["fnformat"]}')
if a['perlin'] > 0: if a['perlin'] > 0:
switches.append(f'--perlin {a["perlin"]}') switches.append(f'--perlin {a["perlin"]}')
if a['threshold'] > 0: if a['threshold'] > 0:
@ -419,6 +420,11 @@ class Args(object):
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}', help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto', default='auto',
) )
model_group.add_argument(
'--safety_checker',
action='store_true',
help='Check for and blur potentially NSFW images',
)
file_group.add_argument( file_group.add_argument(
'--from_file', '--from_file',
dest='infile', dest='infile',
@ -438,6 +444,12 @@ class Args(object):
action='store_true', action='store_true',
help='Place images in subdirectories named after the prompt.', help='Place images in subdirectories named after the prompt.',
) )
render_group.add_argument(
'--fnformat',
default='{prefix}.{seed}.png',
type=str,
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
)
render_group.add_argument( render_group.add_argument(
'--grid', '--grid',
'-g', '-g',
@ -611,6 +623,12 @@ class Args(object):
type=float, type=float,
help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.', help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.',
) )
render_group.add_argument(
'--fnformat',
default='{prefix}.{seed}.png',
type=str,
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
)
render_group.add_argument( render_group.add_argument(
'--grid', '--grid',
'-g', '-g',
@ -811,6 +829,13 @@ class Args(object):
type=str, type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...' help='list of variations to apply, in the format `seed:weight,seed:weight,...'
) )
render_group.add_argument(
'--use_mps_noise',
action='store_true',
dest='use_mps_noise',
help='Simulate noise on M1 systems to get the same results'
)
return parser return parser
def format_metadata(**kwargs): def format_metadata(**kwargs):
@ -846,9 +871,8 @@ def metadata_dumps(opt,
# remove any image keys not mentioned in RFC #266 # remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps', rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength', 'cfg_scale','threshold','perlin','fnformat', 'step_number','width','height','extra','strength',
'init_img','init_mask','facetool','facetool_strength','upscale'] 'init_img','init_mask','facetool','facetool_strength','upscale']
rfc_dict ={} rfc_dict ={}
for item in image_dict.items(): for item in image_dict.items():

View File

@ -7,13 +7,14 @@ import numpy as np
import random import random
import os import os
from tqdm import tqdm, trange from tqdm import tqdm, trange
from PIL import Image from PIL import Image, ImageFilter
from einops import rearrange, repeat from einops import rearrange, repeat
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.util import rand_perlin_2d from ldm.util import rand_perlin_2d
downsampling = 8 downsampling = 8
CAUTION_IMG = 'assets/caution.png'
class Generator(): class Generator():
def __init__(self, model, precision): def __init__(self, model, precision):
@ -22,10 +23,12 @@ class Generator():
self.seed = None self.seed = None
self.latent_channels = model.channels self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None
self.perlin = 0.0 self.perlin = 0.0
self.threshold = 0 self.threshold = 0
self.variation_amount = 0 self.variation_amount = 0
self.with_variations = [] self.with_variations = []
self.use_mps_noise = False
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py # this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs): def get_make_image(self,prompt,**kwargs):
@ -42,8 +45,10 @@ class Generator():
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
sampler = sampler, sampler = sampler,
@ -79,10 +84,17 @@ class Generator():
except: except:
pass pass
image = make_image(x_T) image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
image_callback(image, seed, first_seed=first_seed) image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed() seed = self.new_seed()
return results return results
def sample_to_image(self,samples)->Image.Image: def sample_to_image(self,samples)->Image.Image:
@ -171,6 +183,39 @@ class Generator():
return v2 return v2
def safety_check(self,image:Image.Image):
'''
If the CompViz safety checker flags an NSFW image, we
blur it out.
'''
import diffusers
checker = self.safety_checker['checker']
extractor = self.safety_checker['extractor']
features = extractor([image], return_tensors="pt")
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
if has_nsfw_concept[0]:
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
return self.blur(image)
else:
return image
def blur(self,input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = Image.open(CAUTION_IMG)
caution = caution.resize((caution.width // 2, caution.height //2))
blurry.paste(caution,(0,0),caution)
except FileNotFoundError:
pass
return blurry
# this is a handy routine for debugging use. Given a generated sample, # this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path # convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath): def save_sample(self, sample, filepath):

View File

@ -59,7 +59,7 @@ class Txt2Img(Generator):
# returns a tensor filled with random numbers from a normal distribution # returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height): def get_noise(self,width,height):
device = self.model.device device = self.model.device
if device.type == 'mps': if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1, x = torch.randn([1,
self.latent_channels, self.latent_channels,
height // self.downsampling_factor, height // self.downsampling_factor,

View File

@ -118,7 +118,7 @@ class Txt2Img2Img(Generator):
scaled_height = height scaled_height = height
device = self.model.device device = self.model.device
if device.type == 'mps': if self.use_mps_noise or device.type == 'mps':
return torch.randn([1, return torch.randn([1,
self.latent_channels, self.latent_channels,
scaled_height // self.downsampling_factor, scaled_height // self.downsampling_factor,

View File

@ -1,5 +1,6 @@
albumentations==0.4.3 albumentations==0.4.3
einops==0.3.0 einops==0.3.0
diffusers==0.6.0
huggingface-hub==0.8.1 huggingface-hub==0.8.1
imageio==2.9.0 imageio==2.9.0
imageio-ffmpeg==0.4.2 imageio-ffmpeg==0.4.2

View File

@ -32,6 +32,7 @@ send2trash
dependency_injector==4.40.0 dependency_injector==4.40.0
eventlet eventlet
realesrgan realesrgan
diffusers
git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan

View File

@ -79,6 +79,7 @@ def main():
codeformer=codeformer, codeformer=codeformer,
esrgan=esrgan, esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem, free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker,
) )
except (FileNotFoundError, IOError, KeyError) as e: except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.') print(f'{e}. Aborting.')
@ -673,6 +674,16 @@ def prepare_image_metadata(
if postprocessed and opt.save_original: if postprocessed and opt.save_original:
filename = choose_postprocess_name(opt,prefix,seed) filename = choose_postprocess_name(opt,prefix,seed)
else: else:
wildcards = dict(opt.__dict__)
wildcards['prefix'] = prefix
wildcards['seed'] = seed
try:
filename = opt.fnformat.format(**wildcards)
except KeyError as e:
print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use \'{{prefix}}.{{seed}}.png\' instead')
filename = f'{prefix}.{seed}.png'
except IndexError as e:
print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead')
filename = f'{prefix}.{seed}.png' filename = f'{prefix}.{seed}.png'
if opt.variation_amount > 0: if opt.variation_amount > 0:

View File

@ -5,7 +5,7 @@
# two machines must share a common .cache directory. # two machines must share a common .cache directory.
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
import clip import clip
from transformers import BertTokenizerFast from transformers import BertTokenizerFast, AutoFeatureExtractor
import sys import sys
import transformers import transformers
import os import os
@ -17,21 +17,27 @@ import traceback
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
#---------------------------------------------
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
print('Loading bert tokenizer (ignore deprecation errors)...', end='') def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success') print('...success')
sys.stdout.flush() sys.stdout.flush()
#---------------------------------------------
# this will download requirements for Kornia # this will download requirements for Kornia
print('Loading Kornia requirements...', end='') def download_kornia():
print('Installing Kornia requirements...', end='')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia import kornia
print('...success') print('...success')
#---------------------------------------------
def download_clip():
version = 'openai/clip-vit-large-patch14' version = 'openai/clip-vit-large-patch14'
sys.stdout.flush() sys.stdout.flush()
print('Loading CLIP model...',end='') print('Loading CLIP model...',end='')
@ -39,19 +45,11 @@ tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version) transformer = CLIPTextModel.from_pretrained(version)
print('...success') print('...success')
# In the event that the user has installed GFPGAN and also elected to use #---------------------------------------------
# RealESRGAN, this will attempt to download the model needed by RealESRGANer def download_gfpgan():
gfpgan = False print('Installing models from RealESRGAN and facexlib...',end='')
try: try:
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
gfpgan = True
except ModuleNotFoundError:
pass
if gfpgan:
print('Loading models from RealESRGAN and facexlib...',end='')
try:
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@ -94,7 +92,9 @@ if gfpgan:
print('Error loading GFPGAN:') print('Error loading GFPGAN:')
print(traceback.format_exc()) print(traceback.format_exc())
print('preloading CodeFormer model file...',end='') #---------------------------------------------
def download_codeformer():
print('Installing CodeFormer model file...',end='')
try: try:
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth' model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
@ -107,7 +107,9 @@ except Exception:
print(traceback.format_exc()) print(traceback.format_exc())
print('...success') print('...success')
print('Loading clipseg model for text-based masking...',end='') #---------------------------------------------
def download_clipseg():
print('Installing clipseg model for text-based masking...',end='')
try: try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download' model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip' model_dest = 'src/clipseg/clipseg_weights.zip'
@ -134,4 +136,28 @@ except Exception:
print(traceback.format_exc()) print(traceback.format_exc())
print('...success') print('...success')
#-------------------------------------
def download_safety_checker():
print('Installing safety model for NSFW content detection...',end='')
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
except ModuleNotFoundError:
print('Error installing safety checker model:')
print(traceback.format_exc())
return
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
print('...success')
#-------------------------------------
if __name__ == '__main__':
download_bert()
download_kornia()
download_clip()
download_gfpgan()
download_codeformer()
download_clipseg()
download_safety_checker()

162
shell.nix Normal file
View File

@ -0,0 +1,162 @@
{ pkgs ? import <nixpkgs> {}
, lib ? pkgs.lib
, stdenv ? pkgs.stdenv
, fetchurl ? pkgs.fetchurl
, runCommand ? pkgs.runCommand
, makeWrapper ? pkgs.makeWrapper
, mkShell ? pkgs.mkShell
, buildFHSUserEnv ? pkgs.buildFHSUserEnv
, frameworks ? pkgs.darwin.apple_sdk.frameworks
}:
# Setup InvokeAI environment using nix
# Simple usage:
# nix-shell
# python3 scripts/preload_models.py
# python3 scripts/invoke.py -h
let
conda-shell = { url, sha256, installPath, packages, shellHook }:
let
src = fetchurl { inherit url sha256; };
libPath = lib.makeLibraryPath ([] ++ lib.optionals (stdenv.isLinux) [ pkgs.zlib ]);
condaArch = if stdenv.system == "aarch64-darwin" then "osx-arm64" else "";
installer =
if stdenv.isDarwin then
runCommand "conda-install" {
nativeBuildInputs = [ makeWrapper ];
} ''
mkdir -p $out/bin
cp ${src} $out/bin/miniconda-installer.sh
chmod +x $out/bin/miniconda-installer.sh
makeWrapper \
$out/bin/miniconda-installer.sh \
$out/bin/conda-install \
--add-flags "-p ${installPath}" \
--add-flags "-b"
''
else if stdenv.isLinux then
runCommand "conda-install" {
nativeBuildInputs = [ makeWrapper ];
buildInputs = [ pkgs.zlib ];
}
# on line 10, we have 'unset LD_LIBRARY_PATH'
# we have to comment it out however in a way that the number of bytes in the
# file does not change. So we replace the 'u' in the line with a '#'
# The reason is that the binary payload is encoded as number
# of bytes from the top of the installer script
# and unsetting the library path prevents the zlib library from being discovered
''
mkdir -p $out/bin
sed 's/unset LD_LIBRARY_PATH/#nset LD_LIBRARY_PATH/' ${src} > $out/bin/miniconda-installer.sh
chmod +x $out/bin/miniconda-installer.sh
makeWrapper \
$out/bin/miniconda-installer.sh \
$out/bin/conda-install \
--add-flags "-p ${installPath}" \
--add-flags "-b" \
--prefix "LD_LIBRARY_PATH" : "${libPath}"
''
else {};
hook = ''
export CONDA_SUBDIR=${condaArch}
'' + shellHook;
fhs = buildFHSUserEnv {
name = "conda-shell";
targetPkgs = pkgs: [ stdenv.cc pkgs.git installer ] ++ packages;
profile = hook;
runScript = "bash";
};
shell = mkShell {
shellHook = if stdenv.isDarwin then hook else "conda-shell; exit";
packages = if stdenv.isDarwin then [ pkgs.git installer ] ++ packages else [ fhs ];
};
in shell;
packages = with pkgs; [
cmake
protobuf
libiconv
rustc
cargo
rustPlatform.bindgenHook
];
env = {
aarch64-darwin = {
envFile = "environment-mac.yml";
condaPath = (builtins.toString ./.) + "/.conda";
ptrSize = "8";
};
x86_64-linux = {
envFile = "environment.yml";
condaPath = (builtins.toString ./.) + "/.conda";
ptrSize = "8";
};
};
envFile = env.${stdenv.system}.envFile;
installPath = env.${stdenv.system}.condaPath;
ptrSize = env.${stdenv.system}.ptrSize;
shellHook = ''
conda-install
# tmpdir is too small in nix
export TMPDIR="${installPath}/tmp"
# Add conda to PATH
export PATH="${installPath}/bin:$PATH"
# Allows `conda activate` to work properly
source ${installPath}/etc/profile.d/conda.sh
# Paths for gcc if compiling some C sources with pip
export NIX_CFLAGS_COMPILE="-I${installPath}/include -I$TMPDIR/include"
export NIX_CFLAGS_LINK="-L${installPath}/lib $BINDGEN_EXTRA_CLANG_ARGS"
export PIP_EXISTS_ACTION=w
# rust-onig fails (think it writes config.h to wrong location)
mkdir -p "$TMPDIR/include"
cat <<'EOF' > "$TMPDIR/include/config.h"
#define HAVE_PROTOTYPES 1
#define STDC_HEADERS 1
#define HAVE_STRING_H 1
#define HAVE_STDARG_H 1
#define HAVE_STDLIB_H 1
#define HAVE_LIMITS_H 1
#define HAVE_INTTYPES_H 1
#define SIZEOF_INT 4
#define SIZEOF_SHORT 2
#define SIZEOF_LONG ${ptrSize}
#define SIZEOF_VOIDP ${ptrSize}
#define SIZEOF_LONG_LONG 8
EOF
conda env create -f "${envFile}" || conda env update --prune -f "${envFile}"
conda activate invokeai
'';
version = "4.12.0";
conda = {
aarch64-darwin = {
shell = conda-shell {
inherit shellHook installPath;
url = "https://repo.anaconda.com/miniconda/Miniconda3-py39_${version}-MacOSX-arm64.sh";
sha256 = "4bd112168cc33f8a4a60d3ef7e72b52a85972d588cd065be803eb21d73b625ef";
packages = [ frameworks.Security ] ++ packages;
};
};
x86_64-linux = {
shell = conda-shell {
inherit shellHook installPath;
url = "https://repo.continuum.io/miniconda/Miniconda3-py39_${version}-Linux-x86_64.sh";
sha256 = "78f39f9bae971ec1ae7969f0516017f2413f17796670f7040725dd83fcff5689";
packages = with pkgs; [ libGL glib ] ++ packages;
};
};
};
in conda.${stdenv.system}.shell