mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
71 Commits
v2.3.4rc1
...
bugfix/con
Author | SHA1 | Date | |
---|---|---|---|
020fb6c1e4 | |||
c63ef500b6 | |||
994a76aeaa | |||
144dfe4a5b | |||
5dbc63e2ae | |||
c6ae1edc82 | |||
2cd0e036ac | |||
5f1d311c52 | |||
c088cf0344 | |||
264af3c054 | |||
b332432a88 | |||
7f7d5894fa | |||
96c39b61cf | |||
40744ed996 | |||
f36452d650 | |||
e5188309ec | |||
a9e8005a92 | |||
c2e6d98e66 | |||
40d9b5dc27 | |||
1a704efff1 | |||
f49d2619be | |||
da96ec9dd5 | |||
298ccda365 | |||
967d853020 | |||
e91117bc74 | |||
4d58444153 | |||
3667eb4d0d | |||
203a7157e1 | |||
6365a7c790 | |||
5fcb3d90e4 | |||
2c449bfb34 | |||
8fb4b05556 | |||
4d7289b20f | |||
d81584c8fd | |||
1183bf96ed | |||
d81394cda8 | |||
0eda1a03e1 | |||
be7e067c95 | |||
afa3cdce27 | |||
6dfbd1c677 | |||
a775c7730e | |||
018d5dab53 | |||
96a5de30e3 | |||
2251d3abfe | |||
0b22a3f34d | |||
2528e14fe9 | |||
4d62d5b802 | |||
17de5c7008 | |||
f95403dcda | |||
16ccc807cc | |||
e54d060d17 | |||
a01f1d4940 | |||
1873817ac9 | |||
31333a736c | |||
03274b6da6 | |||
0646649c05 | |||
2af511c98a | |||
f0039cc70a | |||
8fa7d5ca64 | |||
d90aa42799 | |||
c5b34d21e5 | |||
40a4867143 | |||
4b25f80427 | |||
894e2e643d | |||
a38ff1a16b | |||
41f268b475 | |||
b3ae3f595f | |||
29962613d8 | |||
1170cee1d8 | |||
5983e65b22 | |||
bc724fcdc3 |
34
.github/CODEOWNERS
vendored
34
.github/CODEOWNERS
vendored
@ -1,13 +1,13 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
/.github/workflows/ @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @blessedcoolant
|
||||
mkdocs.yml @mauwii @lstein
|
||||
/docs/ @lstein @blessedcoolant
|
||||
mkdocs.yml @lstein @ebr
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @ebr
|
||||
/docker/ @mauwii
|
||||
/pyproject.toml @lstein @ebr
|
||||
/docker/ @lstein
|
||||
/scripts/ @ebr @lstein @blessedcoolant
|
||||
/installer/ @ebr @lstein
|
||||
ldm/invoke/config @lstein @ebr
|
||||
@ -21,13 +21,13 @@ invokeai/configs @lstein @ebr @blessedcoolant
|
||||
|
||||
# generation and model management
|
||||
/ldm/*.py @lstein @blessedcoolant
|
||||
/ldm/generate.py @lstein @keturn
|
||||
/ldm/generate.py @lstein @gregghelt2
|
||||
/ldm/invoke/args.py @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt* @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
|
||||
/ldm/invoke/CLI.py @lstein @blessedcoolant
|
||||
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant
|
||||
/ldm/invoke/generator @keturn @damian0815
|
||||
/ldm/invoke/config @lstein @ebr @blessedcoolant
|
||||
/ldm/invoke/generator @gregghelt2 @damian0815
|
||||
/ldm/invoke/globals.py @lstein @blessedcoolant
|
||||
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
|
||||
/ldm/invoke/model_manager.py @lstein @blessedcoolant
|
||||
@ -36,17 +36,17 @@ invokeai/configs @lstein @ebr @blessedcoolant
|
||||
/ldm/invoke/restoration @lstein @blessedcoolant
|
||||
|
||||
# attention, textual inversion, model configuration
|
||||
/ldm/models @damian0815 @keturn @blessedcoolant
|
||||
/ldm/models @damian0815 @gregghelt2 @blessedcoolant
|
||||
/ldm/modules/textual_inversion_manager.py @lstein @blessedcoolant
|
||||
/ldm/modules/attention.py @damian0815 @keturn
|
||||
/ldm/modules/diffusionmodules @damian0815 @keturn
|
||||
/ldm/modules/distributions @damian0815 @keturn
|
||||
/ldm/modules/ema.py @damian0815 @keturn
|
||||
/ldm/modules/attention.py @damian0815 @gregghelt2
|
||||
/ldm/modules/diffusionmodules @damian0815 @gregghelt2
|
||||
/ldm/modules/distributions @damian0815 @gregghelt2
|
||||
/ldm/modules/ema.py @damian0815 @gregghelt2
|
||||
/ldm/modules/embedding_manager.py @lstein
|
||||
/ldm/modules/encoders @damian0815 @keturn
|
||||
/ldm/modules/image_degradation @damian0815 @keturn
|
||||
/ldm/modules/losses @damian0815 @keturn
|
||||
/ldm/modules/x_transformer.py @damian0815 @keturn
|
||||
/ldm/modules/encoders @damian0815 @gregghelt2
|
||||
/ldm/modules/image_degradation @damian0815 @gregghelt2
|
||||
/ldm/modules/losses @damian0815 @gregghelt2
|
||||
/ldm/modules/x_transformer.py @damian0815 @gregghelt2
|
||||
|
||||
# Nodes
|
||||
apps/ @Kyle0654 @jpphoto
|
||||
|
@ -30,7 +30,6 @@ from ldm.invoke.conditioning import (
|
||||
get_tokens_for_prompt_object,
|
||||
get_prompt_structure,
|
||||
split_weighted_subprompts,
|
||||
get_tokenizer,
|
||||
)
|
||||
from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
||||
from ldm.invoke.generator.inpaint import infill_methods
|
||||
@ -38,11 +37,11 @@ from ldm.invoke.globals import (
|
||||
Globals,
|
||||
global_converted_ckpts_dir,
|
||||
global_models_dir,
|
||||
global_lora_models_dir,
|
||||
)
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||
from compel.prompt_parser import Blend
|
||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||
from ldm.modules.lora_manager import LoraManager
|
||||
|
||||
# Loading Arguments
|
||||
opt = Args()
|
||||
@ -524,20 +523,12 @@ class InvokeAIWebServer:
|
||||
@socketio.on("getLoraModels")
|
||||
def get_lora_models():
|
||||
try:
|
||||
lora_path = global_lora_models_dir()
|
||||
loras = []
|
||||
for root, _, files in os.walk(lora_path):
|
||||
models = [
|
||||
Path(root, x)
|
||||
for x in files
|
||||
if Path(x).suffix in [".ckpt", ".pt", ".safetensors"]
|
||||
]
|
||||
loras = loras + models
|
||||
|
||||
model = self.generate.model
|
||||
lora_mgr = LoraManager(model)
|
||||
loras = lora_mgr.list_compatible_loras()
|
||||
found_loras = []
|
||||
for lora in sorted(loras, key=lambda s: s.stem.lower()):
|
||||
location = str(lora.resolve()).replace("\\", "/")
|
||||
found_loras.append({"name": lora.stem, "location": location})
|
||||
for lora in sorted(loras, key=str.casefold):
|
||||
found_loras.append({"name":lora,"location":str(loras[lora])})
|
||||
socketio.emit("foundLoras", found_loras)
|
||||
except Exception as e:
|
||||
self.handle_exceptions(e)
|
||||
@ -1314,7 +1305,7 @@ class InvokeAIWebServer:
|
||||
None
|
||||
if type(parsed_prompt) is Blend
|
||||
else get_tokens_for_prompt_object(
|
||||
get_tokenizer(self.generate.model), parsed_prompt
|
||||
self.generate.model.tokenizer, parsed_prompt
|
||||
)
|
||||
)
|
||||
attention_maps_image_base64_url = (
|
||||
|
@ -80,7 +80,8 @@ trinart-2.0:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
recommended: False
|
||||
waifu-diffusion-1.4:
|
||||
description: An SD-1.5 model trained on 680k anime/manga-style images (2.13 GB)
|
||||
description: An SD-2.1 model trained on 5.4M anime/manga-style images (4.27 GB)
|
||||
revision: main
|
||||
repo_id: hakurei/waifu-diffusion
|
||||
format: diffusers
|
||||
vae:
|
||||
|
File diff suppressed because one or more lines are too long
2
invokeai/frontend/dist/index.html
vendored
2
invokeai/frontend/dist/index.html
vendored
@ -5,7 +5,7 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
|
||||
<script type="module" crossorigin src="./assets/index-c1535364.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-b12e648e.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-2ab0eb58.css">
|
||||
</head>
|
||||
|
||||
|
3
invokeai/frontend/dist/locales/en.json
vendored
3
invokeai/frontend/dist/locales/en.json
vendored
@ -328,8 +328,11 @@
|
||||
"updateModel": "Update Model",
|
||||
"availableModels": "Available Models",
|
||||
"addLora": "Add Lora",
|
||||
"clearLoras": "Clear Loras",
|
||||
"noLoraModels": "No Loras Found",
|
||||
"addTextualInversionTrigger": "Add Textual Inversion",
|
||||
"addTIToNegative": "Add To Negative",
|
||||
"clearTextualInversions": "Clear Textual Inversions",
|
||||
"noTextualInversionTriggers": "No Textual Inversions Found",
|
||||
"search": "Search",
|
||||
"load": "Load",
|
||||
|
@ -328,8 +328,11 @@
|
||||
"updateModel": "Update Model",
|
||||
"availableModels": "Available Models",
|
||||
"addLora": "Add Lora",
|
||||
"clearLoras": "Clear Loras",
|
||||
"noLoraModels": "No Loras Found",
|
||||
"addTextualInversionTrigger": "Add Textual Inversion",
|
||||
"addTIToNegative": "Add To Negative",
|
||||
"clearTextualInversions": "Clear Textual Inversions",
|
||||
"noTextualInversionTriggers": "No Textual Inversions Found",
|
||||
"search": "Search",
|
||||
"load": "Load",
|
||||
|
@ -33,6 +33,10 @@ import {
|
||||
setIntermediateImage,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
|
||||
import {
|
||||
getLoraModels,
|
||||
getTextualInversionTriggers,
|
||||
} from 'app/socketio/actions';
|
||||
import type { RootState } from 'app/store';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
@ -463,6 +467,8 @@ const makeSocketIOListeners = (
|
||||
const { model_name, model_list } = data;
|
||||
dispatch(setModelList(model_list));
|
||||
dispatch(setCurrentStatus(i18n.t('common.statusModelChanged')));
|
||||
dispatch(getLoraModels());
|
||||
dispatch(getTextualInversionTriggers());
|
||||
dispatch(setIsProcessing(false));
|
||||
dispatch(setIsCancelable(true));
|
||||
dispatch(
|
||||
|
@ -92,7 +92,8 @@ export default function IAISimpleMenu(props: IAIMenuProps) {
|
||||
zIndex={15}
|
||||
padding={0}
|
||||
borderRadius="0.5rem"
|
||||
overflowY="scroll"
|
||||
overflow="scroll"
|
||||
maxWidth={'22.5rem'}
|
||||
maxHeight={500}
|
||||
backgroundColor="var(--background-color-secondary)"
|
||||
color="var(--text-color-secondary)"
|
||||
|
@ -34,7 +34,6 @@ export default function MainWidth() {
|
||||
withSliderMarks
|
||||
sliderMarkRightOffset={-8}
|
||||
inputWidth="6.2rem"
|
||||
inputReadOnly
|
||||
sliderNumberInputProps={{ max: 15360 }}
|
||||
/>
|
||||
) : (
|
||||
|
@ -1,10 +1,15 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { getLoraModels } from 'app/socketio/actions';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISimpleMenu, { IAIMenuItem } from 'common/components/IAISimpleMenu';
|
||||
import { setLorasInUse } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
setClearLoras,
|
||||
setLorasInUse,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MdClear } from 'react-icons/md';
|
||||
|
||||
export default function LoraManager() {
|
||||
const dispatch = useAppDispatch();
|
||||
@ -53,11 +58,20 @@ export default function LoraManager() {
|
||||
};
|
||||
|
||||
return foundLoras && foundLoras?.length > 0 ? (
|
||||
<IAISimpleMenu
|
||||
menuItems={makeLoraItems()}
|
||||
menuType="regular"
|
||||
buttonText={`${t('modelManager.addLora')} (${numOfActiveLoras()})`}
|
||||
/>
|
||||
<Flex columnGap={2}>
|
||||
<IAISimpleMenu
|
||||
menuItems={makeLoraItems()}
|
||||
menuType="regular"
|
||||
buttonText={`${t('modelManager.addLora')} (${numOfActiveLoras()})`}
|
||||
menuButtonProps={{ width: '100%', padding: '0 1rem' }}
|
||||
/>
|
||||
<IAIIconButton
|
||||
icon={<MdClear />}
|
||||
tooltip={t('modelManager.clearLoras')}
|
||||
aria-label={t('modelManager.clearLoras')}
|
||||
onClick={() => dispatch(setClearLoras())}
|
||||
/>
|
||||
</Flex>
|
||||
) : (
|
||||
<Box
|
||||
background="var(--btn-base-color)"
|
@ -0,0 +1,12 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import LoraManager from './LoraManager/LoraManager';
|
||||
import TextualInversionManager from './TextualInversionManager/TextualInversionManager';
|
||||
|
||||
export default function PromptExtras() {
|
||||
return (
|
||||
<Flex flexDir="column" rowGap={2}>
|
||||
<LoraManager />
|
||||
<TextualInversionManager />
|
||||
</Flex>
|
||||
);
|
||||
}
|
@ -1,17 +1,28 @@
|
||||
import { Box } from '@chakra-ui/react';
|
||||
import { Box, Flex } from '@chakra-ui/react';
|
||||
import { getTextualInversionTriggers } from 'app/socketio/actions';
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import IAISimpleMenu, { IAIMenuItem } from 'common/components/IAISimpleMenu';
|
||||
import { setTextualInversionsInUse } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
setAddTIToNegative,
|
||||
setClearTextualInversions,
|
||||
setTextualInversionsInUse,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MdArrowDownward, MdClear } from 'react-icons/md';
|
||||
|
||||
export default function TextualInversionManager() {
|
||||
const dispatch = useAppDispatch();
|
||||
const textualInversionsInUse = useAppSelector(
|
||||
(state: RootState) => state.generation.textualInversionsInUse
|
||||
);
|
||||
|
||||
const negativeTextualInversionsInUse = useAppSelector(
|
||||
(state: RootState) => state.generation.negativeTextualInversionsInUse
|
||||
);
|
||||
|
||||
const foundLocalTextualInversionTriggers = useAppSelector(
|
||||
(state) => state.system.foundLocalTextualInversionTriggers
|
||||
);
|
||||
@ -31,6 +42,10 @@ export default function TextualInversionManager() {
|
||||
(state) => state.ui.shouldShowHuggingFaceConcepts
|
||||
);
|
||||
|
||||
const addTIToNegative = useAppSelector(
|
||||
(state) => state.generation.addTIToNegative
|
||||
);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
useEffect(() => {
|
||||
@ -41,14 +56,25 @@ export default function TextualInversionManager() {
|
||||
dispatch(setTextualInversionsInUse(textual_inversion));
|
||||
};
|
||||
|
||||
const renderTextualInversionOption = (textual_inversion: string) => {
|
||||
const thisTIExists = textualInversionsInUse.includes(textual_inversion);
|
||||
const tiExistsStyle = {
|
||||
fontWeight: 'bold',
|
||||
color: 'var(--context-menu-active-item)',
|
||||
};
|
||||
const TIPip = ({ color }: { color: string }) => {
|
||||
return (
|
||||
<Box style={thisTIExists ? tiExistsStyle : {}}>{textual_inversion}</Box>
|
||||
<Box width={2} height={2} borderRadius={9999} backgroundColor={color}>
|
||||
{' '}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
const renderTextualInversionOption = (textual_inversion: string) => {
|
||||
return (
|
||||
<Flex alignItems="center" columnGap={1}>
|
||||
{textual_inversion}
|
||||
{textualInversionsInUse.includes(textual_inversion) && (
|
||||
<TIPip color="var(--context-menu-active-item)" />
|
||||
)}
|
||||
{negativeTextualInversionsInUse.includes(textual_inversion) && (
|
||||
<TIPip color="var(--status-bad-color)" />
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@ -56,8 +82,10 @@ export default function TextualInversionManager() {
|
||||
const allTextualInversions = localTextualInversionTriggers.concat(
|
||||
huggingFaceTextualInversionConcepts
|
||||
);
|
||||
return allTextualInversions.filter((ti) =>
|
||||
textualInversionsInUse.includes(ti)
|
||||
return allTextualInversions.filter(
|
||||
(ti) =>
|
||||
textualInversionsInUse.includes(ti) ||
|
||||
negativeTextualInversionsInUse.includes(ti)
|
||||
).length;
|
||||
};
|
||||
|
||||
@ -93,13 +121,34 @@ export default function TextualInversionManager() {
|
||||
(foundHuggingFaceTextualInversionTriggers &&
|
||||
foundHuggingFaceTextualInversionTriggers?.length > 0 &&
|
||||
shouldShowHuggingFaceConcepts)) ? (
|
||||
<IAISimpleMenu
|
||||
menuItems={makeTextualInversionItems()}
|
||||
menuType="regular"
|
||||
buttonText={`${t(
|
||||
'modelManager.addTextualInversionTrigger'
|
||||
)} (${numOfActiveTextualInversions()})`}
|
||||
/>
|
||||
<Flex columnGap={2}>
|
||||
<IAISimpleMenu
|
||||
menuItems={makeTextualInversionItems()}
|
||||
menuType="regular"
|
||||
buttonText={`${t(
|
||||
'modelManager.addTextualInversionTrigger'
|
||||
)} (${numOfActiveTextualInversions()})`}
|
||||
menuButtonProps={{
|
||||
width: '100%',
|
||||
padding: '0 1rem',
|
||||
}}
|
||||
/>
|
||||
<IAIIconButton
|
||||
icon={<MdArrowDownward />}
|
||||
style={{
|
||||
backgroundColor: addTIToNegative ? 'var(--btn-delete-image)' : '',
|
||||
}}
|
||||
tooltip={t('modelManager.addTIToNegative')}
|
||||
aria-label={t('modelManager.addTIToNegative')}
|
||||
onClick={() => dispatch(setAddTIToNegative(!addTIToNegative))}
|
||||
/>
|
||||
<IAIIconButton
|
||||
icon={<MdClear />}
|
||||
tooltip={t('modelManager.clearTextualInversions')}
|
||||
aria-label={t('modelManager.clearTextualInversions')}
|
||||
onClick={() => dispatch(setClearTextualInversions())}
|
||||
/>
|
||||
</Flex>
|
||||
) : (
|
||||
<Box
|
||||
background="var(--btn-base-color)"
|
@ -1,24 +1,43 @@
|
||||
import { FormControl, Textarea } from '@chakra-ui/react';
|
||||
import type { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import { setNegativePrompt } from 'features/parameters/store/generationSlice';
|
||||
import {
|
||||
handlePromptCheckers,
|
||||
setNegativePrompt,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { ChangeEvent, useState } from 'react';
|
||||
|
||||
const NegativePromptInput = () => {
|
||||
const negativePrompt = useAppSelector(
|
||||
(state: RootState) => state.generation.negativePrompt
|
||||
);
|
||||
|
||||
const [promptTimer, setPromptTimer] = useState<number | undefined>(undefined);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleNegativeChangePrompt = (e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
dispatch(setNegativePrompt(e.target.value));
|
||||
|
||||
// Debounce Prompt UI Checking
|
||||
clearTimeout(promptTimer);
|
||||
const newPromptTimer = window.setTimeout(() => {
|
||||
dispatch(
|
||||
handlePromptCheckers({ prompt: e.target.value, toNegative: true })
|
||||
);
|
||||
}, 500);
|
||||
setPromptTimer(newPromptTimer);
|
||||
};
|
||||
|
||||
return (
|
||||
<FormControl>
|
||||
<Textarea
|
||||
id="negativePrompt"
|
||||
name="negativePrompt"
|
||||
value={negativePrompt}
|
||||
onChange={(e) => dispatch(setNegativePrompt(e.target.value))}
|
||||
onChange={handleNegativeChangePrompt}
|
||||
background="var(--prompt-bg-color)"
|
||||
placeholder={t('parameters.negativePrompts')}
|
||||
_placeholder={{ fontSize: '0.8rem' }}
|
||||
|
@ -51,7 +51,9 @@ const PromptInput = () => {
|
||||
// Debounce Prompt UI Checking
|
||||
clearTimeout(promptTimer);
|
||||
const newPromptTimer = window.setTimeout(() => {
|
||||
dispatch(handlePromptCheckers(e.target.value));
|
||||
dispatch(
|
||||
handlePromptCheckers({ prompt: e.target.value, toNegative: false })
|
||||
);
|
||||
}, 500);
|
||||
setPromptTimer(newPromptTimer);
|
||||
};
|
||||
|
@ -3,7 +3,11 @@ import { getPromptAndNegative } from 'common/util/getPromptAndNegative';
|
||||
import * as InvokeAI from 'app/invokeai';
|
||||
import promptToString from 'common/util/promptToString';
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import { setNegativePrompt, setPrompt } from '../store/generationSlice';
|
||||
import {
|
||||
handlePromptCheckers,
|
||||
setNegativePrompt,
|
||||
setPrompt,
|
||||
} from '../store/generationSlice';
|
||||
|
||||
// TECHDEBT: We have two metadata prompt formats and need to handle recalling either of them.
|
||||
// This hook provides a function to do that.
|
||||
@ -20,6 +24,10 @@ const useSetBothPrompts = () => {
|
||||
|
||||
dispatch(setPrompt(prompt));
|
||||
dispatch(setNegativePrompt(negativePrompt));
|
||||
dispatch(handlePromptCheckers({ prompt: prompt, toNegative: false }));
|
||||
dispatch(
|
||||
handlePromptCheckers({ prompt: negativePrompt, toNegative: true })
|
||||
);
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -18,9 +18,11 @@ export interface GenerationState {
|
||||
prompt: string;
|
||||
negativePrompt: string;
|
||||
lorasInUse: string[];
|
||||
localTextualInversionTriggers: string[];
|
||||
huggingFaceTextualInversionConcepts: string[];
|
||||
localTextualInversionTriggers: string[];
|
||||
textualInversionsInUse: string[];
|
||||
negativeTextualInversionsInUse: string[];
|
||||
addTIToNegative: boolean;
|
||||
sampler: string;
|
||||
seamBlur: number;
|
||||
seamless: boolean;
|
||||
@ -53,9 +55,11 @@ const initialGenerationState: GenerationState = {
|
||||
prompt: '',
|
||||
negativePrompt: '',
|
||||
lorasInUse: [],
|
||||
localTextualInversionTriggers: [],
|
||||
huggingFaceTextualInversionConcepts: [],
|
||||
localTextualInversionTriggers: [],
|
||||
textualInversionsInUse: [],
|
||||
negativeTextualInversionsInUse: [],
|
||||
addTIToNegative: false,
|
||||
sampler: 'k_lms',
|
||||
seamBlur: 16,
|
||||
seamless: false,
|
||||
@ -85,15 +89,86 @@ const loraExists = (state: GenerationState, lora: string) => {
|
||||
return false;
|
||||
};
|
||||
|
||||
const getTIRegex = (textualInversion: string) => {
|
||||
if (textualInversion.includes('<' || '>')) {
|
||||
return new RegExp(`${textualInversion}`);
|
||||
} else {
|
||||
return new RegExp(`\\b${textualInversion}\\b`);
|
||||
}
|
||||
};
|
||||
|
||||
const textualInversionExists = (
|
||||
state: GenerationState,
|
||||
textualInversion: string
|
||||
) => {
|
||||
const textualInversionRegex = new RegExp(textualInversion);
|
||||
if (state.prompt.match(textualInversionRegex)) return true;
|
||||
const textualInversionRegex = getTIRegex(textualInversion);
|
||||
|
||||
if (!state.addTIToNegative) {
|
||||
if (state.prompt.match(textualInversionRegex)) return true;
|
||||
} else {
|
||||
if (state.negativePrompt.match(textualInversionRegex)) return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const handleTypedTICheck = (
|
||||
state: GenerationState,
|
||||
newPrompt: string,
|
||||
toNegative: boolean
|
||||
) => {
|
||||
let textualInversionsInUse = !toNegative
|
||||
? [...state.textualInversionsInUse]
|
||||
: [...state.negativeTextualInversionsInUse]; // Get Words In Prompt
|
||||
|
||||
const textualInversionRegex = /([\w<>!@%&*_-]+)/g; // Scan For Each Word
|
||||
|
||||
const textualInversionMatches = [
|
||||
...newPrompt.matchAll(textualInversionRegex),
|
||||
]; // Match All Words
|
||||
|
||||
if (textualInversionMatches.length > 0) {
|
||||
textualInversionsInUse = []; // Reset Textual Inversions In Use
|
||||
|
||||
textualInversionMatches.forEach((textualInversionMatch) => {
|
||||
const textualInversionName = textualInversionMatch[0];
|
||||
if (
|
||||
(!textualInversionsInUse.includes(textualInversionName) &&
|
||||
state.localTextualInversionTriggers.includes(textualInversionName)) ||
|
||||
state.huggingFaceTextualInversionConcepts.includes(textualInversionName)
|
||||
) {
|
||||
textualInversionsInUse.push(textualInversionName); // Add Textual Inversions In Prompt
|
||||
}
|
||||
});
|
||||
} else {
|
||||
textualInversionsInUse = []; // If No Matches, Remove Textual Inversions In Use
|
||||
}
|
||||
|
||||
if (!toNegative) {
|
||||
state.textualInversionsInUse = textualInversionsInUse;
|
||||
} else {
|
||||
state.negativeTextualInversionsInUse = textualInversionsInUse;
|
||||
}
|
||||
};
|
||||
|
||||
const handleTypedLoraCheck = (state: GenerationState, newPrompt: string) => {
|
||||
let lorasInUse = [...state.lorasInUse]; // Get Loras In Prompt
|
||||
|
||||
const loraRegex = /withLora\(([^\\)]+)\)/g; // Scan For Lora Syntax
|
||||
const loraMatches = [...newPrompt.matchAll(loraRegex)]; // Match All Lora Syntaxes
|
||||
|
||||
if (loraMatches.length > 0) {
|
||||
lorasInUse = []; // Reset Loras In Use
|
||||
loraMatches.forEach((loraMatch) => {
|
||||
const loraName = loraMatch[1].split(',')[0];
|
||||
if (!lorasInUse.includes(loraName)) lorasInUse.push(loraName); // Add Loras In Prompt
|
||||
});
|
||||
} else {
|
||||
lorasInUse = []; // If No Matches, Remove Loras In Use
|
||||
}
|
||||
|
||||
state.lorasInUse = lorasInUse;
|
||||
};
|
||||
|
||||
export const generationSlice = createSlice({
|
||||
name: 'generation',
|
||||
initialState,
|
||||
@ -118,6 +193,20 @@ export const generationSlice = createSlice({
|
||||
state.negativePrompt = promptToString(newPrompt);
|
||||
}
|
||||
},
|
||||
handlePromptCheckers: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
prompt: string | InvokeAI.Prompt;
|
||||
toNegative: boolean;
|
||||
}>
|
||||
) => {
|
||||
const newPrompt = action.payload.prompt;
|
||||
|
||||
if (typeof newPrompt === 'string') {
|
||||
if (!action.payload.toNegative) handleTypedLoraCheck(state, newPrompt);
|
||||
handleTypedTICheck(state, newPrompt, action.payload.toNegative);
|
||||
}
|
||||
},
|
||||
setLorasInUse: (state, action: PayloadAction<string>) => {
|
||||
const newLora = action.payload;
|
||||
const loras = [...state.lorasInUse];
|
||||
@ -128,94 +217,99 @@ export const generationSlice = createSlice({
|
||||
'g'
|
||||
);
|
||||
const newPrompt = state.prompt.replaceAll(loraRegex, '');
|
||||
state.prompt = newPrompt;
|
||||
state.prompt = newPrompt.trim();
|
||||
|
||||
if (loras.includes(newLora)) {
|
||||
const newLoraIndex = loras.indexOf(newLora);
|
||||
if (newLoraIndex > -1) loras.splice(newLoraIndex, 1);
|
||||
}
|
||||
} else {
|
||||
state.prompt = `${state.prompt} withLora(${newLora},0.75)`;
|
||||
state.prompt = `${state.prompt.trim()} withLora(${newLora},0.75)`;
|
||||
if (!loras.includes(newLora)) loras.push(newLora);
|
||||
}
|
||||
state.lorasInUse = loras;
|
||||
},
|
||||
handlePromptCheckers: (
|
||||
state,
|
||||
action: PayloadAction<string | InvokeAI.Prompt>
|
||||
) => {
|
||||
const newPrompt = action.payload;
|
||||
setClearLoras: (state) => {
|
||||
const lorasInUse = [...state.lorasInUse];
|
||||
|
||||
// Tackle User Typed Lora Syntax
|
||||
let lorasInUse = [...state.lorasInUse]; // Get Loras In Prompt
|
||||
const loraRegex = /withLora\(([^\\)]+)\)/g; // Scan For Lora Syntax
|
||||
if (typeof newPrompt === 'string') {
|
||||
const loraMatches = [...newPrompt.matchAll(loraRegex)]; // Match All Lora Syntaxes
|
||||
if (loraMatches.length > 0) {
|
||||
lorasInUse = []; // Reset Loras In Use
|
||||
loraMatches.forEach((loraMatch) => {
|
||||
const loraName = loraMatch[1].split(',')[0];
|
||||
if (!lorasInUse.includes(loraName)) lorasInUse.push(loraName); // Add Loras In Prompt
|
||||
});
|
||||
} else {
|
||||
lorasInUse = []; // If No Matches, Remove Loras In Use
|
||||
}
|
||||
}
|
||||
state.lorasInUse = lorasInUse;
|
||||
lorasInUse.forEach((lora) => {
|
||||
const loraRegex = new RegExp(
|
||||
`withLora\\(${lora},?\\s*([^\\)]+)?\\)`,
|
||||
'g'
|
||||
);
|
||||
const newPrompt = state.prompt.replaceAll(loraRegex, '');
|
||||
state.prompt = newPrompt.trim();
|
||||
});
|
||||
|
||||
// Tackle User Typed Textual Inversion
|
||||
let textualInversionsInUse = [...state.textualInversionsInUse]; // Get Words In Prompt
|
||||
const textualInversionRegex = /([\w<>!@%&*_-]+)/g; // Scan For Each Word
|
||||
if (typeof newPrompt === 'string') {
|
||||
const textualInversionMatches = [
|
||||
...newPrompt.matchAll(textualInversionRegex),
|
||||
]; // Match All Words
|
||||
if (textualInversionMatches.length > 0) {
|
||||
textualInversionsInUse = []; // Reset Textual Inversions In Use
|
||||
console.log(textualInversionMatches);
|
||||
textualInversionMatches.forEach((textualInversionMatch) => {
|
||||
const textualInversionName = textualInversionMatch[0];
|
||||
console.log(textualInversionName);
|
||||
if (
|
||||
!textualInversionsInUse.includes(textualInversionName) &&
|
||||
(state.localTextualInversionTriggers.includes(
|
||||
textualInversionName
|
||||
) ||
|
||||
state.huggingFaceTextualInversionConcepts.includes(
|
||||
textualInversionName
|
||||
))
|
||||
)
|
||||
textualInversionsInUse.push(textualInversionName); // Add Textual Inversions In Prompt
|
||||
});
|
||||
} else {
|
||||
textualInversionsInUse = []; // If No Matches, Remove Textual Inversions In Use
|
||||
}
|
||||
}
|
||||
|
||||
console.log([...state.huggingFaceTextualInversionConcepts]);
|
||||
state.textualInversionsInUse = textualInversionsInUse;
|
||||
state.lorasInUse = [];
|
||||
},
|
||||
setTextualInversionsInUse: (state, action: PayloadAction<string>) => {
|
||||
const newTextualInversion = action.payload;
|
||||
|
||||
const textualInversions = [...state.textualInversionsInUse];
|
||||
const negativeTextualInversions = [
|
||||
...state.negativeTextualInversionsInUse,
|
||||
];
|
||||
|
||||
if (textualInversionExists(state, newTextualInversion)) {
|
||||
const textualInversionRegex = new RegExp(newTextualInversion, 'g');
|
||||
const newPrompt = state.prompt.replaceAll(textualInversionRegex, '');
|
||||
state.prompt = newPrompt;
|
||||
const textualInversionRegex = getTIRegex(newTextualInversion);
|
||||
|
||||
if (!state.addTIToNegative) {
|
||||
const newPrompt = state.prompt.replace(textualInversionRegex, '');
|
||||
state.prompt = newPrompt.trim();
|
||||
|
||||
if (textualInversions.includes(newTextualInversion)) {
|
||||
const newTIIndex = textualInversions.indexOf(newTextualInversion);
|
||||
if (newTIIndex > -1) textualInversions.splice(newTIIndex, 1);
|
||||
} else {
|
||||
const newPrompt = state.negativePrompt.replace(
|
||||
textualInversionRegex,
|
||||
''
|
||||
);
|
||||
state.negativePrompt = newPrompt.trim();
|
||||
|
||||
const newTIIndex =
|
||||
negativeTextualInversions.indexOf(newTextualInversion);
|
||||
if (newTIIndex > -1) negativeTextualInversions.splice(newTIIndex, 1);
|
||||
}
|
||||
} else {
|
||||
state.prompt = `${state.prompt} ${newTextualInversion}`;
|
||||
if (!textualInversions.includes(newTextualInversion))
|
||||
if (!state.addTIToNegative) {
|
||||
state.prompt = `${state.prompt.trim()} ${newTextualInversion}`;
|
||||
textualInversions.push(newTextualInversion);
|
||||
} else {
|
||||
state.negativePrompt = `${state.negativePrompt.trim()} ${newTextualInversion}`;
|
||||
negativeTextualInversions.push(newTextualInversion);
|
||||
}
|
||||
}
|
||||
state.lorasInUse = textualInversions;
|
||||
|
||||
state.textualInversionsInUse = textualInversions;
|
||||
state.negativeTextualInversionsInUse = negativeTextualInversions;
|
||||
},
|
||||
setClearTextualInversions: (state) => {
|
||||
const textualInversions = [...state.textualInversionsInUse];
|
||||
const negativeTextualInversions = [
|
||||
...state.negativeTextualInversionsInUse,
|
||||
];
|
||||
|
||||
textualInversions.forEach((ti) => {
|
||||
const textualInversionRegex = getTIRegex(ti);
|
||||
const newPrompt = state.prompt.replace(textualInversionRegex, '');
|
||||
state.prompt = newPrompt.trim();
|
||||
});
|
||||
|
||||
negativeTextualInversions.forEach((ti) => {
|
||||
const textualInversionRegex = getTIRegex(ti);
|
||||
const newPrompt = state.negativePrompt.replace(
|
||||
textualInversionRegex,
|
||||
''
|
||||
);
|
||||
state.negativePrompt = newPrompt.trim();
|
||||
});
|
||||
|
||||
state.textualInversionsInUse = [];
|
||||
state.negativeTextualInversionsInUse = [];
|
||||
},
|
||||
setAddTIToNegative: (state, action: PayloadAction<boolean>) => {
|
||||
state.addTIToNegative = action.payload;
|
||||
},
|
||||
setLocalTextualInversionTriggers: (
|
||||
state,
|
||||
@ -509,11 +603,14 @@ export const {
|
||||
setPerlin,
|
||||
setPrompt,
|
||||
setNegativePrompt,
|
||||
setLorasInUse,
|
||||
setLocalTextualInversionTriggers,
|
||||
setHuggingFaceTextualInversionConcepts,
|
||||
setTextualInversionsInUse,
|
||||
handlePromptCheckers,
|
||||
setLorasInUse,
|
||||
setClearLoras,
|
||||
setHuggingFaceTextualInversionConcepts,
|
||||
setLocalTextualInversionTriggers,
|
||||
setTextualInversionsInUse,
|
||||
setAddTIToNegative,
|
||||
setClearTextualInversions,
|
||||
setSampler,
|
||||
setSeamBlur,
|
||||
setSeamless,
|
||||
|
@ -18,8 +18,7 @@ import PromptInput from 'features/parameters/components/PromptInput/PromptInput'
|
||||
import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import ImageToImageOptions from './ImageToImageOptions';
|
||||
import LoraManager from 'features/parameters/components/LoraManager/LoraManager';
|
||||
import TextualInversionManager from 'features/parameters/components/TextualInversionManager/TextualInversionManager';
|
||||
import PromptExtras from 'features/parameters/components/PromptInput/Extras/PromptExtras';
|
||||
|
||||
export default function ImageToImagePanel() {
|
||||
const { t } = useTranslation();
|
||||
@ -65,8 +64,7 @@ export default function ImageToImagePanel() {
|
||||
<Flex flexDir="column" rowGap="0.5rem">
|
||||
<PromptInput />
|
||||
<NegativePromptInput />
|
||||
<LoraManager />
|
||||
<TextualInversionManager />
|
||||
<PromptExtras />
|
||||
</Flex>
|
||||
<ProcessButtons />
|
||||
<MainSettings />
|
||||
|
@ -10,8 +10,6 @@ import UpscaleSettings from 'features/parameters/components/AdvancedParameters/U
|
||||
import UpscaleToggle from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleToggle';
|
||||
import GenerateVariationsToggle from 'features/parameters/components/AdvancedParameters/Variations/GenerateVariations';
|
||||
import VariationsSettings from 'features/parameters/components/AdvancedParameters/Variations/VariationsSettings';
|
||||
import LoraManager from 'features/parameters/components/LoraManager/LoraManager';
|
||||
import TextualInversionManager from 'features/parameters/components/TextualInversionManager/TextualInversionManager';
|
||||
import MainSettings from 'features/parameters/components/MainParameters/MainParameters';
|
||||
import ParametersAccordion from 'features/parameters/components/ParametersAccordion';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
@ -19,6 +17,7 @@ import NegativePromptInput from 'features/parameters/components/PromptInput/Nega
|
||||
import PromptInput from 'features/parameters/components/PromptInput/PromptInput';
|
||||
import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import PromptExtras from 'features/parameters/components/PromptInput/Extras/PromptExtras';
|
||||
|
||||
export default function TextToImagePanel() {
|
||||
const { t } = useTranslation();
|
||||
@ -64,8 +63,7 @@ export default function TextToImagePanel() {
|
||||
<Flex flexDir="column" rowGap="0.5rem">
|
||||
<PromptInput />
|
||||
<NegativePromptInput />
|
||||
<LoraManager />
|
||||
<TextualInversionManager />
|
||||
<PromptExtras />
|
||||
</Flex>
|
||||
<ProcessButtons />
|
||||
<MainSettings />
|
||||
|
@ -10,8 +10,6 @@ import SymmetryToggle from 'features/parameters/components/AdvancedParameters/Ou
|
||||
import SeedSettings from 'features/parameters/components/AdvancedParameters/Seed/SeedSettings';
|
||||
import GenerateVariationsToggle from 'features/parameters/components/AdvancedParameters/Variations/GenerateVariations';
|
||||
import VariationsSettings from 'features/parameters/components/AdvancedParameters/Variations/VariationsSettings';
|
||||
import LoraManager from 'features/parameters/components/LoraManager/LoraManager';
|
||||
import TextualInversionManager from 'features/parameters/components/TextualInversionManager/TextualInversionManager';
|
||||
import MainSettings from 'features/parameters/components/MainParameters/MainParameters';
|
||||
import ParametersAccordion from 'features/parameters/components/ParametersAccordion';
|
||||
import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons';
|
||||
@ -19,6 +17,7 @@ import NegativePromptInput from 'features/parameters/components/PromptInput/Nega
|
||||
import PromptInput from 'features/parameters/components/PromptInput/PromptInput';
|
||||
import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import PromptExtras from 'features/parameters/components/PromptInput/Extras/PromptExtras';
|
||||
|
||||
export default function UnifiedCanvasPanel() {
|
||||
const { t } = useTranslation();
|
||||
@ -75,8 +74,7 @@ export default function UnifiedCanvasPanel() {
|
||||
<Flex flexDir="column" rowGap="0.5rem">
|
||||
<PromptInput />
|
||||
<NegativePromptInput />
|
||||
<LoraManager />
|
||||
<TextualInversionManager />
|
||||
<PromptExtras />
|
||||
</Flex>
|
||||
<ProcessButtons />
|
||||
<MainSettings />
|
||||
|
File diff suppressed because one or more lines are too long
@ -633,9 +633,8 @@ class Generate:
|
||||
except RuntimeError:
|
||||
# Clear the CUDA cache on an exception
|
||||
self.clear_cuda_cache()
|
||||
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> Could not generate image.")
|
||||
print("** Could not generate image.")
|
||||
raise
|
||||
|
||||
toc = time.time()
|
||||
print("\n>> Usage stats:")
|
||||
|
@ -1 +1 @@
|
||||
__version__='2.3.4rc1'
|
||||
__version__='2.3.5-rc1'
|
||||
|
@ -12,21 +12,13 @@ from typing import Union, Optional, Any
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from compel import Compel
|
||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser
|
||||
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute, PromptParser, \
|
||||
Conjunction
|
||||
from .devices import torch_dtype
|
||||
from .generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
def get_tokenizer(model) -> CLIPTokenizer:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return (getattr(model, 'tokenizer', None) # diffusers
|
||||
or model.cond_stage_model.tokenizer) # ldm
|
||||
|
||||
def get_text_encoder(model) -> Any:
|
||||
# TODO remove legacy ckpt fallback handling
|
||||
return (getattr(model, 'text_encoder', None) # diffusers
|
||||
or UnsqueezingLDMTransformer(model.cond_stage_model.transformer)) # ldm
|
||||
|
||||
class UnsqueezingLDMTransformer:
|
||||
def __init__(self, ldm_transformer):
|
||||
self.ldm_transformer = ldm_transformer
|
||||
@ -40,48 +32,57 @@ class UnsqueezingLDMTransformer:
|
||||
return insufficiently_unsqueezed_tensor.unsqueeze(0)
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
def get_uc_and_c_and_ec(prompt_string,
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
log_tokens=False, skip_normalize_legacy_blend=False):
|
||||
# lazy-load any deferred textual inversions.
|
||||
# this might take a couple of seconds the first time a textual inversion is used.
|
||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||
|
||||
tokenizer = get_tokenizer(model)
|
||||
text_encoder = get_text_encoder(model)
|
||||
compel = Compel(tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
compel = Compel(tokenizer=model.tokenizer,
|
||||
text_encoder=model.text_encoder,
|
||||
textual_inversion_manager=model.textual_inversion_manager,
|
||||
dtype_for_device_getter=torch_dtype)
|
||||
|
||||
# get rid of any newline characters
|
||||
prompt_string = prompt_string.replace("\n", " ")
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_prompt: FlattenedPrompt|Blend
|
||||
lora_conditions = None
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
should_use_lora_manager = True
|
||||
lora_weights = positive_conjunction.lora_weights
|
||||
if model.peft_manager:
|
||||
should_use_lora_manager = model.peft_manager.should_use(lora_weights)
|
||||
if not should_use_lora_manager:
|
||||
model.peft_manager.set_loras(lora_weights)
|
||||
if model.lora_manager and should_use_lora_manager:
|
||||
lora_conditions = model.lora_manager.set_loras_conditions(lora_weights)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
|
||||
should_use_lora_manager = True
|
||||
lora_weights = positive_conjunction.lora_weights
|
||||
lora_conditions = None
|
||||
if model.peft_manager:
|
||||
should_use_lora_manager = model.peft_manager.should_use(lora_weights)
|
||||
if not should_use_lora_manager:
|
||||
model.peft_manager.set_loras(lora_weights)
|
||||
if model.lora_manager and should_use_lora_manager:
|
||||
lora_conditions = model.lora_manager.set_loras_conditions(lora_weights)
|
||||
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt | Blend = negative_conjunction.prompts[0]
|
||||
|
||||
tokens_count = get_max_token_count(model.tokenizer, positive_prompt)
|
||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=tokenizer)
|
||||
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
|
||||
tokens_count = get_max_token_count(tokenizer, positive_prompt)
|
||||
# some LoRA models also mess with the text encoder, so they must be active while compel builds conditioning tensors
|
||||
lora_conditioning_ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
lora_conditions=lora_conditions)
|
||||
with InvokeAIDiffuserComponent.custom_attention_context(model.unet,
|
||||
extra_conditioning_info=lora_conditioning_ec,
|
||||
step_count=-1):
|
||||
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||
|
||||
# now build the "real" ec
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||
cross_attention_control_args=options.get(
|
||||
'cross_attention_control', None),
|
||||
@ -93,12 +94,12 @@ def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = Fals
|
||||
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||
legacy_blend = try_parse_legacy_blend(positive_prompt_string, skip_normalize_legacy_blend)
|
||||
positive_prompt: FlattenedPrompt|Blend
|
||||
positive_conjunction: Conjunction
|
||||
if legacy_blend is not None:
|
||||
positive_prompt = legacy_blend
|
||||
positive_conjunction = legacy_blend
|
||||
else:
|
||||
positive_conjunction = Compel.parse_prompt_string(positive_prompt_string)
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
positive_prompt = positive_conjunction.prompts[0]
|
||||
negative_conjunction = Compel.parse_prompt_string(negative_prompt_string)
|
||||
negative_prompt: FlattenedPrompt|Blend = negative_conjunction.prompts[0]
|
||||
|
||||
@ -217,18 +218,26 @@ def log_tokenization_for_text(text, tokenizer, display_label=None):
|
||||
print(f'{discarded}\x1b[0m')
|
||||
|
||||
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Blend]:
|
||||
def try_parse_legacy_blend(text: str, skip_normalize: bool=False) -> Optional[Conjunction]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
||||
if len(weighted_subprompts) <= 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
pp = PromptParser()
|
||||
parsed_conjunctions = [pp.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
flattened_prompts = []
|
||||
weights = []
|
||||
loras = []
|
||||
for i, x in enumerate(parsed_conjunctions):
|
||||
if len(x.prompts)>0:
|
||||
flattened_prompts.append(x.prompts[0])
|
||||
weights.append(weighted_subprompts[i][1])
|
||||
if len(x.lora_weights)>0:
|
||||
loras.extend(x.lora_weights)
|
||||
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
|
||||
return Conjunction([Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)],
|
||||
lora_weights = loras)
|
||||
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
|
@ -4,14 +4,13 @@ pip install <path_to_git_source>.
|
||||
'''
|
||||
import os
|
||||
import platform
|
||||
import psutil
|
||||
import requests
|
||||
from rich import box, print
|
||||
from rich.console import Console, Group, group
|
||||
from rich.console import Console, group
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
from rich.style import Style
|
||||
from rich.syntax import Syntax
|
||||
from rich.text import Text
|
||||
|
||||
from ldm.invoke import __version__
|
||||
|
||||
@ -32,6 +31,19 @@ else:
|
||||
def get_versions()->dict:
|
||||
return requests.get(url=INVOKE_AI_REL).json()
|
||||
|
||||
def invokeai_is_running()->bool:
|
||||
for p in psutil.process_iter():
|
||||
try:
|
||||
cmdline = p.cmdline()
|
||||
matches = [x for x in cmdline if x.endswith(('invokeai','invokeai.exe'))]
|
||||
if matches:
|
||||
print(f':exclamation: [bold red]An InvokeAI instance appears to be running as process {p.pid}[/red bold]')
|
||||
return True
|
||||
except psutil.AccessDenied:
|
||||
continue
|
||||
return False
|
||||
|
||||
|
||||
def welcome(versions: dict):
|
||||
|
||||
@group()
|
||||
@ -62,6 +74,10 @@ def welcome(versions: dict):
|
||||
|
||||
def main():
|
||||
versions = get_versions()
|
||||
if invokeai_is_running():
|
||||
print(f':exclamation: [bold red]Please terminate all running instances of InvokeAI before updating.[/red bold]')
|
||||
return
|
||||
|
||||
welcome(versions)
|
||||
|
||||
tag = None
|
||||
|
@ -196,16 +196,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrely += 1
|
||||
self.convert_models = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
|
||||
values=["Keep original format", "Convert to diffusers"],
|
||||
value=0,
|
||||
begin_entry_at=4,
|
||||
max_height=4,
|
||||
hidden=True, # will appear when imported models box is edited
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.cancel = self.add_widget_intelligent(
|
||||
npyscreen.ButtonPress,
|
||||
name="CANCEL",
|
||||
@ -240,8 +230,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
self.show_directory_fields.addVisibleWhenSelected(i)
|
||||
|
||||
self.show_directory_fields.when_value_edited = self._clear_scan_directory
|
||||
self.import_model_paths.when_value_edited = self._show_hide_convert
|
||||
self.autoload_directory.when_value_edited = self._show_hide_convert
|
||||
|
||||
def resize(self):
|
||||
super().resize()
|
||||
@ -252,13 +240,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
if not self.show_directory_fields.value:
|
||||
self.autoload_directory.value = ""
|
||||
|
||||
def _show_hide_convert(self):
|
||||
model_paths = self.import_model_paths.value or ""
|
||||
autoload_directory = self.autoload_directory.value or ""
|
||||
self.convert_models.hidden = (
|
||||
len(model_paths) == 0 and len(autoload_directory) == 0
|
||||
)
|
||||
|
||||
def _get_starter_model_labels(self) -> List[str]:
|
||||
window_width, window_height = get_terminal_size()
|
||||
label_width = 25
|
||||
@ -318,7 +299,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
.scan_directory: Path to a directory of models to scan and import
|
||||
.autoscan_on_startup: True if invokeai should scan and import at startup time
|
||||
.import_model_paths: list of URLs, repo_ids and file paths to import
|
||||
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
|
||||
"""
|
||||
# we're using a global here rather than storing the result in the parentapp
|
||||
# due to some bug in npyscreen that is causing attributes to be lost
|
||||
@ -354,7 +334,6 @@ class addModelsForm(npyscreen.FormMultiPage):
|
||||
|
||||
# URLs and the like
|
||||
selections.import_model_paths = self.import_model_paths.value.split()
|
||||
selections.convert_to_diffusers = self.convert_models.value[0] == 1
|
||||
|
||||
|
||||
class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
@ -367,7 +346,6 @@ class AddModelApplication(npyscreen.NPSAppManaged):
|
||||
scan_directory=None,
|
||||
autoscan_on_startup=None,
|
||||
import_model_paths=None,
|
||||
convert_to_diffusers=None,
|
||||
)
|
||||
|
||||
def onStart(self):
|
||||
@ -387,7 +365,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
directory_to_scan = selections.scan_directory
|
||||
scan_at_startup = selections.autoscan_on_startup
|
||||
potential_models_to_install = selections.import_model_paths
|
||||
convert_to_diffusers = selections.convert_to_diffusers
|
||||
|
||||
install_requested_models(
|
||||
install_initial_models=models_to_install,
|
||||
@ -395,7 +372,6 @@ def process_and_execute(opt: Namespace, selections: Namespace):
|
||||
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
|
||||
external_models=potential_models_to_install,
|
||||
scan_at_startup=scan_at_startup,
|
||||
convert_to_diffusers=convert_to_diffusers,
|
||||
precision="float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device())),
|
||||
|
@ -11,6 +11,7 @@ from tempfile import TemporaryFile
|
||||
|
||||
import requests
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import logging as dlogging
|
||||
from huggingface_hub import hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
@ -68,7 +69,6 @@ def install_requested_models(
|
||||
scan_directory: Path = None,
|
||||
external_models: List[str] = None,
|
||||
scan_at_startup: bool = False,
|
||||
convert_to_diffusers: bool = False,
|
||||
precision: str = "float16",
|
||||
purge_deleted: bool = False,
|
||||
config_file_path: Path = None,
|
||||
@ -111,20 +111,20 @@ def install_requested_models(
|
||||
if len(external_models)>0:
|
||||
print("== INSTALLING EXTERNAL MODELS ==")
|
||||
for path_url_or_repo in external_models:
|
||||
print(f'DEBUG: path_url_or_repo = {path_url_or_repo}')
|
||||
try:
|
||||
model_manager.heuristic_import(
|
||||
path_url_or_repo,
|
||||
convert=convert_to_diffusers,
|
||||
config_file_callback=_pick_configuration_file,
|
||||
commit_to_conf=config_file_path
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f'An exception has occurred: {str(e)}')
|
||||
|
||||
if scan_at_startup and scan_directory.is_dir():
|
||||
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
|
||||
argument = '--autoconvert'
|
||||
initfile = Path(Globals.root, Globals.initfile)
|
||||
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
||||
directory = str(scan_directory).replace('\\','/')
|
||||
@ -296,13 +296,21 @@ def _download_diffusion_weights(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
):
|
||||
repo_id = mconfig["repo_id"]
|
||||
revision = mconfig.get('revision',None)
|
||||
model_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if mconfig.get("format", None) == "diffusers"
|
||||
else AutoencoderKL
|
||||
)
|
||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
||||
extra_arg_list = [{"revision": revision}] if revision \
|
||||
else [{"revision": "fp16"}, {}] if precision == "float16" \
|
||||
else [{}]
|
||||
path = None
|
||||
|
||||
# quench safety checker warnings
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
for extra_args in extra_arg_list:
|
||||
try:
|
||||
path = download_from_hf(
|
||||
@ -318,6 +326,7 @@ def _download_diffusion_weights(
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
dlogging.set_verbosity(verbosity)
|
||||
return path
|
||||
|
||||
|
||||
@ -448,6 +457,8 @@ def new_config_file_contents(
|
||||
stanza["description"] = mod["description"]
|
||||
stanza["repo_id"] = mod["repo_id"]
|
||||
stanza["format"] = mod["format"]
|
||||
if "revision" in mod:
|
||||
stanza["revision"] = mod["revision"]
|
||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||
# so we no longer require these in INITIAL_MODELS.yaml
|
||||
if "width" in mod:
|
||||
@ -472,10 +483,9 @@ def new_config_file_contents(
|
||||
|
||||
conf[model] = stanza
|
||||
|
||||
# if no default model was chosen, then we select the first
|
||||
# one in the list
|
||||
# if no default model was chosen, then we select the first one in the list
|
||||
if not default_selected:
|
||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
||||
conf[list(conf.keys())[0]]["default"] = True
|
||||
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
|
@ -99,8 +99,9 @@ def expand_prompts(
|
||||
sequence = 0
|
||||
for command in commands:
|
||||
sequence += 1
|
||||
parent_conn.send(
|
||||
command + f' --fnformat="dp.{sequence:04}.{{prompt}}.png"'
|
||||
format = _get_fn_format(outdir, sequence)
|
||||
parent_conn.send_bytes(
|
||||
(command + f' --fnformat="{format}"').encode('utf-8')
|
||||
)
|
||||
parent_conn.close()
|
||||
else:
|
||||
@ -110,7 +111,20 @@ def expand_prompts(
|
||||
for p in children:
|
||||
p.terminate()
|
||||
|
||||
|
||||
def _get_fn_format(directory:str, sequence:int)->str:
|
||||
"""
|
||||
Get a filename that doesn't exceed filename length restrictions
|
||||
on the current platform.
|
||||
"""
|
||||
try:
|
||||
max_length = os.pathconf(directory,'PC_NAME_MAX')
|
||||
except:
|
||||
max_length = 255
|
||||
prefix = f'dp.{sequence:04}.'
|
||||
suffix = '.png'
|
||||
max_length -= len(prefix)+len(suffix)
|
||||
return f'{prefix}{{prompt:0.{max_length}}}{suffix}'
|
||||
|
||||
class MessageToStdin(object):
|
||||
def __init__(self, connection: Connection):
|
||||
self.connection = connection
|
||||
@ -119,7 +133,7 @@ class MessageToStdin(object):
|
||||
def readline(self) -> str:
|
||||
try:
|
||||
if len(self.linebuffer) == 0:
|
||||
message = self.connection.recv()
|
||||
message = self.connection.recv_bytes().decode('utf-8')
|
||||
self.linebuffer = message.split("\n")
|
||||
result = self.linebuffer.pop(0)
|
||||
return result
|
||||
|
@ -467,8 +467,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps)
|
||||
with InvokeAIDiffuserComponent.custom_attention_context(self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps)
|
||||
):
|
||||
|
||||
yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
|
||||
|
@ -255,8 +255,8 @@ class Inpaint(Img2Img):
|
||||
pipeline.scheduler = sampler
|
||||
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
conditioning_data = (ConditioningData(uc, c, cfg_scale)
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
conditioning_data = (ConditioningData(uc, c, cfg_scale, extra_conditioning_info)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
|
||||
|
@ -372,12 +372,6 @@ class ModelManager(object):
|
||||
)
|
||||
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
# try:
|
||||
# if self.list_models()[self.current_model]['status'] == 'active':
|
||||
# self.offload_model(self.current_model)
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
@ -423,9 +417,9 @@ class ModelManager(object):
|
||||
pipeline_args.update(cache_dir=global_cache_dir("hub"))
|
||||
if using_fp16:
|
||||
pipeline_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
else:
|
||||
fp_args_list = [{}]
|
||||
revision = mconfig.get('revision') or ('fp16' if using_fp16 else None)
|
||||
fp_args_list = [{"revision": revision}] if revision else []
|
||||
fp_args_list.append({})
|
||||
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
@ -288,16 +288,7 @@ class InvokeAICrossAttentionMixin:
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
|
||||
def restore_default_cross_attention(model, is_running_diffusers: bool, processors_to_restore: Optional[AttnProcessor]=None):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(processors_to_restore or CrossAttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def override_cross_attention(model, context: Context, is_running_diffusers = False):
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@ -323,22 +314,15 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
|
||||
|
||||
context.cross_attention_mask = mask.to(device)
|
||||
context.cross_attention_index_map = indices.to(device)
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
old_attn_processors = unet.attn_processors
|
||||
if torch.backends.mps.is_available():
|
||||
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
|
||||
unet.set_attn_processor(SwapCrossAttnProcessor())
|
||||
else:
|
||||
context.register_cross_attention_modules(model)
|
||||
inject_attention_function(model, context)
|
||||
|
||||
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
|
@ -12,17 +12,6 @@ class DDIMSampler(Sampler):
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
# This is the central routine
|
||||
@torch.no_grad()
|
||||
|
@ -38,15 +38,6 @@ class CFGDenoiser(nn.Module):
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
|
@ -14,17 +14,6 @@ class PLMSSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps, device)
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.override_attention_processors(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.restore_default_cross_attention()
|
||||
|
||||
|
||||
# this is the essential routine
|
||||
@torch.no_grad()
|
||||
|
@ -1,18 +1,18 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Callable, Optional, Union, Any, Dict
|
||||
from typing import Callable, Optional, Union, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.models.cross_attention import AttnProcessor
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.cross_attention_control import (
|
||||
Arguments,
|
||||
restore_default_cross_attention,
|
||||
override_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
Context,
|
||||
get_cross_attention_modules,
|
||||
CrossAttentionType,
|
||||
@ -84,66 +84,45 @@ class InvokeAIDiffuserComponent:
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = Globals.sequential_guidance
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self, extra_conditioning_info: Optional[ExtraConditioningInfo], step_count: int
|
||||
clss,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int
|
||||
):
|
||||
old_attn_processor = None
|
||||
old_attn_processors = None
|
||||
if extra_conditioning_info and (
|
||||
extra_conditioning_info.wants_cross_attention_control
|
||||
| extra_conditioning_info.has_lora_conditions
|
||||
):
|
||||
old_attn_processor = self.override_attention_processors(
|
||||
extra_conditioning_info, step_count=step_count
|
||||
)
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.has_lora_conditions:
|
||||
for condition in extra_conditioning_info.lora_conditions:
|
||||
condition() # target model is stored in condition state for some reason
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
cross_attention_control_context,
|
||||
)
|
||||
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if old_attn_processor is not None:
|
||||
self.restore_default_cross_attention(old_attn_processor)
|
||||
if old_attn_processors is not None:
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
if extra_conditioning_info and extra_conditioning_info.has_lora_conditions:
|
||||
for lora_condition in extra_conditioning_info.lora_conditions:
|
||||
lora_condition.unload()
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def override_attention_processors(
|
||||
self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
) -> Dict[str, AttnProcessor]:
|
||||
"""
|
||||
setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
the previous attention processor is returned so that the caller can restore it later.
|
||||
"""
|
||||
old_attn_processors = self.model.attn_processors
|
||||
|
||||
# Load lora conditions into the model
|
||||
if conditioning.has_lora_conditions:
|
||||
for condition in conditioning.lora_conditions:
|
||||
condition(self.model)
|
||||
|
||||
if conditioning.wants_cross_attention_control:
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=conditioning.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
override_cross_attention(
|
||||
self.model,
|
||||
self.cross_attention_control_context,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
)
|
||||
return old_attn_processors
|
||||
|
||||
def restore_default_cross_attention(
|
||||
self, processors_to_restore: Optional[dict[str, "AttnProcessor"]] = None
|
||||
):
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(
|
||||
self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
processors_to_restore=processors_to_restore,
|
||||
)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
|
@ -1,15 +1,16 @@
|
||||
import re
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from compel import Compel
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from filelock import FileLock, Timeout
|
||||
from safetensors.torch import load_file
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ..invoke.globals import global_lora_models_dir
|
||||
from ..invoke.devices import choose_torch_device
|
||||
|
||||
"""
|
||||
This module supports loading LoRA weights trained with https://github.com/kohya-ss/sd-scripts
|
||||
@ -17,6 +18,11 @@ To be removed once support for diffusers LoRA weights is well supported
|
||||
"""
|
||||
|
||||
|
||||
class IncompatibleModelException(Exception):
|
||||
"Raised when there is an attempt to load a LoRA into a model that is incompatible with it"
|
||||
pass
|
||||
|
||||
|
||||
class LoRALayer:
|
||||
lora_name: str
|
||||
name: str
|
||||
@ -31,18 +37,14 @@ class LoRALayer:
|
||||
self.name = name
|
||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||
|
||||
def forward(self, lora, input_h, output):
|
||||
def forward(self, lora, input_h):
|
||||
if self.mid is None:
|
||||
output = (
|
||||
output
|
||||
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
|
||||
)
|
||||
weight = self.up(self.down(*input_h))
|
||||
else:
|
||||
output = (
|
||||
output
|
||||
+ self.up(self.mid(self.down(*input_h))) * lora.multiplier * self.scale
|
||||
)
|
||||
return output
|
||||
weight = self.up(self.mid(self.down(*input_h)))
|
||||
|
||||
return weight * lora.multiplier * self.scale
|
||||
|
||||
|
||||
class LoHALayer:
|
||||
lora_name: str
|
||||
@ -64,8 +66,7 @@ class LoHALayer:
|
||||
self.name = name
|
||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||
|
||||
def forward(self, lora, input_h, output):
|
||||
|
||||
def forward(self, lora, input_h):
|
||||
if type(self.org_module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
@ -80,21 +81,87 @@ class LoHALayer:
|
||||
extra_args = {}
|
||||
|
||||
if self.t1 is None:
|
||||
weight = ((self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b))
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
else:
|
||||
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
|
||||
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
|
||||
rebuild1 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a
|
||||
)
|
||||
rebuild2 = torch.einsum(
|
||||
"i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a
|
||||
)
|
||||
weight = rebuild1 * rebuild2
|
||||
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
return output + op(
|
||||
return op(
|
||||
*input_h,
|
||||
(weight + bias).view(self.org_module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
) * lora.multiplier * self.scale
|
||||
|
||||
class LoKRLayer:
|
||||
lora_name: str
|
||||
name: str
|
||||
scale: float
|
||||
|
||||
w1: Optional[torch.Tensor] = None
|
||||
w1_a: Optional[torch.Tensor] = None
|
||||
w1_b: Optional[torch.Tensor] = None
|
||||
w2: Optional[torch.Tensor] = None
|
||||
w2_a: Optional[torch.Tensor] = None
|
||||
w2_b: Optional[torch.Tensor] = None
|
||||
t2: Optional[torch.Tensor] = None
|
||||
bias: Optional[torch.Tensor] = None
|
||||
|
||||
org_module: torch.nn.Module
|
||||
|
||||
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
|
||||
self.lora_name = lora_name
|
||||
self.name = name
|
||||
self.scale = alpha / rank if (alpha and rank) else 1.0
|
||||
|
||||
def forward(self, lora, input_h):
|
||||
|
||||
if type(self.org_module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
stride=self.org_module.stride,
|
||||
padding=self.org_module.padding,
|
||||
dilation=self.org_module.dilation,
|
||||
groups=self.org_module.groups,
|
||||
)
|
||||
|
||||
else:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
|
||||
w2 = self.w2
|
||||
if w2 is None:
|
||||
if self.t2 is None:
|
||||
w2 = self.w2_a @ self.w2_b
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
|
||||
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
weight = torch.kron(w1, w2).reshape(self.org_module.weight.shape)
|
||||
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
return op(
|
||||
*input_h,
|
||||
(weight + bias).view(self.org_module.weight.shape),
|
||||
None,
|
||||
**extra_args
|
||||
) * lora.multiplier * self.scale
|
||||
|
||||
|
||||
class LoRAModuleWrapper:
|
||||
unet: UNet2DConditionModel
|
||||
@ -111,12 +178,22 @@ class LoRAModuleWrapper:
|
||||
self.applied_loras = {}
|
||||
self.loaded_loras = {}
|
||||
|
||||
self.UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention", "ResnetBlock2D", "Downsample2D", "Upsample2D", "SpatialTransformer"]
|
||||
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = ["ResidualAttentionBlock", "CLIPAttention", "CLIPMLP"]
|
||||
self.UNET_TARGET_REPLACE_MODULE = [
|
||||
"Transformer2DModel",
|
||||
"Attention",
|
||||
"ResnetBlock2D",
|
||||
"Downsample2D",
|
||||
"Upsample2D",
|
||||
"SpatialTransformer",
|
||||
]
|
||||
self.TEXT_ENCODER_TARGET_REPLACE_MODULE = [
|
||||
"ResidualAttentionBlock",
|
||||
"CLIPAttention",
|
||||
"CLIPMLP",
|
||||
]
|
||||
self.LORA_PREFIX_UNET = "lora_unet"
|
||||
self.LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
|
||||
def find_modules(
|
||||
prefix, root_module: torch.nn.Module, target_replace_modules
|
||||
) -> dict[str, torch.nn.Module]:
|
||||
@ -147,7 +224,6 @@ class LoRAModuleWrapper:
|
||||
self.LORA_PREFIX_UNET, unet, self.UNET_TARGET_REPLACE_MODULE
|
||||
)
|
||||
|
||||
|
||||
def lora_forward_hook(self, name):
|
||||
wrapper = self
|
||||
|
||||
@ -159,7 +235,7 @@ class LoRAModuleWrapper:
|
||||
layer = lora.layers.get(name, None)
|
||||
if layer is None:
|
||||
continue
|
||||
output = layer.forward(lora, input_h, output)
|
||||
output += layer.forward(lora, input_h)
|
||||
return output
|
||||
|
||||
return lora_forward
|
||||
@ -180,6 +256,7 @@ class LoRAModuleWrapper:
|
||||
def clear_loaded_loras(self):
|
||||
self.loaded_loras.clear()
|
||||
|
||||
|
||||
class LoRA:
|
||||
name: str
|
||||
layers: dict[str, LoRALayer]
|
||||
@ -205,7 +282,6 @@ class LoRA:
|
||||
state_dict_groupped[stem] = dict()
|
||||
state_dict_groupped[stem][leaf] = value
|
||||
|
||||
|
||||
for stem, values in state_dict_groupped.items():
|
||||
if stem.startswith(self.wrapper.LORA_PREFIX_TEXT_ENCODER):
|
||||
wrapped = self.wrapper.text_modules.get(stem, None)
|
||||
@ -226,34 +302,59 @@ class LoRA:
|
||||
if "alpha" in values:
|
||||
alpha = values["alpha"].item()
|
||||
|
||||
if "bias_indices" in values and "bias_values" in values and "bias_size" in values:
|
||||
if (
|
||||
"bias_indices" in values
|
||||
and "bias_values" in values
|
||||
and "bias_size" in values
|
||||
):
|
||||
bias = torch.sparse_coo_tensor(
|
||||
values["bias_indices"],
|
||||
values["bias_values"],
|
||||
tuple(values["bias_size"]),
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
value_down = values["lora_down.weight"]
|
||||
value_mid = values.get("lora_mid.weight", None)
|
||||
value_up = values["lora_up.weight"]
|
||||
value_mid = values.get("lora_mid.weight", None)
|
||||
value_up = values["lora_up.weight"]
|
||||
|
||||
if type(wrapped) == torch.nn.Conv2d:
|
||||
if value_mid is not None:
|
||||
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], (1, 1), bias=False)
|
||||
layer_mid = torch.nn.Conv2d(value_mid.shape[1], value_mid.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
|
||||
layer_down = torch.nn.Conv2d(
|
||||
value_down.shape[1], value_down.shape[0], (1, 1), bias=False
|
||||
)
|
||||
layer_mid = torch.nn.Conv2d(
|
||||
value_mid.shape[1],
|
||||
value_mid.shape[0],
|
||||
wrapped.kernel_size,
|
||||
wrapped.stride,
|
||||
wrapped.padding,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
layer_down = torch.nn.Conv2d(value_down.shape[1], value_down.shape[0], wrapped.kernel_size, wrapped.stride, wrapped.padding, bias=False)
|
||||
layer_mid = None
|
||||
layer_down = torch.nn.Conv2d(
|
||||
value_down.shape[1],
|
||||
value_down.shape[0],
|
||||
wrapped.kernel_size,
|
||||
wrapped.stride,
|
||||
wrapped.padding,
|
||||
bias=False,
|
||||
)
|
||||
layer_mid = None
|
||||
|
||||
layer_up = torch.nn.Conv2d(value_up.shape[1], value_up.shape[0], (1, 1), bias=False)
|
||||
layer_up = torch.nn.Conv2d(
|
||||
value_up.shape[1], value_up.shape[0], (1, 1), bias=False
|
||||
)
|
||||
|
||||
elif type(wrapped) == torch.nn.Linear:
|
||||
layer_down = torch.nn.Linear(value_down.shape[1], value_down.shape[0], bias=False)
|
||||
layer_mid = None
|
||||
layer_up = torch.nn.Linear(value_up.shape[1], value_up.shape[0], bias=False)
|
||||
layer_down = torch.nn.Linear(
|
||||
value_down.shape[1], value_down.shape[0], bias=False
|
||||
)
|
||||
layer_mid = None
|
||||
layer_up = torch.nn.Linear(
|
||||
value_up.shape[1], value_up.shape[0], bias=False
|
||||
)
|
||||
|
||||
else:
|
||||
print(
|
||||
@ -261,52 +362,90 @@ class LoRA:
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
layer_down.weight.copy_(value_down)
|
||||
if layer_mid is not None:
|
||||
layer_mid.weight.copy_(value_mid)
|
||||
layer_up.weight.copy_(value_up)
|
||||
|
||||
|
||||
layer_down.to(device=self.device, dtype=self.dtype)
|
||||
if layer_mid is not None:
|
||||
layer_mid.to(device=self.device, dtype=self.dtype)
|
||||
layer_up.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
rank = value_down.shape[0]
|
||||
|
||||
layer = LoRALayer(self.name, stem, rank, alpha)
|
||||
#layer.bias = bias # TODO: find and debug lora/locon with bias
|
||||
# layer.bias = bias # TODO: find and debug lora/locon with bias
|
||||
layer.down = layer_down
|
||||
layer.mid = layer_mid
|
||||
layer.up = layer_up
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
|
||||
rank = values["hada_w1_b"].shape[0]
|
||||
|
||||
layer = LoHALayer(self.name, stem, rank, alpha)
|
||||
layer.org_module = wrapped
|
||||
layer.bias = bias
|
||||
|
||||
layer.w1_a = values["hada_w1_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w1_b = values["hada_w1_b"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_a = values["hada_w2_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_b = values["hada_w2_b"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w1_a = values["hada_w1_a"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
layer.w1_b = values["hada_w1_b"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
layer.w2_a = values["hada_w2_a"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
layer.w2_b = values["hada_w2_b"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
if "hada_t1" in values:
|
||||
layer.t1 = values["hada_t1"].to(device=self.device, dtype=self.dtype)
|
||||
layer.t1 = values["hada_t1"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
layer.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
layer.t2 = values["hada_t2"].to(device=self.device, dtype=self.dtype)
|
||||
layer.t2 = values["hada_t2"].to(
|
||||
device=self.device, dtype=self.dtype
|
||||
)
|
||||
else:
|
||||
layer.t2 = None
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
rank = values["lokr_w2_b"].shape[0]
|
||||
else:
|
||||
rank = None # unscaled
|
||||
|
||||
layer = LoKRLayer(self.name, stem, rank, alpha)
|
||||
layer.org_module = wrapped
|
||||
layer.bias = bias
|
||||
|
||||
if "lokr_w1" in values:
|
||||
layer.w1 = values["lokr_w1"].to(device=self.device, dtype=self.dtype)
|
||||
else:
|
||||
layer.w1_a = values["lokr_w1_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w1_b = values["lokr_w1_b"].to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if "lokr_w2" in values:
|
||||
layer.w2 = values["lokr_w2"].to(device=self.device, dtype=self.dtype)
|
||||
else:
|
||||
layer.w2_a = values["lokr_w2_a"].to(device=self.device, dtype=self.dtype)
|
||||
layer.w2_b = values["lokr_w2_b"].to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if "lokr_t2" in values:
|
||||
layer.t2 = values["lokr_t2"].to(device=self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
else:
|
||||
print(
|
||||
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"
|
||||
@ -317,9 +456,11 @@ class LoRA:
|
||||
|
||||
|
||||
class KohyaLoraManager:
|
||||
def __init__(self, pipe, lora_path):
|
||||
lora_path = Path(global_lora_models_dir())
|
||||
vector_length_cache_path = lora_path / '.vectorlength.cache'
|
||||
|
||||
def __init__(self, pipe):
|
||||
self.unet = pipe.unet
|
||||
self.lora_path = lora_path
|
||||
self.wrapper = LoRAModuleWrapper(pipe.unet, pipe.text_encoder)
|
||||
self.text_encoder = pipe.text_encoder
|
||||
self.device = torch.device(choose_torch_device())
|
||||
@ -332,6 +473,9 @@ class KohyaLoraManager:
|
||||
else:
|
||||
checkpoint = torch.load(path_file, map_location="cpu")
|
||||
|
||||
if not self.check_model_compatibility(checkpoint):
|
||||
raise IncompatibleModelException
|
||||
|
||||
lora = LoRA(name, self.device, self.dtype, self.wrapper, multiplier)
|
||||
lora.load_from_dict(checkpoint)
|
||||
self.wrapper.loaded_loras[name] = lora
|
||||
@ -339,12 +483,14 @@ class KohyaLoraManager:
|
||||
return lora
|
||||
|
||||
def apply_lora_model(self, name, mult: float = 1.0):
|
||||
path_file = None
|
||||
for suffix in ["ckpt", "safetensors", "pt"]:
|
||||
path_file = Path(self.lora_path, f"{name}.{suffix}")
|
||||
if path_file.exists():
|
||||
path_files = [x for x in Path(self.lora_path).glob(f"**/{name}.{suffix}")]
|
||||
if len(path_files):
|
||||
path_file = path_files[0]
|
||||
print(f" | Loading lora {path_file.name} with weight {mult}")
|
||||
break
|
||||
if not path_file.exists():
|
||||
if not path_file:
|
||||
print(f" ** Unable to find lora: {name}")
|
||||
return
|
||||
|
||||
@ -355,13 +501,89 @@ class KohyaLoraManager:
|
||||
lora.multiplier = mult
|
||||
self.wrapper.applied_loras[name] = lora
|
||||
|
||||
def unload_applied_lora(self, lora_name: str):
|
||||
def unload_applied_lora(self, lora_name: str) -> bool:
|
||||
"""If the indicated LoRA has previously been applied then
|
||||
unload it and return True. Return False if the LoRA was
|
||||
not previously applied (for status reporting)
|
||||
"""
|
||||
if lora_name in self.wrapper.applied_loras:
|
||||
del self.wrapper.applied_loras[lora_name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def unload_lora(self, lora_name: str):
|
||||
def unload_lora(self, lora_name: str) -> bool:
|
||||
if lora_name in self.wrapper.loaded_loras:
|
||||
del self.wrapper.loaded_loras[lora_name]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear_loras(self):
|
||||
self.wrapper.clear_applied_loras()
|
||||
|
||||
def check_model_compatibility(self, checkpoint) -> bool:
|
||||
"""Checks whether the LoRA checkpoint is compatible with the token vector
|
||||
length of the model that this manager is associated with.
|
||||
"""
|
||||
model_token_vector_length = (
|
||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
)
|
||||
lora_token_vector_length = self.vector_length_from_checkpoint(checkpoint)
|
||||
return model_token_vector_length == lora_token_vector_length
|
||||
|
||||
@staticmethod
|
||||
def vector_length_from_checkpoint(checkpoint: dict) -> int:
|
||||
"""Return the vector token length for the passed LoRA checkpoint object.
|
||||
This is used to determine which SD model version the LoRA was based on.
|
||||
768 -> SDv1
|
||||
1024-> SDv2
|
||||
"""
|
||||
key1 = "lora_te_text_model_encoder_layers_0_mlp_fc1.lora_down.weight"
|
||||
key2 = "lora_te_text_model_encoder_layers_0_self_attn_k_proj.hada_w1_a"
|
||||
lora_token_vector_length = (
|
||||
checkpoint[key1].shape[1]
|
||||
if key1 in checkpoint
|
||||
else checkpoint[key2].shape[0]
|
||||
if key2 in checkpoint
|
||||
else 768
|
||||
)
|
||||
return lora_token_vector_length
|
||||
|
||||
@classmethod
|
||||
def vector_length_from_checkpoint_file(self, checkpoint_path: Path) -> int:
|
||||
with LoraVectorLengthCache(self.vector_length_cache_path) as cache:
|
||||
if str(checkpoint_path) not in cache:
|
||||
if checkpoint_path.suffix == ".safetensors":
|
||||
checkpoint = load_file(
|
||||
checkpoint_path.absolute().as_posix(), device="cpu"
|
||||
)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
cache[str(checkpoint_path)] = KohyaLoraManager.vector_length_from_checkpoint(
|
||||
checkpoint
|
||||
)
|
||||
return cache[str(checkpoint_path)]
|
||||
|
||||
class LoraVectorLengthCache(object):
|
||||
def __init__(self, cache_path: Path):
|
||||
self.cache_path = cache_path
|
||||
self.lock = FileLock(Path(cache_path.parent, ".cachelock"))
|
||||
self.cache = {}
|
||||
|
||||
def __enter__(self):
|
||||
self.lock.acquire(timeout=10)
|
||||
try:
|
||||
if self.cache_path.exists():
|
||||
with open(self.cache_path, "r") as json_file:
|
||||
self.cache = json.load(json_file)
|
||||
except Timeout:
|
||||
print(
|
||||
"** Can't acquire lock on lora vector length cache. Operations will be slower"
|
||||
)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
self.cache_path.unlink()
|
||||
return self.cache
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
with open(self.cache_path, "w") as json_file:
|
||||
json.dump(self.cache, json_file)
|
||||
self.lock.release()
|
||||
|
@ -1,66 +1,101 @@
|
||||
import os
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from pathlib import Path
|
||||
|
||||
from diffusers import UNet2DConditionModel, StableDiffusionPipeline
|
||||
from ldm.invoke.globals import global_lora_models_dir
|
||||
from .kohya_lora_manager import KohyaLoraManager
|
||||
from .kohya_lora_manager import KohyaLoraManager, IncompatibleModelException
|
||||
from typing import Optional, Dict
|
||||
|
||||
class LoraCondition:
|
||||
name: str
|
||||
weight: float
|
||||
|
||||
def __init__(self, name, weight: float = 1.0, kohya_manager: Optional[KohyaLoraManager]=None):
|
||||
def __init__(self,
|
||||
name,
|
||||
weight: float = 1.0,
|
||||
unet: UNet2DConditionModel=None, # for diffusers format LoRAs
|
||||
kohya_manager: Optional[KohyaLoraManager]=None, # for KohyaLoraManager-compatible LoRAs
|
||||
):
|
||||
self.name = name
|
||||
self.weight = weight
|
||||
self.kohya_manager = kohya_manager
|
||||
self.unet = unet
|
||||
|
||||
def __call__(self, model):
|
||||
def __call__(self):
|
||||
# TODO: make model able to load from huggingface, rather then just local files
|
||||
path = Path(global_lora_models_dir(), self.name)
|
||||
if path.is_dir():
|
||||
if model.load_attn_procs:
|
||||
if not self.unet:
|
||||
print(f" ** Unable to load diffusers-format LoRA {self.name}: unet is None")
|
||||
return
|
||||
if self.unet.load_attn_procs:
|
||||
file = Path(path, "pytorch_lora_weights.bin")
|
||||
if file.is_file():
|
||||
print(f">> Loading LoRA: {path}")
|
||||
model.load_attn_procs(path.absolute().as_posix())
|
||||
self.unet.load_attn_procs(path.absolute().as_posix())
|
||||
else:
|
||||
print(f" ** Unable to find valid LoRA at: {path}")
|
||||
else:
|
||||
print(" ** Invalid Model to load LoRA")
|
||||
elif self.kohya_manager:
|
||||
self.kohya_manager.apply_lora_model(self.name,self.weight)
|
||||
try:
|
||||
self.kohya_manager.apply_lora_model(self.name,self.weight)
|
||||
except IncompatibleModelException:
|
||||
print(f" ** LoRA {self.name} is incompatible with this model; will generate without the LoRA applied.")
|
||||
else:
|
||||
print(" ** Unable to load LoRA")
|
||||
|
||||
def unload(self):
|
||||
if self.kohya_manager:
|
||||
if self.kohya_manager and self.kohya_manager.unload_applied_lora(self.name):
|
||||
print(f'>> unloading LoRA {self.name}')
|
||||
self.kohya_manager.unload_applied_lora(self.name)
|
||||
|
||||
|
||||
class LoraManager:
|
||||
def __init__(self, pipe):
|
||||
def __init__(self, pipe: StableDiffusionPipeline):
|
||||
# Kohya class handles lora not generated through diffusers
|
||||
self.kohya = KohyaLoraManager(pipe, global_lora_models_dir())
|
||||
self.kohya = KohyaLoraManager(pipe)
|
||||
self.unet = pipe.unet
|
||||
|
||||
def set_loras_conditions(self, lora_weights: list):
|
||||
conditions = []
|
||||
if len(lora_weights) > 0:
|
||||
for lora in lora_weights:
|
||||
conditions.append(LoraCondition(lora.model, lora.weight, self.kohya))
|
||||
conditions.append(LoraCondition(lora.model, lora.weight, self.unet, self.kohya))
|
||||
|
||||
if len(conditions) > 0:
|
||||
return conditions
|
||||
|
||||
return None
|
||||
|
||||
def list_compatible_loras(self)->Dict[str, Path]:
|
||||
'''
|
||||
List all the LoRAs in the global lora directory that
|
||||
are compatible with the current model. Return a dictionary
|
||||
of the lora basename and its path.
|
||||
'''
|
||||
model_length = self.kohya.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
return self.list_loras(model_length)
|
||||
|
||||
@classmethod
|
||||
def list_loras(self)->Dict[str, Path]:
|
||||
@staticmethod
|
||||
def list_loras(token_vector_length:int=None)->Dict[str, Path]:
|
||||
'''List the LoRAS in the global lora directory.
|
||||
If token_vector_length is provided, then only return
|
||||
LoRAS that have the indicated length:
|
||||
768: v1 models
|
||||
1024: v2 models
|
||||
'''
|
||||
path = Path(global_lora_models_dir())
|
||||
models_found = dict()
|
||||
for root,_,files in os.walk(path):
|
||||
for x in files:
|
||||
name = Path(x).stem
|
||||
suffix = Path(x).suffix
|
||||
if suffix in [".ckpt", ".pt", ".safetensors"]:
|
||||
models_found[name]=Path(root,x)
|
||||
if suffix not in [".ckpt", ".pt", ".safetensors"]:
|
||||
continue
|
||||
path = Path(root,x)
|
||||
if token_vector_length is None:
|
||||
models_found[name]=Path(root,x) # unconditional addition
|
||||
elif token_vector_length == KohyaLoraManager.vector_length_from_checkpoint_file(path):
|
||||
models_found[name]=Path(root,x) # conditional on the base model matching
|
||||
return models_found
|
||||
|
||||
|
@ -34,7 +34,7 @@ dependencies = [
|
||||
"clip_anytorch",
|
||||
"compel~=1.1.0",
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.14",
|
||||
"diffusers[torch]==0.14",
|
||||
"dnspython==2.2.1",
|
||||
"einops",
|
||||
"eventlet",
|
||||
|
Reference in New Issue
Block a user