Merge branch 'development' into fix-disabled-prompt

This commit is contained in:
Lincoln Stein 2022-10-25 07:13:57 -04:00 committed by GitHub
commit fb4dc7eaf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1244 additions and 200 deletions

BIN
assets/caution.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

View File

@ -1,21 +1,23 @@
# This file describes the alternative machine learning models # This file describes the alternative machine learning models
# available to the dream script. # available to the dream script.
# #
# To add a new model, follow the examples below. Each # To add a new model, follow the examples below. Each
# model requires a model config file, a weights file, # model requires a model config file, a weights file,
# and the width and height of the images it # and the width and height of the images it
# was trained on. # was trained on.
stable-diffusion-1.4: stable-diffusion-1.4:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt weights: models/ldm/stable-diffusion-v1/model.ckpt
vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt # vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: Stable Diffusion inference model version 1.4 description: Stable Diffusion inference model version 1.4
width: 512 default: true
height: 512 width: 512
height: 512
default: true
stable-diffusion-1.5: stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
description: Stable Diffusion inference model version 1.5 # vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
width: 512 description: Stable Diffusion inference model version 1.5
height: 512 width: 512
height: 512

View File

@ -86,6 +86,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" | | `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. | | `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) | | `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| `--safety-checker` | | False | Activate safety checker for NSFW and other potentially disturbing imagery |
| `--web` | | `False` | Start in web server mode | | `--web` | | `False` | Start in web server mode |
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. | | `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
| `--port <port>` | | `9090` | Which port web server should listen for requests on. | | `--port <port>` | | `9090` | Which port web server should listen for requests on. |
@ -97,7 +98,6 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models | | `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. | | `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. | | `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
| `--device <device>` | `-d<device>` | `torch.cuda.current_device()` | Device to run SD on, e.g. "cuda:0" |
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions | | `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast | | `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |

View File

@ -81,15 +81,18 @@ text2mask feature. The syntax is `!mask /path/to/image.png -tm <text>
It will generate three files: It will generate three files:
- The image with the selected area highlighted. - The image with the selected area highlighted.
- it will be named XXXXX.<imagename>.<prompt>.selected.png
- The image with the un-selected area highlighted. - The image with the un-selected area highlighted.
- it will be named XXXXX.<imagename>.<prompt>.deselected.png
- The image with the selected area converted into a black and white - The image with the selected area converted into a black and white
image according to the threshold level. image according to the threshold level
- it will be named XXXXX.<imagename>.<prompt>.masked.png
Note that none of these images are intended to be used as the mask The `.masked.png` file can then be directly passed to the `invoke>`
passed to invoke via `-M` and may give unexpected results if you try prompt in the CLI via the `-M` argument. Do not attempt this with
to use them this way. Instead, use `!mask` for testing that you are the `selected.png` or `deselected.png` files, as they contain some
selecting the right mask area, and then do inpainting using the transparency throughout the image and will not produce the desired
best selection term and threshold. results.
Here is an example of how `!mask` works: Here is an example of how `!mask` works:
@ -120,7 +123,7 @@ It looks like we selected the hair pretty well at the 0.5 threshold
let's have some fun: let's have some fun:
``` ```
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair 0.5 -C20 invoke> medusa with cobras -I ./test-pictures/curly.png -M 000019.curly.hair.masked.png -C20
>> loaded input image of size 512x512 from ./test-pictures/curly.png >> loaded input image of size 512x512 from ./test-pictures/curly.png
... ...
Outputs: Outputs:
@ -129,6 +132,13 @@ Outputs:
<img src="../assets/inpainting/000024.801380492.png"> <img src="../assets/inpainting/000024.801380492.png">
You can also skip the `!mask` creation step and just select the masked
region directly:
```
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair -C20
```
### Inpainting is not changing the masked region enough! ### Inpainting is not changing the masked region enough!
One of the things to understand about how inpainting works is that it One of the things to understand about how inpainting works is that it

View File

@ -19,6 +19,7 @@ dependencies:
# ``` # ```
- albumentations==1.2.1 - albumentations==1.2.1
- coloredlogs==15.0.1 - coloredlogs==15.0.1
- diffusers==0.6.0
- einops==0.4.1 - einops==0.4.1
- grpcio==1.46.4 - grpcio==1.46.4
- humanfriendly==10.0 - humanfriendly==10.0

View File

