Merge branch 'lstein/model-manager-refactor' of github.com:invoke-ai/InvokeAI into lstein/model-manager-refactor

This commit is contained in:
Lincoln Stein
2023-10-04 08:42:07 -04:00
11 changed files with 123 additions and 31 deletions

View File

@ -1,5 +1,4 @@
import math
import os
import re
from pathlib import Path
from typing import Optional, TypedDict
@ -11,6 +10,7 @@ from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
from PIL.Image import Image as ImageType
from pydantic import validator
import invokeai.assets.fonts as font_assets
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InputField,
@ -138,6 +138,7 @@ def generate_face_box_mask(
chunk_x_offset: int = 0,
chunk_y_offset: int = 0,
draw_mesh: bool = True,
check_bounds: bool = True,
) -> list[FaceResultData]:
result = []
mask_pil = None
@ -217,7 +218,7 @@ def generate_face_box_mask(
im_width, im_height = pil_image.size
over_w = im_width * 0.1
over_h = im_height * 0.1
if (
if not check_bounds or (
(left_side >= -over_w)
and (right_side < im_width + over_w)
and (top_side >= -over_h)
@ -345,6 +346,7 @@ def get_faces_list(
chunk_x_offset=0,
chunk_y_offset=0,
draw_mesh=draw_mesh,
check_bounds=False,
)
if should_chunk or len(result) == 0:
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
@ -402,7 +404,7 @@ def get_faces_list(
return all_faces
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.0")
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.0.1")
class FaceOffInvocation(BaseInvocation):
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
@ -496,7 +498,7 @@ class FaceOffInvocation(BaseInvocation):
return output
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.0")
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.0.1")
class FaceMaskInvocation(BaseInvocation):
"""Face mask creation using mediapipe face detection"""
@ -614,7 +616,7 @@ class FaceMaskInvocation(BaseInvocation):
@invocation(
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.0"
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.0.1"
)
class FaceIdentifierInvocation(BaseInvocation):
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
@ -641,9 +643,9 @@ class FaceIdentifierInvocation(BaseInvocation):
draw_mesh=False,
)
path = Path(__file__).resolve().parent.parent.parent
font_path = os.path.abspath(path / "assets/fonts/inter/Inter-Regular.ttf")
font = ImageFont.truetype(font_path, FONT_SIZE)
# Note - font may be found either in the repo if running an editable install, or in the venv if running a package install
font_path = [x for x in [Path(y, "inter/Inter-Regular.ttf") for y in font_assets.__path__] if x.exists()]
font = ImageFont.truetype(font_path[0].as_posix(), FONT_SIZE)
# Paste face IDs on the output image
draw = ImageDraw.Draw(image)

View File

@ -584,7 +584,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
ORDER BY images.created_at DESC
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),

View File

@ -697,7 +697,7 @@
"noLoRAsAvailable": "No LoRAs available",
"noMatchingLoRAs": "No matching LoRAs",
"noMatchingModels": "No matching Models",
"noModelsAvailable": "No Modelss available",
"noModelsAvailable": "No models available",
"selectLoRA": "Select a LoRA",
"selectModel": "Select a Model"
},

View File

@ -1,5 +1,8 @@
import { logger } from 'app/logging/logger';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
import {
controlNetRemoved,
ipAdapterModelChanged,
} from 'features/controlNet/store/controlNetSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import {
modelChanged,
@ -16,12 +19,14 @@ import {
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import {
ipAdapterModelsAdapter,
mainModelsAdapter,
modelsApi,
vaeModelsAdapter,
} from 'services/api/endpoints/models';
import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..';
import { zIPAdapterModel } from 'features/nodes/types/types';
export const addModelsLoadedListener = () => {
startAppListening({
@ -234,6 +239,50 @@ export const addModelsLoadedListener = () => {
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getIPAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info(
{ models: action.payload.entities },
`IP Adapter models loaded (${action.payload.ids.length})`
);
const { model } = getState().controlNet.ipAdapterInfo;
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === model?.model_name &&
m?.base_model === model?.base_model
);
if (isModelAvailable) {
return;
}
const firstModel = ipAdapterModelsAdapter
.getSelectors()
.selectAll(action.payload)[0];
if (!firstModel) {
dispatch(ipAdapterModelChanged(null));
}
const result = zIPAdapterModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse IP Adapter model'
);
return;
}
dispatch(ipAdapterModelChanged(result.data));
},
});
startAppListening({
matcher: modelsApi.endpoints.getTextualInversionModels.matchFulfilled,
effect: async (action) => {

View File

@ -151,15 +151,10 @@ const IAICanvasStagingAreaToolbar = () => {
isDisabled={!shouldShowStagingImage}
/>
<IAIButton
colorScheme="accent"
colorScheme="base"
pointerEvents="none"
isDisabled={!shouldShowStagingImage}
sx={{
background: 'base.600',
_dark: {
background: 'base.800',
},
}}
minW={20}
>{`${currentIndex + 1}/${total}`}</IAIButton>
<IAIIconButton
tooltip={`${t('unifiedCanvas.next')} (Right)`}

View File

@ -6,7 +6,9 @@ export const canvasSelector = (state: RootState): CanvasState => state.canvas;
export const isStagingSelector = createSelector(
[stateSelector],
({ canvas }) => canvas.batchIds.length > 0
({ canvas }) =>
canvas.batchIds.length > 0 ||
canvas.layerState.stagingArea.images.length > 0
);
export const initialCanvasImageSelector = (

View File

@ -5,8 +5,23 @@ import ParamIPAdapterFeatureToggle from './ParamIPAdapterFeatureToggle';
import ParamIPAdapterImage from './ParamIPAdapterImage';
import ParamIPAdapterModelSelect from './ParamIPAdapterModelSelect';
import ParamIPAdapterWeight from './ParamIPAdapterWeight';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from '../../../../app/store/store';
import { defaultSelectorOptions } from '../../../../app/store/util/defaultMemoizeOptions';
import { useAppSelector } from '../../../../app/store/storeHooks';
const selector = createSelector(
stateSelector,
(state) => {
const { isIPAdapterEnabled } = state.controlNet;
return { isIPAdapterEnabled };
},
defaultSelectorOptions
);
const IPAdapterPanel = () => {
const { isIPAdapterEnabled } = useAppSelector(selector);
return (
<Flex
sx={{
@ -14,7 +29,6 @@ const IPAdapterPanel = () => {
gap: 3,
paddingInline: 3,
paddingBlock: 2,
paddingBottom: 5,
borderRadius: 'base',
position: 'relative',
bg: 'base.250',
@ -24,10 +38,26 @@ const IPAdapterPanel = () => {
}}
>
<ParamIPAdapterFeatureToggle />
<ParamIPAdapterImage />
<ParamIPAdapterModelSelect />
<ParamIPAdapterWeight />
<ParamIPAdapterBeginEnd />
{isIPAdapterEnabled && (
<>
<ParamIPAdapterModelSelect />
<Flex gap="3">
<Flex
flexDirection="column"
sx={{
h: 28,
w: 'full',
gap: 4,
mb: 4,
}}
>
<ParamIPAdapterWeight />
<ParamIPAdapterBeginEnd />
</Flex>
<ParamIPAdapterImage />
</Flex>
</>
)}
</Flex>
);
};

View File

@ -66,7 +66,8 @@ const ParamIPAdapterImage = () => {
layerStyle="second"
sx={{
position: 'relative',
w: 'full',
h: 28,
w: 28,
alignItems: 'center',
justifyContent: 'center',
aspectRatio: '1/1',

View File

@ -88,12 +88,16 @@ const ParamIPAdapterModelSelect = () => {
className="nowheel nodrag"
tooltip={selectedModel?.description}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
placeholder={
data.length > 0
? t('models.selectModel')
: t('models.noModelsAvailable')
}
error={!selectedModel && data.length > 0}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
disabled={!isEnabled}
disabled={!isEnabled || data.length === 0}
/>
);
};

View File

@ -473,6 +473,7 @@ export const imagesApi = api.injectEndpoints({
if (images[0]) {
const categories = getCategories(images[0]);
const boardId = images[0].board_id;
return [
{
type: 'ImageList',
@ -481,6 +482,10 @@ export const imagesApi = api.injectEndpoints({
categories,
}),
},
{
type: 'Board',
id: boardId,
},
];
}
return [];
@ -595,6 +600,10 @@ export const imagesApi = api.injectEndpoints({
categories,
}),
},
{
type: 'Board',
id: boardId,
},
];
}
return [];

View File

@ -163,16 +163,16 @@ version = { attr = "invokeai.version.__version__" }
[tool.setuptools.packages.find]
"where" = ["."]
"include" = [
"invokeai.assets.web*","invokeai.version*",
"invokeai.assets.fonts*","invokeai.version*",
"invokeai.generator*","invokeai.backend*",
"invokeai.frontend*", "invokeai.frontend.web.dist*",
"invokeai.frontend.web.static*",
"invokeai.configs*",
"invokeai.app*","ldm*",
"invokeai.app*",
]
[tool.setuptools.package-data]
"invokeai.assets.web" = ["**.png","**.js","**.woff2","**.css"]
"invokeai.assets.fonts" = ["**/*.ttf"]
"invokeai.backend" = ["**.png"]
"invokeai.configs" = ["*.example", "**/*.yaml", "*.txt"]
"invokeai.frontend.web.dist" = ["**"]