Merge branch 'development' into vite-relative-paths

This commit is contained in:
Lincoln Stein 2022-10-25 07:09:14 -04:00 committed by GitHub
commit 71a1e0d0e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 2162 additions and 1171 deletions

BIN
assets/caution.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

View File

@ -1,20 +1,23 @@
# This file describes the alternative machine learning models
# available to the dream script.
# available to the dream script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
laion400m:
config: configs/latent-diffusion/txt2img-1p4B-eval.yaml
weights: models/ldm/text2img-large/model.ckpt
description: Latent Diffusion LAION400M model
width: 256
height: 256
stable-diffusion-1.4:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt
description: Stable Diffusion inference model version 1.4
width: 512
height: 512
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/model.ckpt
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: Stable Diffusion inference model version 1.4
default: true
width: 512
height: 512
default: true
stable-diffusion-1.5:
config: configs/stable-diffusion/v1-inference.yaml
weights: models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt
# vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt
description: Stable Diffusion inference model version 1.5
width: 512
height: 512

View File

@ -8,7 +8,7 @@ hide:
## **Interactive Command Line Interface**
The `invoke.py` script, located in `scripts/dream.py`, provides an interactive
The `invoke.py` script, located in `scripts/`, provides an interactive
interface to image generation similar to the "invoke mothership" bot that Stable
AI provided on its Discord server.
@ -86,6 +86,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| `--safety-checker` | | False | Activate safety checker for NSFW and other potentially disturbing imagery |
| `--web` | | `False` | Start in web server mode |
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
| `--port <port>` | | `9090` | Which port web server should listen for requests on. |
@ -97,7 +98,6 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
| `--device <device>` | `-d<device>` | `torch.cuda.current_device()` | Device to run SD on, e.g. "cuda:0" |
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |
@ -283,12 +283,20 @@ Some examples:
Outputs:
[1] outputs/img-samples/000017.4829112.gfpgan-00.png: !fix "outputs/img-samples/0000045.4829112.png" -s 50 -S -W 512 -H 512 -C 7.5 -A k_lms -G 0.8
# Model selection and importation
### !mask
This command takes an image, a text prompt, and uses the `clipseg`
algorithm to automatically generate a mask of the area that matches
the text prompt. It is useful for debugging the text masking process
prior to inpainting with the `--text_mask` argument. See
[INPAINTING.md] for details.
## Model selection and importation
The CLI allows you to add new models on the fly, as well as to switch
among them rapidly without leaving the script.
## !models
### !models
This prints out a list of the models defined in `config/models.yaml'.
The active model is bold-faced
@ -300,7 +308,7 @@ laion400m not loaded <no description>
waifu-diffusion not loaded Waifu Diffusion v1.3
</pre>
## !switch <model>
### !switch <model>
This quickly switches from one model to another without leaving the
CLI script. `invoke.py` uses a memory caching system; once a model
@ -346,7 +354,7 @@ laion400m not loaded <no description>
waifu-diffusion cached Waifu Diffusion v1.3
</pre>
## !import_model <path/to/model/weights>
### !import_model <path/to/model/weights>
This command imports a new model weights file into InvokeAI, makes it
available for image generation within the script, and writes out the
@ -398,7 +406,7 @@ OK to import [n]? <b>y</b>
invoke>
</pre>
##!edit_model <name_of_model>
###!edit_model <name_of_model>
The `!edit_model` command can be used to modify a model that is
already defined in `config/models.yaml`. Call it with the short
@ -434,20 +442,12 @@ OK to import [n]? y
Outputs:
[2] outputs/img-samples/000018.2273800735.embiggen-00.png: !fix "outputs/img-samples/000017.243781548.gfpgan-00.png" -s 50 -S 2273800735 -W 512 -H 512 -C 7.5 -A k_lms --embiggen 3.0 0.75 0.25
```
# History processing
## History processing
The CLI provides a series of convenient commands for reviewing previous
actions, retrieving them, modifying them, and re-running them.
```bash
invoke> !fetch 0000015.8929913.png
# the script returns the next line, ready for editing and running:
invoke> a fantastic alien landscape -W 576 -H 512 -s 60 -A plms -C 7.5
```
Note that this command may behave unexpectedly if given a PNG file that
was not generated by InvokeAI.
### `!history`
### !history
The invoke script keeps track of all the commands you issue during a
session, allowing you to re-run them. On Mac and Linux systems, it
@ -472,20 +472,41 @@ invoke> !20
invoke> watercolor of beautiful woman sitting under tree wearing broad hat and flowing garment -v0.2 -n6 -S2878767194
```
## !fetch
### !fetch
This command retrieves the generation parameters from a previously
generated image and either loads them into the command line. You may
provide either the name of a file in the current output directory, or
a full file path.
generated image and either loads them into the command line
(Linux|Mac), or prints them out in a comment for copy-and-paste
(Windows). You may provide either the name of a file in the current
output directory, or a full file path. Specify path to a folder with
image png files, and wildcard *.png to retrieve the dream command used
to generate the images, and save them to a file commands.txt for
further processing.
~~~
This example loads the generation command for a single png file:
```bash
invoke> !fetch 0000015.8929913.png
# the script returns the next line, ready for editing and running:
invoke> a fantastic alien landscape -W 576 -H 512 -s 60 -A plms -C 7.5
```
This one fetches the generation commands from a batch of files and
stores them into `selected.txt`:
```bash
invoke> !fetch outputs\selected-imgs\*.png selected.txt
```
### !replay
This command replays a text file generated by !fetch or created manually
~~~
invoke> !replay outputs\selected-imgs\selected.txt
~~~
Note that this command may behave unexpectedly if given a PNG file that
Note that these commands may behave unexpectedly if given a PNG file that
was not generated by InvokeAI.
### !search <search string>
@ -503,16 +524,6 @@ invoke> !search surreal
This clears the search history from memory and disk. Be advised that
this operation is irreversible and does not issue any warnings!
Other ! Commands
### !mask
This command takes an image, a text prompt, and uses the `clipseg`
algorithm to automatically generate a mask of the area that matches
the text prompt. It is useful for debugging the text masking process
prior to inpainting with the `--text_mask` argument. See
[INPAINTING.md] for details.
## Command-line editing and completion
The command-line offers convenient history tracking, editing, and

View File

@ -81,15 +81,18 @@ text2mask feature. The syntax is `!mask /path/to/image.png -tm <text>
It will generate three files:
- The image with the selected area highlighted.
- it will be named XXXXX.<imagename>.<prompt>.selected.png
- The image with the un-selected area highlighted.
- it will be named XXXXX.<imagename>.<prompt>.deselected.png
- The image with the selected area converted into a black and white
image according to the threshold level.
image according to the threshold level
- it will be named XXXXX.<imagename>.<prompt>.masked.png
Note that none of these images are intended to be used as the mask
passed to invoke via `-M` and may give unexpected results if you try
to use them this way. Instead, use `!mask` for testing that you are
selecting the right mask area, and then do inpainting using the
best selection term and threshold.
The `.masked.png` file can then be directly passed to the `invoke>`
prompt in the CLI via the `-M` argument. Do not attempt this with
the `selected.png` or `deselected.png` files, as they contain some
transparency throughout the image and will not produce the desired
results.
Here is an example of how `!mask` works:
@ -120,7 +123,7 @@ It looks like we selected the hair pretty well at the 0.5 threshold
let's have some fun:
```
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair 0.5 -C20
invoke> medusa with cobras -I ./test-pictures/curly.png -M 000019.curly.hair.masked.png -C20
>> loaded input image of size 512x512 from ./test-pictures/curly.png
...
Outputs:
@ -129,6 +132,13 @@ Outputs:
<img src="../assets/inpainting/000024.801380492.png">
You can also skip the `!mask` creation step and just select the masked
region directly:
```
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair -C20
```
### Inpainting is not changing the masked region enough!
One of the things to understand about how inpainting works is that it

View File

@ -19,6 +19,7 @@ dependencies:
# ```
- albumentations==1.2.1
- coloredlogs==15.0.1
- diffusers==0.6.0
- einops==0.4.1
- grpcio==1.46.4
- humanfriendly==10.0

View File

@ -26,6 +26,7 @@ dependencies:
- pyreadline3
- torch-fidelity==0.3.0
- transformers==4.21.3
- diffusers==0.6.0
- torchmetrics==0.7.0
- flask==2.1.3
- flask_socketio==5.3.0

690
frontend/dist/assets/index.2d646c45.js vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -14,6 +14,7 @@
"@chakra-ui/react": "^2.3.1",
"@emotion/react": "^11.10.4",
"@emotion/styled": "^11.10.4",
"@radix-ui/react-context-menu": "^2.0.1",
"@reduxjs/toolkit": "^1.8.5",
"@types/uuid": "^8.3.4",
"dateformat": "^5.0.3",
@ -25,7 +26,6 @@
"react-dropzone": "^14.2.2",
"react-hotkeys-hook": "^3.4.7",
"react-icons": "^4.4.0",
"react-masonry-css": "^1.0.16",
"react-redux": "^8.0.2",
"redux-persist": "^6.0.0",
"socket.io": "^4.5.2",

View File

