mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into enhance/update-menu
This commit is contained in:
commit
833079140b
61
.github/CODEOWNERS
vendored
61
.github/CODEOWNERS
vendored
@ -1,50 +1,51 @@
|
||||
# continuous integration
|
||||
/.github/workflows/ @mauwii
|
||||
/.github/workflows/ @mauwii @lstein @blessedcoolant
|
||||
|
||||
# documentation
|
||||
/docs/ @lstein @mauwii @tildebyte
|
||||
mkdocs.yml @lstein @mauwii
|
||||
/docs/ @lstein @mauwii @tildebyte @blessedcoolant
|
||||
mkdocs.yml @lstein @mauwii @blessedcoolant
|
||||
|
||||
# installation and configuration
|
||||
/pyproject.toml @mauwii @lstein @ebr
|
||||
/docker/ @mauwii
|
||||
/scripts/ @ebr @lstein
|
||||
/installer/ @ebr @lstein @tildebyte
|
||||
ldm/invoke/config @lstein @ebr
|
||||
invokeai/assets @lstein @ebr
|
||||
invokeai/configs @lstein @ebr
|
||||
/pyproject.toml @mauwii @lstein @ebr @blessedcoolant
|
||||
/docker/ @mauwii @lstein @blessedcoolant
|
||||
/scripts/ @ebr @lstein @blessedcoolant
|
||||
/installer/ @ebr @lstein @tildebyte @blessedcoolant
|
||||
ldm/invoke/config @lstein @ebr @blessedcoolant
|
||||
invokeai/assets @lstein @ebr @blessedcoolant
|
||||
invokeai/configs @lstein @ebr @blessedcoolant
|
||||
/ldm/invoke/_version.py @lstein @blessedcoolant
|
||||
|
||||
# web ui
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious
|
||||
/invokeai/backend @blessedcoolant @psychedelicious
|
||||
/invokeai/frontend @blessedcoolant @psychedelicious @lstein
|
||||
/invokeai/backend @blessedcoolant @psychedelicious @lstein
|
||||
|
||||
# generation and model management
|
||||
/ldm/*.py @lstein
|
||||
/ldm/generate.py @lstein @keturn
|
||||
/ldm/*.py @lstein @blessedcoolant
|
||||
/ldm/generate.py @lstein @keturn @blessedcoolant
|
||||
/ldm/invoke/args.py @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt* @lstein
|
||||
/ldm/invoke/ckpt_generator @lstein
|
||||
/ldm/invoke/CLI.py @lstein
|
||||
/ldm/invoke/config @lstein @ebr @mauwii
|
||||
/ldm/invoke/generator @keturn @damian0815
|
||||
/ldm/invoke/globals.py @lstein @blessedcoolant
|
||||
/ldm/invoke/merge_diffusers.py @lstein
|
||||
/ldm/invoke/ckpt* @lstein @blessedcoolant
|
||||
/ldm/invoke/ckpt_generator @lstein @blessedcoolant
|
||||
/ldm/invoke/CLI.py @lstein @blessedcoolant
|
||||
/ldm/invoke/config @lstein @ebr @mauwii @blessedcoolant
|
||||
/ldm/invoke/generator @keturn @damian0815 @blessedcoolant
|
||||
/ldm/invoke/globals.py @lstein @blessedcoolant
|
||||
/ldm/invoke/merge_diffusers.py @lstein @blessedcoolant
|
||||
/ldm/invoke/model_manager.py @lstein @blessedcoolant
|
||||
/ldm/invoke/txt2mask.py @lstein
|
||||
/ldm/invoke/patchmatch.py @Kyle0654
|
||||
/ldm/invoke/txt2mask.py @lstein @blessedcoolant
|
||||
/ldm/invoke/patchmatch.py @Kyle0654 @blessedcoolant @lstein
|
||||
/ldm/invoke/restoration @lstein @blessedcoolant
|
||||
|
||||
# attention, textual inversion, model configuration
|
||||
/ldm/models @damian0815 @keturn
|
||||
/ldm/modules @damian0815 @keturn
|
||||
/ldm/models @damian0815 @keturn @lstein @blessedcoolant
|
||||
/ldm/modules @damian0815 @keturn @lstein @blessedcoolant
|
||||
|
||||
# Nodes
|
||||
apps/ @Kyle0654
|
||||
apps/ @Kyle0654 @lstein @blessedcoolant
|
||||
|
||||
# legacy REST API
|
||||
# is CapableWeb still engaged?
|
||||
/ldm/invoke/pngwriter.py @CapableWeb
|
||||
/ldm/invoke/server_legacy.py @CapableWeb
|
||||
/scripts/legacy_api.py @CapableWeb
|
||||
/tests/legacy_tests.sh @CapableWeb
|
||||
/ldm/invoke/pngwriter.py @CapableWeb @lstein @blessedcoolant
|
||||
/ldm/invoke/server_legacy.py @CapableWeb @lstein @blessedcoolant
|
||||
/scripts/legacy_api.py @CapableWeb @lstein @blessedcoolant
|
||||
/tests/legacy_tests.sh @CapableWeb @lstein @blessedcoolant
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
2
invokeai/frontend/dist/index.html
vendored
2
invokeai/frontend/dist/index.html
vendored
@ -5,7 +5,7 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>InvokeAI - A Stable Diffusion Toolkit</title>
|
||||
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
|
||||
<script type="module" crossorigin src="./assets/index-53ecf883.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-762ec810.js"></script>
|
||||
<link rel="stylesheet" href="./assets/index-14cb2922.css">
|
||||
</head>
|
||||
|
||||
|
3
invokeai/frontend/dist/locales/en.json
vendored
3
invokeai/frontend/dist/locales/en.json
vendored
@ -441,6 +441,9 @@
|
||||
"infillScalingHeader": "Infill and Scaling",
|
||||
"img2imgStrength": "Image To Image Strength",
|
||||
"toggleLoopback": "Toggle Loopback",
|
||||
"symmetry": "Symmetry",
|
||||
"hSymmetryStep": "H Symmetry Step",
|
||||
"vSymmetryStep": "V Symmetry Step",
|
||||
"invoke": "Invoke",
|
||||
"cancel": {
|
||||
"immediate": "Cancel immediately",
|
||||
|
@ -441,6 +441,9 @@
|
||||
"infillScalingHeader": "Infill and Scaling",
|
||||
"img2imgStrength": "Image To Image Strength",
|
||||
"toggleLoopback": "Toggle Loopback",
|
||||
"symmetry": "Symmetry",
|
||||
"hSymmetryStep": "H Symmetry Step",
|
||||
"vSymmetryStep": "V Symmetry Step",
|
||||
"invoke": "Invoke",
|
||||
"cancel": {
|
||||
"immediate": "Cancel immediately",
|
||||
|
@ -65,6 +65,8 @@ export type BackendGenerationParameters = {
|
||||
with_variations?: Array<Array<number>>;
|
||||
variation_amount?: number;
|
||||
enable_image_debugging?: boolean;
|
||||
h_symmetry_time_pct?: number;
|
||||
v_symmetry_time_pct?: number;
|
||||
};
|
||||
|
||||
export type BackendEsrGanParameters = {
|
||||
@ -141,6 +143,9 @@ export const frontendToBackendParameters = (
|
||||
tileSize,
|
||||
variationAmount,
|
||||
width,
|
||||
shouldUseSymmetry,
|
||||
horizontalSymmetryTimePercentage,
|
||||
verticalSymmetryTimePercentage,
|
||||
} = generationState;
|
||||
|
||||
const {
|
||||
@ -170,9 +175,6 @@ export const frontendToBackendParameters = (
|
||||
let esrganParameters: false | BackendEsrGanParameters = false;
|
||||
let facetoolParameters: false | BackendFacetoolParameters = false;
|
||||
|
||||
// Multiplying it by 10000 so the Slider can have values between 0 and 1 which makes more sense
|
||||
generationParameters.threshold = threshold * 1000;
|
||||
|
||||
if (negativePrompt !== '') {
|
||||
generationParameters.prompt = `${prompt} [${negativePrompt}]`;
|
||||
}
|
||||
@ -181,6 +183,23 @@ export const frontendToBackendParameters = (
|
||||
? randomInt(NUMPY_RAND_MIN, NUMPY_RAND_MAX)
|
||||
: seed;
|
||||
|
||||
// Symmetry Settings
|
||||
if (shouldUseSymmetry) {
|
||||
if (horizontalSymmetryTimePercentage > 0) {
|
||||
generationParameters.h_symmetry_time_pct = Math.max(
|
||||
0,
|
||||
Math.min(1, horizontalSymmetryTimePercentage / steps)
|
||||
);
|
||||
}
|
||||
|
||||
if (horizontalSymmetryTimePercentage > 0) {
|
||||
generationParameters.v_symmetry_time_pct = Math.max(
|
||||
0,
|
||||
Math.min(1, verticalSymmetryTimePercentage / steps)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// txt2img exclusive parameters
|
||||
if (generationMode === 'txt2img') {
|
||||
generationParameters.hires_fix = hiresFix;
|
||||
|
@ -0,0 +1,55 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import {
|
||||
setHorizontalSymmetryTimePercentage,
|
||||
setVerticalSymmetryTimePercentage,
|
||||
} from 'features/parameters/store/generationSlice';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export default function SymmetrySettings() {
|
||||
const horizontalSymmetryTimePercentage = useAppSelector(
|
||||
(state: RootState) => state.generation.horizontalSymmetryTimePercentage
|
||||
);
|
||||
|
||||
const verticalSymmetryTimePercentage = useAppSelector(
|
||||
(state: RootState) => state.generation.verticalSymmetryTimePercentage
|
||||
);
|
||||
|
||||
const steps = useAppSelector((state: RootState) => state.generation.steps);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<>
|
||||
<IAISlider
|
||||
label={t('parameters.hSymmetryStep')}
|
||||
value={horizontalSymmetryTimePercentage}
|
||||
onChange={(v) => dispatch(setHorizontalSymmetryTimePercentage(v))}
|
||||
min={0}
|
||||
max={steps}
|
||||
step={1}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => dispatch(setHorizontalSymmetryTimePercentage(0))}
|
||||
sliderMarkRightOffset={-6}
|
||||
></IAISlider>
|
||||
<IAISlider
|
||||
label={t('parameters.vSymmetryStep')}
|
||||
value={verticalSymmetryTimePercentage}
|
||||
onChange={(v) => dispatch(setVerticalSymmetryTimePercentage(v))}
|
||||
min={0}
|
||||
max={steps}
|
||||
step={1}
|
||||
withInput
|
||||
withSliderMarks
|
||||
withReset
|
||||
handleReset={() => dispatch(setVerticalSymmetryTimePercentage(0))}
|
||||
sliderMarkRightOffset={-6}
|
||||
></IAISlider>
|
||||
</>
|
||||
);
|
||||
}
|
@ -0,0 +1,19 @@
|
||||
import { RootState } from 'app/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/storeHooks';
|
||||
import IAISwitch from 'common/components/IAISwitch';
|
||||
import { setShouldUseSymmetry } from 'features/parameters/store/generationSlice';
|
||||
|
||||
export default function SymmetryToggle() {
|
||||
const shouldUseSymmetry = useAppSelector(
|
||||
(state: RootState) => state.generation.shouldUseSymmetry
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
return (
|
||||
<IAISwitch
|
||||
isChecked={shouldUseSymmetry}
|
||||
onChange={(e) => dispatch(setShouldUseSymmetry(e.target.checked))}
|
||||
/>
|
||||
);
|
||||
}
|
@ -15,15 +15,15 @@ export default function Threshold() {
|
||||
<IAISlider
|
||||
label={t('parameters.noiseThreshold')}
|
||||
min={0}
|
||||
max={1}
|
||||
step={0.005}
|
||||
max={20}
|
||||
step={0.1}
|
||||
onChange={(v) => dispatch(setThreshold(v))}
|
||||
handleReset={() => dispatch(setThreshold(0))}
|
||||
value={threshold}
|
||||
withInput
|
||||
withReset
|
||||
withSliderMarks
|
||||
inputWidth="6rem"
|
||||
sliderMarkRightOffset={-4}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
@ -32,6 +32,9 @@ export interface GenerationState {
|
||||
tileSize: number;
|
||||
variationAmount: number;
|
||||
width: number;
|
||||
shouldUseSymmetry: boolean;
|
||||
horizontalSymmetryTimePercentage: number;
|
||||
verticalSymmetryTimePercentage: number;
|
||||
}
|
||||
|
||||
const initialGenerationState: GenerationState = {
|
||||
@ -60,6 +63,9 @@ const initialGenerationState: GenerationState = {
|
||||
tileSize: 32,
|
||||
variationAmount: 0.1,
|
||||
width: 512,
|
||||
shouldUseSymmetry: false,
|
||||
horizontalSymmetryTimePercentage: 0,
|
||||
verticalSymmetryTimePercentage: 0,
|
||||
};
|
||||
|
||||
const initialState: GenerationState = initialGenerationState;
|
||||
@ -325,6 +331,21 @@ export const generationSlice = createSlice({
|
||||
setInfillMethod: (state, action: PayloadAction<string>) => {
|
||||
state.infillMethod = action.payload;
|
||||
},
|
||||
setShouldUseSymmetry: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldUseSymmetry = action.payload;
|
||||
},
|
||||
setHorizontalSymmetryTimePercentage: (
|
||||
state,
|
||||
action: PayloadAction<number>
|
||||
) => {
|
||||
state.horizontalSymmetryTimePercentage = action.payload;
|
||||
},
|
||||
setVerticalSymmetryTimePercentage: (
|
||||
state,
|
||||
action: PayloadAction<number>
|
||||
) => {
|
||||
state.verticalSymmetryTimePercentage = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@ -362,6 +383,9 @@ export const {
|
||||
setTileSize,
|
||||
setVariationAmount,
|
||||
setWidth,
|
||||
setShouldUseSymmetry,
|
||||
setHorizontalSymmetryTimePercentage,
|
||||
setVerticalSymmetryTimePercentage,
|
||||
} = generationSlice.actions;
|
||||
|
||||
export default generationSlice.reducer;
|
||||
|
@ -3,6 +3,8 @@ import { Feature } from 'app/features';
|
||||
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
||||
import FaceRestoreToggle from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreToggle';
|
||||
import ImageToImageOutputSettings from 'features/parameters/components/AdvancedParameters/Output/ImageToImageOutputSettings';
|
||||
import SymmetrySettings from 'features/parameters/components/AdvancedParameters/Output/SymmetrySettings';
|
||||
import SymmetryToggle from 'features/parameters/components/AdvancedParameters/Output/SymmetryToggle';
|
||||
import SeedSettings from 'features/parameters/components/AdvancedParameters/Seed/SeedSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
||||
import UpscaleToggle from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleToggle';
|
||||
@ -44,6 +46,11 @@ export default function ImageToImagePanel() {
|
||||
content: <UpscaleSettings />,
|
||||
additionalHeaderComponents: <UpscaleToggle />,
|
||||
},
|
||||
symmetry: {
|
||||
header: `${t('parameters.symmetry')}`,
|
||||
content: <SymmetrySettings />,
|
||||
additionalHeaderComponents: <SymmetryToggle />,
|
||||
},
|
||||
other: {
|
||||
header: `${t('parameters.otherOptions')}`,
|
||||
feature: Feature.OTHER,
|
||||
|
@ -3,6 +3,8 @@ import { Feature } from 'app/features';
|
||||
import FaceRestoreSettings from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreSettings';
|
||||
import FaceRestoreToggle from 'features/parameters/components/AdvancedParameters/FaceRestore/FaceRestoreToggle';
|
||||
import OutputSettings from 'features/parameters/components/AdvancedParameters/Output/OutputSettings';
|
||||
import SymmetrySettings from 'features/parameters/components/AdvancedParameters/Output/SymmetrySettings';
|
||||
import SymmetryToggle from 'features/parameters/components/AdvancedParameters/Output/SymmetryToggle';
|
||||
import SeedSettings from 'features/parameters/components/AdvancedParameters/Seed/SeedSettings';
|
||||
import UpscaleSettings from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleSettings';
|
||||
import UpscaleToggle from 'features/parameters/components/AdvancedParameters/Upscale/UpscaleToggle';
|
||||
@ -43,6 +45,11 @@ export default function TextToImagePanel() {
|
||||
content: <UpscaleSettings />,
|
||||
additionalHeaderComponents: <UpscaleToggle />,
|
||||
},
|
||||
symmetry: {
|
||||
header: `${t('parameters.symmetry')}`,
|
||||
content: <SymmetrySettings />,
|
||||
additionalHeaderComponents: <SymmetryToggle />,
|
||||
},
|
||||
other: {
|
||||
header: `${t('parameters.otherOptions')}`,
|
||||
feature: Feature.OTHER,
|
||||
|
@ -5,6 +5,8 @@ import BoundingBoxSettings from 'features/parameters/components/AdvancedParamete
|
||||
import InfillAndScalingSettings from 'features/parameters/components/AdvancedParameters/Canvas/InfillAndScalingSettings';
|
||||
import SeamCorrectionSettings from 'features/parameters/components/AdvancedParameters/Canvas/SeamCorrection/SeamCorrectionSettings';
|
||||
import ImageToImageStrength from 'features/parameters/components/AdvancedParameters/ImageToImage/ImageToImageStrength';
|
||||
import SymmetrySettings from 'features/parameters/components/AdvancedParameters/Output/SymmetrySettings';
|
||||
import SymmetryToggle from 'features/parameters/components/AdvancedParameters/Output/SymmetryToggle';
|
||||
import SeedSettings from 'features/parameters/components/AdvancedParameters/Seed/SeedSettings';
|
||||
import GenerateVariationsToggle from 'features/parameters/components/AdvancedParameters/Variations/GenerateVariations';
|
||||
import VariationsSettings from 'features/parameters/components/AdvancedParameters/Variations/VariationsSettings';
|
||||
@ -46,6 +48,11 @@ export default function UnifiedCanvasPanel() {
|
||||
content: <VariationsSettings />,
|
||||
additionalHeaderComponents: <GenerateVariationsToggle />,
|
||||
},
|
||||
symmetry: {
|
||||
header: `${t('parameters.symmetry')}`,
|
||||
content: <SymmetrySettings />,
|
||||
additionalHeaderComponents: <SymmetryToggle />,
|
||||
},
|
||||
};
|
||||
|
||||
const unifiedCanvasImg2ImgAccordion = {
|
||||
|
File diff suppressed because one or more lines are too long
@ -964,6 +964,7 @@ class Generate:
|
||||
|
||||
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
|
||||
if self.embedding_path is not None:
|
||||
print(f'>> Loading embeddings from {self.embedding_path}')
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
@ -971,7 +972,7 @@ class Generate:
|
||||
ti_path, defer_injecting_tokens=True
|
||||
)
|
||||
print(
|
||||
f'>> Textual inversions available: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
|
||||
f'>> Textual inversion triggers: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
|
@ -15,17 +15,18 @@ if sys.platform == "darwin":
|
||||
import pyparsing # type: ignore
|
||||
|
||||
import ldm.invoke
|
||||
from ldm.generate import Generate
|
||||
from ldm.invoke.args import (Args, dream_cmd_from_png, metadata_dumps,
|
||||
from ..generate import Generate
|
||||
from .args import (Args, dream_cmd_from_png, metadata_dumps,
|
||||
metadata_from_png)
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.image_util import make_grid
|
||||
from ldm.invoke.log import write_log
|
||||
from ldm.invoke.model_manager import ModelManager
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.invoke.readline import Completer, get_completer
|
||||
from ldm.util import url_attachment_name
|
||||
from .generator.diffusers_pipeline import PipelineIntermediateState
|
||||
from .globals import Globals
|
||||
from .image_util import make_grid
|
||||
from .log import write_log
|
||||
from .model_manager import ModelManager
|
||||
from .pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||
from .prompt_parser import PromptParser
|
||||
from .readline import Completer, get_completer
|
||||
from ..util import url_attachment_name
|
||||
|
||||
# global used in multiple functions (fix)
|
||||
infile = None
|
||||
@ -1262,10 +1263,13 @@ def make_step_callback(gen, opt, prefix):
|
||||
os.makedirs(destination, exist_ok=True)
|
||||
print(f">> Intermediate images will be written into {destination}")
|
||||
|
||||
def callback(img, step):
|
||||
def callback(state: PipelineIntermediateState):
|
||||
latents = state.latents
|
||||
step = state.step
|
||||
if step % opt.save_intermediates == 0 or step == opt.steps - 1:
|
||||
filename = os.path.join(destination, f"{step:04}.png")
|
||||
image = gen.sample_to_image(img)
|
||||
image = gen.sample_to_lowres_estimated_image(latents)
|
||||
image = image.resize((image.size[0]*8,image.size[1]*8))
|
||||
image.save(filename, "PNG")
|
||||
|
||||
return callback
|
||||
@ -1387,3 +1391,7 @@ def check_internet() -> bool:
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
|
@ -441,6 +441,7 @@ class TextualInversionDataset(Dataset):
|
||||
self.image_paths = [
|
||||
os.path.join(self.data_root, file_path)
|
||||
for file_path in os.listdir(self.data_root)
|
||||
if os.path.isfile(file_path) and file_path.endswith(('.png','.PNG','.jpg','.JPG','.jpeg','.JPEG','.gif','.GIF'))
|
||||
]
|
||||
|
||||
self.num_images = len(self.image_paths)
|
||||
|
@ -584,7 +584,9 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
attention_mask = attn.prepare_attention_mask(
|
||||
attention_mask=attention_mask, target_length=sequence_length,
|
||||
batch_size=batch_size)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
|
@ -61,8 +61,13 @@ class TextualInversionManager:
|
||||
|
||||
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
|
||||
ckpt_path = Path(ckpt_path)
|
||||
|
||||
if not ckpt_path.is_file():
|
||||
return
|
||||
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
|
||||
try:
|
||||
scan_result = scan_file_path(str(ckpt_path))
|
||||
if scan_result.infected_files == 1:
|
||||
@ -84,10 +89,10 @@ class TextualInversionManager:
|
||||
return
|
||||
elif (
|
||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||
!= embedding_info["embedding"].shape[0]
|
||||
!= embedding_info['token_dim']
|
||||
):
|
||||
print(
|
||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with a different token dimension. It can't be used with this model."
|
||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
||||
)
|
||||
return
|
||||
|
||||
@ -333,7 +338,6 @@ class TextualInversionManager:
|
||||
# .pt files found at https://cyberes.github.io/stable-diffusion-textual-inversion-models/
|
||||
# They are actually .bin files
|
||||
elif len(embedding_ckpt.keys()) == 1:
|
||||
print(">> Detected .bin file masquerading as .pt file")
|
||||
embedding_info = self._parse_embedding_bin(embedding_file)
|
||||
|
||||
else:
|
||||
@ -372,9 +376,6 @@ class TextualInversionManager:
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
print(
|
||||
">> Detected .pt file variant 1"
|
||||
) # example at https://github.com/invoke-ai/InvokeAI/issues/1829
|
||||
for token in list(embedding_ckpt["string_to_token"].keys()):
|
||||
embedding_info["name"] = (
|
||||
token
|
||||
@ -387,7 +388,7 @@ class TextualInversionManager:
|
||||
embedding_info["num_vectors_per_token"] = embedding_info[
|
||||
"embedding"
|
||||
].shape[0]
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[0]
|
||||
embedding_info["token_dim"] = embedding_info["embedding"].size()[1]
|
||||
else:
|
||||
print(">> Invalid embedding format")
|
||||
embedding_info = None
|
||||
|
Loading…
Reference in New Issue
Block a user