Merge branch 'main' into api/add-trigger-string-retrieval

This commit is contained in:
Lincoln Stein 2023-02-17 15:53:57 -05:00 committed by GitHub
commit fc14ac7faa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 1592 additions and 1112 deletions

View File

@ -6,12 +6,12 @@ on:
- 'update/ci/docker/*'
- 'update/docker/*'
paths:
- '/pyproject.toml'
- '/ldm/**'
- '/invokeai/backend/**'
- '/invokeai/configs/**'
- '/invokeai/frontend/dist/**'
- '/docker/Dockerfile'
- 'pyproject.toml'
- 'ldm/**'
- 'invokeai/backend/**'
- 'invokeai/configs/**'
- 'invokeai/frontend/dist/**'
- 'docker/Dockerfile'
tags:
- 'v*.*.*'
workflow_dispatch:

View File

@ -2,11 +2,11 @@ name: Test invoke.py pip
on:
pull_request:
paths-ignore:
- '/pyproject.toml'
- '/ldm/**'
- '/invokeai/backend/**'
- '/invokeai/configs/**'
- '/invokeai/frontend/dist/**'
- 'pyproject.toml'
- 'ldm/**'
- 'invokeai/backend/**'
- 'invokeai/configs/**'
- 'invokeai/frontend/dist/**'
merge_group:
workflow_dispatch:

View File

@ -4,18 +4,18 @@ on:
branches:
- 'main'
paths:
- '/pyproject.toml'
- '/ldm/**'
- '/invokeai/backend/**'
- '/invokeai/configs/**'
- '/invokeai/frontend/dist/**'
- 'pyproject.toml'
- 'ldm/**'
- 'invokeai/backend/**'
- 'invokeai/configs/**'
- 'invokeai/frontend/dist/**'
pull_request:
paths:
- '/pyproject.toml'
- '/ldm/**'
- '/invokeai/backend/**'
- '/invokeai/configs/**'
- '/invokeai/frontend/dist/**'
- 'pyproject.toml'
- 'ldm/**'
- 'invokeai/backend/**'
- 'invokeai/configs/**'
- 'invokeai/frontend/dist/**'
types:
- 'ready_for_review'
- 'opened'

View File