@ -26,6 +26,7 @@ dependencies:
- pyreadline3 - pyreadline3
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0
- transformers==4.21.3 - transformers==4.21.3
- diffusers==0.6.0
- torchmetrics==0.7.0 - torchmetrics==0.7.0
- flask==2.1.3 - flask==2.1.3
- flask_socketio==5.3.0 - flask_socketio==5.3.0

690
frontend/dist/assets/index.48782019.js vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -26,6 +26,7 @@ export const socketioMiddleware = () => {
const socketio = io(origin, { const socketio = io(origin, {
timeout: 60000, timeout: 60000,
path: window.location.pathname + 'socket.io',
}); });
let areListenersSet = false; let areListenersSet = false;

View File

@ -5,6 +5,7 @@ import eslint from 'vite-plugin-eslint';
// https://vitejs.dev/config/ // https://vitejs.dev/config/
export default defineConfig(({ mode }) => { export default defineConfig(({ mode }) => {
const common = { const common = {
base: '',
plugins: [react(), eslint()], plugins: [react(), eslint()],
server: { server: {
// Proxy HTTP requests to the flask server // Proxy HTTP requests to the flask server

View File

@ -110,12 +110,13 @@ still work.
The full list of arguments to Generate() are: The full list of arguments to Generate() are:
gr = Generate( gr = Generate(
# these values are set once and shouldn't be changed # these values are set once and shouldn't be changed
conf = path to configuration file ('configs/models.yaml') conf:str = path to configuration file ('configs/models.yaml')
model = symbolic name of the model in the configuration file model:str = symbolic name of the model in the configuration file
precision = float precision to be used precision:float = float precision to be used
safety_checker:bool = activate safety checker [False]
# this value is sticky and maintained between generation calls # this value is sticky and maintained between generation calls
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
# these are deprecated - use conf and model instead # these are deprecated - use conf and model instead
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt') weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
@ -132,20 +133,21 @@ class Generate:
def __init__( def __init__(
self, self,
model = None, model = None,
conf = 'configs/models.yaml', conf = 'configs/models.yaml',
embedding_path = None, embedding_path = None,
sampler_name = 'k_lms', sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic ddim_eta = 0.0, # deterministic
full_precision = False, full_precision = False,
precision = 'auto', precision = 'auto',
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
gfpgan=None, gfpgan=None,
codeformer=None, codeformer=None,
esrgan=None, esrgan=None,
free_gpu_mem=False, free_gpu_mem=False,
safety_checker:bool=False,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
): ):
mconfig = OmegaConf.load(conf) mconfig = OmegaConf.load(conf)
self.height = None self.height = None
@ -176,6 +178,7 @@ class Generate:
self.free_gpu_mem = free_gpu_mem self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None self.txt2mask = None
self.safety_checker = None
# Note that in previous versions, there was an option to pass the # Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so # device to Generate(). However the device was then ignored, so
@ -203,6 +206,19 @@ class Generate:
# gets rid of annoying messages about random seed # gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load safety checker if requested
if safety_checker:
try:
print('>> Initializing safety checker')
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())
def prompt2png(self, prompt, outdir, **kwargs): def prompt2png(self, prompt, outdir, **kwargs):
""" """
Takes a prompt and an output directory, writes out the requested number Takes a prompt and an output directory, writes out the requested number
@ -271,6 +287,8 @@ class Generate:
upscale = None, upscale = None,
# this is specific to inpainting and causes more extreme inpainting # this is specific to inpainting and causes more extreme inpainting
inpaint_replace = 0.0, inpaint_replace = 0.0,
# This will help match inpainted areas to the original image more smoothly
mask_blur_radius: int = 8,
# Set this True to handle KeyboardInterrupt internally # Set this True to handle KeyboardInterrupt internally
catch_interrupts = False, catch_interrupts = False,
hires_fix = False, hires_fix = False,
@ -391,7 +409,7 @@ class Generate:
log_tokens =self.log_tokenization log_tokens =self.log_tokenization
) )
init_image,mask_image = self._make_images( init_image, mask_image = self._make_images(
init_img, init_img,
init_mask, init_mask,
width, width,
@ -416,6 +434,11 @@ class Generate:
self.seed, variation_amount, with_variations self.seed, variation_amount, with_variations
) )
checker = {
'checker':self.safety_checker,
'extractor':self.safety_feature_extractor
} if self.safety_checker else None
results = generator.generate( results = generator.generate(
prompt, prompt,
iterations=iterations, iterations=iterations,
@ -426,10 +449,10 @@ class Generate:
conditioning=(uc, c), conditioning=(uc, c),
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated step_callback=step_callback, # called after each intermediate image is generated
width=width, width=width,
height=height, height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image, mask_image=mask_image,
strength=strength, strength=strength,
@ -438,6 +461,8 @@ class Generate:
embiggen=embiggen, embiggen=embiggen,
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace, inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius,
safety_checker=checker
) )
if init_color: if init_color:
@ -631,23 +656,22 @@ class Generate:
# if image has a transparent area and no mask was provided, then try to generate mask # if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image): if self._has_transparency(image):
self._transparency_check_and_warning(image, mask) self._transparency_check_and_warning(image, mask)
# this returns a torch tensor
init_mask = self._create_init_mask(image, width, height, fit=fit) init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height) and self.size_matters: if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False self.size_matters = False
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor init_image = self._create_init_image(image,width,height,fit=fit)
if mask: if mask:
mask_image = self._load_img(mask) # this returns an Image mask_image = self._load_img(mask)
init_mask = self._create_init_mask(mask_image,width,height,fit=fit) init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
elif text_mask: elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
return init_image, init_mask return init_image,init_mask
def _make_base(self): def _make_base(self):
if not self.generators.get('base'): if not self.generators.get('base'):
@ -864,46 +888,31 @@ class Generate:
def _create_init_image(self, image, width, height, fit=True): def _create_init_image(self, image, width, height, fit=True):
image = image.convert('RGB') image = image.convert('RGB')
if fit: image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
image = self._fit_image(image, (width, height)) return image
else:
image = self._squeeze_image(image)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
return image.to(self.device)
def _create_init_mask(self, image, width, height, fit=True): def _create_init_mask(self, image, width, height, fit=True):
# convert into a black/white mask # convert into a black/white mask
image = self._image_to_mask(image) image = self._image_to_mask(image)
image = image.convert('RGB') image = image.convert('RGB')
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
# now we adjust the size return image
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.NEAREST)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image.to(self.device)
# The mask is expected to have the region to be inpainted # The mask is expected to have the region to be inpainted
# with alpha transparency. It converts it into a black/white # with alpha transparency. It converts it into a black/white
# image with the transparent part black. # image with the transparent part black.
def _image_to_mask(self, mask_image, invert=False) -> Image: def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
# Obtain the mask from the transparency channel # Obtain the mask from the transparency channel
mask = Image.new(mode="L", size=mask_image.size, color=255) if mask_image.mode == 'L':
mask.putdata(mask_image.getdata(band=3)) mask = mask_image
else:
# Obtain the mask from the transparency channel
mask = Image.new(mode="L", size=mask_image.size, color=255)
mask.putdata(mask_image.getdata(band=3))
if invert: if invert:
mask = ImageOps.invert(mask) mask = ImageOps.invert(mask)
return mask return mask
# TODO: The latter part of this method repeats code from _create_init_mask()
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image: def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
prompt = text_mask[0] prompt = text_mask[0]
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5 confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
@ -913,18 +922,8 @@ class Generate:
segmented = self.txt2mask.segment(image, prompt) segmented = self.txt2mask.segment(image, prompt)
mask = segmented.to_mask(float(confidence_level)) mask = segmented.to_mask(float(confidence_level))
mask = mask.convert('RGB') mask = mask.convert('RGB')
# now we adjust the size mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
if fit: return mask
mask = self._fit_image(mask, (width, height))
else:
mask = self._squeeze_image(mask)
mask = mask.resize((mask.width//downsampling, mask.height //
downsampling), resample=Image.Resampling.NEAREST)
mask = np.array(mask)
mask = mask.astype(np.float32) / 255.0
mask = mask[None].transpose(0, 3, 1, 2)
mask = torch.from_numpy(mask)
return mask.to(self.device)
def _has_transparency(self, image): def _has_transparency(self, image):
if image.info.get("transparency", None) is not None: if image.info.get("transparency", None) is not None:

View File

@ -113,8 +113,8 @@ PRECISION_CHOICES = [
] ]
# is there a way to pick this up during git commits? # is there a way to pick this up during git commits?
APP_ID = 'lstein/stable-diffusion' APP_ID = 'invoke-ai/InvokeAI'
APP_VERSION = 'v1.15' APP_VERSION = 'v2.02'
class ArgFormatter(argparse.RawTextHelpFormatter): class ArgFormatter(argparse.RawTextHelpFormatter):
# use defined argument order to display usage # use defined argument order to display usage
@ -172,6 +172,7 @@ class Args(object):
command = cmd_string.replace("'", "\\'") command = cmd_string.replace("'", "\\'")
try: try:
elements = shlex.split(command) elements = shlex.split(command)
elements = [x.replace("\\'","'") for x in elements]
except ValueError: except ValueError:
import sys, traceback import sys, traceback
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
@ -418,6 +419,11 @@ class Args(object):
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}', help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto', default='auto',
) )
model_group.add_argument(
'--safety_checker',
action='store_true',
help='Check for and blur potentially NSFW images',
)
file_group.add_argument( file_group.add_argument(
'--from_file', '--from_file',
dest='infile', dest='infile',
@ -846,7 +852,7 @@ def metadata_dumps(opt,
# remove any image keys not mentioned in RFC #266 # remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps', rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength', 'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
'init_img','init_mask'] 'init_img','init_mask','facetool','facetool_strength','upscale']
rfc_dict ={} rfc_dict ={}
@ -930,7 +936,7 @@ def metadata_loads(metadata) -> list:
for image in images: for image in images:
# repack the prompt and variations # repack the prompt and variations
if 'prompt' in image: if 'prompt' in image:
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']]) image['prompt'] = repack_prompt(image['prompt'])
if 'variations' in image: if 'variations' in image:
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']]) image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
# fix a bit of semantic drift here # fix a bit of semantic drift here
@ -938,12 +944,19 @@ def metadata_loads(metadata) -> list:
opt = Args() opt = Args()
opt._cmd_switches = Namespace(**image) opt._cmd_switches = Namespace(**image)
results.append(opt) results.append(opt)
except KeyError as e: except Exception as e:
import sys, traceback import sys, traceback
print('>> badly-formatted metadata',file=sys.stderr) print('>> could not read metadata',file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
return results return results
def repack_prompt(prompt_list:list)->str:
# in the common case of no weighting syntax, just return the prompt as is
if len(prompt_list) > 1:
return ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in prompt_list])
else:
return prompt_list[0]['prompt']
# image can either be a file path on disk or a base64-encoded # image can either be a file path on disk or a base64-encoded
# representation of the file's contents # representation of the file's contents
def calculate_init_img_hash(image_string): def calculate_init_img_hash(image_string):

View File

@ -7,25 +7,27 @@ import numpy as np
import random import random
import os import os
from tqdm import tqdm, trange from tqdm import tqdm, trange
from PIL import Image from PIL import Image, ImageFilter
from einops import rearrange, repeat from einops import rearrange, repeat
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.util import rand_perlin_2d from ldm.util import rand_perlin_2d
downsampling = 8 downsampling = 8
CAUTION_IMG = 'assets/caution.png'
class Generator(): class Generator():
def __init__(self, model, precision): def __init__(self, model, precision):
self.model = model self.model = model
self.precision = precision self.precision = precision
self.seed = None self.seed = None
self.latent_channels = model.channels self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config self.downsampling_factor = downsampling # BUG: should come from model or config
self.perlin = 0.0 self.safety_checker = None
self.threshold = 0 self.perlin = 0.0
self.variation_amount = 0 self.threshold = 0
self.with_variations = [] self.variation_amount = 0
self.with_variations = []
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py # this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs): def get_make_image(self,prompt,**kwargs):
@ -42,8 +44,10 @@ class Generator():
def generate(self,prompt,init_image,width,height,iterations=1,seed=None, def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
init_image = init_image, init_image = init_image,
@ -77,10 +81,17 @@ class Generator():
pass pass
image = make_image(x_T) image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
image_callback(image, seed, first_seed=first_seed) image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed() seed = self.new_seed()
return results return results
def sample_to_image(self,samples): def sample_to_image(self,samples):
@ -169,6 +180,39 @@ class Generator():
return v2 return v2
def safety_check(self,image:Image.Image):
'''
If the CompViz safety checker flags an NSFW image, we
blur it out.
'''
import diffusers
checker = self.safety_checker['checker']
extractor = self.safety_checker['extractor']
features = extractor([image], return_tensors="pt")
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
if has_nsfw_concept[0]:
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
return self.blur(image)
else:
return image
def blur(self,input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = Image.open(CAUTION_IMG)
caution = caution.resize((caution.width // 2, caution.height //2))
blurry.paste(caution,(0,0),caution)
except FileNotFoundError:
pass
return blurry
# this is a handy routine for debugging use. Given a generated sample, # this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path # convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath): def save_sample(self, sample, filepath):

View File

@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
import torch import torch
import numpy as np import numpy as np
from ldm.invoke.devices import choose_autocast import PIL
from ldm.invoke.generator.base import Generator from torch import Tensor
from ldm.models.diffusion.ddim import DDIMSampler from PIL import Image
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
class Img2Img(Generator): class Img2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
@ -25,6 +28,9 @@ class Img2Img(Generator):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
) )
if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image)
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
with scope(self.model.device.type): with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding( self.init_latent = self.model.get_first_stage_encoding(
@ -68,3 +74,11 @@ class Img2Img(Generator):
shape = init_latent.shape shape = init_latent.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x return x
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0
return image.to(self.model.device)

View File

@ -3,27 +3,55 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator
''' '''
import torch import torch
import torchvision.transforms as T
import numpy as np import numpy as np
import cv2 as cv
import PIL
from PIL import Image, ImageFilter
from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.img2img import Img2Img from ldm.invoke.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
from ldm.invoke.generator.base import downsampling
class Inpaint(Img2Img): class Inpaint(Img2Img):
def __init__(self, model, precision): def __init__(self, model, precision):
self.init_latent = None self.init_latent = None
self.pil_image = None
self.pil_mask = None
self.mask_blur_radius = 0
super().__init__(model, precision) super().__init__(model, precision)
@torch.no_grad() @torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength, conditioning,init_image,mask_image,strength,
step_callback=None,inpaint_replace=False,**kwargs): mask_blur_radius: int = 8,
step_callback=None,inpaint_replace=False, **kwargs):
""" """
Returns a function returning an image derived from the prompt and Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength' the time you call it. kwargs are 'init_latent' and 'strength'
""" """
if isinstance(init_image, PIL.Image.Image):
self.pil_image = init_image
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image
mask_image = mask_image.resize(
(
mask_image.width // downsampling,
mask_image.height // downsampling
),
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler # klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler): if isinstance(sampler,KSampler):
print( print(
@ -77,10 +105,50 @@ class Inpaint(Img2Img):
mask = mask_image, mask = mask_image,
init_latent = self.init_latent init_latent = self.init_latent
) )
return self.sample_to_image(samples) return self.sample_to_image(samples)
return make_image return make_image
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
pil_mask = self.pil_mask
pil_image = self.pil_image
mask_blur_radius = self.mask_blur_radius
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the
# histogram and cause slight color changes.
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
# Get numpy version
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
# Color correct
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8)
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
pmd = Image.fromarray(nmd, mode='L')
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
return matched_result

View File

@ -38,7 +38,7 @@ class PngWriter:
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt) info.add_text('Dream', dream_prompt)
if metadata: if metadata:
info.add_text('sd-metadata', json.dumps(metadata)) info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level) image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
return path return path

View File

@ -1,5 +1,6 @@
albumentations==0.4.3 albumentations==0.4.3
einops==0.3.0 einops==0.3.0
diffusers==0.6.0
huggingface-hub==0.8.1 huggingface-hub==0.8.1
imageio==2.9.0 imageio==2.9.0
imageio-ffmpeg==0.4.2 imageio-ffmpeg==0.4.2

View File

@ -32,7 +32,8 @@ send2trash
dependency_injector==4.40.0 dependency_injector==4.40.0
eventlet eventlet
realesrgan realesrgan
diffusers
git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg

View File

@ -69,16 +69,17 @@ def main():
# creating a Generate object: # creating a Generate object:
try: try:
gen = Generate( gen = Generate(
conf = opt.conf, conf = opt.conf,
model = opt.model, model = opt.model,
sampler_name = opt.sampler_name, sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path, embedding_path = opt.embedding_path,
full_precision = opt.full_precision, full_precision = opt.full_precision,
precision = opt.precision, precision = opt.precision,
gfpgan=gfpgan, gfpgan=gfpgan,
codeformer=codeformer, codeformer=codeformer,
esrgan=esrgan, esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem, free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker,
) )
except (FileNotFoundError, IOError, KeyError) as e: except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.') print(f'{e}. Aborting.')
@ -799,26 +800,38 @@ def retrieve_dream_command(opt,command,completer):
will retrieve and format the dream command used to generate the image, will retrieve and format the dream command used to generate the image,
and pop it into the readline buffer (linux, Mac), or print out a comment and pop it into the readline buffer (linux, Mac), or print out a comment
for cut-and-paste (windows) for cut-and-paste (windows)
Given a wildcard path to a folder with image png files, Given a wildcard path to a folder with image png files,
will retrieve and format the dream command used to generate the images, will retrieve and format the dream command used to generate the images,
and save them to a file commands.txt for further processing and save them to a file commands.txt for further processing
''' '''
if len(command) == 0: if len(command) == 0:
return return
tokens = command.split() tokens = command.split()
if len(tokens) > 1: dir,basename = os.path.split(tokens[0])
outfilepath = tokens[1]
else:
outfilepath = "commands.txt"
file_path = tokens[0]
dir,basename = os.path.split(file_path)
if len(dir) == 0: if len(dir) == 0:
dir = opt.outdir path = os.path.join(opt.outdir,basename)
else:
outdir,outname = os.path.split(outfilepath) path = tokens[0]
if len(outdir) == 0:
outfilepath = os.path.join(dir,outname) if len(tokens) > 1:
return write_commands(opt, path, tokens[1])
cmd = ''
try:
cmd = dream_cmd_from_png(path)
except OSError:
print(f'## {tokens[0]}: file could not be read')
except (KeyError, AttributeError, IndexError):
print(f'## {tokens[0]}: file has no metadata')
except:
print(f'## {tokens[0]}: file could not be processed')
if len(cmd)>0:
completer.set_line(cmd)
def write_commands(opt, file_path:str, outfilepath:str):
dir,basename = os.path.split(file_path)
try: try:
paths = list(Path(dir).glob(basename)) paths = list(Path(dir).glob(basename))
except ValueError: except ValueError:
@ -826,28 +839,24 @@ def retrieve_dream_command(opt,command,completer):
return return
commands = [] commands = []
cmd = None
for path in paths: for path in paths:
try: try:
cmd = dream_cmd_from_png(path) cmd = dream_cmd_from_png(path)
except OSError:
print(f'## {path}: file could not be read')
continue
except (KeyError, AttributeError, IndexError): except (KeyError, AttributeError, IndexError):
print(f'## {path}: file has no metadata') print(f'## {path}: file has no metadata')
continue
except: except:
print(f'## {path}: file could not be processed') print(f'## {path}: file could not be processed')
continue if cmd:
commands.append(f'# {path}')
commands.append(f'# {path}') commands.append(cmd)
commands.append(cmd) if len(commands)>0:
dir,basename = os.path.split(outfilepath)
with open(outfilepath, 'w', encoding='utf-8') as f: if len(dir)==0:
f.write('\n'.join(commands)) outfilepath = os.path.join(opt.outdir,basename)
print(f'>> File {outfilepath} with commands created') with open(outfilepath, 'w', encoding='utf-8') as f:
f.write('\n'.join(commands))
if len(commands) == 2: print(f'>> File {outfilepath} with commands created')
completer.set_line(commands[1])
###################################### ######################################

View File

@ -5,7 +5,7 @@
# two machines must share a common .cache directory. # two machines must share a common .cache directory.
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
import clip import clip
from transformers import BertTokenizerFast from transformers import BertTokenizerFast, AutoFeatureExtractor
import sys import sys
import transformers import transformers
import os import os
@ -17,41 +17,39 @@ import traceback
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
#---------------------------------------------
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
print('Loading bert tokenizer (ignore deprecation errors)...', end='') def download_bert():
with warnings.catch_warnings(): print('Installing bert tokenizer (ignore deprecation errors)...', end='')
warnings.filterwarnings('ignore', category=DeprecationWarning) with warnings.catch_warnings():
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') warnings.filterwarnings('ignore', category=DeprecationWarning)
print('...success') tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
sys.stdout.flush() print('...success')
sys.stdout.flush()
#---------------------------------------------
# this will download requirements for Kornia # this will download requirements for Kornia
print('Loading Kornia requirements...', end='') def download_kornia():
with warnings.catch_warnings(): print('Installing Kornia requirements...', end='')
warnings.filterwarnings('ignore', category=DeprecationWarning) with warnings.catch_warnings():
import kornia warnings.filterwarnings('ignore', category=DeprecationWarning)
print('...success') import kornia
print('...success')
version = 'openai/clip-vit-large-patch14' #---------------------------------------------
sys.stdout.flush() def download_clip():
print('Loading CLIP model...',end='') version = 'openai/clip-vit-large-patch14'
tokenizer = CLIPTokenizer.from_pretrained(version) sys.stdout.flush()
transformer = CLIPTextModel.from_pretrained(version) print('Loading CLIP model...',end='')
print('...success') tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('...success')
# In the event that the user has installed GFPGAN and also elected to use #---------------------------------------------
# RealESRGAN, this will attempt to download the model needed by RealESRGANer def download_gfpgan():
gfpgan = False print('Installing models from RealESRGAN and facexlib...',end='')
try:
from realesrgan import RealESRGANer
gfpgan = True
except ModuleNotFoundError:
pass
if gfpgan:
print('Loading models from RealESRGAN and facexlib...',end='')
try: try:
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@ -94,44 +92,72 @@ if gfpgan:
print('Error loading GFPGAN:') print('Error loading GFPGAN:')
print(traceback.format_exc()) print(traceback.format_exc())
print('preloading CodeFormer model file...',end='') #---------------------------------------------
try: def download_codeformer():
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' print('Installing CodeFormer model file...',end='')
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth' try:
if not os.path.exists(model_dest): model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
print('Downloading codeformer model file...') model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
if not os.path.exists(model_dest):
print('Downloading codeformer model file...')
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
except Exception:
print('Error loading CodeFormer:')
print(traceback.format_exc())
print('...success')
#---------------------------------------------
def download_clipseg():
print('Installing clipseg model for text-based masking...',end='')
try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip'
weights_dir = 'src/clipseg/weights'
if not os.path.exists(weights_dir):
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest) urllib.request.urlretrieve(model_url,model_dest)
except Exception: with zipfile.ZipFile(model_dest,'r') as zip:
print('Error loading CodeFormer:') zip.extractall('src/clipseg')
print(traceback.format_exc()) os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
print('...success') os.remove(model_dest)
from clipseg_models.clipseg import CLIPDensePredT
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval()
model.load_state_dict(
torch.load(
'src/clipseg/weights/rd64-uni-refined.pth',
map_location=torch.device('cpu')
),
strict=False,
)
except Exception:
print('Error installing clipseg model:')
print(traceback.format_exc())
print('...success')
print('Loading clipseg model for text-based masking...',end='') #-------------------------------------
try: def download_safety_checker():
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download' print('Installing safety model for NSFW content detection...',end='')
model_dest = 'src/clipseg/clipseg_weights.zip' try:
weights_dir = 'src/clipseg/weights' from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
if not os.path.exists(weights_dir): except ModuleNotFoundError:
os.makedirs(os.path.dirname(model_dest), exist_ok=True) print('Error installing safety checker model:')
urllib.request.urlretrieve(model_url,model_dest) print(traceback.format_exc())
with zipfile.ZipFile(model_dest,'r') as zip: return
zip.extractall('src/clipseg') safety_model_id = "CompVis/stable-diffusion-safety-checker"
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights') safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
os.remove(model_dest) safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
from clipseg_models.clipseg import CLIPDensePredT print('...success')
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval() #-------------------------------------
model.load_state_dict( if __name__ == '__main__':
torch.load( download_bert()
'src/clipseg/weights/rd64-uni-refined.pth', download_kornia()
map_location=torch.device('cpu') download_clip()
), download_gfpgan()
strict=False, download_codeformer()
) download_clipseg()
except Exception: download_safety_checker()
print('Error installing clipseg model:')
print(traceback.format_exc())
print('...success')