@ -26,7 +26,7 @@ const makeSocketIOEmitters = (
const options = { ...getState().options };
if (tabMap[options.activeTab] === 'txt2img') {
if (tabMap[options.activeTab] !== 'img2img') {
options.shouldUseInitImage = false;
}

View File

@ -7,12 +7,17 @@ export const PostProcessingWIP = () => {
<p>
Invoke AI offers a wide variety of post processing features. Image
Upscaling and Face Restoration are already available in the WebUI. You
can access them from the Advanced Options menu of the Text To Image tab.
A dedicated UI will be released soon.
can access them from the Advanced Options menu of the Text To Image and
Image To Image tabs. You can also process images directly, using the
image action buttons above the main image display.
</p>
<p>
A dedicated UI will be released soon to facilitate more advanced post
processing workflows.
</p>
<p>
The Invoke AI Command Line Interface offers various other features
including Embiggen, High Resolution Fixing and more.
including Embiggen.
</p>
</div>
);

View File

@ -12,6 +12,7 @@ import {
FormControl,
FormLabel,
Flex,
useToast,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import {
@ -57,6 +58,7 @@ const DeleteImageModal = forwardRef(
const dispatch = useAppDispatch();
const shouldConfirmOnDelete = useAppSelector(systemSelector);
const cancelRef = useRef<HTMLButtonElement>(null);
const toast = useToast();
const handleClickDelete = (e: SyntheticEvent) => {
e.stopPropagation();
@ -65,6 +67,12 @@ const DeleteImageModal = forwardRef(
const handleDelete = () => {
dispatch(deleteImage(image));
toast({
title: 'Image Deleted',
status: 'success',
duration: 2500,
isClosable: true,
});
onClose();
};

View File

@ -17,6 +17,12 @@
max-height: 100%;
}
.hoverable-image-delete-button {
position: absolute;
top: 0.25rem;
right: 0.25rem;
}
.hoverable-image-content {
display: flex;
position: absolute;
@ -57,3 +63,39 @@
}
}
}
.hoverable-image-context-menu {
z-index: 999;
padding: 0.4rem;
border-radius: 0.25rem;
background-color: var(--context-menu-bg-color);
box-shadow: var(--context-menu-box-shadow);
[role='menuitem'] {
font-size: 0.8rem;
line-height: 1rem;
border-radius: 3px;
display: flex;
align-items: center;
height: 1.75rem;
padding: 0 0.5rem;
position: relative;
user-select: none;
cursor: pointer;
outline: none;
&[data-disabled] {
color: grey;
pointer-events: none;
cursor: not-allowed;
}
&[data-warning] {
color: var(--status-bad-color);
}
&[data-highlighted] {
background-color: var(--context-menu-bg-color-hover);
}
}
}

View File

@ -1,17 +1,27 @@
import { Box, Icon, IconButton, Image, Tooltip } from '@chakra-ui/react';
import {
Box,
Icon,
IconButton,
Image,
Tooltip,
useToast,
} from '@chakra-ui/react';
import { RootState, useAppDispatch, useAppSelector } from '../../app/store';
import { setCurrentImage } from './gallerySlice';
import { FaCheck, FaImage, FaSeedling, FaTrashAlt } from 'react-icons/fa';
import { FaCheck, FaTrashAlt } from 'react-icons/fa';
import DeleteImageModal from './DeleteImageModal';
import { memo, SyntheticEvent, useState } from 'react';
import { memo, useState } from 'react';
import {
setActiveTab,
setAllParameters,
setAllImageToImageParameters,
setAllTextToImageParameters,
setInitialImagePath,
setPrompt,
setSeed,
} from '../options/optionsSlice';
import * as InvokeAI from '../../app/invokeai';
import { IoArrowUndoCircleOutline } from 'react-icons/io5';
import * as ContextMenu from '@radix-ui/react-context-menu';
import { tabMap } from '../tabs/InvokeTabs';
interface HoverableImageProps {
image: InvokeAI.Image;
@ -27,115 +37,174 @@ const memoEqualityCheck = (
* Gallery image component with delete/use all/use seed buttons on hover.
*/
const HoverableImage = memo((props: HoverableImageProps) => {
const [isHovered, setIsHovered] = useState<boolean>(false);
const dispatch = useAppDispatch();
const activeTab = useAppSelector(
(state: RootState) => state.options.activeTab
);
const [isHovered, setIsHovered] = useState<boolean>(false);
const toast = useToast();
const { image, isSelected } = props;
const { url, uuid, metadata } = image;
const handleMouseOver = () => setIsHovered(true);
const handleMouseOut = () => setIsHovered(false);
const handleClickSetAllParameters = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setAllParameters(metadata));
const handleUsePrompt = () => {
dispatch(setPrompt(image.metadata.image.prompt));
toast({
title: 'Prompt Set',
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleClickSetSeed = (e: SyntheticEvent) => {
e.stopPropagation();
const handleUseSeed = () => {
dispatch(setSeed(image.metadata.image.seed));
toast({
title: 'Seed Set',
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleSetInitImage = (e: SyntheticEvent) => {
e.stopPropagation();
const handleSendToImageToImage = () => {
dispatch(setInitialImagePath(image.url));
if (activeTab !== 1) {
dispatch(setActiveTab(1));
}
toast({
title: 'Sent to Image To Image',
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleClickImage = () => dispatch(setCurrentImage(image));
const handleUseAllParameters = () => {
dispatch(setAllTextToImageParameters(metadata));
toast({
title: 'Parameters Set',
status: 'success',
duration: 2500,
isClosable: true,
});
};
const handleUseInitialImage = async () => {
// check if the image exists before setting it as initial image
if (metadata?.image?.init_image_path) {
const response = await fetch(metadata.image.init_image_path);
if (response.ok) {
dispatch(setActiveTab(tabMap.indexOf('img2img')));
dispatch(setAllImageToImageParameters(metadata));
toast({
title: 'Initial Image Set',
status: 'success',
duration: 2500,
isClosable: true,
});
return;
}
}
toast({
title: 'Initial Image Not Set',
description: 'Could not load initial image.',
status: 'error',
duration: 2500,
isClosable: true,
});
};
const handleSelectImage = () => dispatch(setCurrentImage(image));
return (
<Box
position={'relative'}
key={uuid}
className="hoverable-image"
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
<Image
objectFit="cover"
rounded={'md'}
src={url}
loading={'lazy'}
className="hoverable-image-image"
/>
<div className="hoverable-image-content" onClick={handleClickImage}>
{isSelected && (
<Icon
width={'50%'}
height={'50%'}
as={FaCheck}
className="hoverable-image-check"
<ContextMenu.Root>
<ContextMenu.Trigger>
<Box
position={'relative'}
key={uuid}
className="hoverable-image"
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
<Image
className="hoverable-image-image"
objectFit="cover"
rounded={'md'}
src={url}
loading={'lazy'}
/>
)}
</div>
{isHovered && (
<div className="hoverable-image-icons">
<Tooltip label={'Delete image'} hasArrow>
<DeleteImageModal image={image}>
<IconButton
colorScheme="red"
aria-label="Delete image"
icon={<FaTrashAlt />}
size="xs"
variant={'imageHoverIconButton'}
fontSize={14}
<div className="hoverable-image-content" onClick={handleSelectImage}>
{isSelected && (
<Icon
width={'50%'}
height={'50%'}
as={FaCheck}
className="hoverable-image-check"
/>
</DeleteImageModal>
</Tooltip>
{['txt2img', 'img2img'].includes(image?.metadata?.image?.type) && (
<Tooltip label="Use All Parameters" hasArrow>
<IconButton
aria-label="Use All Parameters"
icon={<IoArrowUndoCircleOutline />}
size="xs"
fontSize={18}
variant={'imageHoverIconButton'}
onClickCapture={handleClickSetAllParameters}
/>
</Tooltip>
)}
</div>
{isHovered && (
<div className="hoverable-image-delete-button">
<Tooltip label={'Delete image'} hasArrow>
<DeleteImageModal image={image}>
<IconButton
aria-label="Delete image"
icon={<FaTrashAlt />}
size="xs"
variant={'imageHoverIconButton'}
fontSize={14}
/>
</DeleteImageModal>
</Tooltip>
</div>
)}
{image?.metadata?.image?.seed !== undefined && (
<Tooltip label="Use Seed" hasArrow>
<IconButton
aria-label="Use Seed"
icon={<FaSeedling />}
size="xs"
fontSize={16}
variant={'imageHoverIconButton'}
onClickCapture={handleClickSetSeed}
/>
</Tooltip>
)}
<Tooltip label="Send To Image To Image" hasArrow>
<IconButton
aria-label="Send To Image To Image"
icon={<FaImage />}
size="xs"
fontSize={16}
variant={'imageHoverIconButton'}
onClickCapture={handleSetInitImage}
/>
</Tooltip>
</div>
)}
</Box>
</Box>
</ContextMenu.Trigger>
<ContextMenu.Content className="hoverable-image-context-menu">
<ContextMenu.Item
onClickCapture={handleUsePrompt}
disabled={image?.metadata?.image?.prompt === undefined}
>
Use Prompt
</ContextMenu.Item>
<ContextMenu.Item
onClickCapture={handleUseSeed}
disabled={image?.metadata?.image?.seed === undefined}
>
Use Seed
</ContextMenu.Item>
<ContextMenu.Item
onClickCapture={handleUseAllParameters}
disabled={
!['txt2img', 'img2img'].includes(image?.metadata?.image?.type)
}
>
Use All Parameters
</ContextMenu.Item>
<Tooltip label="Load initial image used for this generation">
<ContextMenu.Item
onClickCapture={handleUseInitialImage}
disabled={image?.metadata?.image?.type !== 'img2img'}
>
Use Initial Image
</ContextMenu.Item>
</Tooltip>
<ContextMenu.Item onClickCapture={handleSendToImageToImage}>
Send to Image To Image
</ContextMenu.Item>
<DeleteImageModal image={image}>
<ContextMenu.Item data-warning>Delete Image</ContextMenu.Item>
</DeleteImageModal>
</ContextMenu.Content>
</ContextMenu.Root>
);
}, memoEqualityCheck);

View File

@ -55,31 +55,37 @@
@include HideScrollbar;
}
.masonry-grid {
display: -webkit-box; /* Not needed if autoprefixing */
display: -ms-flexbox; /* Not needed if autoprefixing */
display: flex;
margin-left: 0.5rem; /* gutter size offset */
width: auto;
}
.masonry-grid_column {
padding-left: 0.5rem; /* gutter size */
background-clip: padding-box;
}
// from https://css-tricks.com/a-grid-of-logos-in-squares/
.image-gallery {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(80px, auto));
grid-gap: 0.5rem;
.hoverable-image {
padding: 0.5rem;
position: relative;
&::before {
// for apsect ratio
content: '';
display: block;
padding-bottom: 100%;
}
.hoverable-image-image {
position: absolute;
max-width: 100%;
/* Style your items */
.masonry-grid_column > .hoverable-image {
/* change div to reference your elements you put in <Masonry> */
background: var(--tab-color);
margin-bottom: 0.5rem;
}
// Alternate Version
// top: 0;
// bottom: 0;
// right: 0;
// left: 0;
// margin: auto;
// .image-gallery {
// display: flex;
// grid-template-columns: repeat(auto-fill, minmax(80px, auto));
// gap: 0.5rem;
// justify-items: center;
// }
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
}
}
}
.image-gallery-load-more-btn {
background-color: var(--btn-load-more) !important;

View File

@ -1,10 +1,9 @@
import { Button, IconButton } from '@chakra-ui/button';
import { Resizable } from 're-resizable';
import React, { useState } from 'react';
import React from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { MdClear, MdPhotoLibrary } from 'react-icons/md';
import Masonry from 'react-masonry-css';
import { requestImages } from '../../app/socketio/actions';
import { RootState, useAppDispatch, useAppSelector } from '../../app/store';
import IAIIconButton from '../../common/components/IAIIconButton';
@ -27,12 +26,6 @@ export default function ImageGallery() {
const dispatch = useAppDispatch();
const [column, setColumn] = useState<number | undefined>();
const handleResize = (event: MouseEvent | TouchEvent | any) => {
setColumn(Math.floor((window.innerWidth - event.x) / 120));
};
const handleShowGalleryToggle = () => {
dispatch(setShouldShowGallery(!shouldShowGallery));
};
@ -89,9 +82,7 @@ export default function ImageGallery() {
minWidth={'300'}
maxWidth={activeTab == 1 ? '300' : '600'}
className="image-gallery-popup"
onResize={handleResize}
>
{/* <div className="image-gallery-popup"></div> */}
<div className="image-gallery-header">
<h1>Your Invocations</h1>
<IconButton
@ -104,12 +95,7 @@ export default function ImageGallery() {
</div>
<div className="image-gallery-container">
{images.length ? (
<Masonry
className="masonry-grid"
columnClassName="masonry-grid_column"
breakpointCols={column}
>
{/* <div className="image-gallery"> */}
<div className="image-gallery">
{images.map((image) => {
const { uuid } = image;
const isSelected = currentImageUuid === uuid;
@ -121,8 +107,7 @@ export default function ImageGallery() {
/>
);
})}
{/* </div> */}
</Masonry>
</div>
) : (
<div className="image-gallery-container-placeholder">
<MdPhotoLibrary />

View File

@ -183,6 +183,67 @@ export const optionsSlice = createSlice({
setSeedWeights: (state, action: PayloadAction<string>) => {
state.seedWeights = action.payload;
},
setAllTextToImageParameters: (
state,
action: PayloadAction<InvokeAI.Metadata>
) => {
const {
sampler,
prompt,
seed,
variations,
steps,
cfg_scale,
threshold,
perlin,
seamless,
hires_fix,
width,
height,
} = action.payload.image;
if (variations && variations.length > 0) {
state.seedWeights = seedWeightsToString(variations);
state.shouldGenerateVariations = true;
} else {
state.shouldGenerateVariations = false;
}
if (seed) {
state.seed = seed;
state.shouldRandomizeSeed = false;
}
if (prompt) state.prompt = promptToString(prompt);
if (sampler) state.sampler = sampler;
if (steps) state.steps = steps;
if (cfg_scale) state.cfgScale = cfg_scale;
if (threshold) state.threshold = threshold;
if (typeof threshold === 'undefined') state.threshold = 0;
if (perlin) state.perlin = perlin;
if (typeof perlin === 'undefined') state.perlin = 0;
if (typeof seamless === 'boolean') state.seamless = seamless;
if (typeof hires_fix === 'boolean') state.hiresFix = hires_fix;
if (width) state.width = width;
if (height) state.height = height;
},
setAllImageToImageParameters: (
state,
action: PayloadAction<InvokeAI.Metadata>
) => {
const { type, strength, fit, init_image_path, mask_image_path } =
action.payload.image;
if (type === 'img2img') {
if (init_image_path) state.initialImagePath = init_image_path;
if (mask_image_path) state.maskPath = mask_image_path;
if (strength) state.img2imgStrength = strength;
if (typeof fit === 'boolean') state.shouldFitToWidthHeight = fit;
state.shouldUseInitImage = true;
} else {
state.shouldUseInitImage = false;
}
},
setAllParameters: (state, action: PayloadAction<InvokeAI.Metadata>) => {
const {
type,
@ -226,43 +287,6 @@ export const optionsSlice = createSlice({
state.shouldRandomizeSeed = false;
}
/**
* We support arbitrary numbers of postprocessing steps, so it
* doesnt make sense to be include postprocessing metadata when
* we use all parameters. Because this code needed a bit of braining
* to figure out, I am leaving it, in case it is needed again.
*/
// let postprocessingNotDone = ['gfpgan', 'esrgan'];
// if (postprocessing && postprocessing.length > 0) {
// postprocessing.forEach(
// (postprocess: InvokeAI.PostProcessedImageMetadata) => {
// if (postprocess.type === 'gfpgan') {
// const { strength } = postprocess;
// if (strength) state.facetoolStrength = strength;
// state.shouldRunFacetool = true;
// postprocessingNotDone = postprocessingNotDone.filter(
// (p) => p !== 'gfpgan'
// );
// }
// if (postprocess.type === 'esrgan') {
// const { scale, strength } = postprocess;
// if (scale) state.upscalingLevel = scale;
// if (strength) state.upscalingStrength = strength;
// state.shouldRunESRGAN = true;
// postprocessingNotDone = postprocessingNotDone.filter(
// (p) => p !== 'esrgan'
// );
// }
// }
// );
// }
// postprocessingNotDone.forEach((p) => {
// if (p === 'esrgan') state.shouldRunESRGAN = false;
// if (p === 'gfpgan') state.shouldRunFacetool = false;
// });
if (prompt) state.prompt = promptToString(prompt);
if (sampler) state.sampler = sampler;
if (steps) state.steps = steps;
@ -346,6 +370,8 @@ export const {
setActiveTab,
setShouldShowImageDetails,
setShouldShowGallery,
setAllTextToImageParameters,
setAllImageToImageParameters,
} = optionsSlice.actions;
export default optionsSlice.reducer;

View File

@ -1,4 +1,4 @@
import { IconButton, Image } from '@chakra-ui/react';
import { IconButton, Image, useToast } from '@chakra-ui/react';
import React, { SyntheticEvent } from 'react';
import { MdClear } from 'react-icons/md';
import { RootState, useAppDispatch, useAppSelector } from '../../../app/store';
@ -11,10 +11,23 @@ export default function InitImagePreview() {
const dispatch = useAppDispatch();
const toast = useToast();
const handleClickResetInitialImage = (e: SyntheticEvent) => {
e.stopPropagation();
dispatch(setInitialImagePath(null));
};
const alertMissingInitImage = () => {
toast({
title: 'Problem loading parameters',
description: 'Unable to load init image.',
status: 'error',
isClosable: true,
});
dispatch(setInitialImagePath(null));
};
return (
<div className="init-image-preview">
<div className="init-image-preview-header">
@ -29,7 +42,12 @@ export default function InitImagePreview() {
</div>
{initialImagePath && (
<div className="init-image-image">
<Image fit={'contain'} src={initialImagePath} rounded={'md'} />
<Image
fit={'contain'}
src={initialImagePath}
rounded={'md'}
onError={alertMissingInitImage}
/>
</div>
)}
</div>

View File

@ -50,8 +50,13 @@ export const tab_dict = {
},
};
// Array where index maps to the key of tab_dict
export const tabMap = _.map(tab_dict, (tab, key) => key);
// Use tabMap to generate a union type of tab names
const tabMapTypes = [...tabMap] as const;
export type InvokeTabName = typeof tabMapTypes[number];
export default function InvokeTabs() {
const activeTab = useAppSelector(
(state: RootState) => state.options.activeTab

View File

@ -95,4 +95,9 @@
// Gallery
--gallery-resizeable-color: rgb(36, 38, 48);
// Context Menus
--context-menu-bg-color: rgb(46, 48, 58);
--context-menu-box-shadow: none;
--context-menu-bg-color-hover: rgb(30, 32, 42);
}

View File

@ -94,4 +94,11 @@
// Gallery
--gallery-resizeable-color: rgb(192, 194, 196);
// Context Menus
--context-menu-bg-color: var(--background-color);
--context-menu-box-shadow: 0px 10px 38px -10px rgba(22, 23, 24, 0.35),
0px 10px 20px -15px rgba(22, 23, 24, 0.2);
--context-menu-bg-color-hover: var(--background-color-secondary);
}

View File

@ -213,6 +213,13 @@
dependencies:
regenerator-runtime "^0.13.4"
"@babel/runtime@^7.13.10":
version "7.19.4"
resolved "https://registry.yarnpkg.com/@babel/runtime/-/runtime-7.19.4.tgz#a42f814502ee467d55b38dd1c256f53a7b885c78"
integrity sha512-EXpLCrk55f+cYqmHsSR+yD/0gAIMxxA9QK9lnQWzhMCvt+YmoBN7Zx94s++Kv0+unHk39vxNO8t+CMA2WSS3wA==
dependencies:
regenerator-runtime "^0.13.4"
"@babel/template@^7.18.10":
version "7.18.10"
resolved "https://registry.yarnpkg.com/@babel/template/-/template-7.18.10.tgz#6f9134835970d1dbf0835c0d100c9f38de0c5e71"
@ -1122,6 +1129,26 @@
minimatch "^3.1.2"
strip-json-comments "^3.1.1"
"@floating-ui/core@^0.7.3":
version "0.7.3"
resolved "https://registry.yarnpkg.com/@floating-ui/core/-/core-0.7.3.tgz#d274116678ffae87f6b60e90f88cc4083eefab86"
integrity sha512-buc8BXHmG9l82+OQXOFU3Kr2XQx9ys01U/Q9HMIrZ300iLc8HLMgh7dcCqgYzAzf4BkoQvDcXf5Y+CuEZ5JBYg==
"@floating-ui/dom@^0.5.3":
version "0.5.4"
resolved "https://registry.yarnpkg.com/@floating-ui/dom/-/dom-0.5.4.tgz#4eae73f78bcd4bd553ae2ade30e6f1f9c73fe3f1"
integrity sha512-419BMceRLq0RrmTSDxn8hf9R3VCJv2K9PUfugh5JyEFmdjzDo+e8U5EdR8nzKq8Yj1htzLm3b6eQEEam3/rrtg==
dependencies:
"@floating-ui/core" "^0.7.3"
"@floating-ui/react-dom@0.7.2":
version "0.7.2"
resolved "https://registry.yarnpkg.com/@floating-ui/react-dom/-/react-dom-0.7.2.tgz#0bf4ceccb777a140fc535c87eb5d6241c8e89864"
integrity sha512-1T0sJcpHgX/u4I1OzIEhlcrvkUN8ln39nz7fMoE/2HDHrPiMFoOGR7++GYyfUmIQHkkrTinaeQsO3XWubjSvGg==
dependencies:
"@floating-ui/dom" "^0.5.3"
use-isomorphic-layout-effect "^1.1.1"
"@humanwhocodes/config-array@^0.10.4":
version "0.10.4"
resolved "https://registry.yarnpkg.com/@humanwhocodes/config-array/-/config-array-0.10.4.tgz#01e7366e57d2ad104feea63e72248f22015c520c"
@ -1265,6 +1292,246 @@
resolved "https://registry.yarnpkg.com/@popperjs/core/-/core-2.11.6.tgz#cee20bd55e68a1720bdab363ecf0c821ded4cd45"
integrity sha512-50/17A98tWUfQ176raKiOGXuYpLyyVMkxxG6oylzL3BPOlA6ADGdK7EYunSa4I064xerltq9TGXs8HmOk5E+vw==
"@radix-ui/primitive@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/primitive/-/primitive-1.0.0.tgz#e1d8ef30b10ea10e69c76e896f608d9276352253"
integrity sha512-3e7rn8FDMin4CgeL7Z/49smCA3rFYY3Ha2rUQ7HRWFadS5iCRw08ZgVT1LaNTCNqgvrUiyczLflrVrF0SRQtNA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-arrow@1.0.1":
version "1.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-arrow/-/react-arrow-1.0.1.tgz#5246adf79e97f89e819af68da51ddcf349ecf1c4"
integrity sha512-1yientwXqXcErDHEv8av9ZVNEBldH8L9scVR3is20lL+jOCfcJyMFZFEY5cgIrgexsq1qggSXqiEL/d/4f+QXA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-primitive" "1.0.1"
"@radix-ui/react-collection@1.0.1":
version "1.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-collection/-/react-collection-1.0.1.tgz#259506f97c6703b36291826768d3c1337edd1de5"
integrity sha512-uuiFbs+YCKjn3X1DTSx9G7BHApu4GHbi3kgiwsnFUbOKCrwejAJv4eE4Vc8C0Oaxt9T0aV4ox0WCOdx+39Xo+g==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-context" "1.0.0"
"@radix-ui/react-primitive" "1.0.1"
"@radix-ui/react-slot" "1.0.1"
"@radix-ui/react-compose-refs@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-compose-refs/-/react-compose-refs-1.0.0.tgz#37595b1f16ec7f228d698590e78eeed18ff218ae"
integrity sha512-0KaSv6sx787/hK3eF53iOkiSLwAGlFMx5lotrqD2pTjB18KbybKoEIgkNZTKC60YECDQTKGTRcDBILwZVqVKvA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-context-menu@^2.0.1":
version "2.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-context-menu/-/react-context-menu-2.0.1.tgz#aee7c81bac9983b3748284bf3925dd63796c90b4"
integrity sha512-7DuhU4xDcUk3AMJUlb5tHHOvJZ1GF4+snDIpjtWGlTvO0VktNKgbvBuGLlirdkYoUSI0mJXwOUcUXQapgIyefw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "1.0.0"
"@radix-ui/react-context" "1.0.0"
"@radix-ui/react-menu" "2.0.1"
"@radix-ui/react-primitive" "1.0.1"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-use-controllable-state" "1.0.0"
"@radix-ui/react-context@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-context/-/react-context-1.0.0.tgz#f38e30c5859a9fb5e9aa9a9da452ee3ed9e0aee0"
integrity sha512-1pVM9RfOQ+n/N5PJK33kRSKsr1glNxomxONs5c49MliinBY6Yw2Q995qfBUUo0/Mbg05B/sGA0gkgPI7kmSHBg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-direction@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-direction/-/react-direction-1.0.0.tgz#a2e0b552352459ecf96342c79949dd833c1e6e45"
integrity sha512-2HV05lGUgYcA6xgLQ4BKPDmtL+QbIZYH5fCOTAOOcJ5O0QbWS3i9lKaurLzliYUDhORI2Qr3pyjhJh44lKA3rQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-dismissable-layer@1.0.2":
version "1.0.2"
resolved "https://registry.yarnpkg.com/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.2.tgz#f04d1061bddf00b1ca304148516b9ddc62e45fb2"
integrity sha512-WjJzMrTWROozDqLB0uRWYvj4UuXsM/2L19EmQ3Au+IJWqwvwq9Bwd+P8ivo0Deg9JDPArR1I6MbWNi1CmXsskg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "1.0.0"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-primitive" "1.0.1"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-use-escape-keydown" "1.0.2"
"@radix-ui/react-focus-guards@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-focus-guards/-/react-focus-guards-1.0.0.tgz#339c1c69c41628c1a5e655f15f7020bf11aa01fa"
integrity sha512-UagjDk4ijOAnGu4WMUPj9ahi7/zJJqNZ9ZAiGPp7waUWJO0O1aWXi/udPphI0IUjvrhBsZJGSN66dR2dsueLWQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-focus-scope@1.0.1":
version "1.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-focus-scope/-/react-focus-scope-1.0.1.tgz#faea8c25f537c5a5c38c50914b63722db0e7f951"
integrity sha512-Ej2MQTit8IWJiS2uuujGUmxXjF/y5xZptIIQnyd2JHLwtV0R2j9NRVoRj/1j/gJ7e3REdaBw4Hjf4a1ImhkZcQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-primitive" "1.0.1"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-id@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-id/-/react-id-1.0.0.tgz#8d43224910741870a45a8c9d092f25887bb6d11e"
integrity sha512-Q6iAB/U7Tq3NTolBBQbHTgclPmGWE3OlktGGqrClPozSw4vkQ1DfQAOtzgRPecKsMdJINE05iaoDUG8tRzCBjw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-layout-effect" "1.0.0"
"@radix-ui/react-menu@2.0.1":
version "2.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-menu/-/react-menu-2.0.1.tgz#44ebfd45d8482db678b935c0b9d1102d683372d8"
integrity sha512-I5FFZQxCl2fHoJ7R0m5/oWA9EX8/ttH4AbgneoCH7DAXQioFeb0XMAYnOVSp1GgJZ1Nx/mohxNQSeTMcaF1YPw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "1.0.0"
"@radix-ui/react-collection" "1.0.1"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-context" "1.0.0"
"@radix-ui/react-direction" "1.0.0"
"@radix-ui/react-dismissable-layer" "1.0.2"
"@radix-ui/react-focus-guards" "1.0.0"
"@radix-ui/react-focus-scope" "1.0.1"
"@radix-ui/react-id" "1.0.0"
"@radix-ui/react-popper" "1.0.1"
"@radix-ui/react-portal" "1.0.1"
"@radix-ui/react-presence" "1.0.0"
"@radix-ui/react-primitive" "1.0.1"
"@radix-ui/react-roving-focus" "1.0.1"
"@radix-ui/react-slot" "1.0.1"
"@radix-ui/react-use-callback-ref" "1.0.0"
aria-hidden "^1.1.1"
react-remove-scroll "2.5.5"
"@radix-ui/react-popper@1.0.1":
version "1.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-popper/-/react-popper-1.0.1.tgz#9fa8a6a493404afa225866a5cd75af23d141baa0"
integrity sha512-J4Vj7k3k+EHNWgcKrE+BLlQfpewxA7Zd76h5I0bIa+/EqaIZ3DuwrbPj49O3wqN+STnXsBuxiHLiF0iU3yfovw==
dependencies:
"@babel/runtime" "^7.13.10"
"@floating-ui/react-dom" "0.7.2"
"@radix-ui/react-arrow" "1.0.1"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-context" "1.0.0"
"@radix-ui/react-primitive" "1.0.1"
"@radix-ui/react-use-layout-effect" "1.0.0"
"@radix-ui/react-use-rect" "1.0.0"
"@radix-ui/react-use-size" "1.0.0"
"@radix-ui/rect" "1.0.0"
"@radix-ui/react-portal@1.0.1":
version "1.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-portal/-/react-portal-1.0.1.tgz#169c5a50719c2bb0079cf4c91a27aa6d37e5dd33"
integrity sha512-NY2vUWI5WENgAT1nfC6JS7RU5xRYBfjZVLq0HmgEN1Ezy3rk/UruMV4+Rd0F40PEaFC5SrLS1ixYvcYIQrb4Ig==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-primitive" "1.0.1"
"@radix-ui/react-presence@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-presence/-/react-presence-1.0.0.tgz#814fe46df11f9a468808a6010e3f3ca7e0b2e84a"
integrity sha512-A+6XEvN01NfVWiKu38ybawfHsBjWum42MRPnEuqPsBZ4eV7e/7K321B5VgYMPv3Xx5An6o1/l9ZuDBgmcmWK3w==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-use-layout-effect" "1.0.0"
"@radix-ui/react-primitive@1.0.1":
version "1.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-primitive/-/react-primitive-1.0.1.tgz#c1ebcce283dd2f02e4fbefdaa49d1cb13dbc990a"
integrity sha512-fHbmislWVkZaIdeF6GZxF0A/NH/3BjrGIYj+Ae6eTmTCr7EB0RQAAVEiqsXK6p3/JcRqVSBQoceZroj30Jj3XA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-slot" "1.0.1"
"@radix-ui/react-roving-focus@1.0.1":
version "1.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-roving-focus/-/react-roving-focus-1.0.1.tgz#475621f63aee43faa183a5270f35d49e530de3d7"
integrity sha512-TB76u5TIxKpqMpUAuYH2VqMhHYKa+4Vs1NHygo/llLvlffN6mLVsFhz0AnSFlSBAvTBYVHYAkHAyEt7x1gPJOA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/primitive" "1.0.0"
"@radix-ui/react-collection" "1.0.1"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-context" "1.0.0"
"@radix-ui/react-direction" "1.0.0"
"@radix-ui/react-id" "1.0.0"
"@radix-ui/react-primitive" "1.0.1"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-use-controllable-state" "1.0.0"
"@radix-ui/react-slot@1.0.1":
version "1.0.1"
resolved "https://registry.yarnpkg.com/@radix-ui/react-slot/-/react-slot-1.0.1.tgz#e7868c669c974d649070e9ecbec0b367ee0b4d81"
integrity sha512-avutXAFL1ehGvAXtPquu0YK5oz6ctS474iM3vNGQIkswrVhdrS52e3uoMQBzZhNRAIE0jBnUyXWNmSjGHhCFcw==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-compose-refs" "1.0.0"
"@radix-ui/react-use-callback-ref@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.0.tgz#9e7b8b6b4946fe3cbe8f748c82a2cce54e7b6a90"
integrity sha512-GZtyzoHz95Rhs6S63D2t/eqvdFCm7I+yHMLVQheKM7nBD8mbZIt+ct1jz4536MDnaOGKIxynJ8eHTkVGVVkoTg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-controllable-state@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-use-controllable-state/-/react-use-controllable-state-1.0.0.tgz#a64deaafbbc52d5d407afaa22d493d687c538b7f"
integrity sha512-FohDoZvk3mEXh9AWAVyRTYR4Sq7/gavuofglmiXB2g1aKyboUD4YtgWxKj8O5n+Uak52gXQ4wKz5IFST4vtJHg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-use-escape-keydown@1.0.2":
version "1.0.2"
resolved "https://registry.yarnpkg.com/@radix-ui/react-use-escape-keydown/-/react-use-escape-keydown-1.0.2.tgz#09ab6455ab240b4f0a61faf06d4e5132c4d639f6"
integrity sha512-DXGim3x74WgUv+iMNCF+cAo8xUHHeqvjx8zs7trKf+FkQKPQXLk2sX7Gx1ysH7Q76xCpZuxIJE7HLPxRE+Q+GA==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-callback-ref" "1.0.0"
"@radix-ui/react-use-layout-effect@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-use-layout-effect/-/react-use-layout-effect-1.0.0.tgz#2fc19e97223a81de64cd3ba1dc42ceffd82374dc"
integrity sha512-6Tpkq+R6LOlmQb1R5NNETLG0B4YP0wc+klfXafpUCj6JGyaUc8il7/kUZ7m59rGbXGczE9Bs+iz2qloqsZBduQ==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-rect@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-use-rect/-/react-use-rect-1.0.0.tgz#b040cc88a4906b78696cd3a32b075ed5b1423b3e"
integrity sha512-TB7pID8NRMEHxb/qQJpvSt3hQU4sqNPM1VCTjTRjEOa7cEop/QMuq8S6fb/5Tsz64kqSvB9WnwsDHtjnrM9qew==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/rect" "1.0.0"
"@radix-ui/react-use-size@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/react-use-size/-/react-use-size-1.0.0.tgz#a0b455ac826749419f6354dc733e2ca465054771"
integrity sha512-imZ3aYcoYCKhhgNpkNDh/aTiU05qw9hX+HHI1QDBTyIlcFjgeFlKKySNGMwTp7nYFLQg/j0VA2FmCY4WPDDHMg==
dependencies:
"@babel/runtime" "^7.13.10"
"@radix-ui/react-use-layout-effect" "1.0.0"
"@radix-ui/rect@1.0.0":
version "1.0.0"
resolved "https://registry.yarnpkg.com/@radix-ui/rect/-/rect-1.0.0.tgz#0dc8e6a829ea2828d53cbc94b81793ba6383bf3c"
integrity sha512-d0O68AYy/9oeEy1DdC07bz1/ZXX+DqCskRd3i4JzLSTXwefzaepQrKjXC7aNM8lTHjFLDO0pDgaEiQ7jEk+HVg==
dependencies:
"@babel/runtime" "^7.13.10"
"@reduxjs/toolkit@^1.8.5":
version "1.8.5"
resolved "https://registry.yarnpkg.com/@reduxjs/toolkit/-/toolkit-1.8.5.tgz#c14bece03ee08be88467f22dc0ecf9cf875527cd"
@ -2850,11 +3117,6 @@ react-is@^18.0.0:
resolved "https://registry.yarnpkg.com/react-is/-/react-is-18.2.0.tgz#199431eeaaa2e09f86427efbb4f1473edb47609b"
integrity sha512-xWGDIW6x921xtzPkhiULtthJHoJvBbF3q26fzloPCK0hsvxtPVelvftw3zjbHWSkR2km9Z+4uxbDDK/6Zw9B8w==
react-masonry-css@^1.0.16:
version "1.0.16"
resolved "https://registry.yarnpkg.com/react-masonry-css/-/react-masonry-css-1.0.16.tgz#72b28b4ae3484e250534700860597553a10f1a2c"
integrity sha512-KSW0hR2VQmltt/qAa3eXOctQDyOu7+ZBevtKgpNDSzT7k5LA/0XntNa9z9HKCdz3QlxmJHglTZ18e4sX4V8zZQ==
react-redux@^8.0.2:
version "8.0.2"
resolved "https://registry.yarnpkg.com/react-redux/-/react-redux-8.0.2.tgz#bc2a304bb21e79c6808e3e47c50fe1caf62f7aad"
@ -2880,7 +3142,7 @@ react-remove-scroll-bar@^2.3.3:
react-style-singleton "^2.2.1"
tslib "^2.0.0"
react-remove-scroll@^2.5.4:
react-remove-scroll@2.5.5, react-remove-scroll@^2.5.4:
version "2.5.5"
resolved "https://registry.yarnpkg.com/react-remove-scroll/-/react-remove-scroll-2.5.5.tgz#1e31a1260df08887a8a0e46d09271b52b3a37e77"
integrity sha512-ImKhrzJJsyXJfBZ4bzu8Bwpka14c/fQt0k+cyFp/PBhTfyDnU5hjOtM4AG/0AMyy8oKzOTR0lDgJIM7pYXI0kw==
@ -3255,6 +3517,11 @@ use-callback-ref@^1.3.0:
dependencies:
tslib "^2.0.0"
use-isomorphic-layout-effect@^1.1.1:
version "1.1.2"
resolved "https://registry.yarnpkg.com/use-isomorphic-layout-effect/-/use-isomorphic-layout-effect-1.1.2.tgz#497cefb13d863d687b08477d9e5a164ad8c1a6fb"
integrity sha512-49L8yCO3iGT/ZF9QttjwLF/ZD9Iwto5LnH5LmEdk/6cFmXddqi2ulF0edxTwjj+7mqvpVVGQWvbXZdn32wRSHA==
use-sidecar@^1.1.2:
version "1.1.2"
resolved "https://registry.yarnpkg.com/use-sidecar/-/use-sidecar-1.1.2.tgz#2f43126ba2d7d7e117aa5855e5d8f0276dfe73c2"

View File

@ -55,23 +55,8 @@ torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
torch.randn_like = fix_func(torch.randn_like)
torch.randint = fix_func(torch.randint)
torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
# this is fallback model in case no default is defined
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
"""Simplified text to image API for stable diffusion/latent diffusion
@ -125,12 +110,13 @@ still work.
The full list of arguments to Generate() are:
gr = Generate(
# these values are set once and shouldn't be changed
conf = path to configuration file ('configs/models.yaml')
model = symbolic name of the model in the configuration file
precision = float precision to be used
conf:str = path to configuration file ('configs/models.yaml')
model:str = symbolic name of the model in the configuration file
precision:float = float precision to be used
safety_checker:bool = activate safety checker [False]
# this value is sticky and maintained between generation calls
sampler_name = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
sampler_name:str = ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms
# these are deprecated - use conf and model instead
weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt')
@ -147,23 +133,23 @@ class Generate:
def __init__(
self,
model = 'stable-diffusion-1.4',
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
model = None,
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
gfpgan=None,
codeformer=None,
esrgan=None,
free_gpu_mem=False,
safety_checker:bool=False,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
):
mconfig = OmegaConf.load(conf)
self.model_name = model
self.height = None
self.width = None
self.model_cache = None
@ -192,6 +178,7 @@ class Generate:
self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -210,6 +197,7 @@ class Generate:
# model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision)
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
# for VRAM usage statistics
self.session_peakmem = torch.cuda.max_memory_allocated() if self._has_cuda else None
@ -218,6 +206,19 @@ class Generate:
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load safety checker if requested
if safety_checker:
try:
print('>> Initializing safety checker')
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())
def prompt2png(self, prompt, outdir, **kwargs):
"""
Takes a prompt and an output directory, writes out the requested number
@ -286,6 +287,8 @@ class Generate:
upscale = None,
# this is specific to inpainting and causes more extreme inpainting
inpaint_replace = 0.0,
# This will help match inpainted areas to the original image more smoothly
mask_blur_radius: int = 8,
# Set this True to handle KeyboardInterrupt internally
catch_interrupts = False,
hires_fix = False,
@ -406,7 +409,7 @@ class Generate:
log_tokens =self.log_tokenization
)
init_image,mask_image = self._make_images(
init_image, mask_image = self._make_images(
init_img,
init_mask,
width,
@ -431,6 +434,11 @@ class Generate:
self.seed, variation_amount, with_variations
)
checker = {
'checker':self.safety_checker,
'extractor':self.safety_feature_extractor
} if self.safety_checker else None
results = generator.generate(
prompt,
iterations=iterations,
@ -441,10 +449,10 @@ class Generate:
conditioning=(uc, c),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
@ -453,6 +461,8 @@ class Generate:
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius,
safety_checker=checker
)
if init_color:
@ -570,16 +580,19 @@ class Generate:
from ldm.invoke.restoration.outcrop import Outcrop
extend_instructions = {}
for direction,pixels in _pairwise(opt.outcrop):
extend_instructions[direction]=int(pixels)
restorer = Outcrop(image,self,)
return restorer.process (
extend_instructions,
opt = opt,
orig_opt = args,
image_callback = callback,
prefix = prefix,
)
try:
extend_instructions[direction]=int(pixels)
except ValueError:
print(f'** invalid extension instruction. Use <directions> <pixels>..., as in "top 64 left 128 right 64 bottom 64"')
if len(extend_instructions)>0:
restorer = Outcrop(image,self,)
return restorer.process (
extend_instructions,
opt = opt,
orig_opt = args,
image_callback = callback,
prefix = prefix,
)
elif tool == 'embiggen':
# fetch the metadata from the image
@ -643,23 +656,22 @@ class Generate:
# if image has a transparent area and no mask was provided, then try to generate mask
if self._has_transparency(image):
self._transparency_check_and_warning(image, mask)
# this returns a torch tensor
init_mask = self._create_init_mask(image, width, height, fit=fit)
if (image.width * image.height) > (self.width * self.height) and self.size_matters:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
self.size_matters = False
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
init_image = self._create_init_image(image,width,height,fit=fit)
if mask:
mask_image = self._load_img(mask) # this returns an Image
mask_image = self._load_img(mask)
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
return init_image, init_mask
return init_image,init_mask
def _make_base(self):
if not self.generators.get('base'):
@ -715,8 +727,7 @@ class Generate:
model_data = self.model_cache.get_model(model_name)
if model_data is None or len(model_data) == 0:
print(f'** Model switch failed **')
return self.model
return None
self.model = model_data['model']
self.width = model_data['width']
@ -877,46 +888,31 @@ class Generate:
def _create_init_image(self, image, width, height, fit=True):
image = image.convert('RGB')
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image = 2.0 * image - 1.0
return image.to(self.device)
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
return image
def _create_init_mask(self, image, width, height, fit=True):
# convert into a black/white mask
image = self._image_to_mask(image)
image = image.convert('RGB')
# now we adjust the size
if fit:
image = self._fit_image(image, (width, height))
else:
image = self._squeeze_image(image)
image = image.resize((image.width//downsampling, image.height //
downsampling), resample=Image.Resampling.NEAREST)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image.to(self.device)
image = self._fit_image(image, (width, height)) if fit else self._squeeze_image(image)
return image
# The mask is expected to have the region to be inpainted
# with alpha transparency. It converts it into a black/white
# image with the transparent part black.
def _image_to_mask(self, mask_image, invert=False) -> Image:
def _image_to_mask(self, mask_image: Image.Image, invert=False) -> Image:
# Obtain the mask from the transparency channel
mask = Image.new(mode="L", size=mask_image.size, color=255)
mask.putdata(mask_image.getdata(band=3))
if mask_image.mode == 'L':
mask = mask_image
else:
# Obtain the mask from the transparency channel
mask = Image.new(mode="L", size=mask_image.size, color=255)
mask.putdata(mask_image.getdata(band=3))
if invert:
mask = ImageOps.invert(mask)
return mask
# TODO: The latter part of this method repeats code from _create_init_mask()
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
prompt = text_mask[0]
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
@ -926,18 +922,8 @@ class Generate:
segmented = self.txt2mask.segment(image, prompt)
mask = segmented.to_mask(float(confidence_level))
mask = mask.convert('RGB')
# now we adjust the size
if fit:
mask = self._fit_image(mask, (width, height))
else:
mask = self._squeeze_image(mask)
mask = mask.resize((mask.width//downsampling, mask.height //
downsampling), resample=Image.Resampling.NEAREST)
mask = np.array(mask)
mask = mask.astype(np.float32) / 255.0
mask = mask[None].transpose(0, 3, 1, 2)
mask = torch.from_numpy(mask)
return mask.to(self.device)
mask = self._fit_image(mask, (width, height)) if fit else self._squeeze_image(mask)
return mask
def _has_transparency(self, image):
if image.info.get("transparency", None) is not None:

View File

@ -113,8 +113,8 @@ PRECISION_CHOICES = [
]
# is there a way to pick this up during git commits?
APP_ID = 'lstein/stable-diffusion'
APP_VERSION = 'v1.15'
APP_ID = 'invoke-ai/InvokeAI'
APP_VERSION = 'v2.02'
class ArgFormatter(argparse.RawTextHelpFormatter):
# use defined argument order to display usage
@ -172,6 +172,7 @@ class Args(object):
command = cmd_string.replace("'", "\\'")
try:
elements = shlex.split(command)
elements = [x.replace("\\'","'") for x in elements]
except ValueError:
import sys, traceback
print(traceback.format_exc(), file=sys.stderr)
@ -366,17 +367,16 @@ class Args(object):
deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated
model_group.add_argument(
'--conf',
'--config',
'-c',
'-conf',
'-config',
dest='conf',
default='./configs/models.yaml',
help='Path to configuration file for alternate models.',
)
model_group.add_argument(
'--model',
default='stable-diffusion-1.4',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
help='Indicates which diffusion model to load (defaults to "default" stanza in configs/models.yaml)',
)
model_group.add_argument(
'--png_compression','-z',
@ -419,6 +419,11 @@ class Args(object):
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto',
)
model_group.add_argument(
'--safety_checker',
action='store_true',
help='Check for and blur potentially NSFW images',
)
file_group.add_argument(
'--from_file',
dest='infile',
@ -529,7 +534,7 @@ class Args(object):
formatter_class=ArgFormatter,
description=
"""
*Image generation:*
*Image generation*
invoke> a fantastic alien landscape -W576 -H512 -s60 -n4
*postprocessing*
@ -544,6 +549,13 @@ class Args(object):
!history lists all the commands issued during the current session.
!NN retrieves the NNth command from the history
*Model manipulation*
!models -- list models in configs/models.yaml
!switch <model_name> -- switch to model named <model_name>
!import_model path/to/weights/file.ckpt -- adds a model to your config
!edit_model <model_name> -- edit a model's description
!del_model <model_name> -- delete a model
"""
)
render_group = parser.add_argument_group('General rendering')
@ -840,7 +852,7 @@ def metadata_dumps(opt,
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','threshold','perlin','step_number','width','height','extra','strength',
'init_img','init_mask']
'init_img','init_mask','facetool','facetool_strength','upscale']
rfc_dict ={}
@ -924,7 +936,7 @@ def metadata_loads(metadata) -> list:
for image in images:
# repack the prompt and variations
if 'prompt' in image:
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
image['prompt'] = repack_prompt(image['prompt'])
if 'variations' in image:
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
# fix a bit of semantic drift here
@ -932,12 +944,19 @@ def metadata_loads(metadata) -> list:
opt = Args()
opt._cmd_switches = Namespace(**image)
results.append(opt)
except KeyError as e:
except Exception as e:
import sys, traceback
print('>> badly-formatted metadata',file=sys.stderr)
print('>> could not read metadata',file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return results
def repack_prompt(prompt_list:list)->str:
# in the common case of no weighting syntax, just return the prompt as is
if len(prompt_list) > 1:
return ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in prompt_list])
else:
return prompt_list[0]['prompt']
# image can either be a file path on disk or a base64-encoded
# representation of the file's contents
def calculate_init_img_hash(image_string):
@ -967,17 +986,17 @@ def sha256(path):
return sha.hexdigest()
def legacy_metadata_load(meta,pathname) -> Args:
opt = Args()
if 'Dream' in meta and len(meta['Dream']) > 0:
dream_prompt = meta['Dream']
opt = Args()
opt.parse_cmd(dream_prompt)
return opt
else: # if nothing else, we can get the seed
match = re.search('\d+\.(\d+)',pathname)
if match:
seed = match.groups()[0]
opt = Args()
opt.seed = seed
return opt
return None
else:
opt.prompt = ''
opt.seed = 0
return opt

View File

@ -7,25 +7,27 @@ import numpy as np
import random
import os
from tqdm import tqdm, trange
from PIL import Image
from PIL import Image, ImageFilter
from einops import rearrange, repeat
from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast
from ldm.util import rand_perlin_2d
downsampling = 8
CAUTION_IMG = 'assets/caution.png'
class Generator():
def __init__(self, model, precision):
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.model = model
self.precision = precision
self.seed = None
self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
self.safety_checker = None
self.perlin = 0.0
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs):
@ -42,8 +44,10 @@ class Generator():
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs):
scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
make_image = self.get_make_image(
prompt,
init_image = init_image,
@ -77,10 +81,17 @@ class Generator():
pass
image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed()
return results
def sample_to_image(self,samples):
@ -169,6 +180,39 @@ class Generator():
return v2
def safety_check(self,image:Image.Image):
'''
If the CompViz safety checker flags an NSFW image, we
blur it out.
'''
import diffusers
checker = self.safety_checker['checker']
extractor = self.safety_checker['extractor']
features = extractor([image], return_tensors="pt")
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
if has_nsfw_concept[0]:
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
return self.blur(image)
else:
return image
def blur(self,input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = Image.open(CAUTION_IMG)
caution = caution.resize((caution.width // 2, caution.height //2))
blurry.paste(caution,(0,0),caution)
except FileNotFoundError:
pass
return blurry
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):

View File

@ -4,9 +4,12 @@ ldm.invoke.generator.img2img descends from ldm.invoke.generator
import torch
import numpy as np
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
import PIL
from torch import Tensor
from PIL import Image
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
class Img2Img(Generator):
def __init__(self, model, precision):
@ -25,6 +28,9 @@ class Img2Img(Generator):
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
)
if isinstance(init_image, PIL.Image.Image):
init_image = self._image_to_tensor(init_image)
scope = choose_autocast(self.precision)
with scope(self.model.device.type):
self.init_latent = self.model.get_first_stage_encoding(
@ -68,3 +74,11 @@ class Img2Img(Generator):
shape = init_latent.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
return x
def _image_to_tensor(self, image:Image, normalize:bool=True)->Tensor:
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
if normalize:
image = 2.0 * image - 1.0
return image.to(self.model.device)

View File

@ -3,27 +3,55 @@ ldm.invoke.generator.inpaint descends from ldm.invoke.generator
'''
import torch
import torchvision.transforms as T
import numpy as np
import cv2 as cv
import PIL
from PIL import Image, ImageFilter
from skimage.exposure.histogram_matching import match_histograms
from einops import rearrange, repeat
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.img2img import Img2Img
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.invoke.generator.base import downsampling
class Inpaint(Img2Img):
def __init__(self, model, precision):
self.init_latent = None
self.pil_image = None
self.pil_mask = None
self.mask_blur_radius = 0
super().__init__(model, precision)
@torch.no_grad()
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
conditioning,init_image,mask_image,strength,
step_callback=None,inpaint_replace=False,**kwargs):
mask_blur_radius: int = 8,
step_callback=None,inpaint_replace=False, **kwargs):
"""
Returns a function returning an image derived from the prompt and
the initial image + mask. Return value depends on the seed at
the time you call it. kwargs are 'init_latent' and 'strength'
"""
if isinstance(init_image, PIL.Image.Image):
self.pil_image = init_image
init_image = self._image_to_tensor(init_image)
if isinstance(mask_image, PIL.Image.Image):
self.pil_mask = mask_image
mask_image = mask_image.resize(
(
mask_image.width // downsampling,
mask_image.height // downsampling
),
resample=Image.Resampling.NEAREST
)
mask_image = self._image_to_tensor(mask_image,normalize=False)
self.mask_blur_radius = mask_blur_radius
# klms samplers not supported yet, so ignore previous sampler
if isinstance(sampler,KSampler):
print(
@ -77,10 +105,50 @@ class Inpaint(Img2Img):
mask = mask_image,
init_latent = self.init_latent
)
return self.sample_to_image(samples)
return make_image
def sample_to_image(self, samples)->Image.Image:
gen_result = super().sample_to_image(samples).convert('RGB')
if self.pil_image is None or self.pil_mask is None:
return gen_result
pil_mask = self.pil_mask
pil_image = self.pil_image
mask_blur_radius = self.mask_blur_radius
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
pil_init_mask = pil_mask.getchannel('A') if pil_mask.mode == 'RGBA' else pil_mask.convert('L')
pil_init_image = pil_image.convert('RGBA') # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
# Note that this doesn't use the mask, which would exclude some source image pixels from the
# histogram and cause slight color changes.
init_rgb_pixels = np.asarray(pil_image.convert('RGB'), dtype=np.uint8).reshape(pil_image.width * pil_image.height, 3)
init_a_pixels = np.asarray(pil_init_image.getchannel('A'), dtype=np.uint8).reshape(pil_init_mask.width * pil_init_mask.height)
init_rgb_pixels = init_rgb_pixels[init_a_pixels > 0]
init_rgb_pixels = init_rgb_pixels.reshape(1, init_rgb_pixels.shape[0], init_rgb_pixels.shape[1]) # Filter to just pixels that have any alpha, this is now our histogram
# Get numpy version
np_gen_result = np.asarray(gen_result, dtype=np.uint8)
# Color correct
np_matched_result = match_histograms(np_gen_result, init_rgb_pixels, channel_axis=-1)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:
nm = np.asarray(pil_init_mask, dtype=np.uint8)
nmd = cv.erode(nm, kernel=np.ones((3,3), dtype=np.uint8), iterations=int(mask_blur_radius / 2))
pmd = Image.fromarray(nmd, mode='L')
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(pil_image, (0,0), mask = blurred_init_mask)
return matched_result

View File

@ -13,6 +13,7 @@ import gc
import hashlib
import psutil
import transformers
import os
from sys import getrefcount
from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError
@ -73,7 +74,8 @@ class ModelCache(object):
except Exception as e:
print(f'** model {model_name} could not be loaded: {str(e)}')
print(f'** restoring {self.current_model}')
return self.get_model(self.current_model)
self.get_model(self.current_model)
return None
self.current_model = model_name
self._push_newest_model(model_name)
@ -84,6 +86,26 @@ class ModelCache(object):
'hash': hash
}
def default_model(self) -> str:
'''
Returns the name of the default model, or None
if none is defined.
'''
for model_name in self.config:
if self.config[model_name].get('default',False):
return model_name
return None
def set_default_model(self,model_name:str):
'''
Set the default model. The change will not take
effect until you call model_cache.commit()
'''
assert model_name in self.models,f"unknown model '{model_name}'"
for model in self.models:
self.models[model].pop('default',None)
self.models[model_name]['default'] = True
def list_models(self) -> dict:
'''
Return a dict of models in the format:
@ -121,12 +143,23 @@ class ModelCache(object):
else:
print(line)
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->str:
def del_model(self, model_name:str) ->bool:
'''
Delete the named model.
'''
omega = self.config
del omega[model_name]
if model_name in self.stack:
self.stack.remove(model_name)
return True
def add_model(self, model_name:str, model_attributes:dict, clobber=False) ->True:
'''
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
On a successful update, the config will be changed in memory and a YAML
string will be returned.
On a successful update, the config will be changed in memory and the
method will return True. Will fail with an assertion error if provided
attributes are incorrect or the model name is missing.
'''
omega = self.config
# check that all the required fields are present
@ -139,7 +172,9 @@ class ModelCache(object):
config[field] = model_attributes[field]
omega[model_name] = config
return OmegaConf.to_yaml(omega)
if clobber:
self._invalidate_cached_model(model_name)
return True
def _check_memory(self):
avail_memory = psutil.virtual_memory()[1]
@ -159,6 +194,7 @@ class ModelCache(object):
mconfig = self.config[model_name]
config = mconfig.config
weights = mconfig.weights
vae = mconfig.get('vae',None)
width = mconfig.width
height = mconfig.height
@ -188,9 +224,17 @@ class ModelCache(object):
else:
print(' | Using more accurate float32 precision')
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae and os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
model.cond_stage_model.device = self.device
model.eval()
for m in model.modules():
@ -219,6 +263,36 @@ class ModelCache(object):
if self._has_cuda():
torch.cuda.empty_cache()
def commit(self,config_file_path:str):
'''
Write current configuration out to the indicated file.
'''
yaml_str = OmegaConf.to_yaml(self.config)
tmpfile = os.path.join(os.path.dirname(config_file_path),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(self.preamble())
outfile.write(yaml_str)
os.rename(tmpfile,config_file_path)
def preamble(self):
'''
Returns the preamble for the config file.
'''
return '''# This file describes the alternative machine learning models
# available to the dream script.
#
# To add a new model, follow the examples below. Each
# model requires a model config file, a weights file,
# and the width and height of the images it
# was trained on.
'''
def _invalidate_cached_model(self,model_name:str):
self.unload_model(model_name)
if model_name in self.stack:
self.stack.remove(model_name)
self.models.pop(model_name,None)
def _model_to_cpu(self,model):
if self.device != 'cpu':
model.cond_stage_model.device = 'cpu'

View File

@ -38,7 +38,7 @@ class PngWriter:
info = PngImagePlugin.PngInfo()
info.add_text('Dream', dream_prompt)
if metadata:
info.add_text('sd-metadata', json.dumps(metadata))
info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info, compress_level=compress_level)
return path

View File

@ -22,6 +22,7 @@ except (ImportError,ModuleNotFoundError):
IMG_EXTENSIONS = ('.png','.jpg','.jpeg','.PNG','.JPG','.JPEG','.gif','.GIF')
WEIGHT_EXTENSIONS = ('.ckpt','.bae')
TEXT_EXTENSIONS = ('.txt','.TXT')
CONFIG_EXTENSIONS = ('.yaml','.yml')
COMMANDS = (
'--steps','-s',
@ -55,13 +56,14 @@ COMMANDS = (
'--inpaint_replace','-r',
'--png_compression','-z',
'--text_mask','-tm',
'!fix','!fetch','!history','!search','!clear',
'!fix','!fetch','!replay','!history','!search','!clear',
'!models','!switch','!import_model','!edit_model','!del_model',
'!mask',
'!models','!switch','!import_model','!edit_model'
)
MODEL_COMMANDS = (
'!switch',
'!edit_model',
'!del_model',
)
WEIGHT_COMMANDS = (
'!import_model',
@ -69,6 +71,9 @@ WEIGHT_COMMANDS = (
IMG_PATH_COMMANDS = (
'--outdir[=\s]',
)
TEXT_PATH_COMMANDS=(
'!replay',
)
IMG_FILE_COMMANDS=(
'!fix',
'!fetch',
@ -78,8 +83,9 @@ IMG_FILE_COMMANDS=(
'--init_color[=\s]',
'--embedding_path[=\s]',
)
path_regexp = '('+'|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
weight_regexp = '('+'|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
path_regexp = '(' + '|'.join(IMG_PATH_COMMANDS+IMG_FILE_COMMANDS) + ')\s*\S*$'
weight_regexp = '(' + '|'.join(WEIGHT_COMMANDS) + ')\s*\S*$'
text_regexp = '(' + '|'.join(TEXT_PATH_COMMANDS) + ')\s*\S*$'
class Completer(object):
def __init__(self, options, models=[]):
@ -122,6 +128,9 @@ class Completer(object):
elif re.search(weight_regexp,buffer):
self.matches = self._path_completions(text, state, WEIGHT_EXTENSIONS)
elif re.search(text_regexp,buffer):
self.matches = self._path_completions(text, state, TEXT_EXTENSIONS)
# This is the first time for this text, so build a match list.
elif text:
self.matches = [
@ -210,9 +219,24 @@ class Completer(object):
pydoc.pager('\n'.join(lines))
def set_line(self,line)->None:
'''
Set the default string displayed in the next line of input.
'''
self.linebuffer = line
readline.redisplay()
def add_model(self,model_name:str)->None:
'''
add a model name to the completion list
'''
self.models.append(model_name)
def del_model(self,model_name:str)->None:
'''
removes a model name from the completion list
'''
self.models.remove(model_name)
def _seed_completions(self, text, state):
m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text)
if m:

View File

@ -64,7 +64,8 @@ def make_ddim_timesteps(
):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
# ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
if c < 1:
c = 1
ddim_timesteps = (np.arange(0, num_ddim_timesteps) * c).astype(int)
elif ddim_discr_method == 'quad':
ddim_timesteps = (

26
main.py
View File

@ -439,7 +439,7 @@ class ImageLogger(Callback):
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = { pl.loggers.TestTubeLogger: self._testtube, } if torch.cuda.is_available() else { }
self.logger_log_images = { }
self.log_steps = [
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
]
@ -451,17 +451,6 @@ class ImageLogger(Callback):
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
@rank_zero_only
def _testtube(self, pl_module, images, batch_idx, split):
for k in images:
grid = torchvision.utils.make_grid(images[k])
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
tag = f'{split}/{k}'
pl_module.logger.experiment.add_image(
tag, grid, global_step=pl_module.global_step
)
@rank_zero_only
def log_local(
self, save_dir, split, images, global_step, current_epoch, batch_idx
@ -714,7 +703,7 @@ if __name__ == '__main__':
# merge trainer cli with config
trainer_config = lightning_config.get('trainer', OmegaConf.create())
# default to ddp
trainer_config['accelerator'] = 'ddp'
trainer_config['accelerator'] = 'auto'
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if not 'gpus' in trainer_config:
@ -751,12 +740,8 @@ if __name__ == '__main__':
trainer_kwargs = dict()
# default logger configs
if torch.cuda.is_available():
def_logger = 'testtube'
def_logger_target = 'TestTubeLogger'
else:
def_logger = 'csv'
def_logger_target = 'CSVLogger'
def_logger = 'csv'
def_logger_target = 'CSVLogger'
default_logger_cfgs = {
'wandb': {
'target': 'pytorch_lightning.loggers.WandbLogger',
@ -918,7 +903,8 @@ if __name__ == '__main__':
config.model.base_learning_rate,
)
if not cpu:
ngpu = len(lightning_config.trainer.gpus.strip(',').split(','))
gpus = str(lightning_config.trainer.gpus).strip(', ').split(',')
ngpu = len(gpus)
else:
ngpu = 1
if 'accumulate_grad_batches' in lightning_config.trainer:

View File

@ -1,5 +1,6 @@
albumentations==0.4.3
einops==0.3.0
diffusers==0.6.0
huggingface-hub==0.8.1
imageio==2.9.0
imageio-ffmpeg==0.4.2

View File

@ -32,7 +32,8 @@ send2trash
dependency_injector==4.40.0
eventlet
realesrgan
diffusers
git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
-e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg

View File

@ -17,9 +17,15 @@ from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from omegaconf import OmegaConf
from pathlib import Path
# global used in multiple functions (fix)
infile = None
def main():
"""Initialize command-line parsers and the diffusion model"""
global infile
opt = Args()
args = opt.parse_args()
if not args:
@ -48,7 +54,6 @@ def main():
os.makedirs(opt.outdir)
# load the infile as a list of lines
infile = None
if opt.infile:
try:
if os.path.isfile(opt.infile):
@ -64,16 +69,17 @@ def main():
# creating a Generate object:
try:
gen = Generate(
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker,
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
@ -96,14 +102,16 @@ def main():
)
try:
main_loop(gen, opt, infile)
main_loop(gen, opt)
except KeyboardInterrupt:
print("\ngoodbye!")
# TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, opt, infile):
def main_loop(gen, opt):
"""prompt/read/execute loop"""
global infile
done = False
doneAfterInFile = infile is not None
path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list()
model_config = OmegaConf.load(opt.conf)
@ -130,7 +138,8 @@ def main_loop(gen, opt, infile):
try:
command = get_next_command(infile)
except EOFError:
done = True
done = infile is None or doneAfterInFile
infile = None
continue
# skip empty lines
@ -368,7 +377,10 @@ def main_loop(gen, opt, infile):
print('goodbye!')
# TO DO: remove repetitive code and the awkward command.replace() trope
# Just do a simple parse of the command!
def do_command(command:str, gen, opt:Args, completer) -> tuple:
global infile
operation = 'generate' # default operation, alternative is 'postprocess'
if command.startswith('!dream'): # in case a stored prompt still contains the !dream command
@ -413,9 +425,26 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
completer.add_history(command)
operation = None
elif command.startswith('!del'):
path = shlex.split(command)
if len(path) < 2:
print('** please provide the name of a model')
else:
del_config(path[1], gen, opt, completer)
completer.add_history(command)
operation = None
elif command.startswith('!fetch'):
file_path = command.replace('!fetch ','',1)
file_path = command.replace('!fetch','',1).strip()
retrieve_dream_command(opt,file_path,completer)
completer.add_history(command)
operation = None
elif command.startswith('!replay'):
file_path = command.replace('!replay','',1).strip()
if infile is None and os.path.isfile(file_path):
infile = open(file_path, 'r', encoding='utf-8')
completer.add_history(command)
operation = None
elif command.startswith('!history'):
@ -423,7 +452,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
operation = None
elif command.startswith('!search'):
search_str = command.replace('!search ','',1)
search_str = command.replace('!search','',1).strip()
completer.show_history(search_str)
operation = None
@ -465,6 +494,16 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
new_config['config'] = input('Configuration file for this model: ')
done = os.path.exists(new_config['config'])
done = False
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
while not done:
vae = input('VAE autoencoder file for this model [None]: ')
if os.path.exists(vae):
new_config['vae'] = vae
done = True
else:
done = len(vae)==0
completer.complete_extensions(None)
for field in ('width','height'):
@ -479,9 +518,25 @@ def add_weights_to_config(model_path:str, gen, opt, completer):
except:
print('** Please enter a valid integer between 64 and 2048')
if write_config_file(opt.conf, gen, model_name, new_config):
gen.set_model(model_name)
make_default = input('Make this the default model? [n] ') in ('y','Y')
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
completer.add_model(model_name)
def del_config(model_name:str, gen, opt, completer):
current_model = gen.model_name
if model_name == current_model:
print("** Can't delete active model. !switch to another model first. **")
return
yaml_str = gen.model_cache.del_model(model_name)
tmpfile = os.path.join(os.path.dirname(opt.conf),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(yaml_str)
os.rename(tmpfile,opt.conf)
print(f'** {model_name} deleted')
completer.del_model(model_name)
def edit_config(model_name:str, gen, opt, completer):
config = gen.model_cache.config
@ -493,33 +548,46 @@ def edit_config(model_name:str, gen, opt, completer):
conf = config[model_name]
new_config = {}
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae'))
for field in ('description', 'weights', 'config', 'width','height'):
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
completer.linebuffer = str(conf[field]) if field in conf else ''
new_value = input(f'{field}: ')
new_config[field] = int(new_value) if field in ('width','height') else new_value
make_default = input('Make this the default model? [n] ') in ('y','Y')
completer.complete_extensions(None)
if write_config_file(opt.conf, gen, model_name, new_config, clobber=True):
gen.set_model(model_name)
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
current_model = gen.model_name
def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
op = 'modify' if clobber else 'import'
print('\n>> New configuration:')
if make_default:
new_config['default'] = True
print(yaml.dump({model_name:new_config}))
if input(f'OK to {op} [n]? ') not in ('y','Y'):
return False
try:
print('>> Verifying that new model loads...')
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber)
assert gen.set_model(model_name) is not None, 'model failed to load'
except AssertionError as e:
print(f'** configuration failed: {str(e)}')
print(f'** aborting **')
gen.model_cache.del_model(model_name)
return False
if make_default:
print('making this default')
gen.model_cache.set_default_model(model_name)
gen.model_cache.commit(conf_path)
tmpfile = os.path.join(os.path.dirname(conf_path),'new_config.tmp')
with open(tmpfile, 'w') as outfile:
outfile.write(yaml_str)
os.rename(tmpfile,conf_path)
do_switch = input(f'Keep model loaded? [y]')
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
pass
else:
gen.set_model(current_model)
return True
def do_textmask(gen, opt, callback):
@ -579,7 +647,10 @@ def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file)
new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file)
meta = retrieve_metadata(original_file)['sd-metadata']
img_data = meta['image']
if 'image' not in meta:
meta = metadata_dumps(opt,seeds=[opt.seed])['image']
meta['image'] = {}
img_data = meta.get('image')
pp = img_data.get('postprocessing',[]) or []
pp.append(
{
@ -723,27 +794,71 @@ def make_step_callback(gen, opt, prefix):
image.save(filename,'PNG')
return callback
def retrieve_dream_command(opt,file_path,completer):
def retrieve_dream_command(opt,command,completer):
'''
Given a full or partial path to a previously-generated image file,
will retrieve and format the dream command used to generate the image,
and pop it into the readline buffer (linux, Mac), or print out a comment
for cut-and-paste (windows)
Given a wildcard path to a folder with image png files,
will retrieve and format the dream command used to generate the images,
and save them to a file commands.txt for further processing
'''
dir,basename = os.path.split(file_path)
if len(command) == 0:
return
tokens = command.split()
dir,basename = os.path.split(tokens[0])
if len(dir) == 0:
path = os.path.join(opt.outdir,basename)
else:
path = file_path
path = tokens[0]
if len(tokens) > 1:
return write_commands(opt, path, tokens[1])
cmd = ''
try:
cmd = dream_cmd_from_png(path)
except OSError:
print(f'** {path}: file could not be read')
print(f'## {tokens[0]}: file could not be read')
except (KeyError, AttributeError, IndexError):
print(f'## {tokens[0]}: file has no metadata')
except:
print(f'## {tokens[0]}: file could not be processed')
if len(cmd)>0:
completer.set_line(cmd)
def write_commands(opt, file_path:str, outfilepath:str):
dir,basename = os.path.split(file_path)
try:
paths = list(Path(dir).glob(basename))
except ValueError:
print(f'## "{basename}": unacceptable pattern')
return
except (KeyError, AttributeError):
print(f'** {path}: file has no metadata')
return
completer.set_line(cmd)
commands = []
cmd = None
for path in paths:
try:
cmd = dream_cmd_from_png(path)
except (KeyError, AttributeError, IndexError):
print(f'## {path}: file has no metadata')
except:
print(f'## {path}: file could not be processed')
if cmd:
commands.append(f'# {path}')
commands.append(cmd)
if len(commands)>0:
dir,basename = os.path.split(outfilepath)
if len(dir)==0:
outfilepath = os.path.join(opt.outdir,basename)
with open(outfilepath, 'w', encoding='utf-8') as f:
f.write('\n'.join(commands))
print(f'>> File {outfilepath} with commands created')
######################################
if __name__ == '__main__':
main()

View File

@ -5,7 +5,7 @@
# two machines must share a common .cache directory.
from transformers import CLIPTokenizer, CLIPTextModel
import clip
from transformers import BertTokenizerFast
from transformers import BertTokenizerFast, AutoFeatureExtractor
import sys
import transformers
import os
@ -17,41 +17,39 @@ import traceback
transformers.logging.set_verbosity_error()
#---------------------------------------------
# this will preload the Bert tokenizer fles
print('Loading bert tokenizer (ignore deprecation errors)...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
sys.stdout.flush()
def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
sys.stdout.flush()
#---------------------------------------------
# this will download requirements for Kornia
print('Loading Kornia requirements...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia
print('...success')
def download_kornia():
print('Installing Kornia requirements...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia
print('...success')
version = 'openai/clip-vit-large-patch14'
sys.stdout.flush()
print('Loading CLIP model...',end='')
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('...success')
#---------------------------------------------
def download_clip():
version = 'openai/clip-vit-large-patch14'
sys.stdout.flush()
print('Loading CLIP model...',end='')
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('...success')
# In the event that the user has installed GFPGAN and also elected to use
# RealESRGAN, this will attempt to download the model needed by RealESRGANer
gfpgan = False
try:
from realesrgan import RealESRGANer
gfpgan = True
except ModuleNotFoundError:
pass
if gfpgan:
print('Loading models from RealESRGAN and facexlib...',end='')
#---------------------------------------------
def download_gfpgan():
print('Installing models from RealESRGAN and facexlib...',end='')
try:
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@ -94,44 +92,72 @@ if gfpgan:
print('Error loading GFPGAN:')
print(traceback.format_exc())
print('preloading CodeFormer model file...',end='')
try:
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
if not os.path.exists(model_dest):
print('Downloading codeformer model file...')
#---------------------------------------------
def download_codeformer():
print('Installing CodeFormer model file...',end='')
try:
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
if not os.path.exists(model_dest):
print('Downloading codeformer model file...')
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
except Exception:
print('Error loading CodeFormer:')
print(traceback.format_exc())
print('...success')
#---------------------------------------------
def download_clipseg():
print('Installing clipseg model for text-based masking...',end='')
try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip'
weights_dir = 'src/clipseg/weights'
if not os.path.exists(weights_dir):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
except Exception:
print('Error loading CodeFormer:')
print(traceback.format_exc())
print('...success')
with zipfile.ZipFile(model_dest,'r') as zip:
zip.extractall('src/clipseg')
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
os.remove(model_dest)
from clipseg_models.clipseg import CLIPDensePredT
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval()
model.load_state_dict(
torch.load(
'src/clipseg/weights/rd64-uni-refined.pth',
map_location=torch.device('cpu')
),
strict=False,
)
except Exception:
print('Error installing clipseg model:')
print(traceback.format_exc())
print('...success')
print('Loading clipseg model for text-based masking...',end='')
try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip'
weights_dir = 'src/clipseg/weights'
if not os.path.exists(weights_dir):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
with zipfile.ZipFile(model_dest,'r') as zip:
zip.extractall('src/clipseg')
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
os.remove(model_dest)
from clipseg_models.clipseg import CLIPDensePredT
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval()
model.load_state_dict(
torch.load(
'src/clipseg/weights/rd64-uni-refined.pth',
map_location=torch.device('cpu')
),
strict=False,
)
except Exception:
print('Error installing clipseg model:')
print(traceback.format_exc())
print('...success')
#-------------------------------------
def download_safety_checker():
print('Installing safety model for NSFW content detection...',end='')
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
except ModuleNotFoundError:
print('Error installing safety checker model:')
print(traceback.format_exc())
return
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
print('...success')
#-------------------------------------
if __name__ == '__main__':
download_bert()
download_kornia()
download_clip()
download_gfpgan()
download_codeformer()
download_clipseg()
download_safety_checker()

162
shell.nix Normal file
View File

@ -0,0 +1,162 @@
{ pkgs ? import <nixpkgs> {}
, lib ? pkgs.lib
, stdenv ? pkgs.stdenv
, fetchurl ? pkgs.fetchurl
, runCommand ? pkgs.runCommand
, makeWrapper ? pkgs.makeWrapper
, mkShell ? pkgs.mkShell
, buildFHSUserEnv ? pkgs.buildFHSUserEnv
, frameworks ? pkgs.darwin.apple_sdk.frameworks
}:
# Setup InvokeAI environment using nix
# Simple usage:
# nix-shell
# python3 scripts/preload_models.py
# python3 scripts/invoke.py -h
let
conda-shell = { url, sha256, installPath, packages, shellHook }:
let
src = fetchurl { inherit url sha256; };
libPath = lib.makeLibraryPath ([] ++ lib.optionals (stdenv.isLinux) [ pkgs.zlib ]);
condaArch = if stdenv.system == "aarch64-darwin" then "osx-arm64" else "";
installer =
if stdenv.isDarwin then
runCommand "conda-install" {
nativeBuildInputs = [ makeWrapper ];
} ''
mkdir -p $out/bin
cp ${src} $out/bin/miniconda-installer.sh
chmod +x $out/bin/miniconda-installer.sh
makeWrapper \
$out/bin/miniconda-installer.sh \
$out/bin/conda-install \
--add-flags "-p ${installPath}" \
--add-flags "-b"
''
else if stdenv.isLinux then
runCommand "conda-install" {
nativeBuildInputs = [ makeWrapper ];
buildInputs = [ pkgs.zlib ];
}
# on line 10, we have 'unset LD_LIBRARY_PATH'
# we have to comment it out however in a way that the number of bytes in the
# file does not change. So we replace the 'u' in the line with a '#'
# The reason is that the binary payload is encoded as number
# of bytes from the top of the installer script
# and unsetting the library path prevents the zlib library from being discovered
''
mkdir -p $out/bin
sed 's/unset LD_LIBRARY_PATH/#nset LD_LIBRARY_PATH/' ${src} > $out/bin/miniconda-installer.sh
chmod +x $out/bin/miniconda-installer.sh
makeWrapper \
$out/bin/miniconda-installer.sh \
$out/bin/conda-install \
--add-flags "-p ${installPath}" \
--add-flags "-b" \
--prefix "LD_LIBRARY_PATH" : "${libPath}"
''
else {};
hook = ''
export CONDA_SUBDIR=${condaArch}
'' + shellHook;
fhs = buildFHSUserEnv {
name = "conda-shell";
targetPkgs = pkgs: [ stdenv.cc pkgs.git installer ] ++ packages;
profile = hook;
runScript = "bash";
};
shell = mkShell {
shellHook = if stdenv.isDarwin then hook else "conda-shell; exit";
packages = if stdenv.isDarwin then [ pkgs.git installer ] ++ packages else [ fhs ];
};
in shell;
packages = with pkgs; [
cmake
protobuf
libiconv
rustc
cargo
rustPlatform.bindgenHook
];
env = {
aarch64-darwin = {
envFile = "environment-mac.yml";
condaPath = (builtins.toString ./.) + "/.conda";
ptrSize = "8";
};
x86_64-linux = {
envFile = "environment.yml";
condaPath = (builtins.toString ./.) + "/.conda";
ptrSize = "8";
};
};
envFile = env.${stdenv.system}.envFile;
installPath = env.${stdenv.system}.condaPath;
ptrSize = env.${stdenv.system}.ptrSize;
shellHook = ''
conda-install
# tmpdir is too small in nix
export TMPDIR="${installPath}/tmp"
# Add conda to PATH
export PATH="${installPath}/bin:$PATH"
# Allows `conda activate` to work properly
source ${installPath}/etc/profile.d/conda.sh
# Paths for gcc if compiling some C sources with pip
export NIX_CFLAGS_COMPILE="-I${installPath}/include -I$TMPDIR/include"
export NIX_CFLAGS_LINK="-L${installPath}/lib $BINDGEN_EXTRA_CLANG_ARGS"
export PIP_EXISTS_ACTION=w
# rust-onig fails (think it writes config.h to wrong location)
mkdir -p "$TMPDIR/include"
cat <<'EOF' > "$TMPDIR/include/config.h"
#define HAVE_PROTOTYPES 1
#define STDC_HEADERS 1
#define HAVE_STRING_H 1
#define HAVE_STDARG_H 1
#define HAVE_STDLIB_H 1
#define HAVE_LIMITS_H 1
#define HAVE_INTTYPES_H 1
#define SIZEOF_INT 4
#define SIZEOF_SHORT 2
#define SIZEOF_LONG ${ptrSize}
#define SIZEOF_VOIDP ${ptrSize}
#define SIZEOF_LONG_LONG 8
EOF
conda env create -f "${envFile}" || conda env update --prune -f "${envFile}"
conda activate invokeai
'';
version = "4.12.0";
conda = {
aarch64-darwin = {
shell = conda-shell {
inherit shellHook installPath;
url = "https://repo.anaconda.com/miniconda/Miniconda3-py39_${version}-MacOSX-arm64.sh";
sha256 = "4bd112168cc33f8a4a60d3ef7e72b52a85972d588cd065be803eb21d73b625ef";
packages = [ frameworks.Security ] ++ packages;
};
};
x86_64-linux = {
shell = conda-shell {
inherit shellHook installPath;
url = "https://repo.continuum.io/miniconda/Miniconda3-py39_${version}-Linux-x86_64.sh";
sha256 = "78f39f9bae971ec1ae7969f0516017f2413f17796670f7040725dd83fcff5689";
packages = with pkgs; [ libGL glib ] ++ packages;
};
};
};
in conda.${stdenv.system}.shell