@ -80,6 +80,13 @@ only `.safetensors` and `.ckpt` models, but they can be easily loaded
into InvokeAI and/or converted into optimized `diffusers` models. Be
aware that CIVITAI hosts many models that generate NSFW content.
!!! note
InvokeAI 2.3.x does not support directly importing and
running Stable Diffusion version 2 checkpoint models. You may instead
convert them into `diffusers` models using the conversion methods
described below.
## Installation
There are multiple ways to install and manage models:
@ -90,7 +97,7 @@ There are multiple ways to install and manage models:
models files.
3. The web interface (WebUI) has a GUI for importing and managing
models.
models.
### Installation via `invokeai-configure`
@ -106,7 +113,7 @@ confirm that the files are complete.
You can install a new model, including any of the community-supported ones, via
the command-line client's `!import_model` command.
#### Installing `.ckpt` and `.safetensors` models
#### Installing individual `.ckpt` and `.safetensors` models
If the model is already downloaded to your local disk, use
`!import_model /path/to/file.ckpt` to load it. For example:
@ -131,15 +138,40 @@ invoke> !import_model https://example.org/sd_models/martians.safetensors
For this to work, the URL must not be password-protected. Otherwise
you will receive a 404 error.
When you import a legacy model, the CLI will ask you a few questions
about the model, including what size image it was trained on (usually
512x512), what name and description you wish to use for it, what
configuration file to use for it (usually the default
`v1-inference.yaml`), whether you'd like to make this model the
default at startup time, and whether you would like to install a
custom VAE (variable autoencoder) file for the model. For recent
models, the answer to the VAE question is usually "no," but it won't
hurt to answer "yes".
When you import a legacy model, the CLI will first ask you what type
of model this is. You can indicate whether it is a model based on
Stable Diffusion 1.x (1.4 or 1.5), one based on Stable Diffusion 2.x,
or a 1.x inpainting model. Be careful to indicate the correct model
type, or it will not load correctly. You can correct the model type
after the fact using the `!edit_model` command.
The system will then ask you a few other questions about the model,
including what size image it was trained on (usually 512x512), what
name and description you wish to use for it, and whether you would
like to install a custom VAE (variable autoencoder) file for the
model. For recent models, the answer to the VAE question is usually
"no," but it won't hurt to answer "yes".
After importing, the model will load. If this is successful, you will
be asked if you want to keep the model loaded in memory to start
generating immediately. You'll also be asked if you wish to make this
the default model on startup. You can change this later using
`!edit_model`.
#### Importing a batch of `.ckpt` and `.safetensors` models from a directory
You may also point `!import_model` to a directory containing a set of
`.ckpt` or `.safetensors` files. They will be imported _en masse_.
!!! example
```console
invoke> !import_model C:/Users/fred/Downloads/civitai_models/
```
You will be given the option to import all models found in the
directory, or select which ones to import. If there are subfolders
within the directory, they will be searched for models to import.
#### Installing `diffusers` models
@ -279,19 +311,23 @@ After you save the modified `models.yaml` file relaunch
### Installation via the WebUI
To access the WebUI Model Manager, click on the button that looks like
a cute in the upper right side of the browser screen. This will bring
a cube in the upper right side of the browser screen. This will bring
up a dialogue that lists the models you have already installed, and
allows you to load, delete or edit them:
<figure markdown>
![model-manager](../assets/installing-models/webui-models-1.png)
</figure>
To add a new model, click on **+ Add New** and select to either a
checkpoint/safetensors model, or a diffusers model:
<figure markdown>
![model-manager-add-new](../assets/installing-models/webui-models-2.png)
</figure>
In this example, we chose **Add Diffusers**. As shown in the figure
@ -302,7 +338,9 @@ choose to enter a path to disk, the system will autocomplete for you
as you type:
<figure markdown>
![model-manager-add-diffusers](../assets/installing-models/webui-models-3.png)
</figure>
Press **Add Model** at the bottom of the dialogue (scrolled out of
@ -317,7 +355,9 @@ directory and press the "Search" icon. This will display the
subfolders, and allow you to choose which ones to import:
<figure markdown>
![model-manager-add-checkpoint](../assets/installing-models/webui-models-4.png)
</figure>
## Model Management Startup Options
@ -342,9 +382,8 @@ invoke.sh --autoconvert /home/fred/stable-diffusion-checkpoints
And here is what the same argument looks like in `invokeai.init`:
```
```bash
--outdir="/home/fred/invokeai/outputs
--no-nsfw_checker
--autoconvert /home/fred/stable-diffusion-checkpoints
```

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -5,8 +5,8 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
<script type="module" crossorigin src="./assets/index-12bd70ca.js"></script>
<link rel="stylesheet" href="./assets/index-c1af841f.css">
<script type="module" crossorigin src="./assets/index-9237ac63.js"></script>
<link rel="stylesheet" href="./assets/index-14cb2922.css">
</head>
<body>

View File

@ -1,10 +1,12 @@
{
"general": "General",
"images": "Images",
"steps": "Steps",
"cfgScale": "CFG Scale",
"width": "Width",
"height": "Height",
"sampler": "Sampler",
"imageToImage": "Image To Image",
"seed": "Seed",
"randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle",

View File

@ -1,10 +1,12 @@
{
"general": "General",
"images": "Images",
"steps": "Steps",
"cfgScale": "CFG Scale",
"width": "Width",
"height": "Height",
"sampler": "Sampler",
"imageToImage": "Image To Image",
"seed": "Seed",
"randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle",

View File

@ -5,6 +5,7 @@
"confirmOnDelete": "Confirm On Delete",
"displayHelpIcons": "Display Help Icons",
"useCanvasBeta": "Use Canvas Beta Layout",
"useSlidersForAll": "Use Sliders For All Options",
"enableImageDebugging": "Enable Image Debugging",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",

View File

@ -1,10 +1,12 @@
{
"general": "General",
"images": "Images",
"steps": "Steps",
"cfgScale": "CFG Scale",
"width": "Width",
"height": "Height",
"sampler": "Sampler",
"imageToImage": "Image To Image",
"seed": "Seed",
"randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle",

View File

@ -1,10 +1,12 @@
{
"general": "General",
"images": "Images",
"steps": "Steps",
"cfgScale": "CFG Scale",
"width": "Width",
"height": "Height",
"sampler": "Sampler",
"imageToImage": "Image To Image",
"seed": "Seed",
"randomizeSeed": "Randomize Seed",
"shuffle": "Shuffle",

View File

@ -5,6 +5,7 @@
"confirmOnDelete": "Confirm On Delete",
"displayHelpIcons": "Display Help Icons",
"useCanvasBeta": "Use Canvas Beta Layout",
"useSlidersForAll": "Use Sliders For All Options",
"enableImageDebugging": "Enable Image Debugging",
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",

View File

@ -6,7 +6,6 @@
min-width: max-content;
margin: 0;
font-weight: bold;
font-size: 0.9rem;
color: var(--text-color-secondary);
}

View File

@ -78,7 +78,7 @@ export default function IAISlider(props: IAIFullSliderProps) {
tooltipSuffix = '',
withSliderMarks = false,
sliderMarkLeftOffset = 0,
sliderMarkRightOffset = -7,
sliderMarkRightOffset = -1,
withInput = false,
isInteger = false,
inputWidth = '5.5rem',
@ -164,6 +164,7 @@ export default function IAISlider(props: IAIFullSliderProps) {
>
<FormLabel
className="invokeai__slider-component-label"
fontSize="sm"
{...sliderFormLabelProps}
>
{label}

View File

@ -0,0 +1,55 @@
import { Box } from '@chakra-ui/react';
interface SubItemHookProps {
active?: boolean;
width?: string | number;
height?: string | number;
side?: 'left' | 'right';
}
export default function SubItemHook(props: SubItemHookProps) {
const {
active = true,
width = '1rem',
height = '1.3rem',
side = 'right',
} = props;
return (
<>
{side === 'right' && (
<Box
width={width}
height={height}
margin="-0.5rem 0.5rem 0 0.5rem"
borderLeft={
active
? '3px solid var(--subhook-color)'
: '3px solid var(--tab-hover-color)'
}
borderBottom={
active
? '3px solid var(--subhook-color)'
: '3px solid var(--tab-hover-color)'
}
/>
)}
{side === 'left' && (
<Box
width={width}
height={height}
margin="-0.5rem 0.5rem 0 0.5rem"
borderRight={
active
? '3px solid var(--subhook-color)'
: '3px solid var(--tab-hover-color)'
}
borderBottom={
active
? '3px solid var(--subhook-color)'
: '3px solid var(--tab-hover-color)'
}
/>
)}
</>
);
}

View File

@ -170,6 +170,9 @@ export const frontendToBackendParameters = (
let esrganParameters: false | BackendEsrGanParameters = false;
let facetoolParameters: false | BackendFacetoolParameters = false;
// Multiplying it by 10000 so the Slider can have values between 0 and 1 which makes more sense
generationParameters.threshold = threshold * 1000;
if (negativePrompt !== '') {
generationParameters.prompt = `${prompt} [${negativePrompt}]`;
}

View File

@ -68,7 +68,7 @@ const BoundingBoxSettings = () => {
};
return (
<Flex direction="column" gap="1rem">
<Flex direction="column" gap={2}>
<IAISlider
label={t('parameters:width')}
min={64}
@ -82,6 +82,7 @@ const BoundingBoxSettings = () => {
inputReadOnly
withReset
handleReset={handleResetWidth}
sliderMarkRightOffset={-7}
/>
<IAISlider
label={t('parameters:height')}
@ -96,6 +97,7 @@ const BoundingBoxSettings = () => {
inputReadOnly
withReset
handleReset={handleResetHeight}
sliderMarkRightOffset={-7}
/>
</Flex>
);

View File

@ -107,7 +107,7 @@ const InfillAndScalingSettings = () => {
};
return (
<Flex direction="column" gap="1rem">
<Flex direction="column" gap={4}>
<IAISelect
label={t('parameters:scaleBeforeProcessing')}
validValues={BOUNDING_BOX_SCALES_DICT}
@ -130,6 +130,7 @@ const InfillAndScalingSettings = () => {
inputReadOnly
withReset
handleReset={handleResetScaledWidth}
sliderMarkRightOffset={-7}
/>
<IAISlider
isInputDisabled={!isManual}
@ -147,6 +148,7 @@ const InfillAndScalingSettings = () => {
inputReadOnly
withReset
handleReset={handleResetScaledHeight}
sliderMarkRightOffset={-7}
/>
<IAISelect
label={t('parameters:infillMethod')}

View File

@ -6,7 +6,7 @@ import SeamStrength from './SeamStrength';
const SeamCorrectionSettings = () => {
return (
<Flex direction="column" gap="1rem">
<Flex direction="column" gap={2}>
<SeamSize />
<SeamBlur />
<SeamStrength />

View File

@ -0,0 +1,36 @@
import type { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setCodeformerFidelity } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function CodeformerFidelity() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const codeformerFidelity = useAppSelector(
(state: RootState) => state.postprocessing.codeformerFidelity
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
isSliderDisabled={!isGFPGANAvailable}
isInputDisabled={!isGFPGANAvailable}
isResetDisabled={!isGFPGANAvailable}
label={t('parameters:codeformerFidelity')}
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setCodeformerFidelity(v))}
handleReset={() => dispatch(setCodeformerFidelity(1))}
value={codeformerFidelity}
withReset
withSliderMarks
withInput
/>
);
}

View File

@ -1,99 +1,23 @@
import { Flex } from '@chakra-ui/react';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import { FacetoolType } from 'features/parameters/store/postprocessingSlice';
import {
setCodeformerFidelity,
setFacetoolStrength,
setFacetoolType,
} from 'features/parameters/store/postprocessingSlice';
import { createSelector } from '@reduxjs/toolkit';
import { FACETOOL_TYPES } from 'app/constants';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISelect from 'common/components/IAISelect';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
const optionsSelector = createSelector(
[postprocessingSelector, systemSelector],
(
{ facetoolStrength, facetoolType, codeformerFidelity },
{ isGFPGANAvailable }
) => {
return {
facetoolStrength,
facetoolType,
codeformerFidelity,
isGFPGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
import { useAppSelector } from 'app/storeHooks';
import type { RootState } from 'app/store';
import FaceRestoreType from './FaceRestoreType';
import FaceRestoreStrength from './FaceRestoreStrength';
import CodeformerFidelity from './CodeformerFidelity';
/**
* Displays face-fixing/GFPGAN options (strength).
*/
const FaceRestoreSettings = () => {
const dispatch = useAppDispatch();
const {
facetoolStrength,
facetoolType,
codeformerFidelity,
isGFPGANAvailable,
} = useAppSelector(optionsSelector);
const handleChangeStrength = (v: number) => dispatch(setFacetoolStrength(v));
const handleChangeCodeformerFidelity = (v: number) =>
dispatch(setCodeformerFidelity(v));
const handleChangeFacetoolType = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setFacetoolType(e.target.value as FacetoolType));
const { t } = useTranslation();
const facetoolType = useAppSelector(
(state: RootState) => state.postprocessing.facetoolType
);
return (
<Flex direction="column" gap={2}>
<IAISelect
label={t('parameters:type')}
validValues={FACETOOL_TYPES.concat()}
value={facetoolType}
onChange={handleChangeFacetoolType}
/>
<IAINumberInput
isDisabled={!isGFPGANAvailable}
label={t('parameters:strength')}
step={0.05}
min={0}
max={1}
onChange={handleChangeStrength}
value={facetoolStrength}
width="90px"
isInteger={false}
/>
{facetoolType === 'codeformer' && (
<IAINumberInput
isDisabled={!isGFPGANAvailable}
label={t('parameters:codeformerFidelity')}
step={0.05}
min={0}
max={1}
onChange={handleChangeCodeformerFidelity}
value={codeformerFidelity}
width="90px"
isInteger={false}
/>
)}
<Flex direction="column" gap={2} minWidth="20rem">
<FaceRestoreType />
<FaceRestoreStrength />
{facetoolType === 'codeformer' && <CodeformerFidelity />}
</Flex>
);
};

View File

@ -0,0 +1,36 @@
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setFacetoolStrength } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function FaceRestoreStrength() {
const isGFPGANAvailable = useAppSelector(
(state: RootState) => state.system.isGFPGANAvailable
);
const facetoolStrength = useAppSelector(
(state: RootState) => state.postprocessing.facetoolStrength
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
isSliderDisabled={!isGFPGANAvailable}
isInputDisabled={!isGFPGANAvailable}
isResetDisabled={!isGFPGANAvailable}
label={t('parameters:strength')}
step={0.05}
min={0}
max={1}
onChange={(v) => dispatch(setFacetoolStrength(v))}
handleReset={() => dispatch(setFacetoolStrength(0.75))}
value={facetoolStrength}
withReset
withSliderMarks
withInput
/>
);
}

View File

@ -0,0 +1,31 @@
import { FACETOOL_TYPES } from 'app/constants';
import { type RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISelect from 'common/components/IAISelect';
import {
type FacetoolType,
setFacetoolType,
} from 'features/parameters/store/postprocessingSlice';
import { type ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export default function FaceRestoreType() {
const facetoolType = useAppSelector(
(state: RootState) => state.postprocessing.facetoolType
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChangeFacetoolType = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setFacetoolType(e.target.value as FacetoolType));
return (
<IAISelect
label={t('parameters:type')}
validValues={FACETOOL_TYPES.concat()}
value={facetoolType}
onChange={handleChangeFacetoolType}
/>
);
}

View File

@ -4,6 +4,7 @@ import type { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISlider from 'common/components/IAISlider';
import IAISwitch from 'common/components/IAISwitch';
import SubItemHook from 'common/components/SubItemHook';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import {
setHiresFix,
@ -39,23 +40,27 @@ const HiresStrength = () => {
};
return (
<IAISlider
label={t('parameters:hiresStrength')}
step={0.01}
min={0.01}
max={0.99}
onChange={handleHiresStrength}
value={hiresStrength}
isInteger={false}
withInput
withSliderMarks
inputWidth="5.5rem"
withReset
handleReset={handleHiResStrengthReset}
isSliderDisabled={!hiresFix}
isInputDisabled={!hiresFix}
isResetDisabled={!hiresFix}
/>
<Flex>
<SubItemHook active={hiresFix} />
<IAISlider
label={t('parameters:hiresStrength')}
step={0.01}
min={0.01}
max={0.99}
onChange={handleHiresStrength}
value={hiresStrength}
isInteger={false}
withInput
withSliderMarks
inputWidth={'5.5rem'}
withReset
handleReset={handleHiResStrengthReset}
isSliderDisabled={!hiresFix}
isInputDisabled={!hiresFix}
isResetDisabled={!hiresFix}
sliderMarkRightOffset={-7}
/>
</Flex>
);
};
@ -75,7 +80,7 @@ const HiresSettings = () => {
dispatch(setHiresFix(e.target.checked));
return (
<Flex gap={2} direction="column">
<Flex rowGap="0.8rem" direction={'column'}>
<IAISwitch
label={t('parameters:hiresOptim')}
fontSize="md"

View File

@ -1,6 +1,6 @@
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setPerlin } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next';
@ -9,17 +9,18 @@ export default function Perlin() {
const perlin = useAppSelector((state: RootState) => state.generation.perlin);
const { t } = useTranslation();
const handleChangePerlin = (v: number) => dispatch(setPerlin(v));
return (
<IAINumberInput
<IAISlider
label={t('parameters:perlinNoise')}
min={0}
max={1}
step={0.05}
onChange={handleChangePerlin}
onChange={(v) => dispatch(setPerlin(v))}
handleReset={() => dispatch(setPerlin(0))}
value={perlin}
isInteger={false}
withInput
withReset
withSliderMarks
/>
);
}

View File

@ -1,6 +1,6 @@
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setThreshold } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next';
@ -11,17 +11,19 @@ export default function Threshold() {
);
const { t } = useTranslation();
const handleChangeThreshold = (v: number) => dispatch(setThreshold(v));
return (
<IAINumberInput
<IAISlider
label={t('parameters:noiseThreshold')}
min={0}
max={1000}
step={0.1}
onChange={handleChangeThreshold}
max={1}
step={0.005}
onChange={(v) => dispatch(setThreshold(v))}
handleReset={() => dispatch(setThreshold(0))}
value={threshold}
isInteger={false}
withInput
withReset
withSliderMarks
inputWidth="6rem"
/>
);
}

View File

@ -0,0 +1,38 @@
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setUpscalingDenoising } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleDenoisingStrength() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingDenoising = useAppSelector(
(state: RootState) => state.postprocessing.upscalingDenoising
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
label={t('parameters:denoisingStrength')}
value={upscalingDenoising}
min={0}
max={1}
step={0.01}
onChange={(v) => {
dispatch(setUpscalingDenoising(v));
}}
handleReset={() => dispatch(setUpscalingDenoising(0.75))}
withSliderMarks
withInput
withReset
isSliderDisabled={!isESRGANAvailable}
isInputDisabled={!isESRGANAvailable}
isResetDisabled={!isESRGANAvailable}
/>
);
}

View File

@ -0,0 +1,36 @@
import { UPSCALING_LEVELS } from 'app/constants';
import type { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISelect from 'common/components/IAISelect';
import {
setUpscalingLevel,
type UpscalingLevel,
} from 'features/parameters/store/postprocessingSlice';
import type { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export default function UpscaleScale() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingLevel = useAppSelector(
(state: RootState) => state.postprocessing.upscalingLevel
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setUpscalingLevel(Number(e.target.value) as UpscalingLevel));
return (
<IAISelect
isDisabled={!isESRGANAvailable}
label={t('parameters:scale')}
value={upscalingLevel}
onChange={handleChangeLevel}
validValues={UPSCALING_LEVELS}
/>
);
}

View File

@ -1,104 +1,17 @@
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import {
setUpscalingDenoising,
setUpscalingLevel,
setUpscalingStrength,
UpscalingLevel,
} from 'features/parameters/store/postprocessingSlice';
import { createSelector } from '@reduxjs/toolkit';
import { UPSCALING_LEVELS } from 'app/constants';
import IAISelect from 'common/components/IAISelect';
import { postprocessingSelector } from 'features/parameters/store/postprocessingSelectors';
import { systemSelector } from 'features/system/store/systemSelectors';
import { isEqual } from 'lodash';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
import IAISlider from 'common/components/IAISlider';
import { Flex } from '@chakra-ui/react';
const parametersSelector = createSelector(
[postprocessingSelector, systemSelector],
(
{ upscalingLevel, upscalingStrength, upscalingDenoising },
{ isESRGANAvailable }
) => {
return {
upscalingLevel,
upscalingDenoising,
upscalingStrength,
isESRGANAvailable,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
import UpscaleDenoisingStrength from './UpscaleDenoisingStrength';
import UpscaleStrength from './UpscaleStrength';
import UpscaleScale from './UpscaleScale';
/**
* Displays upscaling/ESRGAN options (level and strength).
*/
const UpscaleSettings = () => {
const dispatch = useAppDispatch();
const {
upscalingLevel,
upscalingStrength,
upscalingDenoising,
isESRGANAvailable,
} = useAppSelector(parametersSelector);
const { t } = useTranslation();
const handleChangeLevel = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setUpscalingLevel(Number(e.target.value) as UpscalingLevel));
const handleChangeStrength = (v: number) => dispatch(setUpscalingStrength(v));
return (
<Flex flexDir="column" rowGap="1rem" minWidth="20rem">
<IAISelect
isDisabled={!isESRGANAvailable}
label={t('parameters:scale')}
value={upscalingLevel}
onChange={handleChangeLevel}
validValues={UPSCALING_LEVELS}
/>
<IAISlider
label={t('parameters:denoisingStrength')}
value={upscalingDenoising}
min={0}
max={1}
step={0.01}
onChange={(v) => {
dispatch(setUpscalingDenoising(v));
}}
handleReset={() => dispatch(setUpscalingDenoising(0.75))}
withSliderMarks
withInput
withReset
isSliderDisabled={!isESRGANAvailable}
isInputDisabled={!isESRGANAvailable}
isResetDisabled={!isESRGANAvailable}
/>
<IAISlider
label={`${t('parameters:upscale')} ${t('parameters:strength')}`}
value={upscalingStrength}
min={0}
max={1}
step={0.05}
onChange={handleChangeStrength}
handleReset={() => dispatch(setUpscalingStrength(0.75))}
withSliderMarks
withInput
withReset
isSliderDisabled={!isESRGANAvailable}
isInputDisabled={!isESRGANAvailable}
isResetDisabled={!isESRGANAvailable}
/>
<Flex flexDir="column" rowGap={2} minWidth="20rem">
<UpscaleScale />
<UpscaleDenoisingStrength />
<UpscaleStrength />
</Flex>
);
};

View File

@ -0,0 +1,35 @@
import type { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { setUpscalingStrength } from 'features/parameters/store/postprocessingSlice';
import { useTranslation } from 'react-i18next';
export default function UpscaleStrength() {
const isESRGANAvailable = useAppSelector(
(state: RootState) => state.system.isESRGANAvailable
);
const upscalingStrength = useAppSelector(
(state: RootState) => state.postprocessing.upscalingStrength
);
const { t } = useTranslation();
const dispatch = useAppDispatch();
return (
<IAISlider
label={`${t('parameters:upscale')} ${t('parameters:strength')}`}
value={upscalingStrength}
min={0}
max={1}
step={0.05}
onChange={(v) => dispatch(setUpscalingStrength(v))}
handleReset={() => dispatch(setUpscalingStrength(0.75))}
withSliderMarks
withInput
withReset
isSliderDisabled={!isESRGANAvailable}
isInputDisabled={!isESRGANAvailable}
isResetDisabled={!isESRGANAvailable}
/>
);
}

View File

@ -1,6 +1,6 @@
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setVariationAmount } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next';
@ -16,19 +16,22 @@ export default function VariationAmount() {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleChangevariationAmount = (v: number) =>
dispatch(setVariationAmount(v));
return (
<IAINumberInput
<IAISlider
label={t('parameters:variationAmount')}
value={variationAmount}
step={0.01}
min={0}
max={1}
isDisabled={!shouldGenerateVariations}
onChange={handleChangevariationAmount}
isInteger={false}
isSliderDisabled={!shouldGenerateVariations}
isInputDisabled={!shouldGenerateVariations}
isResetDisabled={!shouldGenerateVariations}
onChange={(v) => dispatch(setVariationAmount(v))}
handleReset={() => dispatch(setVariationAmount(0.1))}
withInput
withReset
withSliderMarks
/>
);
}

View File

@ -1,6 +1,7 @@
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setCfgScale } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next';
@ -9,11 +10,29 @@ export default function MainCFGScale() {
const cfgScale = useAppSelector(
(state: RootState) => state.generation.cfgScale
);
const shouldUseSliders = useAppSelector(
(state: RootState) => state.ui.shouldUseSliders
);
const { t } = useTranslation();
const handleChangeCfgScale = (v: number) => dispatch(setCfgScale(v));
return (
return shouldUseSliders ? (
<IAISlider
label={t('parameters:cfgScale')}
step={0.5}
min={1.01}
max={30}
onChange={handleChangeCfgScale}
handleReset={() => dispatch(setCfgScale(7.5))}
value={cfgScale}
sliderMarkRightOffset={-5}
sliderNumberInputProps={{ max: 200 }}
withInput
withReset
withSliderMarks
/>
) : (
<IAINumberInput
label={t('parameters:cfgScale')}
step={0.5}

View File

@ -2,29 +2,50 @@ import { HEIGHTS } from 'app/constants';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISelect from 'common/components/IAISelect';
import IAISlider from 'common/components/IAISlider';
import { setHeight } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export default function MainHeight() {
const height = useAppSelector((state: RootState) => state.generation.height);
const shouldUseSliders = useAppSelector(
(state: RootState) => state.ui.shouldUseSliders
);
const activeTabName = useAppSelector(activeTabNameSelector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChangeHeight = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setHeight(Number(e.target.value)));
return (
return shouldUseSliders ? (
<IAISlider
isSliderDisabled={activeTabName === 'unifiedCanvas'}
isInputDisabled={activeTabName === 'unifiedCanvas'}
isResetDisabled={activeTabName === 'unifiedCanvas'}
label={t('parameters:height')}
value={height}
min={64}
step={64}
max={2048}
onChange={(v) => dispatch(setHeight(v))}
handleReset={() => dispatch(setHeight(512))}
withInput
withReset
withSliderMarks
sliderMarkRightOffset={-8}
inputWidth="6.2rem"
sliderNumberInputProps={{ max: 15360 }}
/>
) : (
<IAISelect
isDisabled={activeTabName === 'unifiedCanvas'}
label={t('parameters:height')}
value={height}
flexGrow={1}
onChange={handleChangeHeight}
onChange={(e) => dispatch(setHeight(Number(e.target.value)))}
validValues={HEIGHTS}
styleClass="main-settings-block"
width="5.5rem"
/>
);
}

View File

@ -1,39 +1,41 @@
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput';
import {
GenerationState,
setIterations,
} from 'features/parameters/store/generationSlice';
import { isEqual } from 'lodash';
import IAISlider from 'common/components/IAISlider';
import { setIterations } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next';
const mainIterationsSelector = createSelector(
[(state: RootState) => state.generation],
(parameters: GenerationState) => {
const { iterations } = parameters;
return {
iterations,
};
},
{
memoizeOptions: {
resultEqualityCheck: isEqual,
},
}
);
export default function MainIterations() {
const iterations = useAppSelector(
(state: RootState) => state.generation.iterations
);
const shouldUseSliders = useAppSelector(
(state: RootState) => state.ui.shouldUseSliders
);
const dispatch = useAppDispatch();
const { iterations } = useAppSelector(mainIterationsSelector);
const { t } = useTranslation();
const handleChangeIterations = (v: number) => dispatch(setIterations(v));
return (
return shouldUseSliders ? (
<IAISlider
label={t('parameters:images')}
step={1}
min={1}
max={16}
onChange={handleChangeIterations}
handleReset={() => dispatch(setIterations(1))}
value={iterations}
withInput
withReset
withSliderMarks
sliderMarkRightOffset={-5}
sliderNumberInputProps={{ max: 9999 }}
/>
) : (
<IAINumberInput
label={t('parameters:images')}
step={1}

View File

@ -1,3 +1,8 @@
import { Flex } from '@chakra-ui/react';
import { type RootState } from 'app/store';
import { useAppSelector } from 'app/storeHooks';
import { useTranslation } from 'react-i18next';
import ParametersAccordion from '../ParametersAccordion';
import MainCFGScale from './MainCFGScale';
import MainHeight from './MainHeight';
import MainIterations from './MainIterations';
@ -8,20 +13,40 @@ import MainWidth from './MainWidth';
export const inputWidth = 'auto';
export default function MainSettings() {
return (
<div className="main-settings">
<div className="main-settings-list">
<div className="main-settings-row">
const { t } = useTranslation();
const shouldUseSliders = useAppSelector(
(state: RootState) => state.ui.shouldUseSliders
);
const accordionItems = {
main: {
header: `${t('parameters:general')}`,
feature: undefined,
content: shouldUseSliders ? (
<Flex flexDir="column" rowGap={2}>
<MainIterations />
<MainSteps />
<MainCFGScale />
</div>
<div className="main-settings-row">
<MainWidth />
<MainHeight />
<MainSampler />
</div>
</div>
</div>
);
</Flex>
) : (
<Flex flexDirection="column" rowGap={2}>
<Flex gap={2}>
<MainIterations />
<MainSteps />
<MainCFGScale />
</Flex>
<Flex>
<MainWidth />
<MainHeight />
<MainSampler />
</Flex>
</Flex>
),
},
};
return <ParametersAccordion accordionInfo={accordionItems} />;
}

View File

@ -27,6 +27,7 @@ export default function MainSampler() {
activeModel.format === 'diffusers' ? DIFFUSERS_SAMPLERS : SAMPLERS
}
styleClass="main-settings-block"
minWidth="9rem"
/>
);
}

View File

@ -1,17 +1,36 @@
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setSteps } from 'features/parameters/store/generationSlice';
import { useTranslation } from 'react-i18next';
export default function MainSteps() {
const dispatch = useAppDispatch();
const steps = useAppSelector((state: RootState) => state.generation.steps);
const shouldUseSliders = useAppSelector(
(state: RootState) => state.ui.shouldUseSliders
);
const { t } = useTranslation();
const handleChangeSteps = (v: number) => dispatch(setSteps(v));
return (
return shouldUseSliders ? (
<IAISlider
label={t('parameters:steps')}
min={1}
step={1}
onChange={handleChangeSteps}
handleReset={() => dispatch(setSteps(20))}
value={steps}
withInput
withReset
withSliderMarks
sliderMarkRightOffset={-6}
sliderNumberInputProps={{ max: 9999 }}
/>
) : (
<IAINumberInput
label={t('parameters:steps')}
min={1}

View File

@ -2,30 +2,51 @@ import { WIDTHS } from 'app/constants';
import { RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAISelect from 'common/components/IAISelect';
import IAISlider from 'common/components/IAISlider';
import { setWidth } from 'features/parameters/store/generationSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { ChangeEvent } from 'react';
import { useTranslation } from 'react-i18next';
export default function MainWidth() {
const width = useAppSelector((state: RootState) => state.generation.width);
const shouldUseSliders = useAppSelector(
(state: RootState) => state.ui.shouldUseSliders
);
const activeTabName = useAppSelector(activeTabNameSelector);
const { t } = useTranslation();
const dispatch = useAppDispatch();
const handleChangeWidth = (e: ChangeEvent<HTMLSelectElement>) =>
dispatch(setWidth(Number(e.target.value)));
return (
return shouldUseSliders ? (
<IAISlider
isSliderDisabled={activeTabName === 'unifiedCanvas'}
isInputDisabled={activeTabName === 'unifiedCanvas'}
isResetDisabled={activeTabName === 'unifiedCanvas'}
label={t('parameters:width')}
value={width}
min={64}
step={64}
max={2048}
onChange={(v) => dispatch(setWidth(v))}
handleReset={() => dispatch(setWidth(512))}
withInput
withReset
withSliderMarks
sliderMarkRightOffset={-8}
inputWidth="6.2rem"
inputReadOnly
sliderNumberInputProps={{ max: 15360 }}
/>
) : (
<IAISelect
isDisabled={activeTabName === 'unifiedCanvas'}
label={t('parameters:width')}
value={width}
flexGrow={1}
onChange={handleChangeWidth}
onChange={(e) => dispatch(setWidth(Number(e.target.value)))}
validValues={WIDTHS}
styleClass="main-settings-block"
width="5.5rem"
/>
);
}

View File

@ -22,7 +22,7 @@ export interface PostprocessingState {
const initialPostprocessingState: PostprocessingState = {
codeformerFidelity: 0.75,
facetoolStrength: 0.8,
facetoolStrength: 0.75,
facetoolType: 'gfpgan',
hiresFix: false,
hiresStrength: 0.75,

View File

@ -14,7 +14,7 @@ import {
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { IN_PROGRESS_IMAGE_TYPES } from 'app/constants';
import { RootState } from 'app/store';
import { type RootState } from 'app/store';
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISelect from 'common/components/IAISelect';
@ -27,9 +27,14 @@ import {
setShouldConfirmOnDelete,
setShouldDisplayGuides,
setShouldDisplayInProgressType,
type SystemState,
} from 'features/system/store/systemSlice';
import { uiSelector } from 'features/ui/store/uiSelectors';
import { setShouldUseCanvasBetaLayout } from 'features/ui/store/uiSlice';
import {
setShouldUseCanvasBetaLayout,
setShouldUseSliders,
} from 'features/ui/store/uiSlice';
import { type UIState } from 'features/ui/store/uiTypes';
import { isEqual, map } from 'lodash';
import { persistor } from 'persistor';
import { ChangeEvent, cloneElement, ReactElement } from 'react';
@ -37,7 +42,7 @@ import { useTranslation } from 'react-i18next';
const selector = createSelector(
[systemSelector, uiSelector],
(system, ui) => {
(system: SystemState, ui: UIState) => {
const {
shouldDisplayInProgressType,
shouldConfirmOnDelete,
@ -47,7 +52,7 @@ const selector = createSelector(
enableImageDebugging,
} = system;
const { shouldUseCanvasBetaLayout } = ui;
const { shouldUseCanvasBetaLayout, shouldUseSliders } = ui;
return {
shouldDisplayInProgressType,
@ -57,6 +62,7 @@ const selector = createSelector(
saveIntermediatesInterval,
enableImageDebugging,
shouldUseCanvasBetaLayout,
shouldUseSliders,
};
},
{
@ -100,6 +106,7 @@ const SettingsModal = ({ children }: SettingsModalProps) => {
saveIntermediatesInterval,
enableImageDebugging,
shouldUseCanvasBetaLayout,
shouldUseSliders,
} = useAppSelector(selector);
/**
@ -191,6 +198,14 @@ const SettingsModal = ({ children }: SettingsModalProps) => {
dispatch(setShouldUseCanvasBetaLayout(e.target.checked))
}
/>
<IAISwitch
styleClass="settings-modal-item"
label={t('settings:useSlidersForAll')}
isChecked={shouldUseSliders}
onChange={(e: ChangeEvent<HTMLInputElement>) =>
dispatch(setShouldUseSliders(e.target.checked))
}
/>
</div>
<div className="settings-modal-items">

View File

@ -0,0 +1,26 @@
import { Flex } from '@chakra-ui/react';
import ImageFit from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageFit';
import ImageToImageStrength from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageToImageStrength';
import ParametersAccordion from 'features/parameters/components/ParametersAccordion';
import { useTranslation } from 'react-i18next';
export default function ImageToImageOptions() {
const { t } = useTranslation();
const imageToImageAccordionItems = {
imageToImage: {
header: `${t('parameters:imageToImage')}`,
feature: undefined,
content: (
<Flex gap={2} flexDir="column">
<ImageToImageStrength
label={t('parameters:img2imgStrength')}
styleClass="main-settings-block image-to-image-strength-main-option"
/>
<ImageFit />
</Flex>
),
},
};
return <ParametersAccordion accordionInfo={imageToImageAccordionItems} />;
}

View File

@ -2,8 +2,6 @@ import { Flex } from '@chakra-ui/react';
import { Feature } from 'app/features';
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
import FaceRestoreToggle from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreToggle';
import ImageFit from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageFit';
import ImageToImageStrength from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageToImageStrength';
import ImageToImageOutputSettings from 'features/parameters/components/AdvancedParameters/Output/ImageToImageOutputSettings';
import SeedSettings from 'features/parameters/components/AdvancedParameters/Seed/SeedSettings';
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
@ -17,6 +15,7 @@ import NegativePromptInput from 'features/parameters/components/PromptInput/Nega
import PromptInput from 'features/parameters/components/PromptInput/PromptInput';
import InvokeOptionsPanel from 'features/ui/components/InvokeParametersPanel';
import { useTranslation } from 'react-i18next';
import ImageToImageOptions from './ImageToImageOptions';
export default function ImageToImagePanel() {
const { t } = useTranslation();
@ -60,11 +59,7 @@ export default function ImageToImagePanel() {
</Flex>
<ProcessButtons />
<MainSettings />
<ImageToImageStrength
label={t('parameters:img2imgStrength')}
styleClass="main-settings-block image-to-image-strength-main-option"
/>
<ImageFit />
<ImageToImageOptions />
<ParametersAccordion accordionInfo={imageToImageAccordions} />
</InvokeOptionsPanel>
);

View File

@ -32,7 +32,7 @@
.parameters-panel {
display: flex;
flex-direction: column;
row-gap: 1rem;
row-gap: 0.5rem;
height: 100%;
@include HideScrollbar;
background-color: var(--background-color);

View File

@ -20,6 +20,11 @@ export default function UnifiedCanvasPanel() {
const { t } = useTranslation();
const unifiedCanvasAccordions = {
seed: {
header: `${t('parameters:seed')}`,
feature: Feature.SEED,
content: <SeedSettings />,
},
boundingBox: {
header: `${t('parameters:boundingBoxHeader')}`,
feature: Feature.BOUNDING_BOX,
@ -35,11 +40,6 @@ export default function UnifiedCanvasPanel() {
feature: Feature.INFILL_AND_SCALING,
content: <InfillAndScalingSettings />,
},
seed: {
header: `${t('parameters:seed')}`,
feature: Feature.SEED,
content: <SeedSettings />,
},
variations: {
header: `${t('parameters:variations')}`,
feature: Feature.VARIATIONS,
@ -48,6 +48,19 @@ export default function UnifiedCanvasPanel() {
},
};
const unifiedCanvasImg2ImgAccordion = {
unifiedCanvasImg2Img: {
header: `${t('parameters:imageToImage')}`,
feature: undefined,
content: (
<ImageToImageStrength
label={t('parameters:img2imgStrength')}
styleClass="main-settings-block image-to-image-strength-main-option"
/>
),
},
};
return (
<InvokeOptionsPanel>
<Flex flexDir="column" rowGap="0.5rem">
@ -56,10 +69,7 @@ export default function UnifiedCanvasPanel() {
</Flex>
<ProcessButtons />
<MainSettings />
<ImageToImageStrength
label={t('parameters:img2imgStrength')}
styleClass="main-settings-block image-to-image-strength-main-option"
/>
<ParametersAccordion accordionInfo={unifiedCanvasImg2ImgAccordion} />
<ParametersAccordion accordionInfo={unifiedCanvasAccordions} />
</InvokeOptionsPanel>
);

View File

@ -14,6 +14,7 @@ const initialtabsState: UIState = {
shouldShowImageDetails: false,
shouldUseCanvasBetaLayout: false,
shouldShowExistingModelsInSearch: false,
shouldUseSliders: false,
addNewModelUIOption: null,
};
@ -66,6 +67,9 @@ export const uiSlice = createSlice({
) => {
state.shouldShowExistingModelsInSearch = action.payload;
},
setShouldUseSliders: (state, action: PayloadAction<boolean>) => {
state.shouldUseSliders = action.payload;
},
setAddNewModelUIOption: (state, action: PayloadAction<AddNewModelType>) => {
state.addNewModelUIOption = action.payload;
},
@ -83,6 +87,7 @@ export const {
setShouldShowImageDetails,
setShouldUseCanvasBetaLayout,
setShouldShowExistingModelsInSearch,
setShouldUseSliders,
setAddNewModelUIOption,
} = uiSlice.actions;

View File

@ -11,5 +11,6 @@ export interface UIState {
shouldShowImageDetails: boolean;
shouldUseCanvasBetaLayout: boolean;
shouldShowExistingModelsInSearch: boolean;
shouldUseSliders: boolean;
addNewModelUIOption: AddNewModelType;
}

View File

@ -137,4 +137,7 @@
// Scrollbar
--scrollbar-color: var(--accent-color);
--scrollbar-color-hover: var(--accent-color-bright);
// SubHook
--subhook-color: var(--accent-color);
}

View File

@ -135,4 +135,7 @@
// Scrollbar
--scrollbar-color: var(--accent-color);
--scrollbar-color-hover: var(--accent-color-bright);
// SubHook
--subhook-color: var(--accent-color);
}

View File

@ -132,4 +132,7 @@
// Scrollbar
--scrollbar-color: rgb(180, 180, 184);
--scrollbar-color-hover: rgb(150, 150, 154);
// SubHook
--subhook-color: rgb(0, 0, 0);
}

File diff suppressed because one or more lines are too long

View File

@ -5,7 +5,9 @@ import sys
import traceback
from argparse import Namespace
from pathlib import Path
from typing import Optional, Union
from typing import List, Optional, Union
import click
if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@ -24,6 +26,7 @@ from ldm.invoke.model_manager import ModelManager
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import Completer, get_completer
from ldm.util import url_attachment_name
# global used in multiple functions (fix)
infile = None
@ -78,7 +81,6 @@ def main():
import transformers # type: ignore
from ldm.generate import Generate
transformers.logging.set_verbosity_error()
import diffusers
@ -623,10 +625,11 @@ def set_default_output_dir(opt: Args, completer: Completer):
def import_model(model_path: str, gen, opt, completer):
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path;
(3) a huggingface repository id; or (4) a local directory containing a
diffusers model.
"""
model_path can be (1) a URL to a .ckpt file; (2) a local .ckpt file path; or
(3) a huggingface repository id
"""
model.path = model_path.replace('\\','/') # windows
model_name = None
if model_path.startswith(("http:", "https:", "ftp:")):
@ -669,7 +672,7 @@ def import_model(model_path: str, gen, opt, completer):
print("** model failed to load. Discarding configuration entry")
gen.model_manager.del_model(model_name)
return
if input("Make this the default model? [n] ").strip() in ("y", "Y"):
if click.confirm('Make this the default model?', default=False):
gen.model_manager.set_default_model(model_name)
gen.model_manager.commit(opt.conf)
@ -677,9 +680,46 @@ def import_model(model_path: str, gen, opt, completer):
print(f">> {model_name} successfully installed")
def import_checkpoint_list(models: List[Path], gen, opt, completer)->List[str]:
'''
Does a mass import of all the checkpoint/safetensors on a path list
'''
model_names = list()
choice = input('** Directory of checkpoint/safetensors models detected. Install <a>ll or <s>elected models? [a] ') or 'a'
do_all = choice.startswith('a')
if do_all:
config_file = _ask_for_config_file(models[0], completer, plural=True)
manager = gen.model_manager
for model in sorted(models):
model_name = f'{model.stem}'
model_description = f'Imported model {model_name}'
if model_name in manager.model_names():
print(f'** {model_name} is already imported. Skipping.')
elif manager.import_ckpt_model(
model,
config = config_file,
model_name = model_name,
model_description = model_description,
commit_to_conf = opt.conf):
model_names.append(model_name)
print(f'>> Model {model_name} imported successfully')
else:
print(f'** Model {model} failed to import')
else:
for model in sorted(models):
if click.confirm(f'Import {model.stem} ?', default=True):
if model_name := import_ckpt_model(model, gen, opt, completer):
print(f'>> Model {model.stem} imported successfully')
model_names.append(model_name)
else:
printf('** Model {model} failed to import')
print()
return model_names
def import_diffuser_model(
path_or_repo: Union[Path, str], gen, _, completer
) -> Optional[str]:
path_or_repo = path_or_repo.replace('\\','/') # windows
manager = gen.model_manager
default_name = Path(path_or_repo).stem
default_description = f"Imported model {default_name}"
@ -690,10 +730,8 @@ def import_diffuser_model(
model_description=default_description,
)
vae = None
if input(
'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] '
).strip() in ("y", "Y"):
vae = dict(repo_id="stabilityai/sd-vae-ft-mse")
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
if not manager.import_diffuser_model(
path_or_repo, model_name=model_name, vae=vae, description=model_description
@ -702,13 +740,16 @@ def import_diffuser_model(
return None
return model_name
def import_ckpt_model(
path_or_url: Union[Path, str], gen, opt, completer
) -> Optional[str]:
path_or_url = path_or_url.replace('\\','/')
manager = gen.model_manager
default_name = Path(path_or_url).stem
is_a_url = str(path_or_url).startswith(('http:','https:'))
base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name
default_name = Path(base_name).stem
default_description = f"Imported model {default_name}"
model_name, model_description = _get_model_name_and_desc(
manager,
completer,
@ -758,10 +799,14 @@ def import_ckpt_model(
def _verify_load(model_name: str, gen) -> bool:
print(">> Verifying that new model loads...")
current_model = gen.model_name
if not gen.model_manager.get_model(model_name):
try:
if not gen.model_manager.get_model(model_name):
return False
except Exception as e:
print(f'** model failed to load: {str(e)}')
print('** note that importing 2.X checkpoints is not supported. Please use !convert_model instead.')
return False
do_switch = input("Keep model loaded? [y] ")
if len(do_switch) == 0 or do_switch[0] in ("y", "Y"):
if click.confirm('Keep model loaded?', default=True):
gen.set_model(model_name)
else:
print(">> Restoring previous model")
@ -780,18 +825,44 @@ def _get_model_name_and_desc(
)
return model_name, model_description
def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=False)->Path:
default = '1'
if re.search('inpaint',str(model_path),flags=re.IGNORECASE):
default = '3'
choices={
'1': 'v1-inference.yaml',
'2': 'v2-inference-v.yaml',
'3': 'v1-inpainting-inference.yaml',
}
prompt = '''What type of models are these?:
[1] Models based on Stable Diffusion 1.X
[2] Models based on Stable Diffusion 2.X
[3] Inpainting models based on Stable Diffusion 1.X
[4] Something else''' if plural else '''What type of model is this?:
[1] A model based on Stable Diffusion 1.X
[2] A model based on Stable Diffusion 2.X
[3] An inpainting models based on Stable Diffusion 1.X
[4] Something else'''
print(prompt)
choice = input(f'Your choice: [{default}] ')
choice = choice.strip() or default
if config_file := choices.get(choice,None):
return Path('configs','stable-diffusion',config_file)
def _is_inpainting(model_name_or_path: str) -> bool:
if re.search("inpaint", model_name_or_path, flags=re.IGNORECASE):
return not input("Is this an inpainting model? [y] ").startswith(("n", "N"))
else:
return not input("Is this an inpainting model? [n] ").startswith(("y", "Y"))
# otherwise ask user to select
done = False
completer.complete_extensions(('.yaml','.yml'))
completer.set_line(str(Path(Globals.root,'configs/stable-diffusion/')))
while not done:
config_path = input('Configuration file for this model (leave blank to abort): ').strip()
done = not config_path or os.path.exists(config_path)
return config_path
def optimize_model(model_name_or_path: str, gen, opt, completer):
def optimize_model(model_name_or_path: Union[Path,str], gen, opt, completer):
model_name_or_path = model_name_or_path.replace('\\','/') # windows
manager = gen.model_manager
ckpt_path = None
original_config_file = None
if model_name_or_path == gen.model_name:
print("** Can't convert the active model. !switch to another model first. **")
@ -806,16 +877,13 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
print(f"** {model_name_or_path} is not a legacy .ckpt weights file")
return
elif os.path.exists(model_name_or_path):
original_config_file = original_config_file or _ask_for_config_file(model_name_or_path, completer)
if not original_config_file:
return
ckpt_path = Path(model_name_or_path)
model_name, model_description = _get_model_name_and_desc(
manager, completer, ckpt_path.stem, f"Converted model {ckpt_path.stem}"
)
is_inpainting = _is_inpainting(model_name_or_path)
original_config_file = Path(
"configs",
"stable-diffusion",
"v1-inpainting-inference.yaml" if is_inpainting else "v1-inference.yaml",
)
else:
print(
f"** {model_name_or_path} is neither an existing model nor the path to a .ckpt file"
@ -838,10 +906,8 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
return
vae = None
if input(
'Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"? [n] '
).strip() in ("y", "Y"):
vae = dict(repo_id="stabilityai/sd-vae-ft-mse")
if click.confirm('Replace this model\'s VAE with "stabilityai/sd-vae-ft-mse"?', default=False):
vae = dict(repo_id='stabilityai/sd-vae-ft-mse')
new_config = gen.model_manager.convert_and_import(
ckpt_path,
@ -856,11 +922,10 @@ def optimize_model(model_name_or_path: str, gen, opt, completer):
return
completer.update_models(gen.model_manager.list_models())
if input(f"Load optimized model {model_name}? [y] ").strip() not in ("n", "N"):
if click.confirm(f'Load optimized model {model_name}?', default=True):
gen.set_model(model_name)
response = input(f"Delete the original .ckpt file at ({ckpt_path} ? [n] ")
if response.startswith(("y", "Y")):
if click.confirm(f'Delete the original .ckpt file at {ckpt_path}?',default=False):
ckpt_path.unlink(missing_ok=True)
print(f"{ckpt_path} deleted")
@ -874,17 +939,11 @@ def del_config(model_name: str, gen, opt, completer):
print(f"** Unknown model {model_name}")
return
if (
input(f"Remove {model_name} from the list of models known to InvokeAI? [y] ")
.strip()
.startswith(("n", "N"))
):
if not click.confirm(f'Remove {model_name} from the list of models known to InvokeAI?',default=True):
return
delete_completely = input(
"Completely remove the model file or directory from disk? [n] "
).startswith(("y", "Y"))
gen.model_manager.del_model(model_name, delete_files=delete_completely)
delete_completely = click.confirm('Completely remove the model file or directory from disk?',default=False)
gen.model_manager.del_model(model_name,delete_files=delete_completely)
gen.model_manager.commit(opt.conf)
print(f"** {model_name} deleted")
completer.update_models(gen.model_manager.list_models())
@ -913,7 +972,7 @@ def edit_model(model_name: str, gen, opt, completer):
# this does the update
manager.add_model(new_name, info, True)
if input("Make this the default model? [n] ").startswith(("y", "Y")):
if click.confirm('Make this the default model?',default=False):
manager.set_default_model(new_name)
manager.commit(opt.conf)
completer.update_models(manager.list_models())
@ -1288,10 +1347,7 @@ def report_model_error(opt: Namespace, e: Exception):
"** Reconfiguration is being forced by environment variable INVOKE_MODEL_RECONFIGURE"
)
else:
response = input(
"Do you want to run invokeai-configure script to select and/or reinstall models? [y] "
)
if response.startswith(("n", "N")):
if click.confirm('Do you want to run invokeai-configure script to select and/or reinstall models?', default=True):
return
print("invokeai-configure is launching....\n")

View File

@ -34,8 +34,8 @@ from ldm.invoke.generator.diffusers_pipeline import \
StableDiffusionGeneratorPipeline
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
global_models_dir)
from ldm.util import (ask_user, download_with_progress_bar,
instantiate_from_config)
from ldm.util import (ask_user, download_with_resume,
url_attachment_name, instantiate_from_config)
DEFAULT_MAX_MODELS = 2
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
@ -673,15 +673,18 @@ class ModelManager(object):
path to the configuration file, then the new entry will be committed to the
models.yaml file.
"""
if str(weights).startswith(("http:", "https:")):
model_name = model_name or url_attachment_name(weights)
weights_path = self._resolve_path(weights, "models/ldm/stable-diffusion-v1")
config_path = self._resolve_path(config, "configs/stable-diffusion")
config_path = self._resolve_path(config, "configs/stable-diffusion")
if weights_path is None or not weights_path.exists():
return False
if config_path is None or not config_path.exists():
return False
model_name = model_name or Path(weights).stem
model_name = model_name or Path(weights).stem # note this gives ugly pathnames if used on a URL without a Content-Disposition header
model_description = (
model_description or f"imported stable diffusion weights file {model_name}"
)
@ -971,16 +974,15 @@ class ModelManager(object):
print("** Migration is done. Continuing...")
def _resolve_path(
self, source: Union[str, Path], dest_directory: str
self, source: Union[str, Path], dest_directory: str
) -> Optional[Path]:
resolved_path = None
if str(source).startswith(("http:", "https:", "ftp:")):
basename = os.path.basename(source)
if not os.path.isabs(dest_directory):
dest_directory = os.path.join(Globals.root, dest_directory)
dest = os.path.join(dest_directory, basename)
if download_with_progress_bar(str(source), Path(dest)):
resolved_path = Path(dest)
dest_directory = Path(dest_directory)
if not dest_directory.is_absolute():
dest_directory = Globals.root / dest_directory
dest_directory.mkdir(parents=True, exist_ok=True)
resolved_path = download_with_resume(str(source), dest_directory)
else:
if not os.path.isabs(source):
source = os.path.join(Globals.root, source)

View File

@ -1,20 +1,21 @@
import importlib
import math
import multiprocessing as mp
import os
import re
from collections import abc
from inspect import isfunction
from pathlib import Path
from queue import Queue
from threading import Thread
from urllib import request
from tqdm import tqdm
from pathlib import Path
from ldm.invoke.devices import torch_dtype
import numpy as np
import requests
import torch
import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
from ldm.invoke.devices import torch_dtype
def log_txt_as_img(wh, xc, size=10):
@ -23,18 +24,18 @@ def log_txt_as_img(wh, xc, size=10):
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new('RGB', wh, color='white')
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.load_default()
nc = int(40 * (wh[0] / 256))
lines = '\n'.join(
lines = "\n".join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try:
draw.text((0, 0), lines, fill='black', font=font)
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print('Cant encode string for logging. Skipping.')
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@ -77,25 +78,23 @@ def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f' | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.'
f" | {model.__class__.__name__} has {total_params * 1.e-6:.2f} M params."
)
return total_params
def instantiate_from_config(config, **kwargs):
if not 'target' in config:
if config == '__is_first_stage__':
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == '__is_unconditional__':
elif config == "__is_unconditional__":
return None
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(
**config.get('params', dict()), **kwargs
)
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()), **kwargs)
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit('.', 1)
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
@ -111,14 +110,14 @@ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
else:
res = func(data)
Q.put([idx, res])
Q.put('Done')
Q.put("Done")
def parallel_data_prefetch(
func: callable,
data,
n_proc,
target_data_type='ndarray',
target_data_type="ndarray",
cpu_intensive=True,
use_worker_id=False,
):
@ -126,21 +125,21 @@ def parallel_data_prefetch(
# raise ValueError(
# "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
# )
if isinstance(data, np.ndarray) and target_data_type == 'list':
raise ValueError('list expected but function got ndarray.')
if isinstance(data, np.ndarray) and target_data_type == "list":
raise ValueError("list expected but function got ndarray.")
elif isinstance(data, abc.Iterable):
if isinstance(data, dict):
print(
f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
)
data = list(data.values())
if target_data_type == 'ndarray':
if target_data_type == "ndarray":
data = np.asarray(data)
else:
data = list(data)
else:
raise TypeError(
f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.'
f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
)
if cpu_intensive:
@ -150,7 +149,7 @@ def parallel_data_prefetch(
Q = Queue(1000)
proc = Thread
# spawn processes
if target_data_type == 'ndarray':
if target_data_type == "ndarray":
arguments = [
[func, Q, part, i, use_worker_id]
for i, part in enumerate(np.array_split(data, n_proc))
@ -173,7 +172,7 @@ def parallel_data_prefetch(
processes += [p]
# start processes
print(f'Start prefetching...')
print("Start prefetching...")
import time
start = time.time()
@ -186,13 +185,13 @@ def parallel_data_prefetch(
while k < n_proc:
# get result
res = Q.get()
if res == 'Done':
if res == "Done":
k += 1
else:
gather_res[res[0]] = res[1]
except Exception as e:
print('Exception: ', e)
print("Exception: ", e)
for p in processes:
p.terminate()
@ -200,15 +199,15 @@ def parallel_data_prefetch(
finally:
for p in processes:
p.join()
print(f'Prefetching complete. [{time.time() - start} sec.]')
print(f"Prefetching complete. [{time.time() - start} sec.]")
if target_data_type == 'ndarray':
if target_data_type == "ndarray":
if not isinstance(gather_res[0], np.ndarray):
return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
# order outputs
return np.concatenate(gather_res, axis=0)
elif target_data_type == 'list':
elif target_data_type == "list":
out = []
for r in gather_res:
out.extend(r)
@ -216,49 +215,79 @@ def parallel_data_prefetch(
else:
return gather_res
def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t**3):
def rand_perlin_2d(
shape, res, device, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
):
delta = (res[0] / shape[0], res[1] / shape[1])
d = (shape[0] // res[0], shape[1] // res[1])
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1]), indexing='ij'), dim = -1).to(device) % 1
grid = (
torch.stack(
torch.meshgrid(
torch.arange(0, res[0], delta[0]),
torch.arange(0, res[1], delta[1]),
indexing="ij",
),
dim=-1,
).to(device)
% 1
)
rand_val = torch.rand(res[0]+1, res[1]+1)
rand_val = torch.rand(res[0] + 1, res[1] + 1)
angles = 2*math.pi*rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim = -1).to(device)
angles = 2 * math.pi * rand_val
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1).to(device)
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0], 0).repeat_interleave(d[1], 1)
tile_grads = (
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
.repeat_interleave(d[0], 0)
.repeat_interleave(d[1], 1)
)
dot = lambda grad, shift: (torch.stack((grid[:shape[0],:shape[1],0] + shift[0], grid[:shape[0],:shape[1], 1] + shift[1] ), dim = -1) * grad[:shape[0], :shape[1]]).sum(dim = -1)
dot = lambda grad, shift: (
torch.stack(
(
grid[: shape[0], : shape[1], 0] + shift[0],
grid[: shape[0], : shape[1], 1] + shift[1],
),
dim=-1,
)
* grad[: shape[0], : shape[1]]
).sum(dim=-1)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0]).to(device)
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0]).to(device)
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
t = fade(grid[:shape[0], :shape[1]])
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1]).to(device)
t = fade(grid[: shape[0], : shape[1]])
noise = math.sqrt(2) * torch.lerp(
torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
).to(device)
return noise.to(dtype=torch_dtype(device))
def ask_user(question: str, answers: list):
from itertools import chain, repeat
user_prompt = f'\n>> {question} {answers}: '
invalid_answer_msg = 'Invalid answer. Please try again.'
pose_question = chain([user_prompt], repeat('\n'.join([invalid_answer_msg, user_prompt])))
user_prompt = f"\n>> {question} {answers}: "
invalid_answer_msg = "Invalid answer. Please try again."
pose_question = chain(
[user_prompt], repeat("\n".join([invalid_answer_msg, user_prompt]))
)
user_answers = map(input, pose_question)
valid_response = next(filter(answers.__contains__, user_answers))
return valid_response
def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False ):
def debug_image(
debug_image, debug_text, debug_show=True, debug_result=False, debug_status=False
):
if not debug_status:
return
image_copy = debug_image.copy().convert("RGBA")
ImageDraw.Draw(image_copy).text(
(5, 5),
debug_text,
(255, 0, 0)
)
ImageDraw.Draw(image_copy).text((5, 5), debug_text, (255, 0, 0))
if debug_show:
image_copy.show()
@ -266,31 +295,84 @@ def debug_image(debug_image, debug_text, debug_show=True, debug_result=False, de
if debug_result:
return image_copy
#-------------------------------------
class ProgressBar():
def __init__(self,model_name='file'):
self.pbar = None
self.name = model_name
def __call__(self, block_num, block_size, total_size):
if not self.pbar:
self.pbar=tqdm(desc=self.name,
initial=0,
unit='iB',
unit_scale=True,
unit_divisor=1000,
total=total_size)
self.pbar.update(block_size)
# -------------------------------------
def download_with_resume(url: str, dest: Path, access_token: str = None) -> Path:
'''
Download a model file.
:param url: https, http or ftp URL
:param dest: A Path object. If path exists and is a directory, then we try to derive the filename
from the URL's Content-Disposition header and copy the URL contents into
dest/filename
:param access_token: Access token to access this resource
'''
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
if dest.is_dir():
try:
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition")).group(1)
except:
file_name = os.path.basename(url)
dest = dest / file_name
else:
dest.parent.mkdir(parents=True, exist_ok=True)
print(f'DEBUG: after many manipulations, dest={dest}')
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb"
exist_size = 0
if dest.exists():
exist_size = dest.stat().st_size
header["Range"] = f"bytes={exist_size}-"
open_mode = "ab"
if (
resp.status_code == 416
): # "range not satisfiable", which means nothing to return
print(f"* {dest}: complete file found. Skipping.")
return dest
elif resp.status_code != 200:
print(f"** An error occurred during downloading {dest}: {resp.reason}")
elif exist_size > 0:
print(f"* {dest}: partial file found. Resuming...")
else:
print(f"* {dest}: Downloading...")
def download_with_progress_bar(url:str, dest:Path)->bool:
try:
if not dest.exists():
dest.parent.mkdir(parents=True, exist_ok=True)
request.urlretrieve(url,dest,ProgressBar(dest.stem))
return True
else:
return True
except OSError:
print(traceback.format_exc())
return False
if total < 2000:
print(f"*** ERROR DOWNLOADING {url}: {resp.text}")
return None
with open(dest, open_mode) as file, tqdm(
desc=str(dest),
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e:
print(f"An error occurred while downloading {dest}: {str(e)}")
return None
return dest
def url_attachment_name(url: str) -> dict:
try:
resp = requests.get(url, stream=True)
match = re.search('filename="(.+)"', resp.headers.get("Content-Disposition"))
return match.group(1)
except:
return None
def download_with_progress_bar(url: str, dest: Path) -> bool:
result = download_with_resume(url, dest, access_token=None)
return result is not None