162
shell.nix Normal file
View File

@ -0,0 +1,162 @@
{ pkgs ? import <nixpkgs> {}
, lib ? pkgs.lib
, stdenv ? pkgs.stdenv
, fetchurl ? pkgs.fetchurl
, runCommand ? pkgs.runCommand
, makeWrapper ? pkgs.makeWrapper
, mkShell ? pkgs.mkShell
, buildFHSUserEnv ? pkgs.buildFHSUserEnv
, frameworks ? pkgs.darwin.apple_sdk.frameworks
}:
# Setup InvokeAI environment using nix
# Simple usage:
# nix-shell
# python3 scripts/preload_models.py
# python3 scripts/invoke.py -h
let
conda-shell = { url, sha256, installPath, packages, shellHook }:
let
src = fetchurl { inherit url sha256; };
libPath = lib.makeLibraryPath ([] ++ lib.optionals (stdenv.isLinux) [ pkgs.zlib ]);
condaArch = if stdenv.system == "aarch64-darwin" then "osx-arm64" else "";
installer =
if stdenv.isDarwin then
runCommand "conda-install" {
nativeBuildInputs = [ makeWrapper ];
} ''
mkdir -p $out/bin
cp ${src} $out/bin/miniconda-installer.sh
chmod +x $out/bin/miniconda-installer.sh
makeWrapper \
$out/bin/miniconda-installer.sh \
$out/bin/conda-install \
--add-flags "-p ${installPath}" \
--add-flags "-b"
''
else if stdenv.isLinux then
runCommand "conda-install" {
nativeBuildInputs = [ makeWrapper ];
buildInputs = [ pkgs.zlib ];
}
# on line 10, we have 'unset LD_LIBRARY_PATH'
# we have to comment it out however in a way that the number of bytes in the
# file does not change. So we replace the 'u' in the line with a '#'
# The reason is that the binary payload is encoded as number
# of bytes from the top of the installer script
# and unsetting the library path prevents the zlib library from being discovered
''
mkdir -p $out/bin
sed 's/unset LD_LIBRARY_PATH/#nset LD_LIBRARY_PATH/' ${src} > $out/bin/miniconda-installer.sh
chmod +x $out/bin/miniconda-installer.sh
makeWrapper \
$out/bin/miniconda-installer.sh \
$out/bin/conda-install \
--add-flags "-p ${installPath}" \
--add-flags "-b" \
--prefix "LD_LIBRARY_PATH" : "${libPath}"
''
else {};
hook = ''
export CONDA_SUBDIR=${condaArch}
'' + shellHook;
fhs = buildFHSUserEnv {
name = "conda-shell";
targetPkgs = pkgs: [ stdenv.cc pkgs.git installer ] ++ packages;
profile = hook;
runScript = "bash";
};
shell = mkShell {
shellHook = if stdenv.isDarwin then hook else "conda-shell; exit";
packages = if stdenv.isDarwin then [ pkgs.git installer ] ++ packages else [ fhs ];
};
in shell;
packages = with pkgs; [
cmake
protobuf
libiconv
rustc
cargo
rustPlatform.bindgenHook
];
env = {
aarch64-darwin = {
envFile = "environment-mac.yml";
condaPath = (builtins.toString ./.) + "/.conda";
ptrSize = "8";
};
x86_64-linux = {
envFile = "environment.yml";
condaPath = (builtins.toString ./.) + "/.conda";
ptrSize = "8";
};
};
envFile = env.${stdenv.system}.envFile;
installPath = env.${stdenv.system}.condaPath;
ptrSize = env.${stdenv.system}.ptrSize;
shellHook = ''
conda-install
# tmpdir is too small in nix
export TMPDIR="${installPath}/tmp"
# Add conda to PATH
export PATH="${installPath}/bin:$PATH"
# Allows `conda activate` to work properly
source ${installPath}/etc/profile.d/conda.sh
# Paths for gcc if compiling some C sources with pip
export NIX_CFLAGS_COMPILE="-I${installPath}/include -I$TMPDIR/include"
export NIX_CFLAGS_LINK="-L${installPath}/lib $BINDGEN_EXTRA_CLANG_ARGS"
export PIP_EXISTS_ACTION=w
# rust-onig fails (think it writes config.h to wrong location)
mkdir -p "$TMPDIR/include"
cat <<'EOF' > "$TMPDIR/include/config.h"
#define HAVE_PROTOTYPES 1
#define STDC_HEADERS 1
#define HAVE_STRING_H 1
#define HAVE_STDARG_H 1
#define HAVE_STDLIB_H 1
#define HAVE_LIMITS_H 1
#define HAVE_INTTYPES_H 1
#define SIZEOF_INT 4
#define SIZEOF_SHORT 2
#define SIZEOF_LONG ${ptrSize}
#define SIZEOF_VOIDP ${ptrSize}
#define SIZEOF_LONG_LONG 8
EOF
conda env create -f "${envFile}" || conda env update --prune -f "${envFile}"
conda activate invokeai
'';
version = "4.12.0";
conda = {
aarch64-darwin = {
shell = conda-shell {
inherit shellHook installPath;
url = "https://repo.anaconda.com/miniconda/Miniconda3-py39_${version}-MacOSX-arm64.sh";
sha256 = "4bd112168cc33f8a4a60d3ef7e72b52a85972d588cd065be803eb21d73b625ef";
packages = [ frameworks.Security ] ++ packages;
};
};
x86_64-linux = {
shell = conda-shell {
inherit shellHook installPath;
url = "https://repo.continuum.io/miniconda/Miniconda3-py39_${version}-Linux-x86_64.sh";
sha256 = "78f39f9bae971ec1ae7969f0516017f2413f17796670f7040725dd83fcff5689";
packages = with pkgs; [ libGL glib ] ++ packages;
};
};
};
in conda.${stdenv.system}.shell