mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into fix-disabled-prompt
This commit is contained in:
commit
fb4dc7eaf9
BIN
assets/caution.png
Normal file
BIN
assets/caution.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 33 KiB |
@ -1,21 +1,23 @@
|
||||
# This file describes the alternative machine learning models
|
||||
# available to the dream script.
|
||||
# available to the dream script.
|
||||
#
|
||||
# To add a new model, follow the examples below. Each
|
||||
# model requires a model config file, a weights file,
|
||||
# and the width and height of the images it
|
||||
# was trained on.
|
||||
|
||||
stable-diffusion-1.4:
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
description: Stable Diffusion inference model version 1.4
|
||||
width: 512
|
||||
height: 512
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
weights: models/ldm/stable-diffusion-v1/model.ckpt
|
||||
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
description: Stable Diffusion inference model version 1.4
|
||||
default: true
|
||||
width: 512
|
||||
height: 512
|
||||
default: true
|
||||
stable-diffusion-1.5:
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||
description: Stable Diffusion inference model version 1.5
|
||||
width: 512
|
||||
height: 512
|
||||
config: configs/stable-diffusion/v1-inference.yaml
|
||||
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
|
||||
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
description: Stable Diffusion inference model version 1.5
|
||||
width: 512
|
||||
height: 512
|
||||
|
@ -86,6 +86,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
|
||||
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
|
||||
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
|
||||
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
|
||||
| `--safety-checker` | | False | Activate safety checker for NSFW and other potentially disturbing imagery |
|
||||
| `--web` | | `False` | Start in web server mode |
|
||||
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
|
||||
| `--port <port>` | | `9090` | Which port web server should listen for requests on. |
|
||||
@ -97,7 +98,6 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
|
||||
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
|
||||
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
|
||||
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
|
||||
| `--device <device>` | `-d<device>` | `torch.cuda.current_device()` | Device to run SD on, e.g. "cuda:0" |
|
||||
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
|
||||
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |
|
||||
|
||||
|
@ -81,15 +81,18 @@ text2mask feature. The syntax is `!mask /path/to/image.png -tm <text>
|
||||
It will generate three files:
|
||||
|
||||
- The image with the selected area highlighted.
|
||||
- it will be named XXXXX.<imagename>.<prompt>.selected.png
|
||||
- The image with the un-selected area highlighted.
|
||||
- it will be named XXXXX.<imagename>.<prompt>.deselected.png
|
||||
- The image with the selected area converted into a black and white
|
||||
image according to the threshold level.
|
||||
image according to the threshold level
|
||||
- it will be named XXXXX.<imagename>.<prompt>.masked.png
|
||||
|
||||
Note that none of these images are intended to be used as the mask
|
||||
passed to invoke via `-M` and may give unexpected results if you try
|
||||
to use them this way. Instead, use `!mask` for testing that you are
|
||||
selecting the right mask area, and then do inpainting using the
|
||||
best selection term and threshold.
|
||||
The `.masked.png` file can then be directly passed to the `invoke>`
|
||||
prompt in the CLI via the `-M` argument. Do not attempt this with
|
||||
the `selected.png` or `deselected.png` files, as they contain some
|
||||
transparency throughout the image and will not produce the desired
|
||||
results.
|
||||
|
||||
Here is an example of how `!mask` works:
|
||||
|
||||
@ -120,7 +123,7 @@ It looks like we selected the hair pretty well at the 0.5 threshold
|
||||
let's have some fun:
|
||||
|
||||
```
|
||||
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair 0.5 -C20
|
||||
invoke> medusa with cobras -I ./test-pictures/curly.png -M 000019.curly.hair.masked.png -C20
|
||||
>> loaded input image of size 512x512 from ./test-pictures/curly.png
|
||||
...
|
||||
Outputs:
|
||||
@ -129,6 +132,13 @@ Outputs:
|
||||
|
||||
<img src="../assets/inpainting/000024.801380492.png">
|
||||
|
||||
You can also skip the `!mask` creation step and just select the masked
|
||||
|
||||
region directly:
|
||||
```
|
||||
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair -C20
|
||||
```
|
||||
|
||||
### Inpainting is not changing the masked region enough!
|
||||
|
||||
One of the things to understand about how inpainting works is that it
|
||||
|
@ -19,6 +19,7 @@ dependencies:
|
||||
# ```
|
||||
- albumentations==1.2.1
|
||||
- coloredlogs==15.0.1
|
||||
- diffusers==0.6.0
|
||||
- einops==0.4.1
|
||||
- grpcio==1.46.4
|
||||
- humanfriendly==10.0
|
||||
|
@ -26,6 +26,7 @@ dependencies:
|
||||
- pyreadline3
|
||||
- torch-fidelity==0.3.0
|
||||
- transformers==4.21.3
|
||||
- diffusers==0.6.0
|
||||
- torchmetrics==0.7.0
|
||||
- flask==2.1.3
|
||||
- flask_socketio==5.3.0
|
||||
|
690
frontend/dist/assets/index.48782019.js
vendored
Normal file
690
frontend/dist/assets/index.48782019.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
frontend/dist/assets/index.556a5ea7.css
vendored
Normal file
1
frontend/dist/assets/index.556a5ea7.css
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -26,6 +26,7 @@ export const socketioMiddleware = () => {
|
||||
|
||||
const socketio = io(origin, {
|
||||
timeout: 60000,
|
||||
path: window.location.pathname + 'socket.io',
|
||||
});
|
||||
|
||||
let areListenersSet = false;
|
||||
|
@ -5,6 +5,7 @@ import eslint from 'vite-plugin-eslint';
|
||||
// https://vitejs.dev/config/
|
||||
export default defineConfig(({ mode }) => {
|
||||
const common = {
|
||||
base: '',
|
||||
plugins: [react(), eslint()],
|
||||
server: {
|
||||
// Proxy HTTP requests to the flask server
|
||||
|
117
ldm/generate.py
117
ldm/generate.py
@ -110,12 +110,13 @@ still work.
|
||||
The full list of arguments to Generate() are:
|
||||
gr = Generate(
|
||||
# these values are set once and shouldn't be changed
|
||||
conf = path to configuration file ('configs/models.yaml')
|
||||
model = symbolic name of the model in the configuration file
|
||||
precision = float precision to be used
|
||||
conf:str = path to configuration file ('configs/models.yaml')
|
||||
model:str = symbolic name of the model in the configuration file
|
||||
precision:float = float precision to be used
|
||||
safety_checker:bool = activate safety checker [False]
|
||||
|
||||
# this value is sticky and maintained between generation calls
|
||||
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
|
||||
|
||||
# these are deprecated - use conf and model instead
|
||||
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
|
||||
@ -132,20 +133,21 @@ class Generate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model = None,
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
ddim_eta = 0.0, # deterministic
|
||||
full_precision = False,
|
||||
precision = 'auto',
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights = None,
|
||||
config = None,
|
||||
model = None,
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
ddim_eta = 0.0, # deterministic
|
||||
full_precision = False,
|
||||
precision = 'auto',
|
||||
gfpgan=None,
|
||||
codeformer=None,
|
||||
esrgan=None,
|
||||
free_gpu_mem=False,
|
||||
safety_checker:bool=False,
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights = None,
|
||||
config = None,
|
||||
):
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.height = None
|
||||
@ -176,6 +178,7 @@ class Generate:
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
self.safety_checker = None
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -203,6 +206,19 @@ class Generate:
|
||||
# gets rid of annoying messages about random seed
|
||||
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
||||
|
||||
# load safety checker if requested
|
||||
if safety_checker:
|
||||
try:
|
||||
print('>> Initializing safety checker')
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
|
||||
except Exception:
|
||||
print('** An error was encountered while installing the safety checker:')
|
||||
print(traceback.format_exc())
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
Takes a prompt and an output directory, writes out the requested number
|
||||
@ -271,6 +287,8 @@ class Generate:
|
||||
upscale = None,
|
||||
# this is specific to inpainting and causes more extreme inpainting
|
||||
inpaint_replace = 0.0,
|
||||
# This will help match inpainted areas to the original image more smoothly
|
||||
mask_blur_radius: int = 8,
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
catch_interrupts = False,
|
||||
hires_fix = False,
|
||||
@ -391,7 +409,7 @@ class Generate:
|
||||
log_tokens =self.log_tokenization
|
||||
)
|
||||
|
||||
init_image,mask_image = self._make_images(
|
||||
init_image, mask_image = self._make_images(
|
||||
init_img,
|
||||
init_mask,
|
||||
width,
|
||||
@ -416,6 +434,11 @@ class Generate:
|
||||
self.seed, variation_amount, with_variations
|
||||
)
|
||||
|
||||
checker = {
|
||||
'checker':self.safety_checker,
|
||||
'extractor':self.safety_feature_extractor
|
||||
} if self.safety_checker else None
|
||||
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
iterations=iterations,
|
||||
@ -426,10 +449,10 @@ class Generate:
|
||||
conditioning=(uc, c),
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=image_callback, # called after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||
init_image=init_image, # notice that init_image is different from init_img
|
||||
mask_image=mask_image,
|
||||
strength=strength,
|
||||
@ -438,6 +461,8 @@ class Generate:
|
||||
embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
inpaint_replace=inpaint_replace,
|
||||
mask_blur_radius=mask_blur_radius,
|
||||
safety_checker=checker
|
||||
)
|
||||
|
||||
if init_color:
|
||||
@ -631,23 +656,22 @@ class Generate:
|
||||
# if image has a transparent area and no mask was provided, then try to generate mask
|
||||
if self._has_transparency(image):
|
||||
self._transparency_check_and_warning(image, mask)
|
||||
# this returns a torch tensor
|
||||
init_mask = self._create_init_mask(image, width, height, fit=fit)
|
||||
|
||||
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
|
||||
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
|
||||
self.size_matters = False
|
||||
|
||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||
init_image = self._create_init_image(image,width,height,fit=fit)
|
||||
|
||||
if mask:
|
||||
mask_image = self._load_img(mask) # this returns an Image
|
||||
mask_image = self._load_img(mask)
|
||||
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
||||
|
||||
elif text_mask:
|
||||
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||
|
||||
return init_image, init_mask
|
||||
return init_image,init_mask
|
||||
|
||||
def _make_base(self):
|
||||
if not self.generators.get('base'):
|
||||
@ -864,46 +888,31 @@ class Generate:
|
||||
|
||||
def _create_init_image(self, image, width, height, fit=True):
|
||||
image = image.convert('RGB')
|
||||
if fit:
|
||||
image = self._fit_image(image, (width, height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.device)
|
||||
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||
return image
|
||||
|
||||
def _create_init_mask(self, image, width, height, fit=True):
|
||||
# convert into a black/white mask
|
||||
image = self._image_to_mask(image)
|
||||
image = image.convert('RGB')
|
||||
|
||||
# now we adjust the size
|
||||
if fit:
|
||||
image = self._fit_image(image, (width, height))
|
||||
else:
|
||||
image = self._squeeze_image(image)
|
||||
image = image.resize((image.width//downsampling, image.height //
|
||||
downsampling), resample=Image.Resampling.NEAREST)
|
||||
image = np.array(image)
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return image.to(self.device)
|
||||
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
|
||||
return image
|
||||
|
||||
# The mask is expected to have the region to be inpainted
|
||||
# with alpha transparency. It converts it into a black/white
|
||||
# image with the transparent part black.
|
||||
def _image_to_mask(self, mask_image, invert=False) -> Image:
|
||||
def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
|
||||
# Obtain the mask from the transparency channel
|
||||
mask = Image.new(mode="L", size=mask_image.size, color=255)
|
||||
mask.putdata(mask_image.getdata(band=3))
|
||||
if mask_image.mode == 'L':
|
||||
mask = mask_image
|
||||
else:
|
||||
# Obtain the mask from the transparency channel
|
||||
mask = Image.new(mode="L", size=mask_image.size, color=255)
|
||||
mask.putdata(mask_image.getdata(band=3))
|
||||
if invert:
|
||||
mask = ImageOps.invert(mask)
|
||||
return mask
|
||||
|
||||
# TODO: The latter part of this method repeats code from _create_init_mask()
|
||||
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
||||
prompt = text_mask[0]
|
||||
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
||||
@ -913,18 +922,8 @@ class Generate:
|
||||
segmented = self.txt2mask.segment(image, prompt)
|
||||
mask = segmented.to_mask(float(confidence_level))
|
||||
mask = mask.convert('RGB')
|
||||
# now we adjust the size
|
||||
if fit:
|
||||
mask = self._fit_image(mask, (width, height))
|
||||
else:
|
||||
mask = self._squeeze_image(mask)
|
||||
mask = mask.resize((mask.width//downsampling, mask.height //
|
||||
downsampling), resample=Image.Resampling.NEAREST)
|
||||
mask = np.array(mask)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None].transpose(0, 3, 1, 2)
|
||||
mask = torch.from_numpy(mask)
|
||||
return mask.to(self.device)
|
||||
mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
|
||||
return mask
|
||||
|
||||
def _has_transparency(self, image):
|
||||
if image.info.get("transparency", None) is not None:
|
||||
|
@ -113,8 +113,8 @@ PRECISION_CHOICES = [
|
||||
]
|
||||
|
||||
# is there a way to pick this up during git commits?
|
||||
APP_ID = 'lstein/stable-diffusion'
|
||||
APP_VERSION = 'v1.15'
|
||||
APP_ID = 'invoke-ai/InvokeAI'
|
||||
APP_VERSION = 'v2.02'
|
||||
|
||||
class ArgFormatter(argparse.RawTextHelpFormatter):
|
||||
# use defined argument order to display usage
|
||||
@ -172,6 +172,7 @@ class Args(object):
|
||||
command = cmd_string.replace("'", "\\'")
|
||||
try:
|
||||
elements = shlex.split(command)
|
||||
elements = [x.replace("\\'","'") for x in elements]
|
||||
except ValueError:
|
||||
import sys, traceback
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
@ -418,6 +419,11 @@ class Args(object):
|
||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||
default='auto',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--safety_checker',
|
||||
action='store_true',
|
||||
help='Check for and blur potentially NSFW images',
|
||||
)
|
||||
file_group.add_argument(
|
||||
'--from_file',
|
||||
dest='infile',
|
||||
@ -846,7 +852,7 @@ def metadata_dumps(opt,
|
||||
# remove any image keys not mentioned in RFC #266
|
||||
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
|
||||
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
|
||||
'init_img','init_mask']
|
||||
'init_img','init_mask','facetool','facetool_strength','upscale']
|
||||
|
||||
rfc_dict ={}
|
||||
|
||||
@ -930,7 +936,7 @@ def metadata_loads(metadata) -> list:
|
||||
for image in images:
|
||||
# repack the prompt and variations
|
||||
if 'prompt' in image:
|
||||
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
|
||||
image['prompt'] = repack_prompt(image['prompt'])
|
||||
if 'variations' in image:
|
||||
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
|
||||
# fix a bit of semantic drift here
|
||||
@ -938,12 +944,19 @@ def metadata_loads(metadata) -> list:
|
||||
opt = Args()
|
||||
opt._cmd_switches = Namespace(**image)
|
||||
results.append(opt)
|
||||
except KeyError as e:
|
||||
except Exception as e:
|
||||
import sys, traceback
|
||||
print('>> badly-formatted metadata',file=sys.stderr)
|
||||
print('>> could not read metadata',file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return results
|
||||
|
||||
def repack_prompt(prompt_list:list)->str:
|
||||
# in the common case of no weighting syntax, just return the prompt as is
|
||||
if len(prompt_list) > 1:
|
||||
return ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in prompt_list])
|
||||
else:
|
||||
return prompt_list[0]['prompt']
|
||||
|
||||
# image can either be a file path on disk or a base64-encoded
|
||||
# representation of the file's contents
|
||||
def calculate_init_img_hash(image_string):
|
||||
|
@ -7,25 +7,27 @@ import numpy as np
|
||||
import random
|
||||
import os
|
||||
from tqdm import tqdm, trange
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFilter
|
||||
from einops import rearrange, repeat
|
||||
from pytorch_lightning import seed_everything
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = 'assets/caution.png'
|
||||
|
||||
class Generator():
|
||||
def __init__(self, model, precision):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
@ -42,8 +44,10 @@ class Generator():
|
||||
|
||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
self.safety_checker = safety_checker
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
init_image = init_image,
|
||||
@ -77,10 +81,17 @@ class Generator():
|
||||
pass
|
||||
|
||||
image = make_image(x_T)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_check(image)
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, first_seed=first_seed)
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples):
|
||||
@ -169,6 +180,39 @@ class Generator():
|
||||
|
||||
return v2
|
||||
|
||||
def safety_check(self,image:Image.Image):
|
||||
'''
|
||||
If the CompViz safety checker flags an NSFW image, we
|
||||
blur it out.
|
||||
'''
|
||||
import diffusers
|
||||
|
||||
checker = self.safety_checker['checker']
|
||||
extractor = self.safety_checker['extractor']
|
||||
features = extractor([image], return_tensors="pt")
|
||||
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||
if has_nsfw_concept[0]:
|
||||
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self,input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = Image.open(CAUTION_IMG)
|
||||
caution = caution.resize((caution.width // 2, caution.height //2))
|
||||
blurry.paste(caution,(0,0),caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
|
@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
import PIL
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -25,6 +28,9 @@ class Img2Img(Generator):
|
||||
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
|
||||
)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
scope = choose_autocast(self.precision)
|
||||
with scope(self.model.device.type):
|
||||
self.init_latent = self.model.get_first_stage_encoding(
|
||||
@ -68,3 +74,11 @@ class Img2Img(Generator):
|
||||
shape = init_latent.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
return x
|
||||
|
||||
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
if normalize:
|
||||
image = 2.0 * image - 1.0
|
||||
return image.to(self.model.device)
|
||||
|
@ -3,27 +3,55 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator
|
||||
'''
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
import PIL
|
||||
from PIL import Image, ImageFilter
|
||||
from skimage.exposure.histogram_matching import match_histograms
|
||||
from einops import rearrange, repeat
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.img2img import Img2Img
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
from ldm.invoke.generator.base import downsampling
|
||||
|
||||
class Inpaint(Img2Img):
|
||||
def __init__(self, model, precision):
|
||||
self.init_latent = None
|
||||
self.pil_image = None
|
||||
self.pil_mask = None
|
||||
self.mask_blur_radius = 0
|
||||
super().__init__(model, precision)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
|
||||
conditioning,init_image,mask_image,strength,
|
||||
step_callback=None,inpaint_replace=False,**kwargs):
|
||||
mask_blur_radius: int = 8,
|
||||
step_callback=None,inpaint_replace=False, **kwargs):
|
||||
"""
|
||||
Returns a function returning an image derived from the prompt and
|
||||
the initial image + mask. Return value depends on the seed at
|
||||
the time you call it. kwargs are 'init_latent' and 'strength'
|
||||
"""
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
self.pil_image = init_image
|
||||
init_image = self._image_to_tensor(init_image)
|
||||
|
||||
if isinstance(mask_image, PIL.Image.Image):
|
||||
self.pil_mask = mask_image
|
||||
mask_image = mask_image.resize(
|
||||
(
|
||||
mask_image.width // downsampling,
|
||||
mask_image.height // downsampling
|
||||
),
|
||||
resample=Image.Resampling.NEAREST
|
||||
)
|
||||
mask_image = self._image_to_tensor(mask_image,normalize=False)
|
||||
|
||||
self.mask_blur_radius = mask_blur_radius
|
||||
|
||||
# klms samplers not supported yet, so ignore previous sampler
|
||||
if isinstance(sampler,KSampler):
|
||||
print(
|
||||
@ -77,10 +105,50 @@ class Inpaint(Img2Img):
|
||||
mask = mask_image,
|
||||
init_latent = self.init_latent
|
||||
)
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
||||
return make_image
|
||||
|
||||
def sample_to_image(self, samples)->Image.Image:
|
||||
gen_result = super().sample_to_image(samples).convert('RGB')
|
||||
|
||||
if self.pil_image is None or self.pil_mask is None:
|
||||
return gen_result
|
||||
|
||||
pil_mask = self.pil_mask
|
||||
pil_image = self.pil_image
|
||||
mask_blur_radius = self.mask_blur_radius
|
||||
|
||||
# Get the original alpha channel of the mask if there is one.
|
||||
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
|
||||
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
|
||||
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
|
||||
|
||||
# Build an image with only visible pixels from source to use as reference for color-matching.
|
||||
# Note that this doesn't use the mask, which would exclude some source image pixels from the
|
||||
# histogram and cause slight color changes.
|
||||
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
|
||||
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
|
||||
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
|
||||
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
|
||||
|
||||
# Get numpy version
|
||||
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
|
||||
|
||||
# Color correct
|
||||
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
|
||||
matched_result = Image.fromarray(np_matched_result, mode='RGB')
|
||||
|
||||
# Blur the mask out (into init image) by specified amount
|
||||
if mask_blur_radius > 0:
|
||||
nm = np.asarray(pil_init_mask, dtype=np.uint8)
|
||||
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
|
||||
pmd = Image.fromarray(nmd, mode='L')
|
||||
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
|
||||
else:
|
||||
blurred_init_mask = pil_init_mask
|
||||
|
||||
# Paste original on color-corrected generation (using blurred mask)
|
||||
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
|
||||
return matched_result
|
||||
|
||||
|
@ -38,7 +38,7 @@ class PngWriter:
|
||||
info = PngImagePlugin.PngInfo()
|
||||
info.add_text('Dream', dream_prompt)
|
||||
if metadata:
|
||||
info.add_text('sd-metadata', json.dumps(metadata))
|
||||
info.add_text('sd-metadata', json.dumps(metadata))
|
||||
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
|
||||
return path
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
albumentations==0.4.3
|
||||
einops==0.3.0
|
||||
diffusers==0.6.0
|
||||
huggingface-hub==0.8.1
|
||||
imageio==2.9.0
|
||||
imageio-ffmpeg==0.4.2
|
||||
|
@ -32,7 +32,8 @@ send2trash
|
||||
dependency_injector==4.40.0
|
||||
eventlet
|
||||
realesrgan
|
||||
diffusers
|
||||
git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
||||
git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
-e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
|
||||
|
@ -69,16 +69,17 @@ def main():
|
||||
# creating a Generate object:
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = opt.conf,
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
conf = opt.conf,
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
embedding_path = opt.embedding_path,
|
||||
full_precision = opt.full_precision,
|
||||
precision = opt.precision,
|
||||
precision = opt.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=opt.free_gpu_mem,
|
||||
safety_checker=opt.safety_checker,
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
@ -799,26 +800,38 @@ def retrieve_dream_command(opt,command,completer):
|
||||
will retrieve and format the dream command used to generate the image,
|
||||
and pop it into the readline buffer (linux, Mac), or print out a comment
|
||||
for cut-and-paste (windows)
|
||||
|
||||
Given a wildcard path to a folder with image png files,
|
||||
will retrieve and format the dream command used to generate the images,
|
||||
and save them to a file commands.txt for further processing
|
||||
'''
|
||||
if len(command) == 0:
|
||||
return
|
||||
|
||||
tokens = command.split()
|
||||
if len(tokens) > 1:
|
||||
outfilepath = tokens[1]
|
||||
else:
|
||||
outfilepath = "commands.txt"
|
||||
|
||||
file_path = tokens[0]
|
||||
dir,basename = os.path.split(file_path)
|
||||
dir,basename = os.path.split(tokens[0])
|
||||
if len(dir) == 0:
|
||||
dir = opt.outdir
|
||||
|
||||
outdir,outname = os.path.split(outfilepath)
|
||||
if len(outdir) == 0:
|
||||
outfilepath = os.path.join(dir,outname)
|
||||
path = os.path.join(opt.outdir,basename)
|
||||
else:
|
||||
path = tokens[0]
|
||||
|
||||
if len(tokens) > 1:
|
||||
return write_commands(opt, path, tokens[1])
|
||||
|
||||
cmd = ''
|
||||
try:
|
||||
cmd = dream_cmd_from_png(path)
|
||||
except OSError:
|
||||
print(f'## {tokens[0]}: file could not be read')
|
||||
except (KeyError, AttributeError, IndexError):
|
||||
print(f'## {tokens[0]}: file has no metadata')
|
||||
except:
|
||||
print(f'## {tokens[0]}: file could not be processed')
|
||||
if len(cmd)>0:
|
||||
completer.set_line(cmd)
|
||||
|
||||
def write_commands(opt, file_path:str, outfilepath:str):
|
||||
dir,basename = os.path.split(file_path)
|
||||
try:
|
||||
paths = list(Path(dir).glob(basename))
|
||||
except ValueError:
|
||||
@ -826,28 +839,24 @@ def retrieve_dream_command(opt,command,completer):
|
||||
return
|
||||
|
||||
commands = []
|
||||
cmd = None
|
||||
for path in paths:
|
||||
try:
|
||||
cmd = dream_cmd_from_png(path)
|
||||
except OSError:
|
||||
print(f'## {path}: file could not be read')
|
||||
continue
|
||||
except (KeyError, AttributeError, IndexError):
|
||||
print(f'## {path}: file has no metadata')
|
||||
continue
|
||||
except:
|
||||
print(f'## {path}: file could not be processed')
|
||||
continue
|
||||
|
||||
commands.append(f'# {path}')
|
||||
commands.append(cmd)
|
||||
|
||||
with open(outfilepath, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(commands))
|
||||
print(f'>> File {outfilepath} with commands created')
|
||||
|
||||
if len(commands) == 2:
|
||||
completer.set_line(commands[1])
|
||||
if cmd:
|
||||
commands.append(f'# {path}')
|
||||
commands.append(cmd)
|
||||
if len(commands)>0:
|
||||
dir,basename = os.path.split(outfilepath)
|
||||
if len(dir)==0:
|
||||
outfilepath = os.path.join(opt.outdir,basename)
|
||||
with open(outfilepath, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(commands))
|
||||
print(f'>> File {outfilepath} with commands created')
|
||||
|
||||
######################################
|
||||
|
||||
|
@ -5,7 +5,7 @@
|
||||
# two machines must share a common .cache directory.
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import clip
|
||||
from transformers import BertTokenizerFast
|
||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||
import sys
|
||||
import transformers
|
||||
import os
|
||||
@ -17,41 +17,39 @@ import traceback
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
#---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
print('Loading bert tokenizer (ignore deprecation errors)...', end='')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
print('...success')
|
||||
sys.stdout.flush()
|
||||
def download_bert():
|
||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
print('...success')
|
||||
sys.stdout.flush()
|
||||
|
||||
#---------------------------------------------
|
||||
# this will download requirements for Kornia
|
||||
print('Loading Kornia requirements...', end='')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
import kornia
|
||||
print('...success')
|
||||
def download_kornia():
|
||||
print('Installing Kornia requirements...', end='')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
import kornia
|
||||
print('...success')
|
||||
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
sys.stdout.flush()
|
||||
print('Loading CLIP model...',end='')
|
||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
transformer = CLIPTextModel.from_pretrained(version)
|
||||
print('...success')
|
||||
#---------------------------------------------
|
||||
def download_clip():
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
sys.stdout.flush()
|
||||
print('Loading CLIP model...',end='')
|
||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
transformer = CLIPTextModel.from_pretrained(version)
|
||||
print('...success')
|
||||
|
||||
# In the event that the user has installed GFPGAN and also elected to use
|
||||
# RealESRGAN, this will attempt to download the model needed by RealESRGANer
|
||||
gfpgan = False
|
||||
try:
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
gfpgan = True
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
if gfpgan:
|
||||
print('Loading models from RealESRGAN and facexlib...',end='')
|
||||
#---------------------------------------------
|
||||
def download_gfpgan():
|
||||
print('Installing models from RealESRGAN and facexlib...',end='')
|
||||
try:
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
@ -94,44 +92,72 @@ if gfpgan:
|
||||
print('Error loading GFPGAN:')
|
||||
print(traceback.format_exc())
|
||||
|
||||
print('preloading CodeFormer model file...',end='')
|
||||
try:
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
||||
if not os.path.exists(model_dest):
|
||||
print('Downloading codeformer model file...')
|
||||
#---------------------------------------------
|
||||
def download_codeformer():
|
||||
print('Installing CodeFormer model file...',end='')
|
||||
try:
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
||||
if not os.path.exists(model_dest):
|
||||
print('Downloading codeformer model file...')
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
urllib.request.urlretrieve(model_url,model_dest)
|
||||
except Exception:
|
||||
print('Error loading CodeFormer:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clipseg():
|
||||
print('Installing clipseg model for text-based masking...',end='')
|
||||
try:
|
||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||
model_dest = 'src/clipseg/clipseg_weights.zip'
|
||||
weights_dir = 'src/clipseg/weights'
|
||||
if not os.path.exists(weights_dir):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
urllib.request.urlretrieve(model_url,model_dest)
|
||||
except Exception:
|
||||
print('Error loading CodeFormer:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
with zipfile.ZipFile(model_dest,'r') as zip:
|
||||
zip.extractall('src/clipseg')
|
||||
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
|
||||
os.remove(model_dest)
|
||||
from clipseg_models.clipseg import CLIPDensePredT
|
||||
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
||||
model.eval()
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
'src/clipseg/weights/rd64-uni-refined.pth',
|
||||
map_location=torch.device('cpu')
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
except Exception:
|
||||
print('Error installing clipseg model:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
|
||||
print('Loading clipseg model for text-based masking...',end='')
|
||||
try:
|
||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||
model_dest = 'src/clipseg/clipseg_weights.zip'
|
||||
weights_dir = 'src/clipseg/weights'
|
||||
if not os.path.exists(weights_dir):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
urllib.request.urlretrieve(model_url,model_dest)
|
||||
with zipfile.ZipFile(model_dest,'r') as zip:
|
||||
zip.extractall('src/clipseg')
|
||||
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
|
||||
os.remove(model_dest)
|
||||
from clipseg_models.clipseg import CLIPDensePredT
|
||||
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
||||
model.eval()
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
'src/clipseg/weights/rd64-uni-refined.pth',
|
||||
map_location=torch.device('cpu')
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
except Exception:
|
||||
print('Error installing clipseg model:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
#-------------------------------------
|
||||
def download_safety_checker():
|
||||
print('Installing safety model for NSFW content detection...',end='')
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
except ModuleNotFoundError:
|
||||
print('Error installing safety checker model:')
|
||||
print(traceback.format_exc())
|
||||
return
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
print('...success')
|
||||
|
||||
#-------------------------------------
|
||||
if __name__ == '__main__':
|
||||
download_bert()
|
||||
download_kornia()
|
||||
download_clip()
|
||||
download_gfpgan()
|
||||
download_codeformer()
|
||||
download_clipseg()
|
||||
download_safety_checker()
|
||||
|
||||
|
||||
|
||||
|
162
shell.nix
Normal file
162
shell.nix
Normal file
@ -0,0 +1,162 @@
|
||||
{ pkgs ? import <nixpkgs> {}
|
||||
, lib ? pkgs.lib
|
||||
, stdenv ? pkgs.stdenv
|
||||
, fetchurl ? pkgs.fetchurl
|
||||
, runCommand ? pkgs.runCommand
|
||||
, makeWrapper ? pkgs.makeWrapper
|
||||
, mkShell ? pkgs.mkShell
|
||||
, buildFHSUserEnv ? pkgs.buildFHSUserEnv
|
||||
, frameworks ? pkgs.darwin.apple_sdk.frameworks
|
||||
}:
|
||||
|
||||
# Setup InvokeAI environment using nix
|
||||
# Simple usage:
|
||||
# nix-shell
|
||||
# python3 scripts/preload_models.py
|
||||
# python3 scripts/invoke.py -h
|
||||
|
||||
let
|
||||
conda-shell = { url, sha256, installPath, packages, shellHook }:
|
||||
let
|
||||
src = fetchurl { inherit url sha256; };
|
||||
libPath = lib.makeLibraryPath ([] ++ lib.optionals (stdenv.isLinux) [ pkgs.zlib ]);
|
||||
condaArch = if stdenv.system == "aarch64-darwin" then "osx-arm64" else "";
|
||||
installer =
|
||||
if stdenv.isDarwin then
|
||||
runCommand "conda-install" {
|
||||
nativeBuildInputs = [ makeWrapper ];
|
||||
} ''
|
||||
mkdir -p $out/bin
|
||||
cp ${src} $out/bin/miniconda-installer.sh
|
||||
chmod +x $out/bin/miniconda-installer.sh
|
||||
makeWrapper \
|
||||
$out/bin/miniconda-installer.sh \
|
||||
$out/bin/conda-install \
|
||||
--add-flags "-p ${installPath}" \
|
||||
--add-flags "-b"
|
||||
''
|
||||
else if stdenv.isLinux then
|
||||
runCommand "conda-install" {
|
||||
nativeBuildInputs = [ makeWrapper ];
|
||||
buildInputs = [ pkgs.zlib ];
|
||||
}
|
||||
# on line 10, we have 'unset LD_LIBRARY_PATH'
|
||||
# we have to comment it out however in a way that the number of bytes in the
|
||||
# file does not change. So we replace the 'u' in the line with a '#'
|
||||
# The reason is that the binary payload is encoded as number
|
||||
# of bytes from the top of the installer script
|
||||
# and unsetting the library path prevents the zlib library from being discovered
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
sed 's/unset LD_LIBRARY_PATH/#nset LD_LIBRARY_PATH/' ${src} > $out/bin/miniconda-installer.sh
|
||||
chmod +x $out/bin/miniconda-installer.sh
|
||||
makeWrapper \
|
||||
$out/bin/miniconda-installer.sh \
|
||||
$out/bin/conda-install \
|
||||
--add-flags "-p ${installPath}" \
|
||||
--add-flags "-b" \
|
||||
--prefix "LD_LIBRARY_PATH" : "${libPath}"
|
||||
''
|
||||
else {};
|
||||
|
||||
hook = ''
|
||||
export CONDA_SUBDIR=${condaArch}
|
||||
'' + shellHook;
|
||||
|
||||
fhs = buildFHSUserEnv {
|
||||
name = "conda-shell";
|
||||
targetPkgs = pkgs: [ stdenv.cc pkgs.git installer ] ++ packages;
|
||||
profile = hook;
|
||||
runScript = "bash";
|
||||
};
|
||||
|
||||
shell = mkShell {
|
||||
shellHook = if stdenv.isDarwin then hook else "conda-shell; exit";
|
||||
packages = if stdenv.isDarwin then [ pkgs.git installer ] ++ packages else [ fhs ];
|
||||
};
|
||||
in shell;
|
||||
|
||||
packages = with pkgs; [
|
||||
cmake
|
||||
protobuf
|
||||
libiconv
|
||||
rustc
|
||||
cargo
|
||||
rustPlatform.bindgenHook
|
||||
];
|
||||
|
||||
env = {
|
||||
aarch64-darwin = {
|
||||
envFile = "environment-mac.yml";
|
||||
condaPath = (builtins.toString ./.) + "/.conda";
|
||||
ptrSize = "8";
|
||||
};
|
||||
x86_64-linux = {
|
||||
envFile = "environment.yml";
|
||||
condaPath = (builtins.toString ./.) + "/.conda";
|
||||
ptrSize = "8";
|
||||
};
|
||||
};
|
||||
|
||||
envFile = env.${stdenv.system}.envFile;
|
||||
installPath = env.${stdenv.system}.condaPath;
|
||||
ptrSize = env.${stdenv.system}.ptrSize;
|
||||
shellHook = ''
|
||||
conda-install
|
||||
|
||||
# tmpdir is too small in nix
|
||||
export TMPDIR="${installPath}/tmp"
|
||||
|
||||
# Add conda to PATH
|
||||
export PATH="${installPath}/bin:$PATH"
|
||||
|
||||
# Allows `conda activate` to work properly
|
||||
source ${installPath}/etc/profile.d/conda.sh
|
||||
|
||||
# Paths for gcc if compiling some C sources with pip
|
||||
export NIX_CFLAGS_COMPILE="-I${installPath}/include -I$TMPDIR/include"
|
||||
export NIX_CFLAGS_LINK="-L${installPath}/lib $BINDGEN_EXTRA_CLANG_ARGS"
|
||||
|
||||
export PIP_EXISTS_ACTION=w
|
||||
|
||||
# rust-onig fails (think it writes config.h to wrong location)
|
||||
mkdir -p "$TMPDIR/include"
|
||||
cat <<'EOF' > "$TMPDIR/include/config.h"
|
||||
#define HAVE_PROTOTYPES 1
|
||||
#define STDC_HEADERS 1
|
||||
#define HAVE_STRING_H 1
|
||||
#define HAVE_STDARG_H 1
|
||||
#define HAVE_STDLIB_H 1
|
||||
#define HAVE_LIMITS_H 1
|
||||
#define HAVE_INTTYPES_H 1
|
||||
#define SIZEOF_INT 4
|
||||
#define SIZEOF_SHORT 2
|
||||
#define SIZEOF_LONG ${ptrSize}
|
||||
#define SIZEOF_VOIDP ${ptrSize}
|
||||
#define SIZEOF_LONG_LONG 8
|
||||
EOF
|
||||
|
||||
conda env create -f "${envFile}" || conda env update --prune -f "${envFile}"
|
||||
conda activate invokeai
|
||||
'';
|
||||
|
||||
version = "4.12.0";
|
||||
conda = {
|
||||
aarch64-darwin = {
|
||||
shell = conda-shell {
|
||||
inherit shellHook installPath;
|
||||
url = "https://repo.anaconda.com/miniconda/Miniconda3-py39_${version}-MacOSX-arm64.sh";
|
||||
sha256 = "4bd112168cc33f8a4a60d3ef7e72b52a85972d588cd065be803eb21d73b625ef";
|
||||
packages = [ frameworks.Security ] ++ packages;
|
||||
};
|
||||
};
|
||||
x86_64-linux = {
|
||||
shell = conda-shell {
|
||||
inherit shellHook installPath;
|
||||
url = "https://repo.continuum.io/miniconda/Miniconda3-py39_${version}-Linux-x86_64.sh";
|
||||
sha256 = "78f39f9bae971ec1ae7969f0516017f2413f17796670f7040725dd83fcff5689";
|
||||
packages = with pkgs; [ libGL glib ] ++ packages;
|
||||
};
|
||||
};
|
||||
};
|
||||
in conda.${stdenv.system}.shell
|
Loading…
Reference in New Issue
Block a user