mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into merge-main-into-development
This commit is contained in:
commit
36870a8f53
32
.github/workflows/build-container.yml
vendored
32
.github/workflows/build-container.yml
vendored
@ -6,14 +6,22 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- 'main'
|
- 'main'
|
||||||
- 'development'
|
- 'development'
|
||||||
pull_request:
|
|
||||||
branches:
|
|
||||||
- 'main'
|
|
||||||
- 'development'
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
docker:
|
docker:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch:
|
||||||
|
- x86_64
|
||||||
|
- aarch64
|
||||||
|
include:
|
||||||
|
- arch: x86_64
|
||||||
|
conda-env-file: environment.yml
|
||||||
|
- arch: aarch64
|
||||||
|
conda-env-file: environment-linux-aarch64.yml
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
name: ${{ matrix.arch }}
|
||||||
steps:
|
steps:
|
||||||
- name: prepare docker-tag
|
- name: prepare docker-tag
|
||||||
env:
|
env:
|
||||||
@ -25,18 +33,16 @@ jobs:
|
|||||||
uses: docker/setup-qemu-action@v2
|
uses: docker/setup-qemu-action@v2
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v2
|
||||||
- name: Cache Docker layers
|
|
||||||
uses: actions/cache@v2
|
|
||||||
with:
|
|
||||||
path: /tmp/.buildx-cache
|
|
||||||
key: buildx-${{ hashFiles('docker-build/Dockerfile') }}
|
|
||||||
- name: Build container
|
- name: Build container
|
||||||
uses: docker/build-push-action@v3
|
uses: docker/build-push-action@v3
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: docker-build/Dockerfile
|
file: docker-build/Dockerfile
|
||||||
platforms: linux/amd64
|
platforms: Linux/${{ matrix.arch }}
|
||||||
push: false
|
push: false
|
||||||
tags: ${{ env.dockertag }}:latest
|
tags: ${{ env.dockertag }}:${{ matrix.arch }}
|
||||||
cache-from: type=local,src=/tmp/.buildx-cache
|
build-args: |
|
||||||
cache-to: type=local,dest=/tmp/.buildx-cache
|
conda_env_file=${{ matrix.conda-env-file }}
|
||||||
|
conda_version=py39_4.12.0-Linux-${{ matrix.arch }}
|
||||||
|
invokeai_git=${{ github.repository }}
|
||||||
|
invokeai_branch=${{ github.ref_name }}
|
||||||
|
10
.github/workflows/test-invoke-conda.yml
vendored
10
.github/workflows/test-invoke-conda.yml
vendored
@ -76,8 +76,18 @@ jobs:
|
|||||||
if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }}
|
if: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/development' }}
|
||||||
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> $GITHUB_ENV
|
run: echo "TEST_PROMPTS=tests/validate_pr_prompt.txt" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Use Cached Stable Diffusion Model
|
||||||
|
id: cache-sd-model
|
||||||
|
uses: actions/cache@v3
|
||||||
|
env:
|
||||||
|
cache-name: cache-${{ matrix.stable-diffusion-model-switch }}
|
||||||
|
with:
|
||||||
|
path: ${{ matrix.stable-diffusion-model-dl-path }}
|
||||||
|
key: ${{ env.cache-name }}
|
||||||
|
|
||||||
- name: Download ${{ matrix.stable-diffusion-model-switch }}
|
- name: Download ${{ matrix.stable-diffusion-model-switch }}
|
||||||
id: download-stable-diffusion-model
|
id: download-stable-diffusion-model
|
||||||
|
if: ${{ steps.cache-sd-model.outputs.cache-hit != 'true' }}
|
||||||
run: |
|
run: |
|
||||||
[[ -d models/ldm/stable-diffusion-v1 ]] \
|
[[ -d models/ldm/stable-diffusion-v1 ]] \
|
||||||
|| mkdir -p models/ldm/stable-diffusion-v1
|
|| mkdir -p models/ldm/stable-diffusion-v1
|
||||||
|
@ -39,12 +39,13 @@ RUN apt-get update \
|
|||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# clone repository and create symlinks
|
# clone repository, create models.yaml and create symlinks
|
||||||
ARG invokeai_git=https://github.com/invoke-ai/InvokeAI.git
|
ARG invokeai_git=invoke-ai/InvokeAI
|
||||||
|
ARG invokeai_branch=main
|
||||||
ARG project_name=invokeai
|
ARG project_name=invokeai
|
||||||
RUN git clone ${invokeai_git} /${project_name} \
|
RUN git clone -b ${invokeai_branch} https://github.com/${invokeai_git}.git /${project_name} \
|
||||||
&& mkdir /${project_name}/models/ldm/stable-diffusion-v1 \
|
&& cp /${project_name}/configs/models.yaml.example /${project_name}/configs/models.yaml \
|
||||||
&& ln -s /data/models/sd-v1-4.ckpt /${project_name}/models/ldm/stable-diffusion-v1/model.ckpt \
|
&& ln -s /data/models/v1-5-pruned-emaonly.ckpt /${project_name}/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt \
|
||||||
&& ln -s /data/outputs/ /${project_name}/outputs
|
&& ln -s /data/outputs/ /${project_name}/outputs
|
||||||
|
|
||||||
# set workdir
|
# set workdir
|
||||||
@ -63,9 +64,9 @@ RUN source ${conda_prefix}/etc/profile.d/conda.sh \
|
|||||||
&& rm -Rf ~/.cache \
|
&& rm -Rf ~/.cache \
|
||||||
&& conda clean -afy \
|
&& conda clean -afy \
|
||||||
&& echo "conda activate ${project_name}" >> ~/.bashrc \
|
&& echo "conda activate ${project_name}" >> ~/.bashrc \
|
||||||
&& ln -s /data/models/GFPGANv1.4.pth ./src/gfpgan/experiments/pretrained_models/GFPGANv1.4.pth \
|
|
||||||
&& conda activate ${project_name} \
|
&& conda activate ${project_name} \
|
||||||
&& python scripts/preload_models.py
|
&& python scripts/preload_models.py \
|
||||||
|
--no-interactive
|
||||||
|
|
||||||
# Copy entrypoint and set env
|
# Copy entrypoint and set env
|
||||||
ENV CONDA_PREFIX=${conda_prefix}
|
ENV CONDA_PREFIX=${conda_prefix}
|
||||||
|
@ -9,7 +9,8 @@ source ./docker-build/env.sh || echo "please run from repository root" || exit 1
|
|||||||
invokeai_conda_version=${INVOKEAI_CONDA_VERSION:-py39_4.12.0-${platform/\//-}}
|
invokeai_conda_version=${INVOKEAI_CONDA_VERSION:-py39_4.12.0-${platform/\//-}}
|
||||||
invokeai_conda_prefix=${INVOKEAI_CONDA_PREFIX:-\/opt\/conda}
|
invokeai_conda_prefix=${INVOKEAI_CONDA_PREFIX:-\/opt\/conda}
|
||||||
invokeai_conda_env_file=${INVOKEAI_CONDA_ENV_FILE:-environment.yml}
|
invokeai_conda_env_file=${INVOKEAI_CONDA_ENV_FILE:-environment.yml}
|
||||||
invokeai_git=${INVOKEAI_GIT:-https://github.com/invoke-ai/InvokeAI.git}
|
invokeai_git=${INVOKEAI_GIT:-invoke-ai/InvokeAI}
|
||||||
|
invokeai_branch=${INVOKEAI_BRANCH:-main}
|
||||||
huggingface_token=${HUGGINGFACE_TOKEN?}
|
huggingface_token=${HUGGINGFACE_TOKEN?}
|
||||||
|
|
||||||
# print the settings
|
# print the settings
|
||||||
@ -38,11 +39,12 @@ _copyCheckpoints() {
|
|||||||
echo "creating subfolders for models and outputs"
|
echo "creating subfolders for models and outputs"
|
||||||
_runAlpine mkdir models
|
_runAlpine mkdir models
|
||||||
_runAlpine mkdir outputs
|
_runAlpine mkdir outputs
|
||||||
echo -n "downloading sd-v1-4.ckpt"
|
echo "downloading v1-5-pruned-emaonly.ckpt"
|
||||||
_runAlpine wget --header="Authorization: Bearer ${huggingface_token}" -O models/sd-v1-4.ckpt https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
|
_runAlpine wget \
|
||||||
|
--header="Authorization: Bearer ${huggingface_token}" \
|
||||||
|
-O models/v1-5-pruned-emaonly.ckpt \
|
||||||
|
https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt
|
||||||
echo "done"
|
echo "done"
|
||||||
echo "downloading GFPGANv1.4.pth"
|
|
||||||
_runAlpine wget -O models/GFPGANv1.4.pth https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_checkVolumeContent() {
|
_checkVolumeContent() {
|
||||||
@ -51,7 +53,7 @@ _checkVolumeContent() {
|
|||||||
|
|
||||||
_getModelMd5s() {
|
_getModelMd5s() {
|
||||||
_runAlpine \
|
_runAlpine \
|
||||||
alpine sh -c "md5sum /data/models/*"
|
alpine sh -c "md5sum /data/models/*.ckpt"
|
||||||
}
|
}
|
||||||
|
|
||||||
if [[ -n "$(docker volume ls -f name="${volumename}" -q)" ]]; then
|
if [[ -n "$(docker volume ls -f name="${volumename}" -q)" ]]; then
|
||||||
@ -77,5 +79,6 @@ docker build \
|
|||||||
--build-arg conda_prefix="${invokeai_conda_prefix}" \
|
--build-arg conda_prefix="${invokeai_conda_prefix}" \
|
||||||
--build-arg conda_env_file="${invokeai_conda_env_file}" \
|
--build-arg conda_env_file="${invokeai_conda_env_file}" \
|
||||||
--build-arg invokeai_git="${invokeai_git}" \
|
--build-arg invokeai_git="${invokeai_git}" \
|
||||||
|
--build-arg invokeai_branch="${invokeai_branch}" \
|
||||||
--file ./docker-build/Dockerfile \
|
--file ./docker-build/Dockerfile \
|
||||||
.
|
.
|
||||||
|
@ -3,15 +3,14 @@ channels:
|
|||||||
- pytorch
|
- pytorch
|
||||||
- conda-forge
|
- conda-forge
|
||||||
dependencies:
|
dependencies:
|
||||||
- python>=3.9
|
- python=3.9.*
|
||||||
- pip>=20.3
|
- pip>=22.2.2
|
||||||
- cudatoolkit
|
- cudatoolkit
|
||||||
- pytorch
|
- pytorch
|
||||||
- torchvision
|
- torchvision
|
||||||
- numpy=1.19
|
- numpy=1.19
|
||||||
- imageio=2.9.0
|
- imageio=2.9.0
|
||||||
- opencv=4.6.0
|
- opencv=4.6.0
|
||||||
- getpass_asterisk
|
|
||||||
- pillow=8.*
|
- pillow=8.*
|
||||||
- flask=2.1.*
|
- flask=2.1.*
|
||||||
- flask_cors=3.0.10
|
- flask_cors=3.0.10
|
||||||
@ -30,6 +29,7 @@ dependencies:
|
|||||||
- torch-fidelity=0.3.0
|
- torch-fidelity=0.3.0
|
||||||
- tokenizers>=0.11.1,!=0.11.3,<0.13
|
- tokenizers>=0.11.1,!=0.11.3,<0.13
|
||||||
- pip:
|
- pip:
|
||||||
|
- getpass_asterisk
|
||||||
- omegaconf==2.1.1
|
- omegaconf==2.1.1
|
||||||
- realesrgan==0.2.5.0
|
- realesrgan==0.2.5.0
|
||||||
- test-tube>=0.7.5
|
- test-tube>=0.7.5
|
||||||
|
501
frontend/dist/assets/index.8eb7dfe4.js
vendored
Normal file
501
frontend/dist/assets/index.8eb7dfe4.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -1,50 +0,0 @@
|
|||||||
import { useAppDispatch } from '../../../app/store';
|
|
||||||
import IAISelect from '../../../common/components/IAISelect';
|
|
||||||
import IAISwitch from '../../../common/components/IAISwitch';
|
|
||||||
|
|
||||||
export function SettingsModalItem({
|
|
||||||
settingTitle,
|
|
||||||
isChecked,
|
|
||||||
dispatcher,
|
|
||||||
}: {
|
|
||||||
settingTitle: string;
|
|
||||||
isChecked: boolean;
|
|
||||||
dispatcher: any;
|
|
||||||
}) {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
return (
|
|
||||||
<IAISwitch
|
|
||||||
styleClass="settings-modal-item"
|
|
||||||
label={settingTitle}
|
|
||||||
isChecked={isChecked}
|
|
||||||
onChange={(e) => dispatch(dispatcher(e.target.checked))}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
export function SettingsModalSelectItem({
|
|
||||||
settingTitle,
|
|
||||||
validValues,
|
|
||||||
defaultValue,
|
|
||||||
dispatcher,
|
|
||||||
}: {
|
|
||||||
settingTitle: string;
|
|
||||||
validValues:
|
|
||||||
Array<number | string>
|
|
||||||
| Array<{ key: string; value: string | number }>;
|
|
||||||
defaultValue: string;
|
|
||||||
dispatcher: any;
|
|
||||||
}) {
|
|
||||||
const dispatch = useAppDispatch();
|
|
||||||
return (
|
|
||||||
<IAISelect
|
|
||||||
styleClass="settings-modal-item"
|
|
||||||
label={settingTitle}
|
|
||||||
validValues={validValues}
|
|
||||||
defaultValue={defaultValue}
|
|
||||||
onChange={(e) => dispatch(dispatcher(e.target.value))}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
@ -119,19 +119,19 @@ class Generator():
|
|||||||
# write an approximate RGB image from latent samples for a single step to PNG
|
# write an approximate RGB image from latent samples for a single step to PNG
|
||||||
|
|
||||||
def sample_to_lowres_estimated_image(self,samples):
|
def sample_to_lowres_estimated_image(self,samples):
|
||||||
# adapted from code by @erucipe and @keturn here:
|
# origingally adapted from code by @erucipe and @keturn here:
|
||||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||||
|
|
||||||
# these numbers were determined empirically by @keturn
|
# these updated numbers for v1.5 are from @torridgristle
|
||||||
v1_4_latent_rgb_factors = torch.tensor([
|
v1_5_latent_rgb_factors = torch.tensor([
|
||||||
# R G B
|
# R G B
|
||||||
[ 0.298, 0.207, 0.208], # L1
|
[ 0.3444, 0.1385, 0.0670], # L1
|
||||||
[ 0.187, 0.286, 0.173], # L2
|
[ 0.1247, 0.4027, 0.1494], # L2
|
||||||
[-0.158, 0.189, 0.264], # L3
|
[-0.3192, 0.2513, 0.2103], # L3
|
||||||
[-0.184, -0.271, -0.473], # L4
|
[-0.1307, -0.1874, -0.7445] # L4
|
||||||
], dtype=samples.dtype, device=samples.device)
|
], dtype=samples.dtype, device=samples.device)
|
||||||
|
|
||||||
latent_image = samples[0].permute(1, 2, 0) @ v1_4_latent_rgb_factors
|
latent_image = samples[0].permute(1, 2, 0) @ v1_5_latent_rgb_factors
|
||||||
latents_ubyte = (((latent_image + 1) / 2)
|
latents_ubyte = (((latent_image + 1) / 2)
|
||||||
.clamp(0, 1) # change scale from -1..1 to 0..1
|
.clamp(0, 1) # change scale from -1..1 to 0..1
|
||||||
.mul(0xFF) # to 0..255
|
.mul(0xFF) # to 0..255
|
||||||
|
@ -28,7 +28,7 @@ class Prompt():
|
|||||||
def __init__(self, parts: list):
|
def __init__(self, parts: list):
|
||||||
for c in parts:
|
for c in parts:
|
||||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
||||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed")
|
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
|
||||||
self.children = parts
|
self.children = parts
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Prompt:{self.children}"
|
return f"Prompt:{self.children}"
|
||||||
@ -102,12 +102,18 @@ class Attention():
|
|||||||
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
||||||
"""
|
"""
|
||||||
def __init__(self, weight: float, children: list):
|
def __init__(self, weight: float, children: list):
|
||||||
|
if type(weight) is not float:
|
||||||
|
raise PromptParser.ParsingException(
|
||||||
|
f"Attention weight must be float (got {type(weight).__name__} {weight})")
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
if type(children) is not list:
|
||||||
|
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
|
||||||
|
assert(type(children) is list)
|
||||||
self.children = children
|
self.children = children
|
||||||
#print(f"A: requested attention '{children}' to {weight}")
|
#print(f"A: requested attention '{children}' to {weight}")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Attention:'{self.children}' @ {self.weight}"
|
return f"Attention:{self.children} * {self.weight}"
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||||
|
|
||||||
@ -136,9 +142,9 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
|||||||
Fragment('sitting on a car')
|
Fragment('sitting on a car')
|
||||||
])
|
])
|
||||||
"""
|
"""
|
||||||
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None):
|
def __init__(self, original: list, edited: list, options: dict=None):
|
||||||
self.original = original
|
self.original = original if len(original)>0 else [Fragment('')]
|
||||||
self.edited = edited
|
self.edited = edited if len(edited)>0 else [Fragment('')]
|
||||||
|
|
||||||
default_options = {
|
default_options = {
|
||||||
's_start': 0.0,
|
's_start': 0.0,
|
||||||
@ -190,12 +196,12 @@ class Conjunction():
|
|||||||
"""
|
"""
|
||||||
def __init__(self, prompts: list, weights: list = None):
|
def __init__(self, prompts: list, weights: list = None):
|
||||||
# force everything to be a Prompt
|
# force everything to be a Prompt
|
||||||
#print("making conjunction with", parts)
|
#print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
|
||||||
self.prompts = [x if (type(x) is Prompt
|
self.prompts = [x if (type(x) is Prompt
|
||||||
or type(x) is Blend
|
or type(x) is Blend
|
||||||
or type(x) is FlattenedPrompt)
|
or type(x) is FlattenedPrompt)
|
||||||
else Prompt(x) for x in prompts]
|
else Prompt(x) for x in prompts]
|
||||||
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
|
self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||||
if len(self.weights) != len(self.prompts):
|
if len(self.weights) != len(self.prompts):
|
||||||
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
||||||
self.type = 'AND'
|
self.type = 'AND'
|
||||||
@ -216,6 +222,7 @@ class Blend():
|
|||||||
"""
|
"""
|
||||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
||||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||||
|
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||||
if len(prompts) != len(weights):
|
if len(prompts) != len(weights):
|
||||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||||
for p in prompts:
|
for p in prompts:
|
||||||
@ -244,6 +251,10 @@ class PromptParser():
|
|||||||
class ParsingException(Exception):
|
class ParsingException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class UnrecognizedOperatorException(ParsingException):
|
||||||
|
def __init__(self, operator:str):
|
||||||
|
super().__init__("Unrecognized operator: " + operator)
|
||||||
|
|
||||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
||||||
|
|
||||||
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
||||||
@ -279,7 +290,7 @@ class PromptParser():
|
|||||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
||||||
|
|
||||||
|
|
||||||
def flatten(self, root: Conjunction) -> Conjunction:
|
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
||||||
"""
|
"""
|
||||||
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
||||||
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
||||||
@ -289,8 +300,6 @@ class PromptParser():
|
|||||||
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#print("flattening", root)
|
|
||||||
|
|
||||||
def fuse_fragments(items):
|
def fuse_fragments(items):
|
||||||
# print("fusing fragments in ", items)
|
# print("fusing fragments in ", items)
|
||||||
result = []
|
result = []
|
||||||
@ -313,8 +322,8 @@ class PromptParser():
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def flatten_internal(node, weight_scale, results, prefix):
|
def flatten_internal(node, weight_scale, results, prefix):
|
||||||
#print(prefix + "flattening", node, "...")
|
verbose and print(prefix + "flattening", node, "...")
|
||||||
if type(node) is pp.ParseResults:
|
if type(node) is pp.ParseResults or type(node) is list:
|
||||||
for x in node:
|
for x in node:
|
||||||
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
||||||
#print(prefix, " ParseResults expanded, results is now", results)
|
#print(prefix, " ParseResults expanded, results is now", results)
|
||||||
@ -345,67 +354,59 @@ class PromptParser():
|
|||||||
#print(prefix + "after flattening Prompt, results is", results)
|
#print(prefix + "after flattening Prompt, results is", results)
|
||||||
else:
|
else:
|
||||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||||
#print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
verbose and print("flattening", root)
|
||||||
|
|
||||||
flattened_parts = []
|
flattened_parts = []
|
||||||
for part in root.prompts:
|
for part in root.prompts:
|
||||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
||||||
|
|
||||||
#print("flattened to", flattened_parts)
|
verbose and print("flattened to", flattened_parts)
|
||||||
|
|
||||||
weights = root.weights
|
weights = root.weights
|
||||||
return Conjunction(flattened_parts, weights)
|
return Conjunction(flattened_parts, weights)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
||||||
|
def make_operator_object(x):
|
||||||
|
#print('making operator for', x)
|
||||||
|
target = x[0]
|
||||||
|
operator = x[1]
|
||||||
|
arguments = x[2]
|
||||||
|
if operator == '.attend':
|
||||||
|
weight_raw = arguments[0]
|
||||||
|
weight = 1.0
|
||||||
|
if type(weight_raw) is float or type(weight_raw) is int:
|
||||||
|
weight = weight_raw
|
||||||
|
elif type(weight_raw) is str:
|
||||||
|
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
||||||
|
weight = pow(base, len(weight_raw))
|
||||||
|
return Attention(weight=weight, children=[x for x in x[0]])
|
||||||
|
elif operator == '.swap':
|
||||||
|
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
|
||||||
|
elif operator == '.blend':
|
||||||
|
prompts = [Prompt(p) for p in x[0]]
|
||||||
|
weights_raw = x[2]
|
||||||
|
normalize_weights = True
|
||||||
|
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
|
||||||
|
normalize_weights = False
|
||||||
|
weights_raw = weights_raw[:-1]
|
||||||
|
weights = [float(w[0]) for w in weights_raw]
|
||||||
|
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
|
||||||
|
elif operator == '.and' or operator == '.add':
|
||||||
|
prompts = [Prompt(p) for p in x[0]]
|
||||||
|
weights = [float(w[0]) for w in x[2]]
|
||||||
|
return Conjunction(prompts=prompts, weights=weights)
|
||||||
|
|
||||||
lparen = pp.Literal("(").suppress()
|
raise PromptParser.UnrecognizedOperatorException(operator)
|
||||||
rparen = pp.Literal(")").suppress()
|
|
||||||
quotes = pp.Literal('"').suppress()
|
|
||||||
comma = pp.Literal(",").suppress()
|
|
||||||
|
|
||||||
# accepts int or float notation, always maps to float
|
def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
|
||||||
number = pp.pyparsing_common.real | \
|
|
||||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
|
||||||
|
|
||||||
attention = pp.Forward()
|
|
||||||
quoted_fragment = pp.Forward()
|
|
||||||
parenthesized_fragment = pp.Forward()
|
|
||||||
cross_attention_substitute = pp.Forward()
|
|
||||||
|
|
||||||
def make_text_fragment(x):
|
|
||||||
#print("### making fragment for", x)
|
|
||||||
if type(x[0]) is Fragment:
|
|
||||||
assert(False)
|
|
||||||
if type(x) is str:
|
|
||||||
return Fragment(x)
|
|
||||||
elif type(x) is pp.ParseResults or type(x) is list:
|
|
||||||
#print(f'converting {type(x).__name__} to Fragment')
|
|
||||||
return Fragment(' '.join([s for s in x]))
|
|
||||||
else:
|
|
||||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
|
||||||
|
|
||||||
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
|
|
||||||
escapes = []
|
|
||||||
for c in escaped_chars_to_ignore:
|
|
||||||
escapes.append(pp.Literal('\\'+c))
|
|
||||||
return pp.Combine(pp.OneOrMore(
|
|
||||||
pp.MatchFirst(escapes + [pp.CharsNotIn(
|
|
||||||
string.whitespace + escaped_chars_to_ignore,
|
|
||||||
exact=1
|
|
||||||
)])
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
|
|
||||||
#print(f"parsing fragment string for {x}")
|
#print(f"parsing fragment string for {x}")
|
||||||
fragment_string = x[0]
|
fragment_string = x[0]
|
||||||
#print(f"ppparsing fragment string \"{fragment_string}\"")
|
|
||||||
|
|
||||||
if len(fragment_string.strip()) == 0:
|
if len(fragment_string.strip()) == 0:
|
||||||
return Fragment('')
|
return Fragment('')
|
||||||
|
|
||||||
@ -413,234 +414,198 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
# escape unescaped quotes
|
# escape unescaped quotes
|
||||||
fragment_string = fragment_string.replace('"', '\\"')
|
fragment_string = fragment_string.replace('"', '\\"')
|
||||||
|
|
||||||
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
|
|
||||||
try:
|
try:
|
||||||
result = pp.Group(pp.MatchFirst([
|
result = (expression + pp.StringEnd()).parse_string(fragment_string)
|
||||||
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
|
|
||||||
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
|
|
||||||
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
|
|
||||||
#print("parsed to", result)
|
#print("parsed to", result)
|
||||||
return result
|
return result
|
||||||
except pp.ParseException as e:
|
except pp.ParseException as e:
|
||||||
#print("parse_fragment_str couldn't parse prompt string:", e)
|
#print("parse_fragment_str couldn't parse prompt string:", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# meaningful symbols
|
||||||
|
lparen = pp.Literal("(").suppress()
|
||||||
|
rparen = pp.Literal(")").suppress()
|
||||||
|
quote = pp.Literal('"').suppress()
|
||||||
|
comma = pp.Literal(",").suppress()
|
||||||
|
dot = pp.Literal(".").suppress()
|
||||||
|
equals = pp.Literal("=").suppress()
|
||||||
|
|
||||||
|
escaped_lparen = pp.Literal('\\(')
|
||||||
|
escaped_rparen = pp.Literal('\\)')
|
||||||
|
escaped_quote = pp.Literal('\\"')
|
||||||
|
escaped_comma = pp.Literal('\\,')
|
||||||
|
escaped_dot = pp.Literal('\\.')
|
||||||
|
escaped_plus = pp.Literal('\\+')
|
||||||
|
escaped_minus = pp.Literal('\\-')
|
||||||
|
escaped_equals = pp.Literal('\\=')
|
||||||
|
|
||||||
|
syntactic_symbols = {
|
||||||
|
'(': escaped_lparen,
|
||||||
|
')': escaped_rparen,
|
||||||
|
'"': escaped_quote,
|
||||||
|
',': escaped_comma,
|
||||||
|
'.': escaped_dot,
|
||||||
|
'+': escaped_plus,
|
||||||
|
'-': escaped_minus,
|
||||||
|
'=': escaped_equals,
|
||||||
|
}
|
||||||
|
syntactic_chars = "".join(syntactic_symbols.keys())
|
||||||
|
|
||||||
|
# accepts int or float notation, always maps to float
|
||||||
|
number = pp.pyparsing_common.real | \
|
||||||
|
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||||
|
|
||||||
|
# for options
|
||||||
|
keyword = pp.Word(pp.alphanums + '_')
|
||||||
|
|
||||||
|
# a word that absolutely does not contain any meaningful syntax
|
||||||
|
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
|
||||||
|
pp.Or(syntactic_symbols.values()),
|
||||||
|
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
|
||||||
|
# build character-by-character
|
||||||
|
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
|
||||||
|
])))
|
||||||
|
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
|
||||||
|
non_syntax_word.set_name('non_syntax_word')
|
||||||
|
non_syntax_word.set_debug(False)
|
||||||
|
|
||||||
|
# a word that can contain any character at all - greedily consumes syntax, so use with care
|
||||||
|
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
|
||||||
|
free_word.set_name('free_word')
|
||||||
|
free_word.set_debug(False)
|
||||||
|
|
||||||
|
|
||||||
|
# ok here we go. forward declare some things..
|
||||||
|
attention = pp.Forward()
|
||||||
|
cross_attention_substitute = pp.Forward()
|
||||||
|
parenthesized_fragment = pp.Forward()
|
||||||
|
quoted_fragment = pp.Forward()
|
||||||
|
|
||||||
|
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
|
||||||
|
fragment_part_expressions = [
|
||||||
|
attention,
|
||||||
|
cross_attention_substitute,
|
||||||
|
parenthesized_fragment,
|
||||||
|
quoted_fragment,
|
||||||
|
non_syntax_word
|
||||||
|
]
|
||||||
|
# a fragment that is permitted to contain commas
|
||||||
|
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||||
|
fragment_part_expressions + [
|
||||||
|
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
|
||||||
|
]
|
||||||
|
))
|
||||||
|
# a fragment that is not permitted to contain commas
|
||||||
|
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||||
|
fragment_part_expressions
|
||||||
|
))
|
||||||
|
|
||||||
|
# a fragment in double quotes (may be nested)
|
||||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
|
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
|
||||||
|
|
||||||
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
|
# a fragment inside parentheses (may be nested)
|
||||||
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
|
parenthesized_fragment << (lparen + fragment_including_commas + rparen)
|
||||||
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
|
parenthesized_fragment.set_name('parenthesized_fragment')
|
||||||
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
|
parenthesized_fragment.set_debug(False)
|
||||||
|
|
||||||
empty = (
|
# a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
|
||||||
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
|
option = pp.Group(pp.MatchFirst([
|
||||||
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
keyword + equals + (number | keyword), # option=value
|
||||||
|
number.copy().set_parse_action(pp.token_map(str)), # weight
|
||||||
|
keyword # flag
|
||||||
def not_ends_with_swap(x):
|
|
||||||
#print("trying to match:", x)
|
|
||||||
return not x[0].endswith('.swap')
|
|
||||||
|
|
||||||
unquoted_word = (pp.Combine(pp.OneOrMore(
|
|
||||||
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
|
|
||||||
(pp.CharsNotIn(string.whitespace + '\\"()', exact=1)
|
|
||||||
)))
|
|
||||||
# don't whitespace when the next word starts with +, eg "badly +formed"
|
|
||||||
+ (pp.White().suppress() |
|
|
||||||
# don't eat +/-
|
|
||||||
pp.NotAny(pp.Word('+') | pp.Word('-'))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
|
|
||||||
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
|
||||||
|
|
||||||
parenthesized_fragment << (lparen +
|
|
||||||
pp.Or([
|
|
||||||
(parenthesized_fragment),
|
|
||||||
(quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
|
|
||||||
(pp.Combine(pp.OneOrMore(
|
|
||||||
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
|
|
||||||
pp.CharsNotIn(string.whitespace + '\\"()', exact=1) |
|
|
||||||
pp.White()
|
|
||||||
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
|
|
||||||
pp.Empty()
|
|
||||||
]) + rparen)
|
|
||||||
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
|
|
||||||
|
|
||||||
debug_attention = False
|
|
||||||
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
|
|
||||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
|
||||||
attention_with_parens = pp.Forward()
|
|
||||||
attention_without_parens = pp.Forward()
|
|
||||||
|
|
||||||
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
|
|
||||||
.set_name("attention_foot")\
|
|
||||||
.set_debug(False)
|
|
||||||
attention_with_parens <<= pp.Group(
|
|
||||||
lparen +
|
|
||||||
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
|
|
||||||
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
|
|
||||||
)
|
|
||||||
+ rparen + attention_with_parens_foot)
|
|
||||||
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
|
|
||||||
|
|
||||||
attention_without_parens_foot = (pp.NotAny(pp.White()) + pp.Or([pp.Word('+'), pp.Word('-')]) + pp.FollowedBy(pp.StringEnd() | pp.White() | pp.Literal('(') | pp.Literal(')') | pp.Literal(',') | pp.Literal('"')) ).set_name('attention_without_parens_foots')
|
|
||||||
attention_without_parens <<= pp.Group(pp.MatchFirst([
|
|
||||||
quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
|
|
||||||
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
|
|
||||||
+ attention_without_parens_foot#.leave_whitespace()
|
|
||||||
]))
|
]))
|
||||||
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
|
# options for an operator, eg "s_start=0.1, 0.3, no_normalize"
|
||||||
|
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
|
||||||
|
options.set_name('options')
|
||||||
|
options.set_debug(False)
|
||||||
|
|
||||||
|
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
|
||||||
|
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
|
||||||
|
|
||||||
attention << pp.MatchFirst([attention_with_parens,
|
# a fragment whose weight has been increased or decreased by a given amount
|
||||||
attention_without_parens
|
attention_weight_operator = pp.Word('+') | pp.Word('-') | number
|
||||||
])
|
attention_explicit = (
|
||||||
|
pp.Group(potential_operator_target)
|
||||||
|
+ pp.Literal('.attend')
|
||||||
|
+ lparen
|
||||||
|
+ pp.Group(attention_weight_operator)
|
||||||
|
+ rparen
|
||||||
|
)
|
||||||
|
attention_explicit.set_parse_action(make_operator_object)
|
||||||
|
attention_implicit = (
|
||||||
|
pp.Group(potential_operator_target)
|
||||||
|
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
|
||||||
|
+ pp.Group(attention_weight_operator)
|
||||||
|
)
|
||||||
|
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
|
||||||
|
attention << (attention_explicit | attention_implicit)
|
||||||
attention.set_name('attention')
|
attention.set_name('attention')
|
||||||
|
attention.set_debug(False)
|
||||||
|
|
||||||
def make_attention(x):
|
# cross-attention control by swapping one fragment for another
|
||||||
#print("entered make_attention with", x)
|
cross_attention_substitute << (
|
||||||
children = x[0][:-1]
|
pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
|
||||||
weight_raw = x[0][-1]
|
+ pp.Literal(".swap").set_name('ca-operator').set_debug(False)
|
||||||
weight = 1.0
|
+ lparen
|
||||||
if type(weight_raw) is float or type(weight_raw) is int:
|
+ pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
|
||||||
weight = weight_raw
|
+ pp.Optional(comma + options).set_name('ca-options').set_debug(False)
|
||||||
elif type(weight_raw) is str:
|
+ rparen
|
||||||
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
)
|
||||||
weight = pow(base, len(weight_raw))
|
cross_attention_substitute.set_name('cross_attention_substitute')
|
||||||
|
cross_attention_substitute.set_debug(False)
|
||||||
#print("making Attention from", children, "with weight", weight)
|
cross_attention_substitute.set_parse_action(make_operator_object)
|
||||||
|
|
||||||
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
|
|
||||||
|
|
||||||
attention_with_parens.set_parse_action(make_attention)
|
|
||||||
attention_without_parens.set_parse_action(make_attention)
|
|
||||||
|
|
||||||
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
|
|
||||||
|
|
||||||
# cross-attention control
|
|
||||||
empty_string = ((lparen + rparen) |
|
|
||||||
pp.Literal('""').suppress() |
|
|
||||||
(lparen + pp.Literal('""').suppress() + rparen)
|
|
||||||
).set_parse_action(lambda x: Fragment(""))
|
|
||||||
empty_string.set_name('empty_string')
|
|
||||||
|
|
||||||
# cross attention control
|
|
||||||
debug_cross_attention_control = False
|
|
||||||
original_fragment = pp.MatchFirst([
|
|
||||||
quoted_fragment.set_debug(debug_cross_attention_control),
|
|
||||||
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
|
||||||
pp.Combine(pp.OneOrMore(pp.CharsNotIn(string.whitespace + '.', exact=1))).set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"),
|
|
||||||
empty_string.set_debug(debug_cross_attention_control),
|
|
||||||
])
|
|
||||||
# support keyword=number arguments
|
|
||||||
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
|
|
||||||
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
|
|
||||||
edited_fragment = pp.MatchFirst([
|
|
||||||
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
|
|
||||||
lparen +
|
|
||||||
(quoted_fragment | attention |
|
|
||||||
pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment)))
|
|
||||||
) +
|
|
||||||
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
|
|
||||||
rparen,
|
|
||||||
parenthesized_fragment
|
|
||||||
])
|
|
||||||
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment
|
|
||||||
|
|
||||||
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
|
||||||
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
|
|
||||||
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
|
|
||||||
|
|
||||||
def make_cross_attention_substitute(x):
|
|
||||||
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
|
|
||||||
#if len(x>2):
|
|
||||||
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
|
|
||||||
#print("made", cacs)
|
|
||||||
return cacs
|
|
||||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
|
||||||
|
|
||||||
|
|
||||||
# root prompt definition
|
# an entire self-contained prompt, which can be used in a Blend or Conjunction
|
||||||
debug_root_prompt = False
|
prompt = pp.ZeroOrMore(pp.MatchFirst([
|
||||||
prompt = (pp.OneOrMore(pp.MatchFirst([cross_attention_substitute.set_debug(debug_root_prompt),
|
cross_attention_substitute,
|
||||||
attention.set_debug(debug_root_prompt),
|
attention,
|
||||||
quoted_fragment.set_debug(debug_root_prompt),
|
quoted_fragment,
|
||||||
parenthesized_fragment.set_debug(debug_root_prompt),
|
parenthesized_fragment,
|
||||||
unquoted_word.set_debug(debug_root_prompt),
|
free_word,
|
||||||
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
|
pp.White().suppress()
|
||||||
) + pp.StringEnd()) \
|
]))
|
||||||
.set_name('prompt') \
|
quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
|
||||||
.set_parse_action(lambda x: Prompt(x)) \
|
|
||||||
.set_debug(debug_root_prompt)
|
|
||||||
|
|
||||||
#print("parsing test:", prompt.parse_string("spaced eyes--"))
|
|
||||||
#print("parsing test:", prompt.parse_string("eyes--"))
|
|
||||||
|
|
||||||
# weighted blend of prompts
|
# a blend/lerp between the feature vectors for two or more prompts
|
||||||
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
|
blend = (
|
||||||
# int weights.
|
lparen
|
||||||
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
|
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
|
||||||
|
+ rparen
|
||||||
|
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
|
||||||
|
+ lparen
|
||||||
|
+ pp.Group(options).set_name('bl-options').set_debug(False)
|
||||||
|
+ rparen
|
||||||
|
)
|
||||||
|
blend.set_name('blend')
|
||||||
|
blend.set_debug(False)
|
||||||
|
blend.set_parse_action(make_operator_object)
|
||||||
|
|
||||||
def make_prompt_from_quoted_string(x):
|
# an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
|
||||||
#print(' got quoted prompt', x)
|
explicit_conjunction = (
|
||||||
|
lparen
|
||||||
|
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
|
||||||
|
+ rparen
|
||||||
|
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
|
||||||
|
+ lparen
|
||||||
|
+ pp.Group(options).set_name('cj-options').set_debug(False)
|
||||||
|
+ rparen
|
||||||
|
)
|
||||||
|
explicit_conjunction.set_name('explicit_conjunction')
|
||||||
|
explicit_conjunction.set_debug(False)
|
||||||
|
explicit_conjunction.set_parse_action(make_operator_object)
|
||||||
|
|
||||||
x_unquoted = x[0][1:-1]
|
# by default a prompt consists of a Conjunction with a single term
|
||||||
if len(x_unquoted.strip()) == 0:
|
implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
|
||||||
# print(' b : just an empty string')
|
|
||||||
return Prompt([Fragment('')])
|
|
||||||
#print(f' b parsing \'{x_unquoted}\'')
|
|
||||||
x_parsed = prompt.parse_string(x_unquoted)
|
|
||||||
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
|
|
||||||
return x_parsed[0]
|
|
||||||
|
|
||||||
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
|
||||||
quoted_prompt.set_name('quoted_prompt')
|
|
||||||
|
|
||||||
debug_blend=False
|
|
||||||
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
|
|
||||||
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
|
|
||||||
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
|
||||||
+ pp.Literal(".blend").suppress()
|
|
||||||
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
|
||||||
blend.set_debug(debug_blend)
|
|
||||||
|
|
||||||
def make_blend(x):
|
|
||||||
prompts = x[0][0]
|
|
||||||
weights = x[0][1]
|
|
||||||
normalize = True
|
|
||||||
if weights[-1] == 'no_normalize':
|
|
||||||
normalize = False
|
|
||||||
weights = weights[:-1]
|
|
||||||
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
|
|
||||||
|
|
||||||
blend.set_parse_action(make_blend)
|
|
||||||
|
|
||||||
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
|
|
||||||
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
|
|
||||||
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
|
|
||||||
+ pp.Literal(".and").suppress()
|
|
||||||
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
|
|
||||||
def make_conjunction(x):
|
|
||||||
parts_raw = x[0][0]
|
|
||||||
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
|
|
||||||
parts = [part for part in parts_raw]
|
|
||||||
return Conjunction(parts, weights)
|
|
||||||
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
|
|
||||||
|
|
||||||
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
|
|
||||||
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
||||||
|
|
||||||
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
|
conjunction = (explicit_conjunction | implicit_conjunction)
|
||||||
conjunction.set_debug(False)
|
|
||||||
|
|
||||||
# top-level is a conjunction of one or more blends or prompts
|
|
||||||
return conjunction, prompt
|
return conjunction, prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||||
"""
|
"""
|
||||||
Legacy blend parsing.
|
Legacy blend parsing.
|
||||||
|
@ -66,7 +66,9 @@ def make_ddim_timesteps(
|
|||||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||||
if c < 1:
|
if c < 1:
|
||||||
c = 1
|
c = 1
|
||||||
ddim_timesteps = (np.arange(0, num_ddim_timesteps) * c).astype(int)
|
|
||||||
|
# remove 1 final step to prevent index out of bound error
|
||||||
|
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))[:-1]
|
||||||
elif ddim_discr_method == 'quad':
|
elif ddim_discr_method == 'quad':
|
||||||
ddim_timesteps = (
|
ddim_timesteps = (
|
||||||
(
|
(
|
||||||
@ -84,7 +86,6 @@ def make_ddim_timesteps(
|
|||||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||||
steps_out = ddim_timesteps + 1
|
steps_out = ddim_timesteps + 1
|
||||||
# steps_out = ddim_timesteps
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||||
|
@ -17,6 +17,7 @@ from omegaconf import OmegaConf
|
|||||||
from huggingface_hub import HfFolder, hf_hub_url
|
from huggingface_hub import HfFolder, hf_hub_url
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from getpass_asterisk import getpass_asterisk
|
from getpass_asterisk import getpass_asterisk
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
import traceback
|
import traceback
|
||||||
import requests
|
import requests
|
||||||
import clip
|
import clip
|
||||||
@ -30,10 +31,6 @@ warnings.filterwarnings('ignore')
|
|||||||
#warnings.filterwarnings('ignore',category=DeprecationWarning)
|
#warnings.filterwarnings('ignore',category=DeprecationWarning)
|
||||||
#warnings.filterwarnings('ignore',category=UserWarning)
|
#warnings.filterwarnings('ignore',category=UserWarning)
|
||||||
|
|
||||||
# deferred loading so that help message can be printed quickly
|
|
||||||
def load_libs():
|
|
||||||
pass
|
|
||||||
|
|
||||||
#--------------------------globals--
|
#--------------------------globals--
|
||||||
Model_dir = './models/ldm/stable-diffusion-v1/'
|
Model_dir = './models/ldm/stable-diffusion-v1/'
|
||||||
Default_config_file = './configs/models.yaml'
|
Default_config_file = './configs/models.yaml'
|
||||||
@ -347,7 +344,7 @@ def update_config_file(successfully_downloaded:dict,opt:dict):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if os.path.exists(Config_file):
|
if os.path.exists(Config_file):
|
||||||
print(f'* {Config_file} exists. Renaming to {Config_file}.orig')
|
print(f'** {Config_file} exists. Renaming to {Config_file}.orig')
|
||||||
os.rename(Config_file,f'{Config_file}.orig')
|
os.rename(Config_file,f'{Config_file}.orig')
|
||||||
tmpfile = os.path.join(os.path.dirname(Config_file),'new_config.tmp')
|
tmpfile = os.path.join(os.path.dirname(Config_file),'new_config.tmp')
|
||||||
with open(tmpfile, 'w') as outfile:
|
with open(tmpfile, 'w') as outfile:
|
||||||
@ -419,9 +416,6 @@ def download_kornia():
|
|||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_clip():
|
def download_clip():
|
||||||
print('Loading CLIP model...',end='')
|
print('Loading CLIP model...',end='')
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
version = 'openai/clip-vit-large-patch14'
|
version = 'openai/clip-vit-large-patch14'
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
@ -550,7 +544,6 @@ if __name__ == '__main__':
|
|||||||
default='./configs/models.yaml',
|
default='./configs/models.yaml',
|
||||||
help='path to configuration file to create')
|
help='path to configuration file to create')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
load_libs()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if opt.interactive:
|
if opt.interactive:
|
||||||
@ -562,16 +555,11 @@ if __name__ == '__main__':
|
|||||||
if models is None:
|
if models is None:
|
||||||
if yes_or_no('Quit?',default_yes=False):
|
if yes_or_no('Quit?',default_yes=False):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
done = False
|
|
||||||
while not done:
|
|
||||||
print('** LICENSE AGREEMENT FOR WEIGHT FILES **')
|
print('** LICENSE AGREEMENT FOR WEIGHT FILES **')
|
||||||
access_token = authenticate()
|
access_token = authenticate()
|
||||||
print('\n** DOWNLOADING WEIGHTS **')
|
print('\n** DOWNLOADING WEIGHTS **')
|
||||||
successfully_downloaded = download_weight_datasets(models, access_token)
|
successfully_downloaded = download_weight_datasets(models, access_token)
|
||||||
done = successfully_downloaded is not None
|
|
||||||
update_config_file(successfully_downloaded,opt)
|
update_config_file(successfully_downloaded,opt)
|
||||||
|
|
||||||
print('\n** DOWNLOADING SUPPORT MODELS **')
|
print('\n** DOWNLOADING SUPPORT MODELS **')
|
||||||
download_bert()
|
download_bert()
|
||||||
download_kornia()
|
download_kornia()
|
||||||
|
@ -28,8 +28,8 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
|
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
|
||||||
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
||||||
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
||||||
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
|
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
|
||||||
@ -37,14 +37,25 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_attention(self):
|
def test_attention(self):
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames).attend(0.5)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("flames.attend(0.5)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("\"flames\".attend(0.5)"))
|
||||||
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
|
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames).attend(0.5)"))
|
||||||
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames.attend(+)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames).attend(+)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\".attend(+)"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
|
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
|
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
|
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
|
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire flames.attend(0.5)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames).attend(0.5)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire \"flames\".attend(0.5)"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
|
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
|
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
|
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
|
||||||
@ -102,20 +113,17 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
assert_if_prompt_string_not_untouched('a test prompt')
|
assert_if_prompt_string_not_untouched('a test prompt')
|
||||||
assert_if_prompt_string_not_untouched('a badly formed +test prompt')
|
assert_if_prompt_string_not_untouched('a badly formed +test prompt')
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
assert_if_prompt_string_not_untouched('a badly (formed test prompt')
|
||||||
parse_prompt('a badly (formed test prompt')
|
|
||||||
#with self.assertRaises(pyparsing.ParseException):
|
#with self.assertRaises(pyparsing.ParseException):
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
assert_if_prompt_string_not_untouched('a badly (formed +test prompt')
|
||||||
parse_prompt('a badly (formed +test prompt')
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(((a badly formed +test prompt',1)])]) , parse_prompt('(((a badly (formed +test )prompt'))
|
||||||
parse_prompt('(((a badly (formed +test )prompt')
|
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test prompt'))
|
||||||
parse_prompt('(a (ba)dly (f)ormed +test prompt')
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test +prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test +prompt'))
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
self.assertEqual(Conjunction([Blend([FlattenedPrompt([Fragment('((a badly (formed +test', 1)])], [1.0])]),
|
||||||
parse_prompt('(a (ba)dly (f)ormed +test +prompt')
|
parse_prompt('("((a badly (formed +test ").blend(1.0)'))
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
|
||||||
parse_prompt('("((a badly (formed +test ").blend(1.0)')
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
||||||
parse_prompt("hamburger ((bun))"))
|
parse_prompt("hamburger ((bun))"))
|
||||||
@ -128,6 +136,26 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def test_blend(self):
|
def test_blend(self):
|
||||||
|
self.assertEqual(Conjunction(
|
||||||
|
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
||||||
|
parse_prompt("(\"mountain\", \"man\").blend()")
|
||||||
|
)
|
||||||
|
self.assertEqual(Conjunction(
|
||||||
|
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
||||||
|
parse_prompt("(mountain, man).blend()")
|
||||||
|
)
|
||||||
|
self.assertEqual(Conjunction(
|
||||||
|
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
||||||
|
parse_prompt("((mountain), (man)).blend()")
|
||||||
|
)
|
||||||
|
self.assertEqual(Conjunction(
|
||||||
|
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('tall man', 1.0)])], [1.0, 1.0])]),
|
||||||
|
parse_prompt("((mountain), (tall man)).blend()")
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(PromptParser.ParsingException):
|
||||||
|
print(parse_prompt("((mountain), \"cat.swap(dog)\").blend()"))
|
||||||
|
|
||||||
self.assertEqual(Conjunction(
|
self.assertEqual(Conjunction(
|
||||||
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
|
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
|
||||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
|
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
|
||||||
@ -166,10 +194,20 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]),
|
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
|
||||||
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
|
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0], normalize_weights=True)]),
|
||||||
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
|
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
|
||||||
)
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
|
||||||
|
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9 * 0.9)])], weights=[1.0, -1.0], normalize_weights=False)]),
|
||||||
|
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1,no_normalize)')
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(PromptParser.ParsingException):
|
||||||
|
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3, 0.1)")
|
||||||
|
with self.assertRaises(PromptParser.ParsingException):
|
||||||
|
parse_prompt("(\"fire\", \"fire flames\").blend(0.7)")
|
||||||
|
|
||||||
|
|
||||||
def test_nested(self):
|
def test_nested(self):
|
||||||
@ -182,6 +220,9 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_cross_attention_control(self):
|
def test_cross_attention_control(self):
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([CrossAttentionControlSubstitute([Fragment('sun')], [Fragment('moon')])])]),
|
||||||
|
parse_prompt("sun.swap(moon)"))
|
||||||
|
|
||||||
self.assertEqual(Conjunction([
|
self.assertEqual(Conjunction([
|
||||||
FlattenedPrompt([Fragment('a', 1),
|
FlattenedPrompt([Fragment('a', 1),
|
||||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||||
@ -231,6 +272,9 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||||
parse_prompt('a forest landscape "".swap("in winter")'))
|
parse_prompt('a forest landscape "".swap("in winter")'))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||||
|
parse_prompt('a forest landscape ().swap(in winter)'))
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||||
parse_prompt('a forest landscape " ".swap("in winter")'))
|
parse_prompt('a forest landscape " ".swap("in winter")'))
|
||||||
@ -259,6 +303,12 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||||
|
Fragment('eating a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
||||||
|
])]),
|
||||||
|
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++)"))
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||||
Fragment('eating a', 1),
|
Fragment('eating a', 1),
|
||||||
@ -343,31 +393,31 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
|
||||||
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
|
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain , \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
|
||||||
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
|
self.assertEqual(make_weighted_conjunction([('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
|
self.assertEqual(make_weighted_conjunction([('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
|
self.assertEqual(make_weighted_conjunction([('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
|
self.assertEqual(make_weighted_conjunction([('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
|
self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
|
self.assertEqual(make_weighted_conjunction([('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
|
||||||
self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
|
self.assertEqual(make_weighted_conjunction([('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
|
self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
|
self.assertEqual(make_weighted_conjunction([('mountain , \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
|
||||||
|
|
||||||
def test_cross_attention_escaping(self):
|
def test_cross_attention_escaping(self):
|
||||||
|
|
||||||
@ -433,6 +483,15 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def test_single(self):
|
def test_single(self):
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
|
||||||
|
FlattenedPrompt([("a person with a hat", 1.0),
|
||||||
|
("riding a", 1.1*1.1),
|
||||||
|
CrossAttentionControlSubstitute(
|
||||||
|
[Fragment("bicycle", pow(1.1,2))],
|
||||||
|
[Fragment("skateboard", pow(1.1,2))])
|
||||||
|
])
|
||||||
|
], weights=[0.5, 0.5]),
|
||||||
|
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user