mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into diffusers_cross_attention_control_reimplementation
This commit is contained in:
commit
5ce62e00c9
@ -41,6 +41,8 @@ waifu-diffusion-1.4:
|
||||
description: Latest waifu diffusion 1.4 (diffusers version)
|
||||
format: diffusers
|
||||
repo_id: hakurei/waifu-diffusion
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
recommended: True
|
||||
waifu-diffusion-1.3:
|
||||
description: Stable Diffusion 1.4 fine tuned on anime-styled images (ckpt version) (4.27 GB)
|
||||
@ -77,6 +79,8 @@ anything-4.0:
|
||||
description: High-quality, highly detailed anime style images with just a few prompts
|
||||
format: diffusers
|
||||
repo_id: andite/anything-v4.0
|
||||
vae:
|
||||
repo_id: stabilityai/sd-vae-ft-mse
|
||||
recommended: False
|
||||
papercut-1.0:
|
||||
description: SD 1.5 fine-tuned for papercut art (use "PaperCut" in your prompts) (2.13 GB)
|
||||
|
@ -18,10 +18,9 @@ prompts you to select the models to merge, how to merge them, and the
|
||||
merged model name.
|
||||
|
||||
Alternatively you may activate InvokeAI's virtual environment from the
|
||||
command line, and call the script via `merge_models_fe.py` (the "fe"
|
||||
stands for "front end"). There is also a version that accepts
|
||||
command-line arguments, which you can run with the command
|
||||
`merge_models.py`.
|
||||
command line, and call the script via `merge_models --gui` to open up
|
||||
a version that has a nice graphical front end. To get the commandline-
|
||||
only version, omit `--gui`.
|
||||
|
||||
The user interface for the text-based interactive script is
|
||||
straightforward. It shows you a series of setting fields. Use control-N (^N)
|
||||
@ -47,7 +46,7 @@ under the selected name and register it with InvokeAI.
|
||||
display all the diffusers-style models that InvokeAI knows about.
|
||||
If you do not see the model you are looking for, then it is probably
|
||||
a legacy checkpoint model and needs to be converted using the
|
||||
`invoke.py` command-line client and its `!optimize` command. You
|
||||
`invoke` command-line client and its `!optimize` command. You
|
||||
must select at least two models to merge. The third can be left at
|
||||
"None" if you desire.
|
||||
|
||||
|
@ -54,8 +54,8 @@ Please enter 1, 2, 3, or 4: [1] 3
|
||||
```
|
||||
|
||||
From the command line, with the InvokeAI virtual environment active,
|
||||
you can launch the front end with the command
|
||||
`textual_inversion_fe`.
|
||||
you can launch the front end with the command `textual_inversion
|
||||
--gui`.
|
||||
|
||||
This will launch a text-based front end that will look like this:
|
||||
|
||||
@ -219,11 +219,9 @@ term. For example `a plate of banana sushi in <psychedelic> style`.
|
||||
|
||||
## **Training with the Command-Line Script**
|
||||
|
||||
InvokeAI also comes with a traditional command-line script for
|
||||
launching textual inversion training. It is named
|
||||
`textual_inversion`, and can be launched from within the
|
||||
"developer's console", or from the command line after activating
|
||||
InvokeAI's virtual environment.
|
||||
Training can also be done using a traditional command-line script. It
|
||||
can be launched from within the "developer's console", or from the
|
||||
command line after activating InvokeAI's virtual environment.
|
||||
|
||||
It accepts a large number of arguments, which can be summarized by
|
||||
passing the `--help` argument:
|
||||
@ -234,7 +232,7 @@ textual_inversion --help
|
||||
|
||||
Typical usage is shown here:
|
||||
```sh
|
||||
python textual_inversion.py \
|
||||
textual_inversion \
|
||||
--model=stable-diffusion-1.5 \
|
||||
--resolution=512 \
|
||||
--learnable_property=style \
|
||||
|
@ -2,7 +2,7 @@
|
||||
accelerate
|
||||
albumentations
|
||||
datasets
|
||||
diffusers[torch]~=0.11
|
||||
diffusers[torch]~=0.12
|
||||
dnspython==2.2.1
|
||||
einops
|
||||
eventlet
|
||||
@ -37,7 +37,7 @@ taming-transformers-rom1504
|
||||
test-tube>=0.7.5
|
||||
torch-fidelity
|
||||
torchmetrics
|
||||
transformers~=4.25
|
||||
transformers~=4.26
|
||||
windows-curses; sys_platform == 'win32'
|
||||
https://github.com/Birch-san/k-diffusion/archive/refs/heads/mps.zip#egg=k-diffusion
|
||||
https://github.com/invoke-ai/PyPatchMatch/archive/refs/tags/0.1.5.zip#egg=pypatchmatch
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
4
frontend/dist/index.html
vendored
4
frontend/dist/index.html
vendored
@ -7,7 +7,7 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||
<link rel="shortcut icon" type="icon" href="./assets/favicon.0d253ced.ico" />
|
||||
<script type="module" crossorigin src="./assets/index.d5dcf0c5.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index.be0f03f1.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index.8badc8b4.css">
|
||||
<script type="module">try{import.meta.url;import("_").catch(()=>1);}catch(e){}window.__vite_is_modern_browser=true;</script>
|
||||
<script type="module">!function(){if(window.__vite_is_modern_browser)return;console.warn("vite: loading legacy build because dynamic import or import.meta.url is unsupported, syntax error above should be ignored");var e=document.getElementById("vite-legacy-polyfill"),n=document.createElement("script");n.src=e.src,n.onload=function(){System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))},document.body.appendChild(n)}();</script>
|
||||
@ -18,6 +18,6 @@
|
||||
|
||||
<script nomodule>!function(){var e=document,t=e.createElement("script");if(!("noModule"in t)&&"onbeforeload"in t){var n=!1;e.addEventListener("beforeload",(function(e){if(e.target===t)n=!0;else if(!e.target.hasAttribute("nomodule")||!n)return;e.preventDefault()}),!0),t.type="module",t.src=".",e.head.appendChild(t),t.remove()}}();</script>
|
||||
<script nomodule crossorigin id="vite-legacy-polyfill" src="./assets/polyfills-legacy-dde3a68a.js"></script>
|
||||
<script nomodule crossorigin id="vite-legacy-entry" data-src="./assets/index-legacy-dad2eee9.js">System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))</script>
|
||||
<script nomodule crossorigin id="vite-legacy-entry" data-src="./assets/index-legacy-279e042c.js">System.import(document.getElementById('vite-legacy-entry').getAttribute('data-src'))</script>
|
||||
</body>
|
||||
</html>
|
||||
|
2
frontend/dist/locales/modelmanager/en.json
vendored
2
frontend/dist/locales/modelmanager/en.json
vendored
@ -22,7 +22,7 @@
|
||||
"config": "Config",
|
||||
"configValidationMsg": "Path to the config file of your model.",
|
||||
"modelLocation": "Model Location",
|
||||
"modelLocationValidationMsg": "Path to where your model is located.",
|
||||
"modelLocationValidationMsg": "Path to where your model is located locally.",
|
||||
"repo_id": "Repo ID",
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
"vaeLocation": "VAE Location",
|
||||
|
@ -22,7 +22,7 @@
|
||||
"config": "Config",
|
||||
"configValidationMsg": "Path to the config file of your model.",
|
||||
"modelLocation": "Model Location",
|
||||
"modelLocationValidationMsg": "Path to where your model is located.",
|
||||
"modelLocationValidationMsg": "Path to where your model is located locally.",
|
||||
"repo_id": "Repo ID",
|
||||
"repoIDValidationMsg": "Online repository of your model",
|
||||
"vaeLocation": "VAE Location",
|
||||
|
@ -23,7 +23,7 @@ import {
|
||||
Tooltip,
|
||||
TooltipProps,
|
||||
} from '@chakra-ui/react';
|
||||
import React, { FocusEvent, useEffect, useMemo, useState } from 'react';
|
||||
import React, { FocusEvent, useMemo, useState, useEffect } from 'react';
|
||||
import { BiReset } from 'react-icons/bi';
|
||||
import IAIIconButton, { IAIIconButtonProps } from './IAIIconButton';
|
||||
import _ from 'lodash';
|
||||
@ -81,7 +81,7 @@ export default function IAISlider(props: IAIFullSliderProps) {
|
||||
withInput = false,
|
||||
isInteger = false,
|
||||
inputWidth = '5.5rem',
|
||||
inputReadOnly = true,
|
||||
inputReadOnly = false,
|
||||
withReset = false,
|
||||
hideTooltip = false,
|
||||
isCompact = false,
|
||||
@ -103,32 +103,35 @@ export default function IAISlider(props: IAIFullSliderProps) {
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const [localInputValue, setLocalInputValue] = useState<string>(String(value));
|
||||
const [localInputValue, setLocalInputValue] = useState<
|
||||
string | number | undefined
|
||||
>(String(value));
|
||||
|
||||
useEffect(() => {
|
||||
setLocalInputValue(value);
|
||||
}, [value]);
|
||||
|
||||
const numberInputMax = useMemo(
|
||||
() => (sliderNumberInputProps?.max ? sliderNumberInputProps.max : max),
|
||||
[max, sliderNumberInputProps?.max]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (String(value) !== localInputValue && localInputValue !== '') {
|
||||
setLocalInputValue(String(value));
|
||||
}
|
||||
}, [value, localInputValue, setLocalInputValue]);
|
||||
const handleSliderChange = (v: number) => {
|
||||
onChange(v);
|
||||
};
|
||||
|
||||
const handleInputBlur = (e: FocusEvent<HTMLInputElement>) => {
|
||||
if (e.target.value === '') e.target.value = String(min);
|
||||
const clamped = _.clamp(
|
||||
isInteger ? Math.floor(Number(e.target.value)) : Number(e.target.value),
|
||||
isInteger ? Math.floor(Number(e.target.value)) : Number(localInputValue),
|
||||
min,
|
||||
numberInputMax
|
||||
);
|
||||
setLocalInputValue(String(clamped));
|
||||
onChange(clamped);
|
||||
};
|
||||
|
||||
const handleInputChange = (v: number | string) => {
|
||||
setLocalInputValue(String(v));
|
||||
onChange(Number(v));
|
||||
setLocalInputValue(v);
|
||||
};
|
||||
|
||||
const handleResetDisable = () => {
|
||||
@ -172,7 +175,7 @@ export default function IAISlider(props: IAIFullSliderProps) {
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
onChange={handleInputChange}
|
||||
onChange={handleSliderChange}
|
||||
onMouseEnter={() => setShowTooltip(true)}
|
||||
onMouseLeave={() => setShowTooltip(false)}
|
||||
focusThumbOnChange={false}
|
||||
@ -236,13 +239,19 @@ export default function IAISlider(props: IAIFullSliderProps) {
|
||||
<NumberInputField
|
||||
className="invokeai__slider-number-input"
|
||||
width={inputWidth}
|
||||
minWidth={inputWidth}
|
||||
readOnly={inputReadOnly}
|
||||
minWidth={inputWidth}
|
||||
{...sliderNumberInputFieldProps}
|
||||
/>
|
||||
<NumberInputStepper {...sliderNumberInputStepperProps}>
|
||||
<NumberIncrementStepper className="invokeai__slider-number-stepper" />
|
||||
<NumberDecrementStepper className="invokeai__slider-number-stepper" />
|
||||
<NumberIncrementStepper
|
||||
onClick={() => onChange(Number(localInputValue))}
|
||||
className="invokeai__slider-number-stepper"
|
||||
/>
|
||||
<NumberDecrementStepper
|
||||
onClick={() => onChange(Number(localInputValue))}
|
||||
className="invokeai__slider-number-stepper"
|
||||
/>
|
||||
</NumberInputStepper>
|
||||
</NumberInput>
|
||||
)}
|
||||
|
28
frontend/src/common/hooks/useSingleAndDoubleClick.ts
Normal file
28
frontend/src/common/hooks/useSingleAndDoubleClick.ts
Normal file
@ -0,0 +1,28 @@
|
||||
// https://stackoverflow.com/a/73731908
|
||||
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
export function useSingleAndDoubleClick(
|
||||
handleSingleClick: () => void,
|
||||
handleDoubleClick: () => void,
|
||||
delay = 250
|
||||
) {
|
||||
const [click, setClick] = useState(0);
|
||||
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(() => {
|
||||
if (click === 1) {
|
||||
handleSingleClick();
|
||||
}
|
||||
setClick(0);
|
||||
}, delay);
|
||||
|
||||
if (click === 2) {
|
||||
handleDoubleClick();
|
||||
}
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
}, [click, handleSingleClick, handleDoubleClick, delay]);
|
||||
|
||||
return () => setClick((prev) => prev + 1);
|
||||
}
|
@ -42,6 +42,7 @@ import {
|
||||
} from 'features/canvas/store/canvasTypes';
|
||||
import { ChangeEvent } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useSingleAndDoubleClick } from 'common/hooks/useSingleAndDoubleClick';
|
||||
|
||||
export const selector = createSelector(
|
||||
[systemSelector, canvasSelector, isStagingSelector],
|
||||
@ -156,7 +157,12 @@ const IAICanvasOutpaintingControls = () => {
|
||||
|
||||
const handleSelectMoveTool = () => dispatch(setTool('move'));
|
||||
|
||||
const handleResetCanvasView = () => {
|
||||
const handleClickResetCanvasView = useSingleAndDoubleClick(
|
||||
() => handleResetCanvasView(false),
|
||||
() => handleResetCanvasView(true)
|
||||
);
|
||||
|
||||
const handleResetCanvasView = (shouldScaleTo1 = false) => {
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
if (!canvasBaseLayer) return;
|
||||
const clientRect = canvasBaseLayer.getClientRect({
|
||||
@ -165,6 +171,7 @@ const IAICanvasOutpaintingControls = () => {
|
||||
dispatch(
|
||||
resetCanvasView({
|
||||
contentRect: clientRect,
|
||||
shouldScaleTo1,
|
||||
})
|
||||
);
|
||||
};
|
||||
@ -247,7 +254,7 @@ const IAICanvasOutpaintingControls = () => {
|
||||
aria-label={`${t('unifiedcanvas:resetView')} (R)`}
|
||||
tooltip={`${t('unifiedcanvas:resetView')} (R)`}
|
||||
icon={<FaCrosshairs />}
|
||||
onClick={handleResetCanvasView}
|
||||
onClick={handleClickResetCanvasView}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
|
||||
|
@ -602,9 +602,10 @@ export const canvasSlice = createSlice({
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
contentRect: IRect;
|
||||
shouldScaleTo1?: boolean;
|
||||
}>
|
||||
) => {
|
||||
const { contentRect } = action.payload;
|
||||
const { contentRect, shouldScaleTo1 } = action.payload;
|
||||
const {
|
||||
stageDimensions: { width: stageWidth, height: stageHeight },
|
||||
} = state;
|
||||
@ -612,13 +613,15 @@ export const canvasSlice = createSlice({
|
||||
const { x, y, width, height } = contentRect;
|
||||
|
||||
if (width !== 0 && height !== 0) {
|
||||
const newScale = calculateScale(
|
||||
stageWidth,
|
||||
stageHeight,
|
||||
width,
|
||||
height,
|
||||
STAGE_PADDING_PERCENTAGE
|
||||
);
|
||||
const newScale = shouldScaleTo1
|
||||
? 1
|
||||
: calculateScale(
|
||||
stageWidth,
|
||||
stageHeight,
|
||||
width,
|
||||
height,
|
||||
STAGE_PADDING_PERCENTAGE
|
||||
);
|
||||
|
||||
const newCoordinates = calculateCoordinates(
|
||||
stageWidth,
|
||||
|
@ -399,11 +399,11 @@ const CurrentImageButtons = () => {
|
||||
{t('options:copyImageToLink')}
|
||||
</IAIButton>
|
||||
|
||||
<IAIButton leftIcon={<FaDownload />} size={'sm'}>
|
||||
<Link download={true} href={currentImage?.url}>
|
||||
<Link download={true} href={currentImage?.url}>
|
||||
<IAIButton leftIcon={<FaDownload />} size={'sm'} w="100%">
|
||||
{t('options:downloadImage')}
|
||||
</Link>
|
||||
</IAIButton>
|
||||
</IAIButton>
|
||||
</Link>
|
||||
</div>
|
||||
</IAIPopover>
|
||||
<IAIIconButton
|
||||
|
@ -75,11 +75,12 @@ const BoundingBoxSettings = () => {
|
||||
step={64}
|
||||
value={boundingBoxDimensions.width}
|
||||
onChange={handleChangeWidth}
|
||||
handleReset={handleResetWidth}
|
||||
sliderNumberInputProps={{ max: 4096 }}
|
||||
withSliderMarks
|
||||
withInput
|
||||
inputReadOnly
|
||||
withReset
|
||||
handleReset={handleResetWidth}
|
||||
/>
|
||||
<IAISlider
|
||||
label={t('options:height')}
|
||||
@ -88,11 +89,12 @@ const BoundingBoxSettings = () => {
|
||||
step={64}
|
||||
value={boundingBoxDimensions.height}
|
||||
onChange={handleChangeHeight}
|
||||
handleReset={handleResetHeight}
|
||||
sliderNumberInputProps={{ max: 4096 }}
|
||||
withSliderMarks
|
||||
withInput
|
||||
inputReadOnly
|
||||
withReset
|
||||
handleReset={handleResetHeight}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -124,11 +124,12 @@ const InfillAndScalingOptions = () => {
|
||||
step={64}
|
||||
value={scaledBoundingBoxDimensions.width}
|
||||
onChange={handleChangeScaledWidth}
|
||||
handleReset={handleResetScaledWidth}
|
||||
sliderNumberInputProps={{ max: 4096 }}
|
||||
withSliderMarks
|
||||
withInput
|
||||
inputReadOnly
|
||||
withReset
|
||||
handleReset={handleResetScaledWidth}
|
||||
/>
|
||||
<IAISlider
|
||||
isInputDisabled={!isManual}
|
||||
@ -140,11 +141,12 @@ const InfillAndScalingOptions = () => {
|
||||
step={64}
|
||||
value={scaledBoundingBoxDimensions.height}
|
||||
onChange={handleChangeScaledHeight}
|
||||
handleReset={handleResetScaledHeight}
|
||||
sliderNumberInputProps={{ max: 4096 }}
|
||||
withSliderMarks
|
||||
withInput
|
||||
inputReadOnly
|
||||
withReset
|
||||
handleReset={handleResetScaledHeight}
|
||||
/>
|
||||
<InpaintReplace />
|
||||
<IAISelect
|
||||
@ -166,12 +168,12 @@ const InfillAndScalingOptions = () => {
|
||||
onChange={(v) => {
|
||||
dispatch(setTileSize(v));
|
||||
}}
|
||||
handleReset={() => {
|
||||
dispatch(setTileSize(32));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setTileSize(32));
|
||||
}}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -36,7 +36,7 @@ export default function InpaintReplace() {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex alignItems={'center'} columnGap={'1rem'}>
|
||||
<Flex alignItems={'center'} columnGap={'0.2rem'}>
|
||||
<IAISlider
|
||||
label={t('options:inpaintReplace')}
|
||||
value={inpaintReplace}
|
||||
@ -51,7 +51,8 @@ export default function InpaintReplace() {
|
||||
withSliderMarks
|
||||
sliderMarkRightOffset={-2}
|
||||
withReset
|
||||
handleReset={() => dispatch(setInpaintReplace(1))}
|
||||
handleReset={() => dispatch(setInpaintReplace(0.1))}
|
||||
withInput
|
||||
isResetDisabled={!shouldUseInpaintReplace}
|
||||
/>
|
||||
<IAISwitch
|
||||
|
@ -1,113 +0,0 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { optionsSelector } from 'features/options/store/optionsSelectors';
|
||||
import {
|
||||
setSeamBlur,
|
||||
setSeamSize,
|
||||
setSeamSteps,
|
||||
setSeamStrength,
|
||||
} from 'features/options/store/optionsSlice';
|
||||
import _ from 'lodash';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[optionsSelector],
|
||||
(options) => {
|
||||
const { seamSize, seamBlur, seamStrength, seamSteps } = options;
|
||||
|
||||
return {
|
||||
seamSize,
|
||||
seamBlur,
|
||||
seamStrength,
|
||||
seamSteps,
|
||||
};
|
||||
},
|
||||
{
|
||||
memoizeOptions: {
|
||||
resultEqualityCheck: _.isEqual,
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const SeamCorrectionOptions = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { seamSize, seamBlur, seamStrength, seamSteps } =
|
||||
useAppSelector(selector);
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Flex direction="column" gap="1rem">
|
||||
<IAISlider
|
||||
sliderMarkRightOffset={-6}
|
||||
label={t('options:seamSize')}
|
||||
min={1}
|
||||
max={256}
|
||||
sliderNumberInputProps={{ max: 512 }}
|
||||
value={seamSize}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamSize(v));
|
||||
}}
|
||||
handleReset={() => dispatch(setSeamSize(96))}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
/>
|
||||
<IAISlider
|
||||
sliderMarkRightOffset={-4}
|
||||
label={t('options:seamBlur')}
|
||||
min={0}
|
||||
max={64}
|
||||
sliderNumberInputProps={{ max: 512 }}
|
||||
value={seamBlur}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamBlur(v));
|
||||
}}
|
||||
handleReset={() => {
|
||||
dispatch(setSeamBlur(16));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
/>
|
||||
<IAISlider
|
||||
sliderMarkRightOffset={-7}
|
||||
label={t('options:seamStrength')}
|
||||
min={0.01}
|
||||
max={0.99}
|
||||
step={0.01}
|
||||
value={seamStrength}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamStrength(v));
|
||||
}}
|
||||
handleReset={() => {
|
||||
dispatch(setSeamStrength(0.7));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
/>
|
||||
<IAISlider
|
||||
sliderMarkRightOffset={-4}
|
||||
label={t('options:seamSteps')}
|
||||
min={1}
|
||||
max={32}
|
||||
sliderNumberInputProps={{ max: 100 }}
|
||||
value={seamSteps}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamSteps(v));
|
||||
}}
|
||||
handleReset={() => {
|
||||
dispatch(setSeamSteps(10));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default SeamCorrectionOptions;
|
@ -0,0 +1,32 @@
|
||||
import type { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamBlur } from 'features/options/store/optionsSlice';
|
||||
import React from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function SeamBlur() {
|
||||
const dispatch = useAppDispatch();
|
||||
const seamBlur = useAppSelector((state: RootState) => state.options.seamBlur);
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
sliderMarkRightOffset={-4}
|
||||
label={t('options:seamBlur')}
|
||||
min={0}
|
||||
max={64}
|
||||
sliderNumberInputProps={{ max: 512 }}
|
||||
value={seamBlur}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamBlur(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamBlur(16));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
@ -0,0 +1,18 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import SeamBlur from './SeamBlur';
|
||||
import SeamSize from './SeamSize';
|
||||
import SeamSteps from './SeamSteps';
|
||||
import SeamStrength from './SeamStrength';
|
||||
|
||||
const SeamCorrectionOptions = () => {
|
||||
return (
|
||||
<Flex direction="column" gap="1rem">
|
||||
<SeamSize />
|
||||
<SeamBlur />
|
||||
<SeamStrength />
|
||||
<SeamSteps />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default SeamCorrectionOptions;
|
@ -0,0 +1,31 @@
|
||||
import type { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamSize } from 'features/options/store/optionsSlice';
|
||||
import React from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function SeamSize() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const seamSize = useAppSelector((state: RootState) => state.options.seamSize);
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
sliderMarkRightOffset={-6}
|
||||
label={t('options:seamSize')}
|
||||
min={1}
|
||||
max={256}
|
||||
sliderNumberInputProps={{ max: 512 }}
|
||||
value={seamSize}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamSize(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => dispatch(setSeamSize(96))}
|
||||
/>
|
||||
);
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
import type { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamSteps } from 'features/options/store/optionsSlice';
|
||||
import React from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function SeamSteps() {
|
||||
const { t } = useTranslation();
|
||||
const seamSteps = useAppSelector(
|
||||
(state: RootState) => state.options.seamSteps
|
||||
);
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
sliderMarkRightOffset={-4}
|
||||
label={t('options:seamSteps')}
|
||||
min={1}
|
||||
max={100}
|
||||
sliderNumberInputProps={{ max: 999 }}
|
||||
value={seamSteps}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamSteps(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamSteps(30));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setSeamStrength } from 'features/options/store/optionsSlice';
|
||||
import React from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function SeamStrength() {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const seamStrength = useAppSelector(
|
||||
(state: RootState) => state.options.seamStrength
|
||||
);
|
||||
|
||||
return (
|
||||
<IAISlider
|
||||
sliderMarkRightOffset={-7}
|
||||
label={t('options:seamStrength')}
|
||||
min={0.01}
|
||||
max={0.99}
|
||||
step={0.01}
|
||||
value={seamStrength}
|
||||
onChange={(v) => {
|
||||
dispatch(setSeamStrength(v));
|
||||
}}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => {
|
||||
dispatch(setSeamStrength(0.7));
|
||||
}}
|
||||
/>
|
||||
);
|
||||
}
|
@ -36,9 +36,9 @@ export default function ImageToImageStrength(props: ImageToImageStrengthProps) {
|
||||
isInteger={false}
|
||||
styleClass={styleClass}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
inputWidth={'5.5rem'}
|
||||
withReset
|
||||
handleReset={handleImg2ImgStrengthReset}
|
||||
/>
|
||||
);
|
||||
|
@ -81,7 +81,7 @@ const initialOptionsState: OptionsState = {
|
||||
seamBlur: 16,
|
||||
seamless: false,
|
||||
seamSize: 96,
|
||||
seamSteps: 10,
|
||||
seamSteps: 30,
|
||||
seamStrength: 0.7,
|
||||
seed: 0,
|
||||
seedWeights: '',
|
||||
|
@ -77,17 +77,10 @@ export default function AddDiffusersModel() {
|
||||
) => {
|
||||
const diffusersModelToAdd = values;
|
||||
|
||||
if (values.path === '') diffusersModelToAdd['path'] = undefined;
|
||||
if (values.repo_id === '') diffusersModelToAdd['repo_id'] = undefined;
|
||||
if (values.vae.path === '') {
|
||||
if (values.path === undefined) {
|
||||
diffusersModelToAdd['vae']['path'] = undefined;
|
||||
} else {
|
||||
diffusersModelToAdd['vae']['path'] = values.path + '/vae';
|
||||
}
|
||||
}
|
||||
if (values.vae.repo_id === '')
|
||||
diffusersModelToAdd['vae']['repo_id'] = undefined;
|
||||
if (values.path === '') delete diffusersModelToAdd.path;
|
||||
if (values.repo_id === '') delete diffusersModelToAdd.repo_id;
|
||||
if (values.vae.path === '') delete diffusersModelToAdd.vae.path;
|
||||
if (values.vae.repo_id === '') delete diffusersModelToAdd.vae.repo_id;
|
||||
|
||||
dispatch(addNewModel(diffusersModelToAdd));
|
||||
dispatch(setAddNewModelUIOption(null));
|
||||
|
@ -72,8 +72,16 @@ export default function DiffusersModelEdit() {
|
||||
setEditModelFormValues({
|
||||
name: openModel,
|
||||
description: retrievedModel[openModel]?.description,
|
||||
path: retrievedModel[openModel]?.path,
|
||||
repo_id: retrievedModel[openModel]?.repo_id,
|
||||
path:
|
||||
retrievedModel[openModel]?.path &&
|
||||
retrievedModel[openModel]?.path !== 'None'
|
||||
? retrievedModel[openModel]?.path
|
||||
: '',
|
||||
repo_id:
|
||||
retrievedModel[openModel]?.repo_id &&
|
||||
retrievedModel[openModel]?.repo_id !== 'None'
|
||||
? retrievedModel[openModel]?.repo_id
|
||||
: '',
|
||||
vae: {
|
||||
repo_id: retrievedModel[openModel]?.vae?.repo_id
|
||||
? retrievedModel[openModel]?.vae?.repo_id
|
||||
@ -91,6 +99,13 @@ export default function DiffusersModelEdit() {
|
||||
const editModelFormSubmitHandler = (
|
||||
values: InvokeDiffusersModelConfigProps
|
||||
) => {
|
||||
const diffusersModelToEdit = values;
|
||||
|
||||
if (values.path === '') delete diffusersModelToEdit.path;
|
||||
if (values.repo_id === '') delete diffusersModelToEdit.repo_id;
|
||||
if (values.vae.path === '') delete diffusersModelToEdit.vae.path;
|
||||
if (values.vae.repo_id === '') delete diffusersModelToEdit.vae.repo_id;
|
||||
|
||||
dispatch(addNewModel(values));
|
||||
};
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { useAppDispatch } from 'app/storeHooks';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { useSingleAndDoubleClick } from 'common/hooks/useSingleAndDoubleClick';
|
||||
import { resetCanvasView } from 'features/canvas/store/canvasSlice';
|
||||
import { getCanvasBaseLayer } from 'features/canvas/util/konvaInstanceProvider';
|
||||
import React from 'react';
|
||||
@ -24,7 +25,12 @@ export default function UnifiedCanvasResetView() {
|
||||
[canvasBaseLayer]
|
||||
);
|
||||
|
||||
const handleResetCanvasView = () => {
|
||||
const handleClickResetCanvasView = useSingleAndDoubleClick(
|
||||
() => handleResetCanvasView(false),
|
||||
() => handleResetCanvasView(true)
|
||||
);
|
||||
|
||||
const handleResetCanvasView = (shouldScaleTo1 = false) => {
|
||||
const canvasBaseLayer = getCanvasBaseLayer();
|
||||
if (!canvasBaseLayer) return;
|
||||
const clientRect = canvasBaseLayer.getClientRect({
|
||||
@ -33,6 +39,7 @@ export default function UnifiedCanvasResetView() {
|
||||
dispatch(
|
||||
resetCanvasView({
|
||||
contentRect: clientRect,
|
||||
shouldScaleTo1,
|
||||
})
|
||||
);
|
||||
};
|
||||
@ -41,7 +48,7 @@ export default function UnifiedCanvasResetView() {
|
||||
aria-label={`${t('unifiedcanvas:resetView')} (R)`}
|
||||
tooltip={`${t('unifiedcanvas:resetView')} (R)`}
|
||||
icon={<FaCrosshairs />}
|
||||
onClick={handleResetCanvasView}
|
||||
onClick={handleClickResetCanvasView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
// import { Feature } from 'app/features';
|
||||
import { Feature } from 'app/features';
|
||||
import ImageToImageStrength from 'features/options/components/AdvancedOptions/ImageToImage/ImageToImageStrength';
|
||||
import SeamCorrectionOptions from 'features/options/components/AdvancedOptions/Canvas/SeamCorrectionOptions';
|
||||
import SeamCorrectionOptions from 'features/options/components/AdvancedOptions/Canvas/SeamCorrectionOptions/SeamCorrectionOptions';
|
||||
import SeedOptions from 'features/options/components/AdvancedOptions/Seed/SeedOptions';
|
||||
import GenerateVariationsToggle from 'features/options/components/AdvancedOptions/Variations/GenerateVariations';
|
||||
import VariationsOptions from 'features/options/components/AdvancedOptions/Variations/VariationsOptions';
|
||||
|
@ -11,23 +11,26 @@ echo 1. command-line
|
||||
echo 2. browser-based UI
|
||||
echo 3. run textual inversion training
|
||||
echo 4. merge models (diffusers type only)
|
||||
echo 5. open the developer console
|
||||
echo 6. re-run the configure script to download new models
|
||||
echo 5. re-run the configure script to download new models
|
||||
echo 6. open the developer console
|
||||
set /P restore="Please enter 1, 2, 3, 4 or 5: [5] "
|
||||
if not defined restore set restore=2
|
||||
IF /I "%restore%" == "1" (
|
||||
echo Starting the InvokeAI command-line..
|
||||
python .venv\Scripts\invoke.py %*
|
||||
python .venv\Scripts\invoke %*
|
||||
) ELSE IF /I "%restore%" == "2" (
|
||||
echo Starting the InvokeAI browser-based UI..
|
||||
python .venv\Scripts\invoke.py --web %*
|
||||
python .venv\Scripts\invoke --web %*
|
||||
) ELSE IF /I "%restore%" == "3" (
|
||||
echo Starting textual inversion training..
|
||||
python .venv\Scripts\textual_inversion_fe.py --web %*
|
||||
python .venv\Scripts\textual_inversion --gui %*
|
||||
) ELSE IF /I "%restore%" == "4" (
|
||||
echo Starting model merging script..
|
||||
python .venv\Scripts\merge_models_fe.py --web %*
|
||||
python .venv\Scripts\merge_models --gui %*
|
||||
) ELSE IF /I "%restore%" == "5" (
|
||||
echo Running configure_invokeai.py...
|
||||
python .venv\Scripts\configure_invokeai %*
|
||||
) ELSE IF /I "%restore%" == "6" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
@ -39,9 +42,6 @@ IF /I "%restore%" == "1" (
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%restore%" == "6" (
|
||||
echo Running configure_invokeai.py...
|
||||
python .venv\Scripts\configure_invokeai.py --web %*
|
||||
) ELSE (
|
||||
echo Invalid selection
|
||||
pause
|
||||
|
@ -21,17 +21,17 @@ if [ "$0" != "bash" ]; then
|
||||
echo "2. browser-based UI"
|
||||
echo "3. run textual inversion training"
|
||||
echo "4. merge models (diffusers type only)"
|
||||
echo "5. re-run the configure script to download new models"
|
||||
echo "6. open the developer console"
|
||||
echo "5. open the developer console"
|
||||
echo "6. re-run the configure script to download new models"
|
||||
read -p "Please enter 1, 2, 3, 4 or 5: [1] " yn
|
||||
choice=${yn:='2'}
|
||||
case $choice in
|
||||
1 ) printf "\nStarting the InvokeAI command-line..\n"; .venv/bin/python .venv/bin/invoke.py $*;;
|
||||
2 ) printf "\nStarting the InvokeAI browser-based UI..\n"; .venv/bin/python .venv/bin/invoke.py --web $*;;
|
||||
3 ) printf "\nStarting Textual Inversion:\n"; .venv/bin/python .venv/bin/textual_inversion_fe.py $*;;
|
||||
4 ) printf "\nMerging Models:\n"; .venv/bin/python .venv/bin/merge_models_fe.py $*;;
|
||||
1 ) printf "\nStarting the InvokeAI command-line..\n"; invoke $*;;
|
||||
2 ) printf "\nStarting the InvokeAI browser-based UI..\n"; invoke --web $*;;
|
||||
3 ) printf "\nStarting Textual Inversion:\n"; textual_inversion --gui $*;;
|
||||
4 ) printf "\nMerging Models:\n"; merge_models --gui $*;;
|
||||
5 ) printf "\nDeveloper Console:\n"; file_name=$(basename "${BASH_SOURCE[0]}"); bash --init-file "$file_name";;
|
||||
6 ) printf "\nRunning configure_invokeai.py:\n"; .venv/bin/python .venv/bin/configure_invokeai.py $*;;
|
||||
6 ) printf "\nRunning configure_invokeai.py:\n"; configure_invokeai $*;;
|
||||
* ) echo "Invalid selection"; exit;;
|
||||
esac
|
||||
else # in developer console
|
||||
|
@ -708,7 +708,7 @@ def optimize_model(model_name_or_path:str, gen, opt, completer):
|
||||
if not ckpt_path.is_absolute():
|
||||
ckpt_path = Path(Globals.root,ckpt_path)
|
||||
|
||||
diffuser_path = Path(Globals.root, 'models','optimized-ckpts',model_name)
|
||||
diffuser_path = Path(Globals.root, 'models',Globals.converted_ckpts_dir,model_name)
|
||||
if diffuser_path.exists():
|
||||
print(f'** {model_name_or_path} is already optimized. Will not overwrite. If this is an error, please remove the directory {diffuser_path} and try again.')
|
||||
return
|
||||
|
@ -30,7 +30,7 @@ from huggingface_hub.utils._errors import RevisionNotFoundError
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, CLIPSegForImageSegmentation
|
||||
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from ldm.invoke.readline import generic_completer
|
||||
@ -601,31 +601,10 @@ def download_codeformer():
|
||||
#---------------------------------------------
|
||||
def download_clipseg():
|
||||
print('Installing clipseg model for text-based masking...',end='', file=sys.stderr)
|
||||
import zipfile
|
||||
CLIPSEG_MODEL = 'CIDAS/clipseg-rd64-refined'
|
||||
try:
|
||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||
model_dest = os.path.join(Globals.root,'models/clipseg/clipseg_weights')
|
||||
weights_zip = 'models/clipseg/weights.zip'
|
||||
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'):
|
||||
dest = os.path.join(Globals.root,weights_zip)
|
||||
request.urlretrieve(model_url,dest)
|
||||
with zipfile.ZipFile(dest,'r') as zip:
|
||||
zip.extractall(os.path.join(Globals.root,'models/clipseg'))
|
||||
os.remove(dest)
|
||||
|
||||
from clipseg.clipseg import CLIPDensePredT
|
||||
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
||||
model.eval()
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(Globals.root,'models/clipseg/clipseg_weights/rd64-uni-refined.pth'),
|
||||
map_location=torch.device('cpu')
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
download_from_hf(AutoProcessor,CLIPSEG_MODEL)
|
||||
download_from_hf(CLIPSegForImageSegmentation,CLIPSEG_MODEL)
|
||||
except Exception:
|
||||
print('Error installing clipseg model:')
|
||||
print(traceback.format_exc())
|
||||
|
@ -38,10 +38,6 @@ class Txt2Img2Img(Generator):
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
scale_dim = min(width, height)
|
||||
scale = 512 / scale_dim
|
||||
|
||||
init_width, init_height = trim_to_multiple_of(scale * width, scale * height)
|
||||
|
||||
def make_image(x_T):
|
||||
|
||||
@ -54,6 +50,10 @@ class Txt2Img2Img(Generator):
|
||||
# TODO: threshold = threshold,
|
||||
)
|
||||
|
||||
# Get our initial generation width and height directly from the latent output so
|
||||
# the message below is accurate.
|
||||
init_width = first_pass_latent_output.size()[3] * self.downsampling_factor
|
||||
init_height = first_pass_latent_output.size()[2] * self.downsampling_factor
|
||||
print(
|
||||
f"\n>> Interpolating from {init_width}x{init_height} to {width}x{height} using DDIM sampling"
|
||||
)
|
||||
@ -106,11 +106,24 @@ class Txt2Img2Img(Generator):
|
||||
def get_noise(self,width,height,scale = True):
|
||||
# print(f"Get noise: {width}x{height}")
|
||||
if scale:
|
||||
trained_square = 512 * 512
|
||||
actual_square = width * height
|
||||
scale = math.sqrt(trained_square / actual_square)
|
||||
scaled_width = math.ceil(scale * width / 64) * 64
|
||||
scaled_height = math.ceil(scale * height / 64) * 64
|
||||
# Scale the input width and height for the initial generation
|
||||
# Make their area equivalent to the model's resolution area (e.g. 512*512 = 262144),
|
||||
# while keeping the minimum dimension at least 0.5 * resolution (e.g. 512*0.5 = 256)
|
||||
|
||||
aspect = width / height
|
||||
dimension = self.model.unet.config.sample_size * self.model.vae_scale_factor
|
||||
min_dimension = math.floor(dimension * 0.5)
|
||||
model_area = dimension * dimension # hardcoded for now since all models are trained on square images
|
||||
|
||||
if aspect > 1.0:
|
||||
init_height = max(min_dimension, math.sqrt(model_area / aspect))
|
||||
init_width = init_height * aspect
|
||||
else:
|
||||
init_width = max(min_dimension, math.sqrt(model_area * aspect))
|
||||
init_height = init_width / aspect
|
||||
|
||||
scaled_width, scaled_height = trim_to_multiple_of(math.floor(init_width), math.floor(init_height))
|
||||
|
||||
else:
|
||||
scaled_width = width
|
||||
scaled_height = height
|
||||
|
@ -33,6 +33,7 @@ Globals.models_file = 'models.yaml'
|
||||
Globals.models_dir = 'models'
|
||||
Globals.config_dir = 'configs'
|
||||
Globals.autoscan_dir = 'weights'
|
||||
Globals.converted_ckpts_dir = 'converted-ckpts'
|
||||
|
||||
# Try loading patchmatch
|
||||
Globals.try_patchmatch = True
|
||||
@ -71,7 +72,14 @@ def global_cache_dir(subdir:Union[str,Path]='')->Path:
|
||||
is provided, it will be appended to the end of the path, allowing
|
||||
for huggingface-style conventions:
|
||||
global_cache_dir('diffusers')
|
||||
global_cache_dir('hub')
|
||||
Current HuggingFace documentation (mid-Jan 2023) indicates that
|
||||
transformers models will be cached into a "transformers" subdirectory,
|
||||
but in practice they seem to go into "hub". But if needed:
|
||||
global_cache_dir('transformers')
|
||||
One other caveat is that HuggingFace is moving some diffusers models
|
||||
into the "hub" subdirectory as well, so this will need to be revisited
|
||||
from time to time.
|
||||
'''
|
||||
home: str = os.getenv('HF_HOME')
|
||||
|
||||
|
@ -1,21 +1,74 @@
|
||||
'''
|
||||
"""
|
||||
ldm.invoke.merge_diffusers exports a single function call merge_diffusion_models()
|
||||
used to merge 2-3 models together and create a new InvokeAI-registered diffusion model.
|
||||
'''
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
from typing import List
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import npyscreen
|
||||
from diffusers import DiffusionPipeline
|
||||
from ldm.invoke.globals import global_config_file, global_models_dir, global_cache_dir
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
def merge_diffusion_models(models:List['str'],
|
||||
merged_model_name:str,
|
||||
alpha:float=0.5,
|
||||
interp:str=None,
|
||||
force:bool=False,
|
||||
**kwargs):
|
||||
'''
|
||||
from ldm.invoke.globals import (
|
||||
Globals,
|
||||
global_cache_dir,
|
||||
global_config_file,
|
||||
global_models_dir,
|
||||
global_set_root,
|
||||
)
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
|
||||
DEST_MERGED_MODEL_DIR = "merged_models"
|
||||
|
||||
|
||||
def merge_diffusion_models(
|
||||
model_ids_or_paths: List[Union[str, Path]],
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
) -> DiffusionPipeline:
|
||||
"""
|
||||
model_ids_or_paths - up to three models, designated by their local paths or HuggingFace repo_ids
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
interp - The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None.
|
||||
Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported.
|
||||
force - Whether to ignore mismatch in model_config.json for the current models. Defaults to False.
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
"""
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get("cache_dir", global_cache_dir()),
|
||||
custom_pipeline="checkpoint_merger",
|
||||
)
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs,
|
||||
)
|
||||
return merged_pipe
|
||||
|
||||
|
||||
def merge_diffusion_models_and_commit(
|
||||
models: List["str"],
|
||||
merged_model_name: str,
|
||||
alpha: float = 0.5,
|
||||
interp: str = None,
|
||||
force: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
models - up to three models, designated by their InvokeAI models.yaml model name
|
||||
merged_model_name = name for new model
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
@ -26,37 +79,303 @@ def merge_diffusion_models(models:List['str'],
|
||||
|
||||
**kwargs - the default DiffusionPipeline.get_config_dict kwargs:
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map
|
||||
'''
|
||||
"""
|
||||
config_file = global_config_file()
|
||||
model_manager = ModelManager(OmegaConf.load(config_file))
|
||||
for mod in models:
|
||||
assert (mod in model_manager.model_names()), f'** Unknown model "{mod}"'
|
||||
assert (model_manager.model_info(mod).get('format',None) == 'diffusers'), f'** {mod} is not a diffusers model. It must be optimized before merging.'
|
||||
assert mod in model_manager.model_names(), f'** Unknown model "{mod}"'
|
||||
assert (
|
||||
model_manager.model_info(mod).get("format", None) == "diffusers"
|
||||
), f"** {mod} is not a diffusers model. It must be optimized before merging."
|
||||
model_ids_or_paths = [model_manager.model_name_or_path(x) for x in models]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(model_ids_or_paths[0],
|
||||
cache_dir=kwargs.get('cache_dir',global_cache_dir()),
|
||||
custom_pipeline='checkpoint_merger')
|
||||
merged_pipe = pipe.merge(pretrained_model_name_or_path_list=model_ids_or_paths,
|
||||
alpha=alpha,
|
||||
interp=interp,
|
||||
force=force,
|
||||
**kwargs)
|
||||
dump_path = global_models_dir() / 'merged_diffusers'
|
||||
os.makedirs(dump_path,exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained (
|
||||
dump_path,
|
||||
safe_serialization=1
|
||||
merged_pipe = merge_diffusion_models(
|
||||
model_ids_or_paths, alpha, interp, force, **kwargs
|
||||
)
|
||||
model_manager.import_diffuser_model(
|
||||
dump_path,
|
||||
model_name = merged_model_name,
|
||||
description = f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
print('REMINDER: When PR 2369 is merged, replace merge_diffusers.py line 56 with vae= argument to impormodel()')
|
||||
if vae := model_manager.config[models[0]].get('vae',None):
|
||||
print(f'>> Using configured VAE assigned to {models[0]}')
|
||||
model_manager.config[merged_model_name]['vae'] = vae
|
||||
dump_path = global_models_dir() / DEST_MERGED_MODEL_DIR
|
||||
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
dump_path = dump_path / merged_model_name
|
||||
merged_pipe.save_pretrained(dump_path, safe_serialization=1)
|
||||
import_args = dict(
|
||||
model_name=merged_model_name, description=f'Merge of models {", ".join(models)}'
|
||||
)
|
||||
if vae := model_manager.config[models[0]].get("vae", None):
|
||||
print(f">> Using configured VAE assigned to {models[0]}")
|
||||
import_args.update(vae=vae)
|
||||
model_manager.import_diffuser_model(dump_path, **import_args)
|
||||
model_manager.commit(config_file)
|
||||
|
||||
|
||||
def _parse_args() -> Namespace:
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model merging")
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--front_end",
|
||||
"--gui",
|
||||
dest="front_end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
dest="merged_model_name",
|
||||
type=str,
|
||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation",
|
||||
dest="interp",
|
||||
type=str,
|
||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||
default="weighted_sum",
|
||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Try to merge models even if they are incompatible with each other",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clobber",
|
||||
"--overwrite",
|
||||
dest="clobber",
|
||||
action="store_true",
|
||||
help="Overwrite the merged model if --merged_model_name already exists",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ------------------------- GUI HERE -------------------------
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" % (self.value, self.out_of)
|
||||
l = (len(str(self.out_of))) * 2 + 4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"]
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def model_manager(self):
|
||||
return self.parentApp.model_manager
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names = self.get_model_names()
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText, name="Select up to three models to merge", value=""
|
||||
)
|
||||
self.models = self.add_widget_intelligent(
|
||||
npyscreen.TitleMultiSelect,
|
||||
name="Select two to three models to merge:",
|
||||
values=self.model_names,
|
||||
value=None,
|
||||
max_height=len(self.model_names) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.models.when_value_edited = self.models_changed
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Name for merged model:",
|
||||
value="",
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Force merge of incompatible models",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Merge Method:",
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
max_height=len(self.interpolations) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name="Weight (alpha) to assign to second and third models:",
|
||||
out_of=1,
|
||||
step=0.05,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.models.editing = True
|
||||
|
||||
def models_changed(self):
|
||||
model_names = self.models.values
|
||||
selected_models = self.models.value
|
||||
if len(selected_models) > 3:
|
||||
npyscreen.notify_confirm(
|
||||
"Too many models selected for merging. Select two to three."
|
||||
)
|
||||
return
|
||||
elif len(selected_models) > 2:
|
||||
self.merge_method.values = ["add_difference"]
|
||||
self.merge_method.value = 0
|
||||
else:
|
||||
self.merge_method.values = self.interpolations
|
||||
self.merged_model_name.value = "+".join(
|
||||
[model_names[x] for x in selected_models]
|
||||
)
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values() and self.check_for_overwrite():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||
npyscreen.notify("Starting the merge...")
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def on_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
models = [self.models.values[x] for x in self.models.value]
|
||||
args = dict(
|
||||
models=models,
|
||||
alpha=self.alpha.value,
|
||||
interp=self.interpolations[self.merge_method.value[0]],
|
||||
force=self.force.value,
|
||||
merged_model_name=self.merged_model_name.value,
|
||||
)
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self) -> bool:
|
||||
model_out = self.merged_model_name.value
|
||||
if model_out not in self.model_names:
|
||||
return True
|
||||
else:
|
||||
return npyscreen.notify_yes_no(
|
||||
f"The chosen merged model destination, {model_out}, is already in use. Overwrite?"
|
||||
)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
selected_models = self.models.value
|
||||
if len(selected_models) < 2 or len(selected_models) > 3:
|
||||
bad_fields.append("Please select two or three models to merge.")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> List[str]:
|
||||
model_names = [
|
||||
name
|
||||
for name in self.model_manager.model_names()
|
||||
if self.model_manager.model_info(name).get("format") == "diffusers"
|
||||
]
|
||||
print(model_names)
|
||||
return sorted(model_names)
|
||||
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(global_config_file())
|
||||
self.model_manager = ModelManager(
|
||||
conf, "cpu", "float16"
|
||||
) # precision doesn't really matter here
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm("MAIN", mergeModelsForm, name="Merge Models Settings")
|
||||
|
||||
|
||||
def run_gui(args: Namespace):
|
||||
mergeapp = Mergeapp()
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
merge_diffusion_models_and_commit(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
|
||||
|
||||
def run_cli(args: Namespace):
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert (
|
||||
len(args.models) >= 1 and len(args.models) <= 3
|
||||
), "provide 2 or 3 models to merge"
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
print(
|
||||
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
assert (
|
||||
args.clobber or args.merged_model_name not in model_manager.model_names()
|
||||
), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
merge_diffusion_models_and_commit(**vars(args))
|
||||
|
||||
|
||||
def main():
|
||||
args = _parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
cache_dir = str(global_cache_dir("diffusers"))
|
||||
os.environ[
|
||||
"HF_HOME"
|
||||
] = cache_dir # because not clear the merge pipeline is honoring cache_dir
|
||||
args.cache_dir = cache_dir
|
||||
|
||||
try:
|
||||
if args.front_end:
|
||||
run_gui(args)
|
||||
else:
|
||||
run_cli(args)
|
||||
print(f">> Conversion successful. New model is named {args.merged_model_name}")
|
||||
except Exception as e:
|
||||
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(-1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -14,7 +14,6 @@ import os
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
import safetensors.torch
|
||||
from pathlib import Path
|
||||
@ -639,7 +638,7 @@ class ModelManager(object):
|
||||
and import.
|
||||
'''
|
||||
weights_directory = weights_directory or global_autoscan_dir()
|
||||
dest_directory = dest_directory or Path(global_models_dir(), 'optimized-ckpts')
|
||||
dest_directory = dest_directory or Path(global_models_dir(), Globals.converted_ckpts_dir)
|
||||
|
||||
print('>> Checking for unconverted .ckpt files in {weights_directory}')
|
||||
ckpt_files = dict()
|
||||
|
447
ldm/invoke/textual_inversion.py
Executable file
447
ldm/invoke/textual_inversion.py
Executable file
@ -0,0 +1,447 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
"""
|
||||
This is the frontend to "textual_inversion_training.py".
|
||||
|
||||
Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import npyscreen
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.invoke.globals import Globals, global_set_root
|
||||
from ldm.invoke.textual_inversion_training import (
|
||||
do_textual_inversion_training,
|
||||
parse_args,
|
||||
)
|
||||
|
||||
TRAINING_DATA = "text-inversion-training-data"
|
||||
TRAINING_DIR = "text-inversion-output"
|
||||
CONF_FILE = "preferences.conf"
|
||||
|
||||
|
||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
resolutions = [512, 768, 1024]
|
||||
lr_schedulers = [
|
||||
"linear",
|
||||
"cosine",
|
||||
"cosine_with_restarts",
|
||||
"polynomial",
|
||||
"constant",
|
||||
"constant_with_warmup",
|
||||
]
|
||||
precisions = ["no", "fp16", "bf16"]
|
||||
learnable_properties = ["object", "style"]
|
||||
|
||||
def __init__(self, parentApp, name, saved_args=None):
|
||||
self.saved_args = saved_args or {}
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names, default = self.get_model_names()
|
||||
default_initializer_token = "★"
|
||||
default_placeholder_token = ""
|
||||
saved_args = self.saved_args
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args["model"])
|
||||
except:
|
||||
pass
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value="Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.",
|
||||
editable=False,
|
||||
)
|
||||
|
||||
self.model = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Model Name:",
|
||||
values=self.model_names,
|
||||
value=default,
|
||||
max_height=len(self.model_names) + 1,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Trigger Term:",
|
||||
value="", # saved_args.get('placeholder_token',''), # to restore previous term
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||
self.nextrely -= 1
|
||||
self.nextrelx += 30
|
||||
self.prompt_token = self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Trigger term for use in prompt",
|
||||
value="",
|
||||
editable=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.nextrelx -= 30
|
||||
self.initializer_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Initializer:",
|
||||
value=saved_args.get("initializer_token", default_initializer_token),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Resume from last saved checkpoint",
|
||||
value=False,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learnable_property = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learnable property:",
|
||||
values=self.learnable_properties,
|
||||
value=self.learnable_properties.index(
|
||||
saved_args.get("learnable_property", "object")
|
||||
),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_data_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Data Training Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"train_data_dir",
|
||||
Path(Globals.root) / TRAINING_DATA / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.output_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name="Output Destination Directory:",
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(
|
||||
saved_args.get(
|
||||
"output_dir",
|
||||
Path(Globals.root) / TRAINING_DIR / default_placeholder_token,
|
||||
)
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.resolution = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Image resolution (pixels):",
|
||||
values=self.resolutions,
|
||||
value=self.resolutions.index(saved_args.get("resolution", 512)),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.center_crop = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Center crop images before resizing to resolution",
|
||||
value=saved_args.get("center_crop", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.mixed_precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Mixed Precision:",
|
||||
values=self.precisions,
|
||||
value=self.precisions.index(saved_args.get("mixed_precision", "fp16")),
|
||||
max_height=4,
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.num_train_epochs = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Number of training epochs:",
|
||||
out_of=1000,
|
||||
step=50,
|
||||
lowest=1,
|
||||
value=saved_args.get("num_train_epochs", 100),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.max_train_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Max Training Steps:",
|
||||
out_of=10000,
|
||||
step=500,
|
||||
lowest=1,
|
||||
value=saved_args.get("max_train_steps", 3000),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.train_batch_size = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Batch Size (reduce if you run out of memory):",
|
||||
out_of=50,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("train_batch_size", 8),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):",
|
||||
out_of=10,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get("gradient_accumulation_steps", 4),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name="Warmup Steps:",
|
||||
out_of=100,
|
||||
step=1,
|
||||
lowest=0,
|
||||
value=saved_args.get("lr_warmup_steps", 0),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.learning_rate = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Learning Rate:",
|
||||
value=str(
|
||||
saved_args.get("learning_rate", "5.0e-04"),
|
||||
),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.scale_lr = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scale learning rate by number GPUs, steps and batch size",
|
||||
value=saved_args.get("scale_lr", True),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Use xformers acceleration",
|
||||
value=saved_args.get("enable_xformers_memory_efficient_attention", False),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.lr_scheduler = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learning rate scheduler:",
|
||||
values=self.lr_schedulers,
|
||||
max_height=7,
|
||||
value=self.lr_schedulers.index(saved_args.get("lr_scheduler", "constant")),
|
||||
scroll_exit=True,
|
||||
)
|
||||
self.model.editing = True
|
||||
|
||||
def initializer_changed(self):
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f"(Trigger by using <{placeholder}> in your prompts)"
|
||||
self.train_data_dir.value = str(
|
||||
Path(Globals.root) / TRAINING_DATA / placeholder
|
||||
)
|
||||
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
|
||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||
npyscreen.notify(
|
||||
"Launching textual inversion training. This will take a while..."
|
||||
)
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def ok_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def validate_field_values(self) -> bool:
|
||||
bad_fields = []
|
||||
if self.model.value is None:
|
||||
bad_fields.append(
|
||||
"Model Name must correspond to a known model in models.yaml"
|
||||
)
|
||||
if not re.match("^[a-zA-Z0-9.-]+$", self.placeholder_token.value):
|
||||
bad_fields.append(
|
||||
"Trigger term must only contain alphanumeric characters, the dot and hyphen"
|
||||
)
|
||||
if self.train_data_dir.value is None:
|
||||
bad_fields.append("Data Training Directory cannot be empty")
|
||||
if self.output_dir.value is None:
|
||||
bad_fields.append("The Output Destination Directory cannot be empty")
|
||||
if len(bad_fields) > 0:
|
||||
message = "The following problems were detected and must be corrected:"
|
||||
for problem in bad_fields:
|
||||
message += f"\n* {problem}"
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self) -> (List[str], int):
|
||||
conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
||||
model_names = [
|
||||
idx
|
||||
for idx in sorted(list(conf.keys()))
|
||||
if conf[idx].get("format", None) == "diffusers"
|
||||
]
|
||||
defaults = [
|
||||
idx
|
||||
for idx in range(len(model_names))
|
||||
if "default" in conf[model_names[idx]]
|
||||
]
|
||||
return (model_names, defaults[0])
|
||||
|
||||
def marshall_arguments(self) -> dict:
|
||||
args = dict()
|
||||
|
||||
# the choices
|
||||
args.update(
|
||||
model=self.model_names[self.model.value[0]],
|
||||
resolution=self.resolutions[self.resolution.value[0]],
|
||||
lr_scheduler=self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||
mixed_precision=self.precisions[self.mixed_precision.value[0]],
|
||||
learnable_property=self.learnable_properties[
|
||||
self.learnable_property.value[0]
|
||||
],
|
||||
)
|
||||
|
||||
# all the strings and booleans
|
||||
for attr in (
|
||||
"initializer_token",
|
||||
"placeholder_token",
|
||||
"train_data_dir",
|
||||
"output_dir",
|
||||
"scale_lr",
|
||||
"center_crop",
|
||||
"enable_xformers_memory_efficient_attention",
|
||||
):
|
||||
args[attr] = getattr(self, attr).value
|
||||
|
||||
# all the integers
|
||||
for attr in (
|
||||
"train_batch_size",
|
||||
"gradient_accumulation_steps",
|
||||
"num_train_epochs",
|
||||
"max_train_steps",
|
||||
"lr_warmup_steps",
|
||||
):
|
||||
args[attr] = int(getattr(self, attr).value)
|
||||
|
||||
# the floats (just one)
|
||||
args.update(learning_rate=float(self.learning_rate.value))
|
||||
|
||||
# a special case
|
||||
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||
args["resume_from_checkpoint"] = "latest"
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, saved_args=None):
|
||||
super().__init__()
|
||||
self.ti_arguments = None
|
||||
self.saved_args = saved_args
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm(
|
||||
"MAIN",
|
||||
textualInversionForm,
|
||||
name="Textual Inversion Settings",
|
||||
saved_args=self.saved_args,
|
||||
)
|
||||
|
||||
|
||||
def copy_to_embeddings_folder(args: dict):
|
||||
"""
|
||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||
delete the full model and checkpoints.
|
||||
"""
|
||||
source = Path(args["output_dir"], "learned_embeds.bin")
|
||||
dest_dir_name = args["placeholder_token"].strip("<>")
|
||||
destination = Path(Globals.root, "embeddings", dest_dir_name)
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
print(f">> Training completed. Copying learned_embeds.bin into {str(destination)}")
|
||||
shutil.copy(source, destination)
|
||||
if (
|
||||
input("Delete training logs and intermediate checkpoints? [y] ") or "y"
|
||||
).startswith(("y", "Y")):
|
||||
shutil.rmtree(Path(args["output_dir"]))
|
||||
else:
|
||||
print(f'>> Keeping {args["output_dir"]}')
|
||||
|
||||
|
||||
def save_args(args: dict):
|
||||
"""
|
||||
Save the current argument values to an omegaconf file
|
||||
"""
|
||||
dest_dir = Path(Globals.root) / TRAINING_DIR
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
conf_file = dest_dir / CONF_FILE
|
||||
conf = OmegaConf.create(args)
|
||||
OmegaConf.save(config=conf, f=conf_file)
|
||||
|
||||
|
||||
def previous_args() -> dict:
|
||||
"""
|
||||
Get the previous arguments used.
|
||||
"""
|
||||
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf["placeholder_token"] = conf["placeholder_token"].strip("<>")
|
||||
except:
|
||||
conf = None
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
def do_front_end(args: Namespace):
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
|
||||
if args := myapplication.ti_arguments:
|
||||
os.makedirs(args["output_dir"], exist_ok=True)
|
||||
|
||||
# Automatically add angle brackets around the trigger
|
||||
if not re.match("^<.+>$", args["placeholder_token"]):
|
||||
args["placeholder_token"] = f"<{args['placeholder_token']}>"
|
||||
|
||||
args["only_save_embeds"] = True
|
||||
save_args(args)
|
||||
|
||||
try:
|
||||
print(f"DEBUG: args = {args}")
|
||||
do_textual_inversion_training(**args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
print("** An exception occurred during training. The exception was:")
|
||||
print(str(e))
|
||||
print("** DETAILS:")
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
global_set_root(args.root_dir or Globals.root)
|
||||
try:
|
||||
if args.front_end:
|
||||
do_front_end(args)
|
||||
else:
|
||||
do_textual_inversion_training(**vars(args))
|
||||
except AssertionError as e:
|
||||
print(str(e))
|
||||
sys.exit(-1)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -3,6 +3,10 @@
|
||||
# on January 2, 2023
|
||||
# and modified slightly by Lincoln Stein (@lstein) to work with InvokeAI
|
||||
|
||||
"""
|
||||
This is the backend to "textual_inversion.py"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
@ -11,36 +15,41 @@ import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import PIL
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
# invokeai stuff
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
# invokeai stuff
|
||||
from ldm.invoke.args import ArgFormatter, PagingArgumentParser
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
@ -67,152 +76,46 @@ check_min_version("0.10.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path):
|
||||
def save_progress(
|
||||
text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path
|
||||
):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||
learned_embeds = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[placeholder_token_id]
|
||||
)
|
||||
learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
parser = PagingArgumentParser(
|
||||
description="Textual inversion training", formatter_class=ArgFormatter
|
||||
)
|
||||
parser.add_argument(
|
||||
'--root_dir','--root',
|
||||
general_group = parser.add_argument_group("General")
|
||||
model_group = parser.add_argument_group("Models and Paths")
|
||||
image_group = parser.add_argument_group("Training Image Location and Options")
|
||||
trigger_group = parser.add_argument_group("Trigger Token")
|
||||
training_group = parser.add_argument_group("Training Parameters")
|
||||
checkpointing_group = parser.add_argument_group("Checkpointing and Resume")
|
||||
integration_group = parser.add_argument_group("Integration")
|
||||
general_group.add_argument(
|
||||
"--front_end",
|
||||
"--gui",
|
||||
dest="front_end",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activate the text-based graphical front end for collecting parameters. Aside from --root_dir, other parameters will be ignored.",
|
||||
)
|
||||
general_group.add_argument(
|
||||
"--root_dir",
|
||||
"--root",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only_save_embeds",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Save only the embeddings for the new concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--placeholder_token",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A token to use as a placeholder for the concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="A token to use as initializer word."
|
||||
)
|
||||
parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
|
||||
parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=f'{Globals.root}/text-inversion-model',
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
general_group.add_argument(
|
||||
"--logging_dir",
|
||||
type=Path,
|
||||
default="logs",
|
||||
@ -221,7 +124,179 @@ def parse_args():
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
general_group.add_argument(
|
||||
"--output_dir",
|
||||
type=Path,
|
||||
default=f"{Globals.root}/text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="stable-diffusion-1.5",
|
||||
help="Name of the diffusers model to train against, as defined in configs/models.yaml.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
|
||||
model_group.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--train_data_dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="A folder containing the training data.",
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
image_group.add_argument(
|
||||
"--center_crop",
|
||||
action="store_true",
|
||||
help="Whether to center crop images before resizing to resolution",
|
||||
)
|
||||
trigger_group.add_argument(
|
||||
"--placeholder_token",
|
||||
"--trigger_term",
|
||||
dest="placeholder_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help='A token to use as a placeholder for the concept. This token will trigger the concept when included in the prompt as "<trigger>".',
|
||||
)
|
||||
trigger_group.add_argument(
|
||||
"--learnable_property",
|
||||
type=str,
|
||||
choices=["object", "style"],
|
||||
default="object",
|
||||
help="Choose between 'object' and 'style'",
|
||||
)
|
||||
trigger_group.add_argument(
|
||||
"--initializer_token",
|
||||
type=str,
|
||||
default="*",
|
||||
help="A symbol to use as the initializer word.",
|
||||
)
|
||||
checkpointing_group.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
checkpointing_group.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=Path,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
checkpointing_group.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--repeats",
|
||||
type=int,
|
||||
default=100,
|
||||
help="How many times to repeat the training data.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--seed", type=int, default=None, help="A seed for reproducible training."
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--train_batch_size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
training_group.add_argument("--num_train_epochs", type=int, default=100)
|
||||
training_group.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--lr_warmup_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps for the warmup in the lr scheduler.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_beta1",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="The beta1 parameter for the Adam optimizer.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_beta2",
|
||||
type=float,
|
||||
default=0.999,
|
||||
help="The beta2 parameter for the Adam optimizer.",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
default=1e-08,
|
||||
help="Epsilon value for the Adam optimizer",
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
@ -232,7 +307,7 @@ def parse_args():
|
||||
"and an Nvidia Ampere GPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
training_group.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
@ -240,7 +315,31 @@ def parse_args():
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
training_group.add_argument(
|
||||
"--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="For distributed training: local_rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention",
|
||||
action="store_true",
|
||||
help="Whether or not to use xformers.",
|
||||
)
|
||||
|
||||
integration_group.add_argument(
|
||||
"--only_save_embeds",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Save only the embeddings for the new concept.",
|
||||
)
|
||||
integration_group.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
integration_group.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
@ -249,29 +348,17 @@ def parse_args():
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
integration_group.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the model to the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=Path,
|
||||
integration_group.add_argument(
|
||||
"--hub_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
help="The token to use to push to the Model Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -351,7 +438,10 @@ class TextualInversionDataset(Dataset):
|
||||
self.center_crop = center_crop
|
||||
self.flip_p = flip_p
|
||||
|
||||
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
self._length = self.num_images
|
||||
@ -366,7 +456,11 @@ class TextualInversionDataset(Dataset):
|
||||
"lanczos": PIL_INTERPOLATION["lanczos"],
|
||||
}[interpolation]
|
||||
|
||||
self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
|
||||
self.templates = (
|
||||
imagenet_style_templates_small
|
||||
if learnable_property == "style"
|
||||
else imagenet_templates_small
|
||||
)
|
||||
self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
|
||||
|
||||
def __len__(self):
|
||||
@ -399,7 +493,9 @@ class TextualInversionDataset(Dataset):
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
|
||||
img = img[
|
||||
(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
|
||||
]
|
||||
|
||||
image = Image.fromarray(img)
|
||||
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||
@ -412,7 +508,9 @@ class TextualInversionDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
def get_full_repo_name(
|
||||
model_id: str, organization: Optional[str] = None, token: Optional[str] = None
|
||||
):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
@ -423,54 +521,60 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||
|
||||
|
||||
def do_textual_inversion_training(
|
||||
model:str,
|
||||
train_data_dir:Path,
|
||||
output_dir:Path,
|
||||
placeholder_token:str,
|
||||
initializer_token:str,
|
||||
save_steps:int=500,
|
||||
only_save_embeds:bool=False,
|
||||
revision:str=None,
|
||||
tokenizer_name:str=None,
|
||||
learnable_property:str='object',
|
||||
repeats:int=100,
|
||||
seed:int=None,
|
||||
resolution:int=512,
|
||||
center_crop:bool=False,
|
||||
train_batch_size:int=16,
|
||||
num_train_epochs:int=100,
|
||||
max_train_steps:int=5000,
|
||||
gradient_accumulation_steps:int=1,
|
||||
gradient_checkpointing:bool=False,
|
||||
learning_rate:float=1e-4,
|
||||
scale_lr:bool=True,
|
||||
lr_scheduler:str='constant',
|
||||
lr_warmup_steps:int=500,
|
||||
adam_beta1:float=0.9,
|
||||
adam_beta2:float=0.999,
|
||||
adam_weight_decay:float=1e-02,
|
||||
adam_epsilon:float=1e-08,
|
||||
push_to_hub:bool=False,
|
||||
hub_token:str=None,
|
||||
logging_dir:Path=Path('logs'),
|
||||
mixed_precision:str='fp16',
|
||||
allow_tf32:bool=False,
|
||||
report_to:str='tensorboard',
|
||||
local_rank:int=-1,
|
||||
checkpointing_steps:int=500,
|
||||
resume_from_checkpoint:Path=None,
|
||||
enable_xformers_memory_efficient_attention:bool=False,
|
||||
root_dir:Path=None,
|
||||
hub_model_id:str=None,
|
||||
model: str,
|
||||
train_data_dir: Path,
|
||||
output_dir: Path,
|
||||
placeholder_token: str,
|
||||
initializer_token: str,
|
||||
save_steps: int = 500,
|
||||
only_save_embeds: bool = False,
|
||||
revision: str = None,
|
||||
tokenizer_name: str = None,
|
||||
learnable_property: str = "object",
|
||||
repeats: int = 100,
|
||||
seed: int = None,
|
||||
resolution: int = 512,
|
||||
center_crop: bool = False,
|
||||
train_batch_size: int = 16,
|
||||
num_train_epochs: int = 100,
|
||||
max_train_steps: int = 5000,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
gradient_checkpointing: bool = False,
|
||||
learning_rate: float = 1e-4,
|
||||
scale_lr: bool = True,
|
||||
lr_scheduler: str = "constant",
|
||||
lr_warmup_steps: int = 500,
|
||||
adam_beta1: float = 0.9,
|
||||
adam_beta2: float = 0.999,
|
||||
adam_weight_decay: float = 1e-02,
|
||||
adam_epsilon: float = 1e-08,
|
||||
push_to_hub: bool = False,
|
||||
hub_token: str = None,
|
||||
logging_dir: Path = Path("logs"),
|
||||
mixed_precision: str = "fp16",
|
||||
allow_tf32: bool = False,
|
||||
report_to: str = "tensorboard",
|
||||
local_rank: int = -1,
|
||||
checkpointing_steps: int = 500,
|
||||
resume_from_checkpoint: Path = None,
|
||||
enable_xformers_memory_efficient_attention: bool = False,
|
||||
root_dir: Path = None,
|
||||
hub_model_id: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
assert model, "Please specify a base model with --model"
|
||||
assert (
|
||||
train_data_dir
|
||||
), "Please specify a directory containing the training images using --train_data_dir"
|
||||
assert placeholder_token, "Please specify a trigger term using --placeholder_token"
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != local_rank:
|
||||
local_rank = env_local_rank
|
||||
|
||||
# setting up things the way invokeai expects them
|
||||
if not os.path.isabs(output_dir):
|
||||
output_dir = os.path.join(Globals.root,output_dir)
|
||||
|
||||
output_dir = os.path.join(Globals.root, output_dir)
|
||||
|
||||
logging_dir = output_dir / logging_dir
|
||||
|
||||
accelerator = Accelerator(
|
||||
@ -517,28 +621,49 @@ def do_textual_inversion_training(
|
||||
elif output_dir is not None:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
models_conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
||||
model_conf = models_conf.get(model,None)
|
||||
assert model_conf is not None,f'Unknown model: {model}'
|
||||
assert model_conf.get('format','diffusers')=='diffusers', "This script only works with models of type 'diffusers'"
|
||||
pretrained_model_name_or_path = model_conf.get('repo_id',None) or Path(model_conf.get('path'))
|
||||
assert pretrained_model_name_or_path, f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||
pipeline_args = dict(cache_dir=global_cache_dir('diffusers'))
|
||||
models_conf = OmegaConf.load(os.path.join(Globals.root, "configs/models.yaml"))
|
||||
model_conf = models_conf.get(model, None)
|
||||
assert model_conf is not None, f"Unknown model: {model}"
|
||||
assert (
|
||||
model_conf.get("format", "diffusers") == "diffusers"
|
||||
), "This script only works with models of type 'diffusers'"
|
||||
pretrained_model_name_or_path = model_conf.get("repo_id", None) or Path(
|
||||
model_conf.get("path")
|
||||
)
|
||||
assert (
|
||||
pretrained_model_name_or_path
|
||||
), f"models.yaml error: neither 'repo_id' nor 'path' is defined for {model}"
|
||||
pipeline_args = dict(cache_dir=global_cache_dir("diffusers"))
|
||||
|
||||
# Load tokenizer
|
||||
if tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name,**pipeline_args)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_name, **pipeline_args)
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="tokenizer", **pipeline_args
|
||||
)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, **pipeline_args
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="scheduler", **pipeline_args
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision, **pipeline_args)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
pretrained_model_name_or_path, subfolder="unet", revision=revision, **pipeline_args
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=revision,
|
||||
**pipeline_args,
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
@ -553,7 +678,9 @@ def do_textual_inversion_training(
|
||||
token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
|
||||
# Check if initializer_token is a single token or a sequence of tokens
|
||||
if len(token_ids) > 1:
|
||||
raise ValueError(f"The initializer token must be a single token. Provided initializer={initializer_token}. Token ids={token_ids}")
|
||||
raise ValueError(
|
||||
f"The initializer token must be a single token. Provided initializer={initializer_token}. Token ids={token_ids}"
|
||||
)
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
|
||||
@ -584,7 +711,9 @@ def do_textual_inversion_training(
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
raise ValueError(
|
||||
"xformers is not available. Make sure it is installed correctly"
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
@ -593,7 +722,10 @@ def do_textual_inversion_training(
|
||||
|
||||
if scale_lr:
|
||||
learning_rate = (
|
||||
learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes
|
||||
learning_rate
|
||||
* gradient_accumulation_steps
|
||||
* train_batch_size
|
||||
* accelerator.num_processes
|
||||
)
|
||||
|
||||
# Initialize the optimizer
|
||||
@ -616,11 +748,15 @@ def do_textual_inversion_training(
|
||||
center_crop=center_crop,
|
||||
set="train",
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
if max_train_steps is None:
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
@ -650,7 +786,9 @@ def do_textual_inversion_training(
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(train_dataloader) / gradient_accumulation_steps
|
||||
)
|
||||
if overrode_max_train_steps:
|
||||
max_train_steps = num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
@ -660,18 +798,22 @@ def do_textual_inversion_training(
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
params = locals()
|
||||
for k in params: # init_trackers() doesn't like objects
|
||||
params[k] = str(params[k]) if isinstance(params[k],object) else params[k]
|
||||
for k in params: # init_trackers() doesn't like objects
|
||||
params[k] = str(params[k]) if isinstance(params[k], object) else params[k]
|
||||
accelerator.init_trackers("textual_inversion", config=params)
|
||||
|
||||
# Train!
|
||||
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
total_batch_size = (
|
||||
train_batch_size * accelerator.num_processes * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||
global_step = 0
|
||||
@ -688,7 +830,7 @@ def do_textual_inversion_training(
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
@ -701,34 +843,57 @@ def do_textual_inversion_training(
|
||||
|
||||
resume_global_step = global_step * gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * gradient_accumulation_steps)
|
||||
|
||||
resume_step = resume_global_step % (
|
||||
num_update_steps_per_epoch * gradient_accumulation_steps
|
||||
)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar = tqdm(
|
||||
range(global_step, max_train_steps),
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
|
||||
orig_embeds_params = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight.data.clone()
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if resume_step and resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if (
|
||||
resume_step
|
||||
and resume_from_checkpoint
|
||||
and epoch == first_epoch
|
||||
and step < resume_step
|
||||
):
|
||||
if step % gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = (
|
||||
vae.encode(batch["pixel_values"].to(dtype=weight_dtype))
|
||||
.latent_dist.sample()
|
||||
.detach()
|
||||
)
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.config.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
@ -736,10 +901,14 @@ def do_textual_inversion_training(
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(
|
||||
dtype=weight_dtype
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(
|
||||
noisy_latents, timesteps, encoder_hidden_states
|
||||
).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
@ -747,7 +916,9 @@ def do_textual_inversion_training(
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
@ -760,21 +931,35 @@ def do_textual_inversion_training(
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
accelerator.unwrap_model(
|
||||
text_encoder
|
||||
).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds_params[index_no_updates]
|
||||
] = orig_embeds_params[
|
||||
index_no_updates
|
||||
]
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % save_steps == 0:
|
||||
save_path = os.path.join(output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path)
|
||||
save_path = os.path.join(
|
||||
output_dir, f"learned_embeds-steps-{global_step}.bin"
|
||||
)
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_id,
|
||||
accelerator,
|
||||
placeholder_token,
|
||||
save_path,
|
||||
)
|
||||
|
||||
if global_step % checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
|
||||
save_path = os.path.join(
|
||||
output_dir, f"checkpoint-{global_step}"
|
||||
)
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
@ -789,7 +974,9 @@ def do_textual_inversion_training(
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if push_to_hub and only_save_embeds:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
logger.warn(
|
||||
"Enabling full model saving because --push_to_hub=True was specified."
|
||||
)
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not only_save_embeds
|
||||
@ -805,9 +992,17 @@ def do_textual_inversion_training(
|
||||
pipeline.save_pretrained(output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(output_dir, "learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, placeholder_token, save_path)
|
||||
save_progress(
|
||||
text_encoder,
|
||||
placeholder_token_id,
|
||||
accelerator,
|
||||
placeholder_token,
|
||||
save_path,
|
||||
)
|
||||
|
||||
if push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
repo.push_to_hub(
|
||||
commit_message="End of training", blocking=False, auto_lfs_prune=True
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
@ -29,16 +29,12 @@ work fine.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
from clipseg.clipseg import CLIPDensePredT
|
||||
from einops import rearrange, repeat
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation
|
||||
from PIL import Image, ImageOps
|
||||
from torchvision import transforms
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.globals import global_cache_dir
|
||||
|
||||
CLIP_VERSION = 'ViT-B/16'
|
||||
CLIPSEG_WEIGHTS = 'models/clipseg/clipseg_weights/rd64-uni.pth'
|
||||
CLIPSEG_WEIGHTS_REFINED = 'models/clipseg/clipseg_weights/rd64-uni-refined.pth'
|
||||
CLIPSEG_MODEL = 'CIDAS/clipseg-rd64-refined'
|
||||
CLIPSEG_SIZE = 352
|
||||
|
||||
class SegmentedGrayscale(object):
|
||||
@ -77,16 +73,15 @@ class Txt2Mask(object):
|
||||
'''
|
||||
def __init__(self,device='cpu',refined=False):
|
||||
print('>> Initializing clipseg model for text to mask inference')
|
||||
|
||||
# BUG: we are not doing anything with the device option at this time
|
||||
self.device = device
|
||||
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, complex_trans_conv=refined)
|
||||
self.model.eval()
|
||||
# initially we keep everything in cpu to conserve space
|
||||
self.model.to('cpu')
|
||||
self.model.load_state_dict(torch.load(os.path.join(Globals.root,CLIPSEG_WEIGHTS_REFINED)
|
||||
if refined
|
||||
else os.path.join(Globals.root,CLIPSEG_WEIGHTS),
|
||||
map_location=torch.device('cpu')), strict=False
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(CLIPSEG_MODEL,
|
||||
cache_dir=global_cache_dir('hub')
|
||||
)
|
||||
self.model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_MODEL,
|
||||
cache_dir=global_cache_dir('hub')
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def segment(self, image, prompt:str) -> SegmentedGrayscale:
|
||||
@ -95,9 +90,6 @@ class Txt2Mask(object):
|
||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||
pixels indicate where the object is inferred to be.
|
||||
'''
|
||||
self._to_device(self.device)
|
||||
prompts = [prompt] # right now we operate on just a single prompt at a time
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
@ -111,14 +103,14 @@ class Txt2Mask(object):
|
||||
img = self._scale_and_crop(image)
|
||||
img = transform(img).unsqueeze(0)
|
||||
|
||||
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
|
||||
heatmap = torch.sigmoid(preds[0][0]).cpu()
|
||||
self._to_device('cpu')
|
||||
inputs = self.processor(text=[prompt],
|
||||
images=[image],
|
||||
padding=True,
|
||||
return_tensors='pt')
|
||||
outputs = self.model(**inputs)
|
||||
heatmap = torch.sigmoid(outputs.logits)
|
||||
return SegmentedGrayscale(image, heatmap)
|
||||
|
||||
def _to_device(self, device):
|
||||
self.model.to(device)
|
||||
|
||||
def _scale_and_crop(self, image:Image)->Image:
|
||||
scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE))
|
||||
if image.width > image.height: # width is constraint
|
||||
|
@ -36,8 +36,7 @@ classifiers = [
|
||||
dependencies = [
|
||||
"accelerate",
|
||||
"albumentations",
|
||||
"clip_anytorch", # replaceing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"clipseg @ https://github.com/invoke-ai/clipseg/archive/relaxed-python-requirement.zip", # is this still necesarry with diffusers?
|
||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"datasets",
|
||||
"diffusers[torch]~=0.11",
|
||||
"dnspython==2.2.1",
|
||||
@ -53,7 +52,7 @@ dependencies = [
|
||||
"huggingface-hub>=0.11.1",
|
||||
"imageio",
|
||||
"imageio-ffmpeg",
|
||||
"k-diffusion", # replaceing "k-diffusion @ https://github.com/Birch-san/k-diffusion/archive/refs/heads/mps.zip",
|
||||
"k-diffusion", # replacing "k-diffusion @ https://github.com/Birch-san/k-diffusion/archive/refs/heads/mps.zip",
|
||||
"kornia",
|
||||
"npyscreen",
|
||||
"numpy~=1.23",
|
||||
@ -92,12 +91,9 @@ test = ["pytest>6.0.0", "pytest-cov"]
|
||||
|
||||
[project.scripts]
|
||||
"configure_invokeai" = "ldm.invoke.configure_invokeai:main"
|
||||
"dream" = "ldm.invoke:CLI.main"
|
||||
"invoke" = "ldm.invoke:CLI.main"
|
||||
"legacy_api" = "scripts:legacy_api.main"
|
||||
"load_models" = "scripts:configure_invokeai.main"
|
||||
"merge_embeddings" = "scripts:merge_embeddings.main"
|
||||
"preload_models" = "ldm.invoke.configure_invokeai:main"
|
||||
"invoke" = "ldm.invoke.CLI:main"
|
||||
"textual_inversion" = "ldm.invoke.textual_inversion:main"
|
||||
"merge_models" = "ldm.invoke.merge_diffusers:main" # note name munging
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"
|
||||
|
@ -1,92 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ldm.invoke.globals import (Globals, global_cache_dir, global_config_file,
|
||||
global_set_root)
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
|
||||
parser = argparse.ArgumentParser(description="InvokeAI textual inversion training")
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
"--root-dir",
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help="Path to the invokeai runtime directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
required=True,
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Two to three model names to be merged",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merged_model_name",
|
||||
"--destination",
|
||||
dest="merged_model_name",
|
||||
type=str,
|
||||
help="Name of the output model. If not specified, will be the concatenation of the input model names.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="The interpolation parameter, ranging from 0 to 1. It affects the ratio in which the checkpoints are merged. Higher values give more weight to the 2d and 3d models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interpolation",
|
||||
dest="interp",
|
||||
type=str,
|
||||
choices=["weighted_sum", "sigmoid", "inv_sigmoid", "add_difference"],
|
||||
default="weighted_sum",
|
||||
help='Interpolation method to use. If three models are present, only "add_difference" will work.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Try to merge models even if they are incompatible with each other",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clobber",
|
||||
"--overwrite",
|
||||
dest='clobber',
|
||||
action="store_true",
|
||||
help="Overwrite the merged model if --merged_model_name already exists",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
assert args.alpha >= 0 and args.alpha <= 1.0, "alpha must be between 0 and 1"
|
||||
assert len(args.models) >= 1 and len(args.models) <= 3, "provide 2 or 3 models to merge"
|
||||
|
||||
if not args.merged_model_name:
|
||||
args.merged_model_name = "+".join(args.models)
|
||||
print(
|
||||
f'>> No --merged_model_name provided. Defaulting to "{args.merged_model_name}"'
|
||||
)
|
||||
|
||||
model_manager = ModelManager(OmegaConf.load(global_config_file()))
|
||||
assert (args.clobber or args.merged_model_name not in model_manager.model_names()), f'A model named "{args.merged_model_name}" already exists. Use --clobber to overwrite.'
|
||||
|
||||
# It seems that the merge pipeline is not honoring cache_dir, so we set the
|
||||
# HF_HOME environment variable here *before* we load diffusers.
|
||||
cache_dir = str(global_cache_dir("diffusers"))
|
||||
os.environ["HF_HOME"] = cache_dir
|
||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||
|
||||
try:
|
||||
merge_diffusion_models(**vars(args))
|
||||
print(f'>> Models merged into new model: "{args.merged_model_name}".')
|
||||
except Exception as e:
|
||||
print(f"** An error occurred while merging the pipelines: {str(e)}")
|
||||
print("** DETAILS:")
|
||||
print(traceback.format_exc())
|
||||
sys.exit(-1)
|
@ -1,217 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import npyscreen
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import argparse
|
||||
from ldm.invoke.globals import Globals, global_set_root, global_cache_dir, global_config_file
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
class FloatSlider(npyscreen.Slider):
|
||||
# this is supposed to adjust display precision, but doesn't
|
||||
def translate_value(self):
|
||||
stri = "%3.2f / %3.2f" %(self.value, self.out_of)
|
||||
l = (len(str(self.out_of)))*2+4
|
||||
stri = stri.rjust(l)
|
||||
return stri
|
||||
|
||||
class FloatTitleSlider(npyscreen.TitleText):
|
||||
_entry_type = FloatSlider
|
||||
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
interpolations = ['weighted_sum',
|
||||
'sigmoid',
|
||||
'inv_sigmoid',
|
||||
'add_difference']
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
@property
|
||||
def model_manager(self):
|
||||
return self.parentApp.model_manager
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names = self.get_model_names()
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Select up to three models to merge",
|
||||
value=''
|
||||
)
|
||||
self.model1 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='First Model:',
|
||||
values=self.model_names,
|
||||
value=0,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
self.model2 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Second Model:',
|
||||
values=self.model_names,
|
||||
value=1,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
models_plus_none = self.model_names.copy()
|
||||
models_plus_none.insert(0,'None')
|
||||
self.model3 = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Third Model:',
|
||||
values=models_plus_none,
|
||||
value=0,
|
||||
max_height=len(self.model_names)+1,
|
||||
)
|
||||
|
||||
for m in [self.model1,self.model2,self.model3]:
|
||||
m.when_value_edited = self.models_changed
|
||||
|
||||
self.merge_method = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Merge Method:',
|
||||
values=self.interpolations,
|
||||
value=0,
|
||||
max_height=len(self.interpolations),
|
||||
)
|
||||
self.alpha = self.add_widget_intelligent(
|
||||
FloatTitleSlider,
|
||||
name='Weight (alpha) to assign to second and third models:',
|
||||
out_of=1,
|
||||
step=0.05,
|
||||
lowest=0,
|
||||
value=0.5,
|
||||
)
|
||||
self.force = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name='Force merge of incompatible models',
|
||||
value=False,
|
||||
)
|
||||
self.merged_model_name = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name='Name for merged model',
|
||||
value='',
|
||||
)
|
||||
|
||||
def models_changed(self):
|
||||
models = self.model1.values
|
||||
selected_model1 = self.model1.value[0]
|
||||
selected_model2 = self.model2.value[0]
|
||||
selected_model3 = self.model3.value[0]
|
||||
merged_model_name = f'{models[selected_model1]}+{models[selected_model2]}'
|
||||
self.merged_model_name.value = merged_model_name
|
||||
|
||||
if selected_model3 > 0:
|
||||
self.merge_method.values=['add_difference'],
|
||||
self.merged_model_name.value += f'+{models[selected_model3]}'
|
||||
else:
|
||||
self.merge_method.values=self.interpolations
|
||||
self.merge_method.value=0
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values() and self.check_for_overwrite():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.merge_arguments = self.marshall_arguments()
|
||||
npyscreen.notify('Starting the merge...')
|
||||
import ldm.invoke.merge_diffusers # this keeps the message up while diffusers loads
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def on_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def marshall_arguments(self)->dict:
|
||||
model_names = self.model_names
|
||||
models = [
|
||||
model_names[self.model1.value[0]],
|
||||
model_names[self.model2.value[0]],
|
||||
]
|
||||
if self.model3.value[0] > 0:
|
||||
models.append(model_names[self.model3.value[0]-1])
|
||||
|
||||
args = dict(
|
||||
models=models,
|
||||
alpha = self.alpha.value,
|
||||
interp = self.interpolations[self.merge_method.value[0]],
|
||||
force = self.force.value,
|
||||
merged_model_name = self.merged_model_name.value,
|
||||
)
|
||||
return args
|
||||
|
||||
def check_for_overwrite(self)->bool:
|
||||
model_out = self.merged_model_name.value
|
||||
if model_out not in self.model_names:
|
||||
return True
|
||||
else:
|
||||
return npyscreen.notify_yes_no(f'The chosen merged model destination, {model_out}, is already in use. Overwrite?')
|
||||
|
||||
def validate_field_values(self)->bool:
|
||||
bad_fields = []
|
||||
model_names = self.model_names
|
||||
selected_models = set((model_names[self.model1.value[0]],model_names[self.model2.value[0]]))
|
||||
if self.model3.value[0] > 0:
|
||||
selected_models.add(model_names[self.model3.value[0]-1])
|
||||
if len(selected_models) < 2:
|
||||
bad_fields.append(f'Please select two or three DIFFERENT models to compare. You selected {selected_models}')
|
||||
if len(bad_fields) > 0:
|
||||
message = 'The following problems were detected and must be corrected:'
|
||||
for problem in bad_fields:
|
||||
message += f'\n* {problem}'
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self)->List[str]:
|
||||
model_names = [name for name in self.model_manager.model_names() if self.model_manager.model_info(name).get('format') == 'diffusers']
|
||||
print(model_names)
|
||||
return sorted(model_names)
|
||||
|
||||
class Mergeapp(npyscreen.NPSAppManaged):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
conf = OmegaConf.load(global_config_file())
|
||||
self.model_manager = ModelManager(conf,'cpu','float16') # precision doesn't really matter here
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm('MAIN', mergeModelsForm, name='Merge Models Settings')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='InvokeAI textual inversion training')
|
||||
parser.add_argument(
|
||||
'--root_dir','--root-dir',
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help='Path to the invokeai runtime directory',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
cache_dir = str(global_cache_dir('diffusers')) # because not clear the merge pipeline is honoring cache_dir
|
||||
os.environ['HF_HOME'] = cache_dir
|
||||
|
||||
mergeapp = Mergeapp()
|
||||
mergeapp.run()
|
||||
|
||||
args = mergeapp.merge_arguments
|
||||
args.update(cache_dir = cache_dir)
|
||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||
|
||||
try:
|
||||
merge_diffusion_models(**args)
|
||||
print(f'>> Models merged into new model: "{args["merged_model_name"]}".')
|
||||
except Exception as e:
|
||||
print(f'** An error occurred while merging the pipelines: {str(e)}')
|
||||
print('** DETAILS:')
|
||||
print(traceback.format_exc())
|
||||
sys.exit(-1)
|
@ -1,11 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2023, Lincoln Stein @lstein
|
||||
from ldm.invoke.globals import Globals, global_set_root
|
||||
from ldm.invoke.textual_inversion_training import parse_args, do_textual_inversion_training
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
global_set_root(args.root_dir or Globals.root)
|
||||
kwargs = vars(args)
|
||||
do_textual_inversion_training(**kwargs)
|
@ -1,350 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import npyscreen
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import shutil
|
||||
import traceback
|
||||
import curses
|
||||
from ldm.invoke.globals import Globals, global_set_root
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
import argparse
|
||||
|
||||
TRAINING_DATA = 'text-inversion-training-data'
|
||||
TRAINING_DIR = 'text-inversion-output'
|
||||
CONF_FILE = 'preferences.conf'
|
||||
|
||||
class textualInversionForm(npyscreen.FormMultiPageAction):
|
||||
resolutions = [512, 768, 1024]
|
||||
lr_schedulers = [
|
||||
"linear", "cosine", "cosine_with_restarts",
|
||||
"polynomial","constant", "constant_with_warmup"
|
||||
]
|
||||
precisions = ['no','fp16','bf16']
|
||||
learnable_properties = ['object','style']
|
||||
|
||||
def __init__(self, parentApp, name, saved_args=None):
|
||||
self.saved_args = saved_args or {}
|
||||
super().__init__(parentApp, name)
|
||||
|
||||
def afterEditing(self):
|
||||
self.parentApp.setNextForm(None)
|
||||
|
||||
def create(self):
|
||||
self.model_names, default = self.get_model_names()
|
||||
default_initializer_token = '★'
|
||||
default_placeholder_token = ''
|
||||
saved_args = self.saved_args
|
||||
|
||||
try:
|
||||
default = self.model_names.index(saved_args['model'])
|
||||
except:
|
||||
pass
|
||||
|
||||
self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
value='Use ctrl-N and ctrl-P to move to the <N>ext and <P>revious fields, cursor arrows to make a selection, and space to toggle checkboxes.'
|
||||
)
|
||||
|
||||
self.model = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Model Name:',
|
||||
values=self.model_names,
|
||||
value=default,
|
||||
max_height=len(self.model_names)+1
|
||||
)
|
||||
self.placeholder_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name='Trigger Term:',
|
||||
value='', # saved_args.get('placeholder_token',''), # to restore previous term
|
||||
)
|
||||
self.placeholder_token.when_value_edited = self.initializer_changed
|
||||
self.nextrely -= 1
|
||||
self.nextrelx += 30
|
||||
self.prompt_token = self.add_widget_intelligent(
|
||||
npyscreen.FixedText,
|
||||
name="Trigger term for use in prompt",
|
||||
value='',
|
||||
)
|
||||
self.nextrelx -= 30
|
||||
self.initializer_token = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name='Initializer:',
|
||||
value=saved_args.get('initializer_token',default_initializer_token),
|
||||
)
|
||||
self.resume_from_checkpoint = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Resume from last saved checkpoint",
|
||||
value=False,
|
||||
)
|
||||
self.learnable_property = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name="Learnable property:",
|
||||
values=self.learnable_properties,
|
||||
value=self.learnable_properties.index(saved_args.get('learnable_property','object')),
|
||||
max_height=4,
|
||||
)
|
||||
self.train_data_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name='Data Training Directory:',
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(saved_args.get('train_data_dir',Path(Globals.root) / TRAINING_DATA / default_placeholder_token))
|
||||
)
|
||||
self.output_dir = self.add_widget_intelligent(
|
||||
npyscreen.TitleFilename,
|
||||
name='Output Destination Directory:',
|
||||
select_dir=True,
|
||||
must_exist=False,
|
||||
value=str(saved_args.get('output_dir',Path(Globals.root) / TRAINING_DIR / default_placeholder_token))
|
||||
)
|
||||
self.resolution = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Image resolution (pixels):',
|
||||
values = self.resolutions,
|
||||
value=self.resolutions.index(saved_args.get('resolution',512)),
|
||||
scroll_exit = True,
|
||||
max_height=4,
|
||||
)
|
||||
self.center_crop = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Center crop images before resizing to resolution",
|
||||
value=saved_args.get('center_crop',False)
|
||||
)
|
||||
self.mixed_precision = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Mixed Precision:',
|
||||
values=self.precisions,
|
||||
value=self.precisions.index(saved_args.get('mixed_precision','fp16')),
|
||||
max_height=4,
|
||||
)
|
||||
self.num_train_epochs = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name='Number of training epochs:',
|
||||
out_of=1000,
|
||||
step=50,
|
||||
lowest=1,
|
||||
value=saved_args.get('num_train_epochs',100)
|
||||
)
|
||||
self.max_train_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name='Max Training Steps:',
|
||||
out_of=10000,
|
||||
step=500,
|
||||
lowest=1,
|
||||
value=saved_args.get('max_train_steps',3000)
|
||||
)
|
||||
self.train_batch_size = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name='Batch Size (reduce if you run out of memory):',
|
||||
out_of=50,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get('train_batch_size',8),
|
||||
)
|
||||
self.gradient_accumulation_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name='Gradient Accumulation Steps (may need to decrease this to resume from a checkpoint):',
|
||||
out_of=10,
|
||||
step=1,
|
||||
lowest=1,
|
||||
value=saved_args.get('gradient_accumulation_steps',4)
|
||||
)
|
||||
self.lr_warmup_steps = self.add_widget_intelligent(
|
||||
npyscreen.TitleSlider,
|
||||
name='Warmup Steps:',
|
||||
out_of=100,
|
||||
step=1,
|
||||
lowest=0,
|
||||
value=saved_args.get('lr_warmup_steps',0),
|
||||
)
|
||||
self.learning_rate = self.add_widget_intelligent(
|
||||
npyscreen.TitleText,
|
||||
name="Learning Rate:",
|
||||
value=str(saved_args.get('learning_rate','5.0e-04'),)
|
||||
)
|
||||
self.scale_lr = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Scale learning rate by number GPUs, steps and batch size",
|
||||
value=saved_args.get('scale_lr',True),
|
||||
)
|
||||
self.enable_xformers_memory_efficient_attention = self.add_widget_intelligent(
|
||||
npyscreen.Checkbox,
|
||||
name="Use xformers acceleration",
|
||||
value=saved_args.get('enable_xformers_memory_efficient_attention',False),
|
||||
)
|
||||
self.lr_scheduler = self.add_widget_intelligent(
|
||||
npyscreen.TitleSelectOne,
|
||||
name='Learning rate scheduler:',
|
||||
values = self.lr_schedulers,
|
||||
max_height=7,
|
||||
scroll_exit = True,
|
||||
value=self.lr_schedulers.index(saved_args.get('lr_scheduler','constant')),
|
||||
)
|
||||
|
||||
def initializer_changed(self):
|
||||
placeholder = self.placeholder_token.value
|
||||
self.prompt_token.value = f'(Trigger by using <{placeholder}> in your prompts)'
|
||||
self.train_data_dir.value = str(Path(Globals.root) / TRAINING_DATA / placeholder)
|
||||
self.output_dir.value = str(Path(Globals.root) / TRAINING_DIR / placeholder)
|
||||
self.resume_from_checkpoint.value = Path(self.output_dir.value).exists()
|
||||
|
||||
def on_ok(self):
|
||||
if self.validate_field_values():
|
||||
self.parentApp.setNextForm(None)
|
||||
self.editing = False
|
||||
self.parentApp.ti_arguments = self.marshall_arguments()
|
||||
npyscreen.notify('Launching textual inversion training. This will take a while...')
|
||||
# The module load takes a while, so we do it while the form and message are still up
|
||||
import ldm.invoke.textual_inversion_training
|
||||
else:
|
||||
self.editing = True
|
||||
|
||||
def ok_cancel(self):
|
||||
sys.exit(0)
|
||||
|
||||
def validate_field_values(self)->bool:
|
||||
bad_fields = []
|
||||
if self.model.value is None:
|
||||
bad_fields.append('Model Name must correspond to a known model in models.yaml')
|
||||
if not re.match('^[a-zA-Z0-9.-]+$',self.placeholder_token.value):
|
||||
bad_fields.append('Trigger term must only contain alphanumeric characters, the dot and hyphen')
|
||||
if self.train_data_dir.value is None:
|
||||
bad_fields.append('Data Training Directory cannot be empty')
|
||||
if self.output_dir.value is None:
|
||||
bad_fields.append('The Output Destination Directory cannot be empty')
|
||||
if len(bad_fields) > 0:
|
||||
message = 'The following problems were detected and must be corrected:'
|
||||
for problem in bad_fields:
|
||||
message += f'\n* {problem}'
|
||||
npyscreen.notify_confirm(message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_model_names(self)->(List[str],int):
|
||||
conf = OmegaConf.load(os.path.join(Globals.root,'configs/models.yaml'))
|
||||
model_names = [idx for idx in sorted(list(conf.keys())) if conf[idx].get('format',None)=='diffusers']
|
||||
defaults = [idx for idx in range(len(model_names)) if 'default' in conf[model_names[idx]]]
|
||||
return (model_names,defaults[0])
|
||||
|
||||
def marshall_arguments(self)->dict:
|
||||
args = dict()
|
||||
|
||||
# the choices
|
||||
args.update(
|
||||
model = self.model_names[self.model.value[0]],
|
||||
resolution = self.resolutions[self.resolution.value[0]],
|
||||
lr_scheduler = self.lr_schedulers[self.lr_scheduler.value[0]],
|
||||
mixed_precision = self.precisions[self.mixed_precision.value[0]],
|
||||
learnable_property = self.learnable_properties[self.learnable_property.value[0]],
|
||||
)
|
||||
|
||||
# all the strings and booleans
|
||||
for attr in ('initializer_token','placeholder_token','train_data_dir',
|
||||
'output_dir','scale_lr','center_crop','enable_xformers_memory_efficient_attention'):
|
||||
args[attr] = getattr(self,attr).value
|
||||
|
||||
# all the integers
|
||||
for attr in ('train_batch_size','gradient_accumulation_steps',
|
||||
'num_train_epochs','max_train_steps','lr_warmup_steps'):
|
||||
args[attr] = int(getattr(self,attr).value)
|
||||
|
||||
# the floats (just one)
|
||||
args.update(
|
||||
learning_rate = float(self.learning_rate.value)
|
||||
)
|
||||
|
||||
# a special case
|
||||
if self.resume_from_checkpoint.value and Path(self.output_dir.value).exists():
|
||||
args['resume_from_checkpoint'] = 'latest'
|
||||
|
||||
return args
|
||||
|
||||
class MyApplication(npyscreen.NPSAppManaged):
|
||||
def __init__(self, saved_args=None):
|
||||
super().__init__()
|
||||
self.ti_arguments=None
|
||||
self.saved_args=saved_args
|
||||
|
||||
def onStart(self):
|
||||
npyscreen.setTheme(npyscreen.Themes.DefaultTheme)
|
||||
self.main = self.addForm('MAIN', textualInversionForm, name='Textual Inversion Settings', saved_args=self.saved_args)
|
||||
|
||||
def copy_to_embeddings_folder(args:dict):
|
||||
'''
|
||||
Copy learned_embeds.bin into the embeddings folder, and offer to
|
||||
delete the full model and checkpoints.
|
||||
'''
|
||||
source = Path(args['output_dir'],'learned_embeds.bin')
|
||||
dest_dir_name = args['placeholder_token'].strip('<>')
|
||||
destination = Path(Globals.root,'embeddings',dest_dir_name)
|
||||
os.makedirs(destination,exist_ok=True)
|
||||
print(f'>> Training completed. Copying learned_embeds.bin into {str(destination)}')
|
||||
shutil.copy(source,destination)
|
||||
if (input('Delete training logs and intermediate checkpoints? [y] ') or 'y').startswith(('y','Y')):
|
||||
shutil.rmtree(Path(args['output_dir']))
|
||||
else:
|
||||
print(f'>> Keeping {args["output_dir"]}')
|
||||
|
||||
def save_args(args:dict):
|
||||
'''
|
||||
Save the current argument values to an omegaconf file
|
||||
'''
|
||||
dest_dir = Path(Globals.root) / TRAINING_DIR
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
conf_file = dest_dir / CONF_FILE
|
||||
conf = OmegaConf.create(args)
|
||||
OmegaConf.save(config=conf, f=conf_file)
|
||||
|
||||
def previous_args()->dict:
|
||||
'''
|
||||
Get the previous arguments used.
|
||||
'''
|
||||
conf_file = Path(Globals.root) / TRAINING_DIR / CONF_FILE
|
||||
try:
|
||||
conf = OmegaConf.load(conf_file)
|
||||
conf['placeholder_token'] = conf['placeholder_token'].strip('<>')
|
||||
except:
|
||||
conf= None
|
||||
|
||||
return conf
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='InvokeAI textual inversion training')
|
||||
parser.add_argument(
|
||||
'--root_dir','--root-dir',
|
||||
type=Path,
|
||||
default=Globals.root,
|
||||
help='Path to the invokeai runtime directory',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
global_set_root(args.root_dir)
|
||||
|
||||
saved_args = previous_args()
|
||||
myapplication = MyApplication(saved_args=saved_args)
|
||||
myapplication.run()
|
||||
|
||||
from ldm.invoke.textual_inversion_training import do_textual_inversion_training
|
||||
if args := myapplication.ti_arguments:
|
||||
os.makedirs(args['output_dir'],exist_ok=True)
|
||||
|
||||
# Automatically add angle brackets around the trigger
|
||||
if not re.match('^<.+>$',args['placeholder_token']):
|
||||
args['placeholder_token'] = f"<{args['placeholder_token']}>"
|
||||
|
||||
args['only_save_embeds'] = True
|
||||
save_args(args)
|
||||
|
||||
try:
|
||||
print(f'DEBUG: args = {args}')
|
||||
do_textual_inversion_training(**args)
|
||||
copy_to_embeddings_folder(args)
|
||||
except Exception as e:
|
||||
print('** An exception occurred during training. The exception was:')
|
||||
print(str(e))
|
||||
print('** DETAILS:')
|
||||
print(traceback.format_exc())
|
Loading…
Reference in New Issue
Block a user