mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'inpaint-model' of github.com:invoke-ai/InvokeAI into inpaint-model
This commit is contained in:
commit
c732fd0740
BIN
assets/caution.png
Normal file
BIN
assets/caution.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 33 KiB |
@ -10,7 +10,6 @@ stable-diffusion-1.4:
|
|||||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
description: Stable Diffusion inference model version 1.4
|
description: Stable Diffusion inference model version 1.4
|
||||||
default: true
|
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
inpainting-1.5:
|
inpainting-1.5:
|
||||||
@ -20,6 +19,7 @@ inpainting-1.5:
|
|||||||
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
|
default: true
|
||||||
stable-diffusion-1.5:
|
stable-diffusion-1.5:
|
||||||
config: configs/stable-diffusion/v1-inference.yaml
|
config: configs/stable-diffusion/v1-inference.yaml
|
||||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||||
|
@ -86,6 +86,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
|
|||||||
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
|
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
|
||||||
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
|
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
|
||||||
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
|
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
|
||||||
|
| `--safety-checker` | | False | Activate safety checker for NSFW and other potentially disturbing imagery |
|
||||||
| `--web` | | `False` | Start in web server mode |
|
| `--web` | | `False` | Start in web server mode |
|
||||||
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
|
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
|
||||||
| `--port <port>` | | `9090` | Which port web server should listen for requests on. |
|
| `--port <port>` | | `9090` | Which port web server should listen for requests on. |
|
||||||
@ -97,7 +98,6 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
|
|||||||
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
|
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
|
||||||
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
|
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
|
||||||
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
|
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
|
||||||
| `--device <device>` | `-d<device>` | `torch.cuda.current_device()` | Device to run SD on, e.g. "cuda:0" |
|
|
||||||
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
|
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
|
||||||
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |
|
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |
|
||||||
|
|
||||||
|
@ -75,6 +75,23 @@ combination of integers and floating point numbers, and they do not need to add
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## **Filename Format**
|
||||||
|
|
||||||
|
The argument `--fnformat` allows to specify the filename of the
|
||||||
|
image. Supported wildcards are all arguments what can be set such as
|
||||||
|
`perlin`, `seed`, `threshold`, `height`, `width`, `gfpgan_strength`,
|
||||||
|
`sampler_name`, `steps`, `model`, `upscale`, `prompt`, `cfg_scale`,
|
||||||
|
`prefix`.
|
||||||
|
|
||||||
|
The following prompt
|
||||||
|
```bash
|
||||||
|
dream> a red car --steps 25 -C 9.8 --perlin 0.1 --fnformat {prompt}_steps.{steps}_cfg.{cfg_scale}_perlin.{perlin}.png
|
||||||
|
```
|
||||||
|
|
||||||
|
generates a file with the name: `outputs/img-samples/a red car_steps.25_cfg.9.8_perlin.0.1.png`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## **Thresholding and Perlin Noise Initialization Options**
|
## **Thresholding and Perlin Noise Initialization Options**
|
||||||
|
|
||||||
Two new options are the thresholding (`--threshold`) and the perlin noise initialization (`--perlin`) options. Thresholding limits the range of the latent values during optimization, which helps combat oversaturation with higher CFG scale values. Perlin noise initialization starts with a percentage (a value ranging from 0 to 1) of perlin noise mixed into the initial noise. Both features allow for more variations and options in the course of generating images.
|
Two new options are the thresholding (`--threshold`) and the perlin noise initialization (`--perlin`) options. Thresholding limits the range of the latent values during optimization, which helps combat oversaturation with higher CFG scale values. Perlin noise initialization starts with a percentage (a value ranging from 0 to 1) of perlin noise mixed into the initial noise. Both features allow for more variations and options in the course of generating images.
|
||||||
|
@ -19,6 +19,7 @@ dependencies:
|
|||||||
# ```
|
# ```
|
||||||
- albumentations==1.2.1
|
- albumentations==1.2.1
|
||||||
- coloredlogs==15.0.1
|
- coloredlogs==15.0.1
|
||||||
|
- diffusers==0.6.0
|
||||||
- einops==0.4.1
|
- einops==0.4.1
|
||||||
- grpcio==1.46.4
|
- grpcio==1.46.4
|
||||||
- humanfriendly==10.0
|
- humanfriendly==10.0
|
||||||
|
@ -26,6 +26,7 @@ dependencies:
|
|||||||
- pyreadline3
|
- pyreadline3
|
||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
- transformers==4.21.3
|
- transformers==4.21.3
|
||||||
|
- diffusers==0.6.0
|
||||||
- torchmetrics==0.7.0
|
- torchmetrics==0.7.0
|
||||||
- flask==2.1.3
|
- flask==2.1.3
|
||||||
- flask_socketio==5.3.0
|
- flask_socketio==5.3.0
|
||||||
|
483
frontend/dist/assets/index.0a6593a2.js
vendored
Normal file
483
frontend/dist/assets/index.0a6593a2.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
frontend/dist/assets/index.193aec6f.css
vendored
Normal file
1
frontend/dist/assets/index.193aec6f.css
vendored
Normal file
File diff suppressed because one or more lines are too long
690
frontend/dist/assets/index.2d646c45.js
vendored
690
frontend/dist/assets/index.2d646c45.js
vendored
File diff suppressed because one or more lines are too long
1
frontend/dist/assets/index.7749e179.css
vendored
1
frontend/dist/assets/index.7749e179.css
vendored
File diff suppressed because one or more lines are too long
6
frontend/dist/index.html
vendored
6
frontend/dist/index.html
vendored
@ -5,9 +5,9 @@
|
|||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||||
<link rel="shortcut icon" type="icon" href="/assets/favicon.0d253ced.ico" />
|
<link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" />
|
||||||
<script type="module" crossorigin src="/assets/index.2d646c45.js"></script>
|
<script type="module" crossorigin src="./assets/index.0a6593a2.js"></script>
|
||||||
<link rel="stylesheet" href="/assets/index.7749e179.css">
|
<link rel="stylesheet" href="./assets/index.193aec6f.css">
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body>
|
<body>
|
||||||
|
@ -26,6 +26,7 @@ export const socketioMiddleware = () => {
|
|||||||
|
|
||||||
const socketio = io(origin, {
|
const socketio = io(origin, {
|
||||||
timeout: 60000,
|
timeout: 60000,
|
||||||
|
path: window.location.pathname + 'socket.io',
|
||||||
});
|
});
|
||||||
|
|
||||||
let areListenersSet = false;
|
let areListenersSet = false;
|
||||||
|
@ -68,7 +68,6 @@ const PromptInput = () => {
|
|||||||
<div className="prompt-bar">
|
<div className="prompt-bar">
|
||||||
<FormControl
|
<FormControl
|
||||||
isInvalid={prompt.length === 0 || Boolean(prompt.match(/^[\s\r\n]+$/))}
|
isInvalid={prompt.length === 0 || Boolean(prompt.match(/^[\s\r\n]+$/))}
|
||||||
isDisabled={isProcessing}
|
|
||||||
>
|
>
|
||||||
<Textarea
|
<Textarea
|
||||||
id="prompt"
|
id="prompt"
|
||||||
|
@ -5,6 +5,7 @@ import eslint from 'vite-plugin-eslint';
|
|||||||
// https://vitejs.dev/config/
|
// https://vitejs.dev/config/
|
||||||
export default defineConfig(({ mode }) => {
|
export default defineConfig(({ mode }) => {
|
||||||
const common = {
|
const common = {
|
||||||
|
base: '',
|
||||||
plugins: [react(), eslint()],
|
plugins: [react(), eslint()],
|
||||||
server: {
|
server: {
|
||||||
// Proxy HTTP requests to the flask server
|
// Proxy HTTP requests to the flask server
|
||||||
|
@ -140,13 +140,14 @@ class Generate:
|
|||||||
ddim_eta = 0.0, # deterministic
|
ddim_eta = 0.0, # deterministic
|
||||||
full_precision = False,
|
full_precision = False,
|
||||||
precision = 'auto',
|
precision = 'auto',
|
||||||
# these are deprecated; if present they override values in the conf file
|
|
||||||
weights = None,
|
|
||||||
config = None,
|
|
||||||
gfpgan=None,
|
gfpgan=None,
|
||||||
codeformer=None,
|
codeformer=None,
|
||||||
esrgan=None,
|
esrgan=None,
|
||||||
free_gpu_mem=False,
|
free_gpu_mem=False,
|
||||||
|
safety_checker:bool=False,
|
||||||
|
# these are deprecated; if present they override values in the conf file
|
||||||
|
weights = None,
|
||||||
|
config = None,
|
||||||
):
|
):
|
||||||
mconfig = OmegaConf.load(conf)
|
mconfig = OmegaConf.load(conf)
|
||||||
self.height = None
|
self.height = None
|
||||||
@ -177,6 +178,7 @@ class Generate:
|
|||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||||
self.txt2mask = None
|
self.txt2mask = None
|
||||||
|
self.safety_checker = None
|
||||||
|
|
||||||
# Note that in previous versions, there was an option to pass the
|
# Note that in previous versions, there was an option to pass the
|
||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
@ -204,6 +206,19 @@ class Generate:
|
|||||||
# gets rid of annoying messages about random seed
|
# gets rid of annoying messages about random seed
|
||||||
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
# load safety checker if requested
|
||||||
|
if safety_checker:
|
||||||
|
try:
|
||||||
|
print('>> Initializing safety checker')
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
|
||||||
|
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
|
||||||
|
except Exception:
|
||||||
|
print('** An error was encountered while installing the safety checker:')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
def prompt2png(self, prompt, outdir, **kwargs):
|
def prompt2png(self, prompt, outdir, **kwargs):
|
||||||
"""
|
"""
|
||||||
Takes a prompt and an output directory, writes out the requested number
|
Takes a prompt and an output directory, writes out the requested number
|
||||||
@ -277,6 +292,7 @@ class Generate:
|
|||||||
# Set this True to handle KeyboardInterrupt internally
|
# Set this True to handle KeyboardInterrupt internally
|
||||||
catch_interrupts = False,
|
catch_interrupts = False,
|
||||||
hires_fix = False,
|
hires_fix = False,
|
||||||
|
use_mps_noise = False,
|
||||||
**args,
|
**args,
|
||||||
): # eat up additional cruft
|
): # eat up additional cruft
|
||||||
"""
|
"""
|
||||||
@ -421,6 +437,12 @@ class Generate:
|
|||||||
generator.set_variation(
|
generator.set_variation(
|
||||||
self.seed, variation_amount, with_variations
|
self.seed, variation_amount, with_variations
|
||||||
)
|
)
|
||||||
|
generator.use_mps_noise = use_mps_noise
|
||||||
|
|
||||||
|
checker = {
|
||||||
|
'checker':self.safety_checker,
|
||||||
|
'extractor':self.safety_feature_extractor
|
||||||
|
} if self.safety_checker else None
|
||||||
|
|
||||||
results = generator.generate(
|
results = generator.generate(
|
||||||
prompt,
|
prompt,
|
||||||
@ -444,7 +466,8 @@ class Generate:
|
|||||||
embiggen=embiggen,
|
embiggen=embiggen,
|
||||||
embiggen_tiles=embiggen_tiles,
|
embiggen_tiles=embiggen_tiles,
|
||||||
inpaint_replace=inpaint_replace,
|
inpaint_replace=inpaint_replace,
|
||||||
mask_blur_radius=mask_blur_radius
|
mask_blur_radius=mask_blur_radius,
|
||||||
|
safety_checker=checker
|
||||||
)
|
)
|
||||||
|
|
||||||
if init_color:
|
if init_color:
|
||||||
|
@ -217,6 +217,7 @@ class Args(object):
|
|||||||
switches.append(f'-W {a["width"]}')
|
switches.append(f'-W {a["width"]}')
|
||||||
switches.append(f'-H {a["height"]}')
|
switches.append(f'-H {a["height"]}')
|
||||||
switches.append(f'-C {a["cfg_scale"]}')
|
switches.append(f'-C {a["cfg_scale"]}')
|
||||||
|
switches.append(f'--fnformat {a["fnformat"]}')
|
||||||
if a['perlin'] > 0:
|
if a['perlin'] > 0:
|
||||||
switches.append(f'--perlin {a["perlin"]}')
|
switches.append(f'--perlin {a["perlin"]}')
|
||||||
if a['threshold'] > 0:
|
if a['threshold'] > 0:
|
||||||
@ -419,6 +420,11 @@ class Args(object):
|
|||||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||||
default='auto',
|
default='auto',
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--safety_checker',
|
||||||
|
action='store_true',
|
||||||
|
help='Check for and blur potentially NSFW images',
|
||||||
|
)
|
||||||
file_group.add_argument(
|
file_group.add_argument(
|
||||||
'--from_file',
|
'--from_file',
|
||||||
dest='infile',
|
dest='infile',
|
||||||
@ -438,6 +444,12 @@ class Args(object):
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='Place images in subdirectories named after the prompt.',
|
help='Place images in subdirectories named after the prompt.',
|
||||||
)
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--fnformat',
|
||||||
|
default='{prefix}.{seed}.png',
|
||||||
|
type=str,
|
||||||
|
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
|
||||||
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'--grid',
|
'--grid',
|
||||||
'-g',
|
'-g',
|
||||||
@ -611,6 +623,12 @@ class Args(object):
|
|||||||
type=float,
|
type=float,
|
||||||
help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.',
|
help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.',
|
||||||
)
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--fnformat',
|
||||||
|
default='{prefix}.{seed}.png',
|
||||||
|
type=str,
|
||||||
|
help='Overwrite the filename format. You can use any argument as wildcard enclosed in curly braces. Default is {prefix}.{seed}.png',
|
||||||
|
)
|
||||||
render_group.add_argument(
|
render_group.add_argument(
|
||||||
'--grid',
|
'--grid',
|
||||||
'-g',
|
'-g',
|
||||||
@ -811,6 +829,13 @@ class Args(object):
|
|||||||
type=str,
|
type=str,
|
||||||
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
|
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
|
||||||
)
|
)
|
||||||
|
render_group.add_argument(
|
||||||
|
'--use_mps_noise',
|
||||||
|
action='store_true',
|
||||||
|
dest='use_mps_noise',
|
||||||
|
help='Simulate noise on M1 systems to get the same results'
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def format_metadata(**kwargs):
|
def format_metadata(**kwargs):
|
||||||
@ -846,9 +871,8 @@ def metadata_dumps(opt,
|
|||||||
|
|
||||||
# remove any image keys not mentioned in RFC #266
|
# remove any image keys not mentioned in RFC #266
|
||||||
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
||||||
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
|
'cfg_scale','threshold','perlin','fnformat', 'step_number','width','height','extra','strength',
|
||||||
'init_img','init_mask','facetool','facetool_strength','upscale']
|
'init_img','init_mask','facetool','facetool_strength','upscale']
|
||||||
|
|
||||||
rfc_dict ={}
|
rfc_dict ={}
|
||||||
|
|
||||||
for item in image_dict.items():
|
for item in image_dict.items():
|
||||||
|
@ -7,13 +7,14 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from PIL import Image
|
from PIL import Image, ImageFilter
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from ldm.invoke.devices import choose_autocast
|
from ldm.invoke.devices import choose_autocast
|
||||||
from ldm.util import rand_perlin_2d
|
from ldm.util import rand_perlin_2d
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
|
CAUTION_IMG = 'assets/caution.png'
|
||||||
|
|
||||||
class Generator():
|
class Generator():
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
@ -22,10 +23,12 @@ class Generator():
|
|||||||
self.seed = None
|
self.seed = None
|
||||||
self.latent_channels = model.channels
|
self.latent_channels = model.channels
|
||||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||||
|
self.safety_checker = None
|
||||||
self.perlin = 0.0
|
self.perlin = 0.0
|
||||||
self.threshold = 0
|
self.threshold = 0
|
||||||
self.variation_amount = 0
|
self.variation_amount = 0
|
||||||
self.with_variations = []
|
self.with_variations = []
|
||||||
|
self.use_mps_noise = False
|
||||||
|
|
||||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||||
def get_make_image(self,prompt,**kwargs):
|
def get_make_image(self,prompt,**kwargs):
|
||||||
@ -42,8 +45,10 @@ class Generator():
|
|||||||
|
|
||||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
|
safety_checker:dict=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
scope = choose_autocast(self.precision)
|
scope = choose_autocast(self.precision)
|
||||||
|
self.safety_checker = safety_checker
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
prompt,
|
||||||
sampler = sampler,
|
sampler = sampler,
|
||||||
@ -79,10 +84,17 @@ class Generator():
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
image = make_image(x_T)
|
image = make_image(x_T)
|
||||||
|
|
||||||
|
if self.safety_checker is not None:
|
||||||
|
image = self.safety_check(image)
|
||||||
|
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
image_callback(image, seed, first_seed=first_seed)
|
image_callback(image, seed, first_seed=first_seed)
|
||||||
|
|
||||||
seed = self.new_seed()
|
seed = self.new_seed()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def sample_to_image(self,samples)->Image.Image:
|
def sample_to_image(self,samples)->Image.Image:
|
||||||
@ -171,6 +183,39 @@ class Generator():
|
|||||||
|
|
||||||
return v2
|
return v2
|
||||||
|
|
||||||
|
def safety_check(self,image:Image.Image):
|
||||||
|
'''
|
||||||
|
If the CompViz safety checker flags an NSFW image, we
|
||||||
|
blur it out.
|
||||||
|
'''
|
||||||
|
import diffusers
|
||||||
|
|
||||||
|
checker = self.safety_checker['checker']
|
||||||
|
extractor = self.safety_checker['extractor']
|
||||||
|
features = extractor([image], return_tensors="pt")
|
||||||
|
|
||||||
|
# unfortunately checker requires the numpy version, so we have to convert back
|
||||||
|
x_image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
|
diffusers.logging.set_verbosity_error()
|
||||||
|
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||||
|
if has_nsfw_concept[0]:
|
||||||
|
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||||
|
return self.blur(image)
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
def blur(self,input):
|
||||||
|
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
|
try:
|
||||||
|
caution = Image.open(CAUTION_IMG)
|
||||||
|
caution = caution.resize((caution.width // 2, caution.height //2))
|
||||||
|
blurry.paste(caution,(0,0),caution)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
return blurry
|
||||||
|
|
||||||
# this is a handy routine for debugging use. Given a generated sample,
|
# this is a handy routine for debugging use. Given a generated sample,
|
||||||
# convert it into a PNG image and store it at the indicated path
|
# convert it into a PNG image and store it at the indicated path
|
||||||
def save_sample(self, sample, filepath):
|
def save_sample(self, sample, filepath):
|
||||||
|
@ -59,7 +59,7 @@ class Txt2Img(Generator):
|
|||||||
# returns a tensor filled with random numbers from a normal distribution
|
# returns a tensor filled with random numbers from a normal distribution
|
||||||
def get_noise(self,width,height):
|
def get_noise(self,width,height):
|
||||||
device = self.model.device
|
device = self.model.device
|
||||||
if device.type == 'mps':
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
x = torch.randn([1,
|
x = torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
|
@ -118,7 +118,7 @@ class Txt2Img2Img(Generator):
|
|||||||
scaled_height = height
|
scaled_height = height
|
||||||
|
|
||||||
device = self.model.device
|
device = self.model.device
|
||||||
if device.type == 'mps':
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
return torch.randn([1,
|
return torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
scaled_height // self.downsampling_factor,
|
scaled_height // self.downsampling_factor,
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
albumentations==0.4.3
|
albumentations==0.4.3
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
|
diffusers==0.6.0
|
||||||
huggingface-hub==0.8.1
|
huggingface-hub==0.8.1
|
||||||
imageio==2.9.0
|
imageio==2.9.0
|
||||||
imageio-ffmpeg==0.4.2
|
imageio-ffmpeg==0.4.2
|
||||||
|
@ -32,6 +32,7 @@ send2trash
|
|||||||
dependency_injector==4.40.0
|
dependency_injector==4.40.0
|
||||||
eventlet
|
eventlet
|
||||||
realesrgan
|
realesrgan
|
||||||
|
diffusers
|
||||||
git+https://github.com/openai/CLIP.git@main#egg=clip
|
git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||||
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
||||||
|
@ -79,6 +79,7 @@ def main():
|
|||||||
codeformer=codeformer,
|
codeformer=codeformer,
|
||||||
esrgan=esrgan,
|
esrgan=esrgan,
|
||||||
free_gpu_mem=opt.free_gpu_mem,
|
free_gpu_mem=opt.free_gpu_mem,
|
||||||
|
safety_checker=opt.safety_checker,
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, IOError, KeyError) as e:
|
except (FileNotFoundError, IOError, KeyError) as e:
|
||||||
print(f'{e}. Aborting.')
|
print(f'{e}. Aborting.')
|
||||||
@ -673,6 +674,16 @@ def prepare_image_metadata(
|
|||||||
if postprocessed and opt.save_original:
|
if postprocessed and opt.save_original:
|
||||||
filename = choose_postprocess_name(opt,prefix,seed)
|
filename = choose_postprocess_name(opt,prefix,seed)
|
||||||
else:
|
else:
|
||||||
|
wildcards = dict(opt.__dict__)
|
||||||
|
wildcards['prefix'] = prefix
|
||||||
|
wildcards['seed'] = seed
|
||||||
|
try:
|
||||||
|
filename = opt.fnformat.format(**wildcards)
|
||||||
|
except KeyError as e:
|
||||||
|
print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use \'{{prefix}}.{{seed}}.png\' instead')
|
||||||
|
filename = f'{prefix}.{seed}.png'
|
||||||
|
except IndexError as e:
|
||||||
|
print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead')
|
||||||
filename = f'{prefix}.{seed}.png'
|
filename = f'{prefix}.{seed}.png'
|
||||||
|
|
||||||
if opt.variation_amount > 0:
|
if opt.variation_amount > 0:
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
# two machines must share a common .cache directory.
|
# two machines must share a common .cache directory.
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
import clip
|
import clip
|
||||||
from transformers import BertTokenizerFast
|
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||||
import sys
|
import sys
|
||||||
import transformers
|
import transformers
|
||||||
import os
|
import os
|
||||||
@ -17,21 +17,27 @@ import traceback
|
|||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
#---------------------------------------------
|
||||||
# this will preload the Bert tokenizer fles
|
# this will preload the Bert tokenizer fles
|
||||||
print('Loading bert tokenizer (ignore deprecation errors)...', end='')
|
def download_bert():
|
||||||
|
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||||
print('...success')
|
print('...success')
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
#---------------------------------------------
|
||||||
# this will download requirements for Kornia
|
# this will download requirements for Kornia
|
||||||
print('Loading Kornia requirements...', end='')
|
def download_kornia():
|
||||||
|
print('Installing Kornia requirements...', end='')
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
import kornia
|
import kornia
|
||||||
print('...success')
|
print('...success')
|
||||||
|
|
||||||
|
#---------------------------------------------
|
||||||
|
def download_clip():
|
||||||
version = 'openai/clip-vit-large-patch14'
|
version = 'openai/clip-vit-large-patch14'
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
print('Loading CLIP model...',end='')
|
print('Loading CLIP model...',end='')
|
||||||
@ -39,19 +45,11 @@ tokenizer = CLIPTokenizer.from_pretrained(version)
|
|||||||
transformer = CLIPTextModel.from_pretrained(version)
|
transformer = CLIPTextModel.from_pretrained(version)
|
||||||
print('...success')
|
print('...success')
|
||||||
|
|
||||||
# In the event that the user has installed GFPGAN and also elected to use
|
#---------------------------------------------
|
||||||
# RealESRGAN, this will attempt to download the model needed by RealESRGANer
|
def download_gfpgan():
|
||||||
gfpgan = False
|
print('Installing models from RealESRGAN and facexlib...',end='')
|
||||||
try:
|
try:
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
gfpgan = True
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if gfpgan:
|
|
||||||
print('Loading models from RealESRGAN and facexlib...',end='')
|
|
||||||
try:
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
|
||||||
@ -94,7 +92,9 @@ if gfpgan:
|
|||||||
print('Error loading GFPGAN:')
|
print('Error loading GFPGAN:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
print('preloading CodeFormer model file...',end='')
|
#---------------------------------------------
|
||||||
|
def download_codeformer():
|
||||||
|
print('Installing CodeFormer model file...',end='')
|
||||||
try:
|
try:
|
||||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
||||||
@ -107,7 +107,9 @@ except Exception:
|
|||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print('...success')
|
print('...success')
|
||||||
|
|
||||||
print('Loading clipseg model for text-based masking...',end='')
|
#---------------------------------------------
|
||||||
|
def download_clipseg():
|
||||||
|
print('Installing clipseg model for text-based masking...',end='')
|
||||||
try:
|
try:
|
||||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||||
model_dest = 'src/clipseg/clipseg_weights.zip'
|
model_dest = 'src/clipseg/clipseg_weights.zip'
|
||||||
@ -134,4 +136,28 @@ except Exception:
|
|||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print('...success')
|
print('...success')
|
||||||
|
|
||||||
|
#-------------------------------------
|
||||||
|
def download_safety_checker():
|
||||||
|
print('Installing safety model for NSFW content detection...',end='')
|
||||||
|
try:
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
print('Error installing safety checker model:')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return
|
||||||
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||||
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||||
|
print('...success')
|
||||||
|
|
||||||
|
#-------------------------------------
|
||||||
|
if __name__ == '__main__':
|
||||||
|
download_bert()
|
||||||
|
download_kornia()
|
||||||
|
download_clip()
|
||||||
|
download_gfpgan()
|
||||||
|
download_codeformer()
|
||||||
|
download_clipseg()
|
||||||
|
download_safety_checker()
|
||||||
|
|
||||||
|
|
||||||
|
162
shell.nix
Normal file
162
shell.nix
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
{ pkgs ? import <nixpkgs> {}
|
||||||
|
, lib ? pkgs.lib
|
||||||
|
, stdenv ? pkgs.stdenv
|
||||||
|
, fetchurl ? pkgs.fetchurl
|
||||||
|
, runCommand ? pkgs.runCommand
|
||||||
|
, makeWrapper ? pkgs.makeWrapper
|
||||||
|
, mkShell ? pkgs.mkShell
|
||||||
|
, buildFHSUserEnv ? pkgs.buildFHSUserEnv
|
||||||
|
, frameworks ? pkgs.darwin.apple_sdk.frameworks
|
||||||
|
}:
|
||||||
|
|
||||||
|
# Setup InvokeAI environment using nix
|
||||||
|
# Simple usage:
|
||||||
|
# nix-shell
|
||||||
|
# python3 scripts/preload_models.py
|
||||||
|
# python3 scripts/invoke.py -h
|
||||||
|
|
||||||
|
let
|
||||||
|
conda-shell = { url, sha256, installPath, packages, shellHook }:
|
||||||
|
let
|
||||||
|
src = fetchurl { inherit url sha256; };
|
||||||
|
libPath = lib.makeLibraryPath ([] ++ lib.optionals (stdenv.isLinux) [ pkgs.zlib ]);
|
||||||
|
condaArch = if stdenv.system == "aarch64-darwin" then "osx-arm64" else "";
|
||||||
|
installer =
|
||||||
|
if stdenv.isDarwin then
|
||||||
|
runCommand "conda-install" {
|
||||||
|
nativeBuildInputs = [ makeWrapper ];
|
||||||
|
} ''
|
||||||
|
mkdir -p $out/bin
|
||||||
|
cp ${src} $out/bin/miniconda-installer.sh
|
||||||
|
chmod +x $out/bin/miniconda-installer.sh
|
||||||
|
makeWrapper \
|
||||||
|
$out/bin/miniconda-installer.sh \
|
||||||
|
$out/bin/conda-install \
|
||||||
|
--add-flags "-p ${installPath}" \
|
||||||
|
--add-flags "-b"
|
||||||
|
''
|
||||||
|
else if stdenv.isLinux then
|
||||||
|
runCommand "conda-install" {
|
||||||
|
nativeBuildInputs = [ makeWrapper ];
|
||||||
|
buildInputs = [ pkgs.zlib ];
|
||||||
|
}
|
||||||
|
# on line 10, we have 'unset LD_LIBRARY_PATH'
|
||||||
|
# we have to comment it out however in a way that the number of bytes in the
|
||||||
|
# file does not change. So we replace the 'u' in the line with a '#'
|
||||||
|
# The reason is that the binary payload is encoded as number
|
||||||
|
# of bytes from the top of the installer script
|
||||||
|
# and unsetting the library path prevents the zlib library from being discovered
|
||||||
|
''
|
||||||
|
mkdir -p $out/bin
|
||||||
|
sed 's/unset LD_LIBRARY_PATH/#nset LD_LIBRARY_PATH/' ${src} > $out/bin/miniconda-installer.sh
|
||||||
|
chmod +x $out/bin/miniconda-installer.sh
|
||||||
|
makeWrapper \
|
||||||
|
$out/bin/miniconda-installer.sh \
|
||||||
|
$out/bin/conda-install \
|
||||||
|
--add-flags "-p ${installPath}" \
|
||||||
|
--add-flags "-b" \
|
||||||
|
--prefix "LD_LIBRARY_PATH" : "${libPath}"
|
||||||
|
''
|
||||||
|
else {};
|
||||||
|
|
||||||
|
hook = ''
|
||||||
|
export CONDA_SUBDIR=${condaArch}
|
||||||
|
'' + shellHook;
|
||||||
|
|
||||||
|
fhs = buildFHSUserEnv {
|
||||||
|
name = "conda-shell";
|
||||||
|
targetPkgs = pkgs: [ stdenv.cc pkgs.git installer ] ++ packages;
|
||||||
|
profile = hook;
|
||||||
|
runScript = "bash";
|
||||||
|
};
|
||||||
|
|
||||||
|
shell = mkShell {
|
||||||
|
shellHook = if stdenv.isDarwin then hook else "conda-shell; exit";
|
||||||
|
packages = if stdenv.isDarwin then [ pkgs.git installer ] ++ packages else [ fhs ];
|
||||||
|
};
|
||||||
|
in shell;
|
||||||
|
|
||||||
|
packages = with pkgs; [
|
||||||
|
cmake
|
||||||
|
protobuf
|
||||||
|
libiconv
|
||||||
|
rustc
|
||||||
|
cargo
|
||||||
|
rustPlatform.bindgenHook
|
||||||
|
];
|
||||||
|
|
||||||
|
env = {
|
||||||
|
aarch64-darwin = {
|
||||||
|
envFile = "environment-mac.yml";
|
||||||
|
condaPath = (builtins.toString ./.) + "/.conda";
|
||||||
|
ptrSize = "8";
|
||||||
|
};
|
||||||
|
x86_64-linux = {
|
||||||
|
envFile = "environment.yml";
|
||||||
|
condaPath = (builtins.toString ./.) + "/.conda";
|
||||||
|
ptrSize = "8";
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
envFile = env.${stdenv.system}.envFile;
|
||||||
|
installPath = env.${stdenv.system}.condaPath;
|
||||||
|
ptrSize = env.${stdenv.system}.ptrSize;
|
||||||
|
shellHook = ''
|
||||||
|
conda-install
|
||||||
|
|
||||||
|
# tmpdir is too small in nix
|
||||||
|
export TMPDIR="${installPath}/tmp"
|
||||||
|
|
||||||
|
# Add conda to PATH
|
||||||
|
export PATH="${installPath}/bin:$PATH"
|
||||||
|
|
||||||
|
# Allows `conda activate` to work properly
|
||||||
|
source ${installPath}/etc/profile.d/conda.sh
|
||||||
|
|
||||||
|
# Paths for gcc if compiling some C sources with pip
|
||||||
|
export NIX_CFLAGS_COMPILE="-I${installPath}/include -I$TMPDIR/include"
|
||||||
|
export NIX_CFLAGS_LINK="-L${installPath}/lib $BINDGEN_EXTRA_CLANG_ARGS"
|
||||||
|
|
||||||
|
export PIP_EXISTS_ACTION=w
|
||||||
|
|
||||||
|
# rust-onig fails (think it writes config.h to wrong location)
|
||||||
|
mkdir -p "$TMPDIR/include"
|
||||||
|
cat <<'EOF' > "$TMPDIR/include/config.h"
|
||||||
|
#define HAVE_PROTOTYPES 1
|
||||||
|
#define STDC_HEADERS 1
|
||||||
|
#define HAVE_STRING_H 1
|
||||||
|
#define HAVE_STDARG_H 1
|
||||||
|
#define HAVE_STDLIB_H 1
|
||||||
|
#define HAVE_LIMITS_H 1
|
||||||
|
#define HAVE_INTTYPES_H 1
|
||||||
|
#define SIZEOF_INT 4
|
||||||
|
#define SIZEOF_SHORT 2
|
||||||
|
#define SIZEOF_LONG ${ptrSize}
|
||||||
|
#define SIZEOF_VOIDP ${ptrSize}
|
||||||
|
#define SIZEOF_LONG_LONG 8
|
||||||
|
EOF
|
||||||
|
|
||||||
|
conda env create -f "${envFile}" || conda env update --prune -f "${envFile}"
|
||||||
|
conda activate invokeai
|
||||||
|
'';
|
||||||
|
|
||||||
|
version = "4.12.0";
|
||||||
|
conda = {
|
||||||
|
aarch64-darwin = {
|
||||||
|
shell = conda-shell {
|
||||||
|
inherit shellHook installPath;
|
||||||
|
url = "https://repo.anaconda.com/miniconda/Miniconda3-py39_${version}-MacOSX-arm64.sh";
|
||||||
|
sha256 = "4bd112168cc33f8a4a60d3ef7e72b52a85972d588cd065be803eb21d73b625ef";
|
||||||
|
packages = [ frameworks.Security ] ++ packages;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
x86_64-linux = {
|
||||||
|
shell = conda-shell {
|
||||||
|
inherit shellHook installPath;
|
||||||
|
url = "https://repo.continuum.io/miniconda/Miniconda3-py39_${version}-Linux-x86_64.sh";
|
||||||
|
sha256 = "78f39f9bae971ec1ae7969f0516017f2413f17796670f7040725dd83fcff5689";
|
||||||
|
packages = with pkgs; [ libGL glib ] ++ packages;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
in conda.${stdenv.system}.shell
|
Loading…
x
Reference in New Issue
Block a user