From 0f029150126d17732bddc9e4f9c2032d1113b8f0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 1 Jul 2023 21:15:42 -0400 Subject: [PATCH 01/28] remove hardcoded cuda device in model manager init --- invokeai/backend/model_management/model_cache.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 77b6ac5115..df5a2f9272 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -100,8 +100,6 @@ class ModelCache(object): :param sha_chunksize: Chunksize to use when calculating sha256 model hash ''' #max_cache_size = 9999 - execution_device = torch.device('cuda') - self.model_infos: Dict[str, ModelBase] = dict() self.lazy_offloading = lazy_offloading #self.sequential_offload: bool=sequential_offload From 08d428a5e7fd5fd0d2db3ad90a80f24c69af9a26 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 4 Jul 2023 21:11:50 +1000 Subject: [PATCH 02/28] feat(nodes): add lora field, update lora loader --- invokeai/app/invocations/baseinvocation.py | 25 ++++-- invokeai/app/invocations/model.py | 97 +++++++++++++--------- 2 files changed, 77 insertions(+), 45 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 1bf9353368..4c7314bd2b 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -4,9 +4,10 @@ from __future__ import annotations from abc import ABC, abstractmethod from inspect import signature -from typing import get_args, get_type_hints, Dict, List, Literal, TypedDict, TYPE_CHECKING +from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, + get_type_hints) -from pydantic import BaseModel, Field +from pydantic import BaseConfig, BaseModel, Field if TYPE_CHECKING: from ..services.invocation_services import InvocationServices @@ -65,8 +66,13 @@ class BaseInvocation(ABC, BaseModel): @classmethod def get_invocations_map(cls): # Get the type strings out of the literals and into a dictionary - return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseInvocation.get_all_subclasses())) - + return dict( + map( + lambda t: (get_args(get_type_hints(t)["type"])[0], t), + BaseInvocation.get_all_subclasses(), + ) + ) + @classmethod def get_output_type(cls): return signature(cls.invoke).return_annotation @@ -75,11 +81,11 @@ class BaseInvocation(ABC, BaseModel): def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - - #fmt: off + + # fmt: off id: str = Field(description="The id of this node. Must be unique among all nodes.") is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") - #fmt: on + # fmt: on # TODO: figure out a better way to provide these hints @@ -98,16 +104,19 @@ class UIConfig(TypedDict, total=False): "model", "control", "image_collection", + "vae_model", + "lora_model", ], ] tags: List[str] title: str + class CustomisedSchemaExtra(TypedDict): ui: UIConfig -class InvocationConfig(BaseModel.Config): +class InvocationConfig(BaseConfig): """Customizes pydantic's BaseModel.Config class for use by Invocations. Provide `schema_extra` a `ui` dict to add hints for generated UIs. diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index e51873c59e..17297ba417 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -1,5 +1,5 @@ import copy -from typing import List, Literal, Optional +from typing import List, Literal, Optional, Union from pydantic import BaseModel, Field @@ -12,35 +12,42 @@ class ModelInfo(BaseModel): model_name: str = Field(description="Info to load submodel") base_model: BaseModelType = Field(description="Base model") model_type: ModelType = Field(description="Info to load submodel") - submodel: Optional[SubModelType] = Field(description="Info to load submodel") + submodel: Optional[SubModelType] = Field( + default=None, description="Info to load submodel" + ) + class LoraInfo(ModelInfo): weight: float = Field(description="Lora's weight which to use when apply to model") + class UNetField(BaseModel): unet: ModelInfo = Field(description="Info to load unet submodel") scheduler: ModelInfo = Field(description="Info to load scheduler submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class ClipField(BaseModel): tokenizer: ModelInfo = Field(description="Info to load tokenizer submodel") text_encoder: ModelInfo = Field(description="Info to load text_encoder submodel") loras: List[LoraInfo] = Field(description="Loras to apply on model loading") + class VaeField(BaseModel): # TODO: better naming? vae: ModelInfo = Field(description="Info to load vae submodel") + class ModelLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["model_loader_output"] = "model_loader_output" unet: UNetField = Field(default=None, description="UNet submodel") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") vae: VaeField = Field(default=None, description="Vae submodel") - #fmt: on + # fmt: on class MainModelField(BaseModel): @@ -50,6 +57,13 @@ class MainModelField(BaseModel): base_model: BaseModelType = Field(description="Base model") +class LoRAModelField(BaseModel): + """LoRA model field""" + + model_name: str = Field(description="Name of the LoRA model") + base_model: BaseModelType = Field(description="Base model") + + class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" @@ -64,14 +78,11 @@ class MainModelLoaderInvocation(BaseInvocation): "ui": { "title": "Model Loader", "tags": ["model", "loader"], - "type_hints": { - "model": "model" - } + "type_hints": {"model": "model"}, }, } def invoke(self, context: InvocationContext) -> ModelLoaderOutput: - base_model = self.model.base_model model_name = self.model.model_name model_type = ModelType.Main @@ -113,7 +124,6 @@ class MainModelLoaderInvocation(BaseInvocation): ) """ - return ModelLoaderOutput( unet=UNetField( unet=ModelInfo( @@ -152,25 +162,29 @@ class MainModelLoaderInvocation(BaseInvocation): model_type=model_type, submodel=SubModelType.Vae, ), - ) + ), ) + class LoraLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["lora_loader_output"] = "lora_loader_output" unet: Optional[UNetField] = Field(default=None, description="UNet submodel") clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") - #fmt: on + # fmt: on + class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" type: Literal["lora_loader"] = "lora_loader" - lora_name: str = Field(description="Lora model name") + lora: Union[LoRAModelField, None] = Field( + default=None, description="Lora model name" + ) weight: float = Field(default=0.75, description="With what weight to apply lora") unet: Optional[UNetField] = Field(description="UNet model for applying lora") @@ -181,26 +195,33 @@ class LoraLoaderInvocation(BaseInvocation): "ui": { "title": "Lora Loader", "tags": ["lora", "loader"], + "type_hints": {"lora": "lora_model"}, }, } def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + if self.lora is None: + raise Exception("No LoRA provided") - # TODO: ui rewrite - base_model = BaseModelType.StableDiffusion1 + base_model = self.lora.base_model + lora_name = self.lora.model_name if not context.services.model_manager.model_exists( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, ): - raise Exception(f"Unkown lora name: {self.lora_name}!") + raise Exception(f"Unkown lora name: {lora_name}!") - if self.unet is not None and any(lora.model_name == self.lora_name for lora in self.unet.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to unet") + if self.unet is not None and any( + lora.model_name == lora_name for lora in self.unet.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to unet') - if self.clip is not None and any(lora.model_name == self.lora_name for lora in self.clip.loras): - raise Exception(f"Lora \"{self.lora_name}\" already applied to clip") + if self.clip is not None and any( + lora.model_name == lora_name for lora in self.clip.loras + ): + raise Exception(f'Lora "{lora_name}" already applied to clip') output = LoraLoaderOutput() @@ -209,7 +230,7 @@ class LoraLoaderInvocation(BaseInvocation): output.unet.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -221,7 +242,7 @@ class LoraLoaderInvocation(BaseInvocation): output.clip.loras.append( LoraInfo( base_model=base_model, - model_name=self.lora_name, + model_name=lora_name, model_type=ModelType.Lora, submodel=None, weight=self.weight, @@ -230,25 +251,29 @@ class LoraLoaderInvocation(BaseInvocation): return output + class VAEModelField(BaseModel): """Vae model field""" model_name: str = Field(description="Name of the model") base_model: BaseModelType = Field(description="Base model") + class VaeLoaderOutput(BaseInvocationOutput): """Model loader output""" - #fmt: off + # fmt: off type: Literal["vae_loader_output"] = "vae_loader_output" vae: VaeField = Field(default=None, description="Vae model") - #fmt: on + # fmt: on + class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" + type: Literal["vae_loader"] = "vae_loader" - + vae_model: VAEModelField = Field(description="The VAE to load") # Schema customisation @@ -257,29 +282,27 @@ class VaeLoaderInvocation(BaseInvocation): "ui": { "title": "VAE Loader", "tags": ["vae", "loader"], - "type_hints": { - "vae_model": "vae_model" - } + "type_hints": {"vae_model": "vae_model"}, }, } - + def invoke(self, context: InvocationContext) -> VaeLoaderOutput: base_model = self.vae_model.base_model model_name = self.vae_model.model_name model_type = ModelType.Vae if not context.services.model_manager.model_exists( - base_model=base_model, - model_name=model_name, - model_type=model_type, + base_model=base_model, + model_name=model_name, + model_type=model_type, ): raise Exception(f"Unkown vae name: {model_name}!") return VaeLoaderOutput( vae=VaeField( - vae = ModelInfo( - model_name = model_name, - base_model = base_model, - model_type = model_type, + vae=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=model_type, ) ) ) From d537b9f0cb7c1a24c24895f8a1180824e7e930fe Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 4 Jul 2023 21:12:49 +1000 Subject: [PATCH 03/28] chore(ui): regen types --- .../frontend/web/src/services/api/schema.d.ts | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index 7dce36d6b3..d7e50d004e 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -2690,6 +2690,19 @@ export type components = { model_format: components["schemas"]["LoRAModelFormat"]; error?: components["schemas"]["ModelError"]; }; + /** + * LoRAModelField + * @description LoRA model field + */ + LoRAModelField: { + /** + * Model Name + * @description Name of the LoRA model + */ + model_name: string; + /** @description Base model */ + base_model: components["schemas"]["BaseModelType"]; + }; /** * LoRAModelFormat * @description An enumeration. @@ -2766,10 +2779,10 @@ export type components = { */ type?: "lora_loader"; /** - * Lora Name + * Lora * @description Lora model name */ - lora_name: string; + lora?: components["schemas"]["LoRAModelField"]; /** * Weight * @description With what weight to apply lora @@ -3115,7 +3128,7 @@ export type components = { /** ModelsList */ ModelsList: { /** Models */ - models: (components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; + models: (components["schemas"]["StableDiffusion1ModelDiffusersConfig"] | components["schemas"]["StableDiffusion1ModelCheckpointConfig"] | components["schemas"]["VaeModelConfig"] | components["schemas"]["LoRAModelConfig"] | components["schemas"]["ControlNetModelConfig"] | components["schemas"]["TextualInversionModelConfig"] | components["schemas"]["StableDiffusion2ModelDiffusersConfig"] | components["schemas"]["StableDiffusion2ModelCheckpointConfig"])[]; }; /** * MultiplyInvocation @@ -4448,18 +4461,18 @@ export type components = { */ image?: components["schemas"]["ImageField"]; }; - /** - * StableDiffusion2ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusion1ModelFormat * @description An enumeration. * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion2ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; }; responses: never; parameters: never; From db8862d86069ca9a4baea6e080ebda0bcf92e917 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 4 Jul 2023 21:12:52 +1000 Subject: [PATCH 04/28] feat(ui): add LoRA ui & update graphs --- .../enhancers/reduxRemember/serialize.ts | 2 - invokeai/frontend/web/src/app/store/store.ts | 23 +-- .../features/lora/components/ParamLora.tsx | 58 +++++++ .../lora/components/ParamLoraCollapse.tsx | 20 +++ .../lora/components/ParamLoraList.tsx | 19 +++ .../lora/components/ParamLoraSelect.tsx | 103 ++++++++++++ .../web/src/features/lora/store/loraSlice.ts | 44 +++++ .../nodes/components/InputFieldComponent.tsx | 11 ++ .../fields/LoRAModelInputFieldComponent.tsx | 104 ++++++++++++ .../src/features/nodes/store/nodesSlice.ts | 8 +- .../web/src/features/nodes/types/constants.ts | 9 +- .../web/src/features/nodes/types/types.ts | 13 ++ .../nodes/util/fieldTemplateBuilders.ts | 19 +++ .../features/nodes/util/fieldValueBuilders.ts | 4 + .../util/graphBuilders/addLoRAsToGraph.ts | 150 ++++++++++++++++++ .../buildCanvasImageToImageGraph.ts | 3 + .../buildCanvasTextToImageGraph.ts | 3 + .../buildLinearImageToImageGraph.ts | 4 + .../buildLinearTextToImageGraph.ts | 3 + .../util/graphBuilders/buildNodesGraph.ts | 7 + .../nodes/util/graphBuilders/constants.ts | 1 + .../features/nodes/util/modelIdToLoRAName.ts | 12 ++ .../ImageToImageTabParameters.tsx | 22 +-- .../TextToImage/TextToImageTabParameters.tsx | 20 +-- .../web/src/services/api/endpoints/models.ts | 4 +- .../frontend/web/src/services/api/types.d.ts | 3 + 26 files changed, 630 insertions(+), 39 deletions(-) create mode 100644 invokeai/frontend/web/src/features/lora/components/ParamLora.tsx create mode 100644 invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx create mode 100644 invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx create mode 100644 invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx create mode 100644 invokeai/frontend/web/src/features/lora/store/loraSlice.ts create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts diff --git a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts index cb18d48301..ac1b9c5205 100644 --- a/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts +++ b/invokeai/frontend/web/src/app/store/enhancers/reduxRemember/serialize.ts @@ -20,10 +20,8 @@ const serializationDenylist: { nodes: nodesPersistDenylist, postprocessing: postprocessingPersistDenylist, system: systemPersistDenylist, - // config: configPersistDenyList, ui: uiPersistDenylist, controlNet: controlNetDenylist, - // hotkeys: hotkeysPersistDenylist, }; export const serialize: SerializeFunction = (data, key) => { diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 2fd071bd23..5208933e7b 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -8,31 +8,32 @@ import { import dynamicMiddlewares from 'redux-dynamic-middlewares'; import { rememberEnhancer, rememberReducer } from 'redux-remember'; +import batchReducer from 'features/batch/store/batchSlice'; import canvasReducer from 'features/canvas/store/canvasSlice'; import controlNetReducer from 'features/controlNet/store/controlNetSlice'; +import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; +import boardsReducer from 'features/gallery/store/boardSlice'; import galleryReducer from 'features/gallery/store/gallerySlice'; +import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import lightboxReducer from 'features/lightbox/store/lightboxSlice'; +import loraReducer from 'features/lora/store/loraSlice'; +import nodesReducer from 'features/nodes/store/nodesSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; -import systemReducer from 'features/system/store/systemSlice'; -import nodesReducer from 'features/nodes/store/nodesSlice'; -import boardsReducer from 'features/gallery/store/boardSlice'; import configReducer from 'features/system/store/configSlice'; +import systemReducer from 'features/system/store/systemSlice'; import hotkeysReducer from 'features/ui/store/hotkeysSlice'; import uiReducer from 'features/ui/store/uiSlice'; -import dynamicPromptsReducer from 'features/dynamicPrompts/store/slice'; -import batchReducer from 'features/batch/store/batchSlice'; -import imageDeletionReducer from 'features/imageDeletion/store/imageDeletionSlice'; import { listenerMiddleware } from './middleware/listenerMiddleware'; -import { actionSanitizer } from './middleware/devtools/actionSanitizer'; -import { actionsDenylist } from './middleware/devtools/actionsDenylist'; -import { stateSanitizer } from './middleware/devtools/stateSanitizer'; +import { api } from 'services/api'; import { LOCALSTORAGE_PREFIX } from './constants'; import { serialize } from './enhancers/reduxRemember/serialize'; import { unserialize } from './enhancers/reduxRemember/unserialize'; -import { api } from 'services/api'; +import { actionSanitizer } from './middleware/devtools/actionSanitizer'; +import { actionsDenylist } from './middleware/devtools/actionsDenylist'; +import { stateSanitizer } from './middleware/devtools/stateSanitizer'; const allReducers = { canvas: canvasReducer, @@ -50,6 +51,7 @@ const allReducers = { dynamicPrompts: dynamicPromptsReducer, batch: batchReducer, imageDeletion: imageDeletionReducer, + lora: loraReducer, [api.reducerPath]: api.reducer, }; @@ -69,6 +71,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [ 'controlNet', 'dynamicPrompts', 'batch', + 'lora', // 'boards', // 'hotkeys', // 'config', diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx new file mode 100644 index 0000000000..c7d1c44fd3 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -0,0 +1,58 @@ +import { Flex } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAISlider from 'common/components/IAISlider'; +import { memo, useCallback } from 'react'; +import { FaTrash } from 'react-icons/fa'; +import { Lora, loraRemoved, loraWeightChanged } from '../store/loraSlice'; + +type Props = { + lora: Lora; +}; + +const ParamLora = (props: Props) => { + const dispatch = useAppDispatch(); + const { lora } = props; + + const handleChange = useCallback( + (v: number) => { + dispatch(loraWeightChanged({ name: lora.name, weight: v })); + }, + [dispatch, lora.name] + ); + + const handleReset = useCallback(() => { + dispatch(loraWeightChanged({ name: lora.name, weight: 1 })); + }, [dispatch, lora.name]); + + const handleRemoveLora = useCallback(() => { + dispatch(loraRemoved(lora.name)); + }, [dispatch, lora.name]); + + return ( + + + } + colorScheme="error" + /> + + ); +}; + +export default memo(ParamLora); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx new file mode 100644 index 0000000000..fb088bef8a --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraCollapse.tsx @@ -0,0 +1,20 @@ +import { Flex, useDisclosure } from '@chakra-ui/react'; +import IAICollapse from 'common/components/IAICollapse'; +import { memo } from 'react'; +import ParamLoraList from './ParamLoraList'; +import ParamLoraSelect from './ParamLoraSelect'; + +const ParamLoraCollapse = () => { + const { isOpen, onToggle } = useDisclosure(); + + return ( + + + + + + + ); +}; + +export default memo(ParamLoraCollapse); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx new file mode 100644 index 0000000000..8d6ff98498 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx @@ -0,0 +1,19 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { map } from 'lodash-es'; +import ParamLora from './ParamLora'; + +const selector = createSelector(stateSelector, ({ lora }) => { + const { loras } = lora; + + return { loras }; +}); + +const ParamLoraList = () => { + const { loras } = useAppSelector(selector); + + return map(loras, (lora) => ); +}; + +export default ParamLoraList; diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx new file mode 100644 index 0000000000..8e44e7d8f1 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -0,0 +1,103 @@ +import { Text } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; +import { forEach } from 'lodash-es'; +import { forwardRef, useCallback, useMemo } from 'react'; +import { useListModelsQuery } from 'services/api/endpoints/models'; +import { loraAdded } from '../store/loraSlice'; + +type LoraSelectItem = { + label: string; + value: string; + description?: string; +}; + +const selector = createSelector( + stateSelector, + ({ lora }) => ({ + loras: lora.loras, + }), + defaultSelectorOptions +); + +const ParamLoraSelect = () => { + const dispatch = useAppDispatch(); + const { loras } = useAppSelector(selector); + const { data: lorasQueryData } = useListModelsQuery({ model_type: 'lora' }); + + const data = useMemo(() => { + if (!lorasQueryData) { + return []; + } + + const data: LoraSelectItem[] = []; + + forEach(lorasQueryData.entities, (lora, id) => { + if (!lora || Boolean(id in loras)) { + return; + } + + data.push({ + value: id, + label: lora.name, + description: lora.description, + }); + }); + + return data; + }, [loras, lorasQueryData]); + + const handleChange = useCallback( + (v: string[]) => { + v[0] && dispatch(loraAdded(v[0])); + }, + [dispatch] + ); + + return ( + + item.label.toLowerCase().includes(value.toLowerCase().trim()) || + item.value.toLowerCase().includes(value.toLowerCase().trim()) + } + onChange={handleChange} + /> + ); +}; + +interface ItemProps extends React.ComponentPropsWithoutRef<'div'> { + value: string; + label: string; + description?: string; +} + +const SelectItem = forwardRef( + ({ label, description, ...others }: ItemProps, ref) => { + return ( +
+
+ {label} + {description && ( + + {description} + + )} +
+
+ ); + } +); + +SelectItem.displayName = 'SelectItem'; + +export default ParamLoraSelect; diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts new file mode 100644 index 0000000000..49b316b054 --- /dev/null +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -0,0 +1,44 @@ +import { PayloadAction, createSlice } from '@reduxjs/toolkit'; + +export type Lora = { + name: string; + weight: number; +}; + +export const defaultLoRAConfig: Omit = { + weight: 1, +}; + +export type LoraState = { + loras: Record; +}; + +export const intialLoraState: LoraState = { + loras: {}, +}; + +export const loraSlice = createSlice({ + name: 'lora', + initialState: intialLoraState, + reducers: { + loraAdded: (state, action: PayloadAction) => { + const name = action.payload; + state.loras[name] = { name, ...defaultLoRAConfig }; + }, + loraRemoved: (state, action: PayloadAction) => { + const name = action.payload; + delete state.loras[name]; + }, + loraWeightChanged: ( + state, + action: PayloadAction<{ name: string; weight: number }> + ) => { + const { name, weight } = action.payload; + state.loras[name].weight = weight; + }, + }, +}); + +export const { loraAdded, loraRemoved, loraWeightChanged } = loraSlice.actions; + +export default loraSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx index 062fec2fdc..9925a48381 100644 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx @@ -12,6 +12,7 @@ import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFie import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; +import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent'; import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; import StringInputFieldComponent from './fields/StringInputFieldComponent'; @@ -163,6 +164,16 @@ const InputFieldComponent = (props: InputFieldComponentProps) => { ); } + if (type === 'lora_model' && template.type === 'lora_model') { + return ( + + ); + } + if (type === 'array' && template.type === 'array') { return ( +) => { + const { nodeId, field } = props; + + const dispatch = useAppDispatch(); + const { t } = useTranslation(); + + const { data: loraModels } = useListModelsQuery({ + model_type: 'lora', + }); + + const selectedModel = useMemo( + () => loraModels?.entities[field.value ?? loraModels.ids[0]], + [loraModels?.entities, loraModels?.ids, field.value] + ); + + const data = useMemo(() => { + if (!loraModels) { + return []; + } + + const data: SelectItem[] = []; + + forEach(loraModels.entities, (model, id) => { + if (!model) { + return; + } + + data.push({ + value: id, + label: model.name, + group: BASE_MODEL_NAME_MAP[model.base_model], + }); + }); + + return data; + }, [loraModels]); + + const handleValueChanged = useCallback( + (v: string | null) => { + if (!v) { + return; + } + + dispatch( + fieldValueChanged({ + nodeId, + fieldName: field.name, + value: v, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + useEffect(() => { + if (field.value && loraModels?.ids.includes(field.value)) { + return; + } + + const firstLora = loraModels?.ids[0]; + + if (!isString(firstLora)) { + return; + } + + handleValueChanged(firstLora); + }, [field.value, handleValueChanged, loraModels?.ids]); + + return ( + + ); +}; + +export default memo(LoRAModelInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index ffc93db2ba..4fa69c626b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,5 +1,8 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit'; +import { RootState } from 'app/store/store'; +import { cloneDeep, uniqBy } from 'lodash-es'; import { OpenAPIV3 } from 'openapi-types'; +import { RgbaColor } from 'react-colorful'; import { addEdge, applyEdgeChanges, @@ -11,12 +14,9 @@ import { NodeChange, OnConnectStartParams, } from 'reactflow'; -import { ImageField } from 'services/api/types'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; +import { ImageField } from 'services/api/types'; import { InvocationTemplate, InvocationValue } from '../types/types'; -import { RgbaColor } from 'react-colorful'; -import { RootState } from 'app/store/store'; -import { cloneDeep, isArray, uniq, uniqBy } from 'lodash-es'; export type NodesState = { nodes: Node[]; diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index b864501803..5fe780a286 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -18,6 +18,7 @@ export const FIELD_TYPE_MAP: Record = { VaeField: 'vae', model: 'model', vae_model: 'vae_model', + lora_model: 'lora_model', array: 'array', item: 'item', ColorField: 'color', @@ -120,7 +121,13 @@ export const FIELDS: Record = { vae_model: { color: 'teal', colorCssVar: getColorTokenCssVariable('teal'), - title: 'Model', + title: 'VAE', + description: 'Models are models.', + }, + lora_model: { + color: 'teal', + colorCssVar: getColorTokenCssVariable('teal'), + title: 'LoRA', description: 'Models are models.', }, array: { diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts index c7e573ace2..3de8cae9ff 100644 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ b/invokeai/frontend/web/src/features/nodes/types/types.ts @@ -65,6 +65,7 @@ export type FieldType = | 'control' | 'model' | 'vae_model' + | 'lora_model' | 'array' | 'item' | 'color' @@ -93,6 +94,7 @@ export type InputFieldValue = | EnumInputFieldValue | ModelInputFieldValue | VaeModelInputFieldValue + | LoRAModelInputFieldValue | ArrayInputFieldValue | ItemInputFieldValue | ColorInputFieldValue @@ -119,6 +121,7 @@ export type InputFieldTemplate = | EnumInputFieldTemplate | ModelInputFieldTemplate | VaeModelInputFieldTemplate + | LoRAModelInputFieldTemplate | ArrayInputFieldTemplate | ItemInputFieldTemplate | ColorInputFieldTemplate @@ -236,6 +239,11 @@ export type VaeModelInputFieldValue = FieldValueBase & { value?: string; }; +export type LoRAModelInputFieldValue = FieldValueBase & { + type: 'lora_model'; + value?: string; +}; + export type ArrayInputFieldValue = FieldValueBase & { type: 'array'; value?: (string | number)[]; @@ -350,6 +358,11 @@ export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { type: 'vae_model'; }; +export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & { + default: string; + type: 'lora_model'; +}; + export type ArrayInputFieldTemplate = InputFieldTemplateBase & { default: []; type: 'array'; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts index c71618175a..1c2dbc0c3e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts @@ -18,6 +18,7 @@ import { IntegerInputFieldTemplate, ItemInputFieldTemplate, LatentsInputFieldTemplate, + LoRAModelInputFieldTemplate, ModelInputFieldTemplate, OutputFieldTemplate, StringInputFieldTemplate, @@ -191,6 +192,21 @@ const buildVaeModelInputFieldTemplate = ({ return template; }; +const buildLoRAModelInputFieldTemplate = ({ + schemaObject, + baseField, +}: BuildInputFieldArg): LoRAModelInputFieldTemplate => { + const template: LoRAModelInputFieldTemplate = { + ...baseField, + type: 'lora_model', + inputRequirement: 'always', + inputKind: 'direct', + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildImageInputFieldTemplate = ({ schemaObject, baseField, @@ -460,6 +476,9 @@ export const buildInputFieldTemplate = ( if (['vae_model'].includes(fieldType)) { return buildVaeModelInputFieldTemplate({ schemaObject, baseField }); } + if (['lora_model'].includes(fieldType)) { + return buildLoRAModelInputFieldTemplate({ schemaObject, baseField }); + } if (['enum'].includes(fieldType)) { return buildEnumInputFieldTemplate({ schemaObject, baseField }); } diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts index a94d3ddef2..950038b691 100644 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts @@ -79,6 +79,10 @@ export const buildInputFieldValue = ( if (template.type === 'vae_model') { fieldValue.value = undefined; } + + if (template.type === 'lora_model') { + fieldValue.value = undefined; + } } return fieldValue; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts new file mode 100644 index 0000000000..a105a123d8 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -0,0 +1,150 @@ +import { RootState } from 'app/store/store'; +import { NonNullableGraph } from 'features/nodes/types/types'; +import { forEach, size } from 'lodash-es'; +import { LoraLoaderInvocation } from 'services/api/types'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; +import { + LORA_LOADER, + MAIN_MODEL_LOADER, + NEGATIVE_CONDITIONING, + POSITIVE_CONDITIONING, +} from './constants'; + +export const addLoRAsToGraph = ( + graph: NonNullableGraph, + state: RootState, + baseNodeId: string +): void => { + /** + * LoRA nodes get the UNet and CLIP models from the main model loader and apply the LoRA to them. + * They then output the UNet and CLIP models references on to either the next LoRA in the chain, + * or to the inference/conditioning nodes. + * + * So we need to inject a LoRA chain into the graph. + */ + + const { loras } = state.lora; + const loraCount = size(loras); + + if (loraCount > 0) { + // remove any existing connections from main model loader, we need to insert the lora nodes + graph.edges = graph.edges.filter( + (e) => + !( + e.source.node_id === MAIN_MODEL_LOADER && + ['unet', 'clip'].includes(e.source.field) + ) + ); + } + + // we need to remember the last lora so we can chain from it + let lastLoraNodeId = ''; + let currentLoraIndex = 0; + + forEach(loras, (lora) => { + const { name, weight } = lora; + const loraField = modelIdToLoRAModelField(name); + const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace( + '.', + '_' + )}`; + + console.log(lastLoraNodeId, currentLoraNodeId, currentLoraIndex, loraField); + + const loraLoaderNode: LoraLoaderInvocation = { + type: 'lora_loader', + id: currentLoraNodeId, + lora: loraField, + weight, + }; + + graph.nodes[currentLoraNodeId] = loraLoaderNode; + + if (currentLoraIndex === 0) { + // first lora = start the lora chain, attach directly to model loader + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: MAIN_MODEL_LOADER, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } else { + // we are in the middle of the lora chain, instead connect to the previous lora + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'unet', + }, + destination: { + node_id: currentLoraNodeId, + field: 'unet', + }, + }); + graph.edges.push({ + source: { + node_id: lastLoraNodeId, + field: 'clip', + }, + destination: { + node_id: currentLoraNodeId, + field: 'clip', + }, + }); + } + + if (currentLoraIndex === loraCount - 1) { + // final lora, end the lora chain - we need to connect up to inference and conditioning nodes + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'unet', + }, + destination: { + node_id: baseNodeId, + field: 'unet', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: POSITIVE_CONDITIONING, + field: 'clip', + }, + }); + + graph.edges.push({ + source: { + node_id: currentLoraNodeId, + field: 'clip', + }, + destination: { + node_id: NEGATIVE_CONDITIONING, + field: 'clip', + }, + }); + } + + // increment the lora for the next one in the chain + lastLoraNodeId = currentLoraNodeId; + currentLoraIndex += 1; + }); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index 5cf9882ac1..1843efef84 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -9,6 +9,7 @@ import { import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { IMAGE_TO_IMAGE_GRAPH, @@ -252,6 +253,8 @@ export const buildCanvasImageToImageGraph = ( }); } + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index cfe5e62805..976ea4fd01 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, @@ -157,6 +158,8 @@ export const buildCanvasTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index 2e4383c3e7..fe6d1292e4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -10,6 +10,7 @@ import { import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { IMAGE_COLLECTION, @@ -304,6 +305,9 @@ export const buildLinearImageToImageGraph = ( }, }); } + + addLoRAsToGraph(graph, state, LATENTS_TO_LATENTS); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index e0e71a00a2..04dccf4983 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -3,6 +3,7 @@ import { NonNullableGraph } from 'features/nodes/types/types'; import { addControlNetToLinearGraph } from '../addControlNetToLinearGraph'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { LATENTS_TO_IMAGE, @@ -150,6 +151,8 @@ export const buildLinearTextToImageGraph = ( ], }; + addLoRAsToGraph(graph, state, TEXT_TO_LATENTS); + // Add Custom VAE Support addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index 3265a0f889..12a567b009 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -4,6 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es'; import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; import { v4 as uuidv4 } from 'uuid'; +import { modelIdToLoRAModelField } from '../modelIdToLoRAName'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; import { modelIdToVAEModelField } from '../modelIdToVAEModelField'; @@ -38,6 +39,12 @@ export const parseFieldValue = (field: InputFieldValue) => { } } + if (field.type === 'lora_model') { + if (field.value) { + return modelIdToLoRAModelField(field.value); + } + } + return field.value; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts index 58a7d0335b..7aace48def 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts @@ -9,6 +9,7 @@ export const RANGE_OF_SIZE = 'range_of_size'; export const ITERATE = 'iterate'; export const MAIN_MODEL_LOADER = 'main_model_loader'; export const VAE_LOADER = 'vae_loader'; +export const LORA_LOADER = 'lora_loader'; export const IMAGE_TO_LATENTS = 'image_to_latents'; export const LATENTS_TO_LATENTS = 'latents_to_latents'; export const RESIZE = 'resize_image'; diff --git a/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts new file mode 100644 index 0000000000..052b58484b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/modelIdToLoRAName.ts @@ -0,0 +1,12 @@ +import { BaseModelType, LoRAModelField } from 'services/api/types'; + +export const modelIdToLoRAModelField = (loraId: string): LoRAModelField => { + const [base_model, model_type, model_name] = loraId.split('/'); + + const field: LoRAModelField = { + base_model: base_model as BaseModelType, + model_name, + }; + + return field; +}; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx index 4f04abffa1..32b71d6187 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabParameters.tsx @@ -1,14 +1,15 @@ -import { memo } from 'react'; -import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; -import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; -import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; -import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; +import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import { memo } from 'react'; +import ImageToImageTabCoreParameters from './ImageToImageTabCoreParameters'; const ImageToImageTabParameters = () => { return ( @@ -17,6 +18,7 @@ const ImageToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx index bcc6c91ae6..6291b69a8e 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabParameters.tsx @@ -1,15 +1,16 @@ +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; +import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse'; +import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; +import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; import { memo } from 'react'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamNoiseCollapse from 'features/parameters/components/Parameters/Noise/ParamNoiseCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; -import ParamHiresCollapse from 'features/parameters/components/Parameters/Hires/ParamHiresCollapse'; -import ParamSeamlessCollapse from 'features/parameters/components/Parameters/Seamless/ParamSeamlessCollapse'; import TextToImageTabCoreParameters from './TextToImageTabCoreParameters'; -import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; -import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; const TextToImageTabParameters = () => { return ( @@ -18,6 +19,7 @@ const TextToImageTabParameters = () => { + diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index 39e4e46d3b..bff412bacb 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,6 +1,6 @@ -import { ModelsList } from 'services/api/types'; import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; import { keyBy } from 'lodash-es'; +import { ModelsList } from 'services/api/types'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; import { paths } from '../schema'; @@ -24,11 +24,9 @@ export const modelsApi = api.injectEndpoints({ listModels: build.query, ListModelsArg>({ query: (arg) => ({ url: 'models/', params: arg }), providesTags: (result, error, arg) => { - // any list of boards const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }]; if (result) { - // and individual tags for each board tags.push( ...result.ids.map((id) => ({ type: 'Model' as const, diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 18942a47d6..6f97dd1dbb 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -35,7 +35,9 @@ export type ModelType = S<'ModelType'>; export type BaseModelType = S<'BaseModelType'>; export type MainModelField = S<'MainModelField'>; export type VAEModelField = S<'VAEModelField'>; +export type LoRAModelField = S<'LoRAModelField'>; export type ModelsList = S<'ModelsList'>; +export type LoRAModelConfig = S<'LoRAModelConfig'>; // Graphs export type Graph = S<'Graph'>; @@ -60,6 +62,7 @@ export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>; export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>; export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>; export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>; +export type LoraLoaderInvocation = N<'LoraLoaderInvocation'>; // ControlNet Nodes export type ControlNetInvocation = N<'ControlNetInvocation'>; From bf895221c2371b94ec2b1546e864229d7a66c7dc Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 5 Jul 2023 04:13:05 +1200 Subject: [PATCH 05/28] fix: Tab index not being correct This probably needs to be updated to an object over an array so the index of item in the array doesnt break the rest of it. --- .../web/src/features/ui/components/InvokeTabs.tsx | 10 +++++----- invokeai/frontend/web/src/features/ui/store/tabMap.ts | 5 +---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx index 6986ded0a7..c618997f03 100644 --- a/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx +++ b/invokeai/frontend/web/src/features/ui/components/InvokeTabs.tsx @@ -66,16 +66,16 @@ const tabs: InvokeTabInfo[] = [ icon: , content: , }, - // { - // id: 'batch', - // icon: , - // content: , - // }, { id: 'modelManager', icon: , content: , }, + // { + // id: 'batch', + // icon: , + // content: , + // }, ]; const enabledTabsSelector = createSelector( diff --git a/invokeai/frontend/web/src/features/ui/store/tabMap.ts b/invokeai/frontend/web/src/features/ui/store/tabMap.ts index 7c85805fe7..0cae8eac43 100644 --- a/invokeai/frontend/web/src/features/ui/store/tabMap.ts +++ b/invokeai/frontend/web/src/features/ui/store/tabMap.ts @@ -1,13 +1,10 @@ export const tabMap = [ 'txt2img', 'img2img', - // 'generate', 'unifiedCanvas', 'nodes', - 'batch', - // 'postprocessing', - // 'training', 'modelManager', + 'batch', ] as const; export type InvokeTabName = (typeof tabMap)[number]; From c21b56ba310331447c821dbf583c80c680dd0c68 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 11:50:21 +1000 Subject: [PATCH 06/28] fix(ui): fix mantine disabled styles --- .../web/src/common/components/IAIMantineMultiSelect.tsx | 2 +- .../frontend/web/src/common/components/IAIMantineSelect.tsx | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx index 39ec6fd245..97e33f300b 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineMultiSelect.tsx @@ -61,7 +61,7 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { '&:focus-within': { borderColor: mode(accent200, accent600)(colorMode), }, - '&:disabled': { + '&[data-disabled]': { backgroundColor: mode(base300, base700)(colorMode), color: mode(base600, base400)(colorMode), }, diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx index 9b023fd2d7..585dc106a8 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx @@ -64,7 +64,7 @@ const IAIMantineSelect = (props: IAISelectProps) => { '&:focus-within': { borderColor: mode(accent200, accent600)(colorMode), }, - '&:disabled': { + '&[data-disabled]': { backgroundColor: mode(base300, base700)(colorMode), color: mode(base600, base400)(colorMode), }, From 52a09422c7dbac58db99c127ad369e810aa7df33 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 11:52:02 +1000 Subject: [PATCH 07/28] feat(ui): create rtk-query hooks for individual model types Eg `useGetMainModelsQuery()`, `useGetLoRAModelsQuery()` instead of `useListModelsQuery({base_type})`. Add specific adapters for each model type. Just more organised and easier to consume models now. Also updated LoRA UI to use the model name. --- .../features/lora/components/ParamLora.tsx | 12 +- .../lora/components/ParamLoraSelect.tsx | 12 +- .../web/src/features/lora/store/loraSlice.ts | 20 +- .../fields/LoRAModelInputFieldComponent.tsx | 6 +- .../fields/ModelInputFieldComponent.tsx | 6 +- .../fields/VaeModelInputFieldComponent.tsx | 6 +- .../system/components/ModelSelect.tsx | 6 +- .../features/system/components/VAESelect.tsx | 6 +- .../subpanels/MergeModelsPanel.tsx | 6 +- .../subpanels/ModelManagerPanel.tsx | 6 +- .../subpanels/ModelManagerPanel/ModelList.tsx | 6 +- .../web/src/services/api/endpoints/models.ts | 235 ++++++++++++++++-- .../frontend/web/src/services/api/types.d.ts | 214 ++++++++++------ 13 files changed, 395 insertions(+), 146 deletions(-) diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx index c7d1c44fd3..277a316fae 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -16,18 +16,18 @@ const ParamLora = (props: Props) => { const handleChange = useCallback( (v: number) => { - dispatch(loraWeightChanged({ name: lora.name, weight: v })); + dispatch(loraWeightChanged({ id: lora.id, weight: v })); }, - [dispatch, lora.name] + [dispatch, lora.id] ); const handleReset = useCallback(() => { - dispatch(loraWeightChanged({ name: lora.name, weight: 1 })); - }, [dispatch, lora.name]); + dispatch(loraWeightChanged({ id: lora.id, weight: 1 })); + }, [dispatch, lora.id]); const handleRemoveLora = useCallback(() => { - dispatch(loraRemoved(lora.name)); - }, [dispatch, lora.name]); + dispatch(loraRemoved(lora.id)); + }, [dispatch, lora.id]); return ( diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx index 8e44e7d8f1..54ac3d615d 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraSelect.tsx @@ -6,7 +6,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; import { forEach } from 'lodash-es'; import { forwardRef, useCallback, useMemo } from 'react'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { loraAdded } from '../store/loraSlice'; type LoraSelectItem = { @@ -26,7 +26,7 @@ const selector = createSelector( const ParamLoraSelect = () => { const dispatch = useAppDispatch(); const { loras } = useAppSelector(selector); - const { data: lorasQueryData } = useListModelsQuery({ model_type: 'lora' }); + const { data: lorasQueryData } = useGetLoRAModelsQuery(); const data = useMemo(() => { if (!lorasQueryData) { @@ -52,9 +52,13 @@ const ParamLoraSelect = () => { const handleChange = useCallback( (v: string[]) => { - v[0] && dispatch(loraAdded(v[0])); + const loraEntity = lorasQueryData?.entities[v[0]]; + if (!loraEntity) { + return; + } + v[0] && dispatch(loraAdded(loraEntity)); }, - [dispatch] + [dispatch, lorasQueryData?.entities] ); return ( diff --git a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts index 49b316b054..c9b290eb2d 100644 --- a/invokeai/frontend/web/src/features/lora/store/loraSlice.ts +++ b/invokeai/frontend/web/src/features/lora/store/loraSlice.ts @@ -1,11 +1,13 @@ import { PayloadAction, createSlice } from '@reduxjs/toolkit'; +import { LoRAModelConfigEntity } from 'services/api/endpoints/models'; export type Lora = { + id: string; name: string; weight: number; }; -export const defaultLoRAConfig: Omit = { +export const defaultLoRAConfig: Omit = { weight: 1, }; @@ -21,20 +23,20 @@ export const loraSlice = createSlice({ name: 'lora', initialState: intialLoraState, reducers: { - loraAdded: (state, action: PayloadAction) => { - const name = action.payload; - state.loras[name] = { name, ...defaultLoRAConfig }; + loraAdded: (state, action: PayloadAction) => { + const { name, id } = action.payload; + state.loras[id] = { id, name, ...defaultLoRAConfig }; }, loraRemoved: (state, action: PayloadAction) => { - const name = action.payload; - delete state.loras[name]; + const id = action.payload; + delete state.loras[id]; }, loraWeightChanged: ( state, - action: PayloadAction<{ name: string; weight: number }> + action: PayloadAction<{ id: string; weight: number }> ) => { - const { name, weight } = action.payload; - state.loras[name].weight = weight; + const { id, weight } = action.payload; + state.loras[id].weight = weight; }, }, }); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx index 5919fabaac..02cdfd454d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/LoRAModelInputFieldComponent.tsx @@ -10,7 +10,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component import { forEach, isString } from 'lodash-es'; import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; const LoRAModelInputFieldComponent = ( @@ -24,9 +24,7 @@ const LoRAModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: loraModels } = useListModelsQuery({ - model_type: 'lora', - }); + const { data: loraModels } = useGetLoRAModelsQuery(); const selectedModel = useMemo( () => loraModels?.entities[field.value ?? loraModels.ids[0]], diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx index b5bb9c5b74..ee739e1002 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ModelInputFieldComponent.tsx @@ -11,7 +11,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component import { forEach, isString } from 'lodash-es'; import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; const ModelInputFieldComponent = ( @@ -22,9 +22,7 @@ const ModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: mainModels } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels } = useGetMainModelsQuery(); const data = useMemo(() => { if (!mainModels) { diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx index 74d9942c84..b4408e41b2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/VaeModelInputFieldComponent.tsx @@ -10,7 +10,7 @@ import { MODEL_TYPE_MAP as BASE_MODEL_NAME_MAP } from 'features/system/component import { forEach } from 'lodash-es'; import { memo, useCallback, useEffect, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { FieldComponentProps } from './types'; const VaeModelInputFieldComponent = ( @@ -24,9 +24,7 @@ const VaeModelInputFieldComponent = ( const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: vaeModels } = useListModelsQuery({ - model_type: 'vae', - }); + const { data: vaeModels } = useGetVaeModelsQuery(); const selectedModel = useMemo( () => vaeModels?.entities[field.value ?? vaeModels.ids[0]], diff --git a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx index 4232858621..4eeee3e4c6 100644 --- a/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/ModelSelect.tsx @@ -8,7 +8,7 @@ import { modelSelected } from 'features/parameters/store/generationSlice'; import { SelectItem } from '@mantine/core'; import { RootState } from 'app/store/store'; import { forEach, isString } from 'lodash-es'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; export const MODEL_TYPE_MAP = { 'sd-1': 'Stable Diffusion 1.x', @@ -23,9 +23,7 @@ const ModelSelect = () => { (state: RootState) => state.generation.model ); - const { data: mainModels, isLoading } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels, isLoading } = useGetMainModelsQuery(); const data = useMemo(() => { if (!mainModels) { diff --git a/invokeai/frontend/web/src/features/system/components/VAESelect.tsx b/invokeai/frontend/web/src/features/system/components/VAESelect.tsx index 19b508d30f..33901b5bef 100644 --- a/invokeai/frontend/web/src/features/system/components/VAESelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/VAESelect.tsx @@ -6,7 +6,7 @@ import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { SelectItem } from '@mantine/core'; import { forEach } from 'lodash-es'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; import { RootState } from 'app/store/store'; import { vaeSelected } from 'features/parameters/store/generationSlice'; @@ -16,9 +16,7 @@ const VAESelect = () => { const dispatch = useAppDispatch(); const { t } = useTranslation(); - const { data: vaeModels } = useListModelsQuery({ - model_type: 'vae', - }); + const { data: vaeModels } = useGetVaeModelsQuery(); const selectedModelId = useAppSelector( (state: RootState) => state.generation.vae diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx index 0cd90a9492..b71b5636b4 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/MergeModelsPanel.tsx @@ -9,16 +9,14 @@ import IAISlider from 'common/components/IAISlider'; import { pickBy } from 'lodash-es'; import { useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; export default function MergeModelsPanel() { const { t } = useTranslation(); const dispatch = useAppDispatch(); - const { data } = useListModelsQuery({ - model_type: 'main', - }); + const { data } = useGetMainModelsQuery(); const diffusersModels = pickBy( data?.entities, diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx index 228fb79c2e..b22a303571 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel.tsx @@ -2,15 +2,13 @@ import { Flex } from '@chakra-ui/react'; import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; import CheckpointModelEdit from './ModelManagerPanel/CheckpointModelEdit'; import DiffusersModelEdit from './ModelManagerPanel/DiffusersModelEdit'; import ModelList from './ModelManagerPanel/ModelList'; export default function ModelManagerPanel() { - const { data: mainModels } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels } = useGetMainModelsQuery(); const openModel = useAppSelector( (state: RootState) => state.system.openModel diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx index fac89b7edc..eb05e70357 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ModelManager/subpanels/ModelManagerPanel/ModelList.tsx @@ -8,7 +8,7 @@ import { useTranslation } from 'react-i18next'; import type { ChangeEvent, ReactNode } from 'react'; import React, { useMemo, useState, useTransition } from 'react'; -import { useListModelsQuery } from 'services/api/endpoints/models'; +import { useGetMainModelsQuery } from 'services/api/endpoints/models'; function ModelFilterButton({ label, @@ -36,9 +36,7 @@ function ModelFilterButton({ } const ModelList = () => { - const { data: mainModels } = useListModelsQuery({ - model_type: 'main', - }); + const { data: mainModels } = useGetMainModelsQuery(); const [renderModelList, setRenderModelList] = React.useState(false); diff --git a/invokeai/frontend/web/src/services/api/endpoints/models.ts b/invokeai/frontend/web/src/services/api/endpoints/models.ts index bff412bacb..a9a914f0f2 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/models.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/models.ts @@ -1,35 +1,85 @@ import { EntityState, createEntityAdapter } from '@reduxjs/toolkit'; -import { keyBy } from 'lodash-es'; -import { ModelsList } from 'services/api/types'; +import { cloneDeep } from 'lodash-es'; +import { + AnyModelConfig, + ControlNetModelConfig, + LoRAModelConfig, + MainModelConfig, + TextualInversionModelConfig, + VaeModelConfig, +} from 'services/api/types'; import { ApiFullTagDescription, LIST_TAG, api } from '..'; -import { paths } from '../schema'; -type ModelConfig = ModelsList['models'][number]; +export type MainModelConfigEntity = MainModelConfig & { id: string }; -type ListModelsArg = NonNullable< - paths['/api/v1/models/']['get']['parameters']['query'] ->; +export type LoRAModelConfigEntity = LoRAModelConfig & { id: string }; -const modelsAdapter = createEntityAdapter({ - selectId: (model) => getModelId(model), +export type ControlNetModelConfigEntity = ControlNetModelConfig & { + id: string; +}; + +export type TextualInversionModelConfigEntity = TextualInversionModelConfig & { + id: string; +}; + +export type VaeModelConfigEntity = VaeModelConfig & { id: string }; + +type AnyModelConfigEntity = + | MainModelConfigEntity + | LoRAModelConfigEntity + | ControlNetModelConfigEntity + | TextualInversionModelConfigEntity + | VaeModelConfigEntity; + +const mainModelsAdapter = createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +const loraModelsAdapter = createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), +}); +const controlNetModelsAdapter = + createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); +const textualInversionModelsAdapter = + createEntityAdapter({ + sortComparer: (a, b) => a.name.localeCompare(b.name), + }); +const vaeModelsAdapter = createEntityAdapter({ sortComparer: (a, b) => a.name.localeCompare(b.name), }); -const getModelId = ({ base_model, type, name }: ModelConfig) => +export const getModelId = ({ base_model, type, name }: AnyModelConfig) => `${base_model}/${type}/${name}`; +const createModelEntities = ( + models: AnyModelConfig[] +): T[] => { + const entityArray: T[] = []; + models.forEach((model) => { + const entity = { + ...cloneDeep(model), + id: getModelId(model), + } as T; + entityArray.push(entity); + }); + return entityArray; +}; + export const modelsApi = api.injectEndpoints({ endpoints: (build) => ({ - listModels: build.query, ListModelsArg>({ - query: (arg) => ({ url: 'models/', params: arg }), + getMainModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'main' } }), providesTags: (result, error, arg) => { - const tags: ApiFullTagDescription[] = [{ id: 'Model', type: LIST_TAG }]; + const tags: ApiFullTagDescription[] = [ + { id: 'MainModel', type: LIST_TAG }, + ]; if (result) { tags.push( ...result.ids.map((id) => ({ - type: 'Model' as const, + type: 'MainModel' as const, id, })) ); @@ -37,14 +87,161 @@ export const modelsApi = api.injectEndpoints({ return tags; }, - transformResponse: (response: ModelsList, meta, arg) => { - return modelsAdapter.setAll( - modelsAdapter.getInitialState(), - keyBy(response.models, getModelId) + transformResponse: ( + response: { models: MainModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return mainModelsAdapter.setAll( + mainModelsAdapter.getInitialState(), + entities + ); + }, + }), + getLoRAModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'lora' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'LoRAModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'LoRAModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: LoRAModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return loraModelsAdapter.setAll( + loraModelsAdapter.getInitialState(), + entities + ); + }, + }), + getControlNetModels: build.query< + EntityState, + void + >({ + query: () => ({ url: 'models/', params: { model_type: 'controlnet' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'ControlNetModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'ControlNetModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: ControlNetModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return controlNetModelsAdapter.setAll( + controlNetModelsAdapter.getInitialState(), + entities + ); + }, + }), + getVaeModels: build.query, void>({ + query: () => ({ url: 'models/', params: { model_type: 'vae' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'VaeModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'VaeModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: VaeModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return vaeModelsAdapter.setAll( + vaeModelsAdapter.getInitialState(), + entities + ); + }, + }), + getTextualInversionModels: build.query< + EntityState, + void + >({ + query: () => ({ url: 'models/', params: { model_type: 'embedding' } }), + providesTags: (result, error, arg) => { + const tags: ApiFullTagDescription[] = [ + { id: 'TextualInversionModel', type: LIST_TAG }, + ]; + + if (result) { + tags.push( + ...result.ids.map((id) => ({ + type: 'TextualInversionModel' as const, + id, + })) + ); + } + + return tags; + }, + transformResponse: ( + response: { models: TextualInversionModelConfig[] }, + meta, + arg + ) => { + const entities = createModelEntities( + response.models + ); + return textualInversionModelsAdapter.setAll( + textualInversionModelsAdapter.getInitialState(), + entities ); }, }), }), }); -export const { useListModelsQuery } = modelsApi; +export const { + useGetMainModelsQuery, + useGetControlNetModelsQuery, + useGetLoRAModelsQuery, + useGetTextualInversionModelsQuery, + useGetVaeModelsQuery, +} = modelsApi; diff --git a/invokeai/frontend/web/src/services/api/types.d.ts b/invokeai/frontend/web/src/services/api/types.d.ts index 6f97dd1dbb..3a0bdb71a7 100644 --- a/invokeai/frontend/web/src/services/api/types.d.ts +++ b/invokeai/frontend/web/src/services/api/types.d.ts @@ -4,94 +4,156 @@ import { components } from './schema'; type schemas = components['schemas']; /** - * Extracts the schema type from the schema. + * Marks the `type` property as required. Use for nodes. */ -type S = components['schemas'][T]; - -/** - * Extracts the node type from the schema. - * Also flags the `type` property as required. - */ -type N = O.Required< - components['schemas'][T], - 'type' ->; +type TypeReq = O.Required; // Images -export type ImageDTO = S<'ImageDTO'>; -export type BoardDTO = S<'BoardDTO'>; -export type BoardChanges = S<'BoardChanges'>; -export type ImageChanges = S<'ImageRecordChanges'>; -export type ImageCategory = S<'ImageCategory'>; -export type ResourceOrigin = S<'ResourceOrigin'>; -export type ImageField = S<'ImageField'>; +export type ImageDTO = components['schemas']['ImageDTO']; +export type BoardDTO = components['schemas']['BoardDTO']; +export type BoardChanges = components['schemas']['BoardChanges']; +export type ImageChanges = components['schemas']['ImageRecordChanges']; +export type ImageCategory = components['schemas']['ImageCategory']; +export type ResourceOrigin = components['schemas']['ResourceOrigin']; +export type ImageField = components['schemas']['ImageField']; export type OffsetPaginatedResults_BoardDTO_ = - S<'OffsetPaginatedResults_BoardDTO_'>; + components['schemas']['OffsetPaginatedResults_BoardDTO_']; export type OffsetPaginatedResults_ImageDTO_ = - S<'OffsetPaginatedResults_ImageDTO_'>; + components['schemas']['OffsetPaginatedResults_ImageDTO_']; // Models -export type ModelType = S<'ModelType'>; -export type BaseModelType = S<'BaseModelType'>; -export type MainModelField = S<'MainModelField'>; -export type VAEModelField = S<'VAEModelField'>; -export type LoRAModelField = S<'LoRAModelField'>; -export type ModelsList = S<'ModelsList'>; -export type LoRAModelConfig = S<'LoRAModelConfig'>; +export type ModelType = components['schemas']['ModelType']; +export type BaseModelType = components['schemas']['BaseModelType']; +export type MainModelField = components['schemas']['MainModelField']; +export type VAEModelField = components['schemas']['VAEModelField']; +export type LoRAModelField = components['schemas']['LoRAModelField']; +export type ModelsList = components['schemas']['ModelsList']; + +// Model Configs +export type LoRAModelConfig = components['schemas']['LoRAModelConfig']; +export type VaeModelConfig = components['schemas']['VaeModelConfig']; +export type ControlNetModelConfig = + components['schemas']['ControlNetModelConfig']; +export type TextualInversionModelConfig = + components['schemas']['TextualInversionModelConfig']; +export type MainModelConfig = + | components['schemas']['StableDiffusion1ModelCheckpointConfig'] + | components['schemas']['StableDiffusion1ModelDiffusersConfig'] + | components['schemas']['StableDiffusion2ModelCheckpointConfig'] + | components['schemas']['StableDiffusion2ModelDiffusersConfig']; +export type AnyModelConfig = + | LoRAModelConfig + | VaeModelConfig + | ControlNetModelConfig + | TextualInversionModelConfig + | MainModelConfig; // Graphs -export type Graph = S<'Graph'>; -export type Edge = S<'Edge'>; -export type GraphExecutionState = S<'GraphExecutionState'>; +export type Graph = components['schemas']['Graph']; +export type Edge = components['schemas']['Edge']; +export type GraphExecutionState = components['schemas']['GraphExecutionState']; // General nodes -export type CollectInvocation = N<'CollectInvocation'>; -export type IterateInvocation = N<'IterateInvocation'>; -export type RangeInvocation = N<'RangeInvocation'>; -export type RandomRangeInvocation = N<'RandomRangeInvocation'>; -export type RangeOfSizeInvocation = N<'RangeOfSizeInvocation'>; -export type InpaintInvocation = N<'InpaintInvocation'>; -export type ImageResizeInvocation = N<'ImageResizeInvocation'>; -export type RandomIntInvocation = N<'RandomIntInvocation'>; -export type CompelInvocation = N<'CompelInvocation'>; -export type DynamicPromptInvocation = N<'DynamicPromptInvocation'>; -export type NoiseInvocation = N<'NoiseInvocation'>; -export type TextToLatentsInvocation = N<'TextToLatentsInvocation'>; -export type LatentsToLatentsInvocation = N<'LatentsToLatentsInvocation'>; -export type ImageToLatentsInvocation = N<'ImageToLatentsInvocation'>; -export type LatentsToImageInvocation = N<'LatentsToImageInvocation'>; -export type ImageCollectionInvocation = N<'ImageCollectionInvocation'>; -export type MainModelLoaderInvocation = N<'MainModelLoaderInvocation'>; -export type LoraLoaderInvocation = N<'LoraLoaderInvocation'>; +export type CollectInvocation = TypeReq< + components['schemas']['CollectInvocation'] +>; +export type IterateInvocation = TypeReq< + components['schemas']['IterateInvocation'] +>; +export type RangeInvocation = TypeReq; +export type RandomRangeInvocation = TypeReq< + components['schemas']['RandomRangeInvocation'] +>; +export type RangeOfSizeInvocation = TypeReq< + components['schemas']['RangeOfSizeInvocation'] +>; +export type InpaintInvocation = TypeReq< + components['schemas']['InpaintInvocation'] +>; +export type ImageResizeInvocation = TypeReq< + components['schemas']['ImageResizeInvocation'] +>; +export type RandomIntInvocation = TypeReq< + components['schemas']['RandomIntInvocation'] +>; +export type CompelInvocation = TypeReq< + components['schemas']['CompelInvocation'] +>; +export type DynamicPromptInvocation = TypeReq< + components['schemas']['DynamicPromptInvocation'] +>; +export type NoiseInvocation = TypeReq; +export type TextToLatentsInvocation = TypeReq< + components['schemas']['TextToLatentsInvocation'] +>; +export type LatentsToLatentsInvocation = TypeReq< + components['schemas']['LatentsToLatentsInvocation'] +>; +export type ImageToLatentsInvocation = TypeReq< + components['schemas']['ImageToLatentsInvocation'] +>; +export type LatentsToImageInvocation = TypeReq< + components['schemas']['LatentsToImageInvocation'] +>; +export type ImageCollectionInvocation = TypeReq< + components['schemas']['ImageCollectionInvocation'] +>; +export type MainModelLoaderInvocation = TypeReq< + components['schemas']['MainModelLoaderInvocation'] +>; +export type LoraLoaderInvocation = TypeReq< + components['schemas']['LoraLoaderInvocation'] +>; // ControlNet Nodes -export type ControlNetInvocation = N<'ControlNetInvocation'>; -export type CannyImageProcessorInvocation = N<'CannyImageProcessorInvocation'>; -export type ContentShuffleImageProcessorInvocation = - N<'ContentShuffleImageProcessorInvocation'>; -export type HedImageProcessorInvocation = N<'HedImageProcessorInvocation'>; -export type LineartAnimeImageProcessorInvocation = - N<'LineartAnimeImageProcessorInvocation'>; -export type LineartImageProcessorInvocation = - N<'LineartImageProcessorInvocation'>; -export type MediapipeFaceProcessorInvocation = - N<'MediapipeFaceProcessorInvocation'>; -export type MidasDepthImageProcessorInvocation = - N<'MidasDepthImageProcessorInvocation'>; -export type MlsdImageProcessorInvocation = N<'MlsdImageProcessorInvocation'>; -export type NormalbaeImageProcessorInvocation = - N<'NormalbaeImageProcessorInvocation'>; -export type OpenposeImageProcessorInvocation = - N<'OpenposeImageProcessorInvocation'>; -export type PidiImageProcessorInvocation = N<'PidiImageProcessorInvocation'>; -export type ZoeDepthImageProcessorInvocation = - N<'ZoeDepthImageProcessorInvocation'>; +export type ControlNetInvocation = TypeReq< + components['schemas']['ControlNetInvocation'] +>; +export type CannyImageProcessorInvocation = TypeReq< + components['schemas']['CannyImageProcessorInvocation'] +>; +export type ContentShuffleImageProcessorInvocation = TypeReq< + components['schemas']['ContentShuffleImageProcessorInvocation'] +>; +export type HedImageProcessorInvocation = TypeReq< + components['schemas']['HedImageProcessorInvocation'] +>; +export type LineartAnimeImageProcessorInvocation = TypeReq< + components['schemas']['LineartAnimeImageProcessorInvocation'] +>; +export type LineartImageProcessorInvocation = TypeReq< + components['schemas']['LineartImageProcessorInvocation'] +>; +export type MediapipeFaceProcessorInvocation = TypeReq< + components['schemas']['MediapipeFaceProcessorInvocation'] +>; +export type MidasDepthImageProcessorInvocation = TypeReq< + components['schemas']['MidasDepthImageProcessorInvocation'] +>; +export type MlsdImageProcessorInvocation = TypeReq< + components['schemas']['MlsdImageProcessorInvocation'] +>; +export type NormalbaeImageProcessorInvocation = TypeReq< + components['schemas']['NormalbaeImageProcessorInvocation'] +>; +export type OpenposeImageProcessorInvocation = TypeReq< + components['schemas']['OpenposeImageProcessorInvocation'] +>; +export type PidiImageProcessorInvocation = TypeReq< + components['schemas']['PidiImageProcessorInvocation'] +>; +export type ZoeDepthImageProcessorInvocation = TypeReq< + components['schemas']['ZoeDepthImageProcessorInvocation'] +>; // Node Outputs -export type ImageOutput = S<'ImageOutput'>; -export type MaskOutput = S<'MaskOutput'>; -export type PromptOutput = S<'PromptOutput'>; -export type IterateInvocationOutput = S<'IterateInvocationOutput'>; -export type CollectInvocationOutput = S<'CollectInvocationOutput'>; -export type LatentsOutput = S<'LatentsOutput'>; -export type GraphInvocationOutput = S<'GraphInvocationOutput'>; +export type ImageOutput = components['schemas']['ImageOutput']; +export type MaskOutput = components['schemas']['MaskOutput']; +export type PromptOutput = components['schemas']['PromptOutput']; +export type IterateInvocationOutput = + components['schemas']['IterateInvocationOutput']; +export type CollectInvocationOutput = + components['schemas']['CollectInvocationOutput']; +export type LatentsOutput = components['schemas']['LatentsOutput']; +export type GraphInvocationOutput = + components['schemas']['GraphInvocationOutput']; From 0f0336b6ef3cff4d7ed86796117e0a71415a21f9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 12:04:07 +1000 Subject: [PATCH 08/28] fix(ui): fix incorrect lora id processing --- .../src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts index a105a123d8..dd4b713196 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -42,8 +42,8 @@ export const addLoRAsToGraph = ( let currentLoraIndex = 0; forEach(loras, (lora) => { - const { name, weight } = lora; - const loraField = modelIdToLoRAModelField(name); + const { id, name, weight } = lora; + const loraField = modelIdToLoRAModelField(id); const currentLoraNodeId = `${LORA_LOADER}_${loraField.model_name.replace( '.', '_' From c0501ed5c243efd746c178366d41360922aad0a9 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 5 Jul 2023 14:37:16 +1200 Subject: [PATCH 09/28] fix: Slow loading of Loras Co-Authored-By: StAlKeR7779 <7768370+StAlKeR7779@users.noreply.github.com> --- invokeai/app/invocations/compel.py | 166 ++++++----- invokeai/app/invocations/latent.py | 271 ++++++++++-------- invokeai/backend/model_management/lora.py | 20 +- .../util/graphBuilders/addLoRAsToGraph.ts | 2 - 4 files changed, 253 insertions(+), 206 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 0421841e8a..d77269da20 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,28 +1,27 @@ -from typing import Literal, Optional, Union -from pydantic import BaseModel, Field -from contextlib import ExitStack import re +from contextlib import ExitStack +from typing import List, Literal, Optional, Union + import torch +from compel import Compel +from compel.prompt_parser import (Blend, Conjunction, + CrossAttentionControlSubstitute, + FlattenedPrompt, Fragment) +from pydantic import BaseModel, Field -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig -from .model import ClipField - -from ...backend.util.devices import torch_dtype -from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management.lora import ModelPatcher - -from compel import Compel -from compel.prompt_parser import ( - Blend, - CrossAttentionControlSubstitute, - FlattenedPrompt, - Fragment, Conjunction, -) +from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.util.devices import torch_dtype +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, + InvocationConfig, InvocationContext) +from .model import ClipField class ConditioningField(BaseModel): - conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") + conditioning_name: Optional[str] = Field( + default=None, description="The name of conditioning data") + class Config: schema_extra = {"required": ["conditioning_name"]} @@ -52,84 +51,92 @@ class CompelInvocation(BaseInvocation): "title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": { - "model": "model" + "model": "model" } }, } - @torch.no_grad() + @torch.inference_mode() def invoke(self, context: InvocationContext) -> CompelOutput: - tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.dict(), ) text_encoder_info = context.services.model_manager.get_model( **self.clip.text_encoder.dict(), ) - with tokenizer_info as orig_tokenizer,\ - text_encoder_info as text_encoder: - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] + def _lora_loader(): + for lora in self.clip.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - ti_list = [] - for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model - ) - except Exception: - #print(e) - #import traceback - #print(traceback.format_exc()) - print(f"Warn: trigger: \"{trigger}\" not found") + #loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - with ModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ - ModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): - - compel = Compel( - tokenizer=tokenizer, - text_encoder=text_encoder, - textual_inversion_manager=ti_manager, - dtype_for_device_getter=torch_dtype, - truncate_long_prompts=True, # TODO: + ti_list = [] + for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): + name = trigger[1:-1] + try: + ti_list.append( + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ).context.model ) - - conjunction = Compel.parse_prompt_string(self.prompt) - prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + except Exception: + # print(e) + #import traceback + # print(traceback.format_exc()) + print(f"Warn: trigger: \"{trigger}\" not found") - if context.services.configuration.log_tokenization: - log_tokenization_for_prompt_object(prompt, tokenizer) + with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ + ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\ + text_encoder_info as text_encoder: - c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) - - # TODO: long prompt support - #if not self.truncate_long_prompts: - # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) - ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), - cross_attention_control_args=options.get("cross_attention_control", None), - ) - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (c, ec)) - - return CompelOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), + compel = Compel( + tokenizer=tokenizer, + text_encoder=text_encoder, + textual_inversion_manager=ti_manager, + dtype_for_device_getter=torch_dtype, + truncate_long_prompts=True, # TODO: ) + conjunction = Compel.parse_prompt_string(self.prompt) + prompt: Union[FlattenedPrompt, Blend] = conjunction.prompts[0] + + if context.services.configuration.log_tokenization: + log_tokenization_for_prompt_object(prompt, tokenizer) + + c, options = compel.build_conditioning_tensor_for_prompt_object( + prompt) + + # TODO: long prompt support + # if not self.truncate_long_prompts: + # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( + tokens_count_including_eos_bos=get_max_token_count( + tokenizer, conjunction), + cross_attention_control_args=options.get( + "cross_attention_control", None),) + + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + + # TODO: hacky but works ;D maybe rename latents somehow? + context.services.latents.save(conditioning_name, (c, ec)) + + return CompelOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) + def get_max_token_count( - tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False -) -> int: + tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], + truncate_if_too_long=False) -> int: if type(prompt) is Blend: blend: Blend = prompt return max( @@ -148,13 +155,13 @@ def get_max_token_count( ) else: return len( - get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) - ) + get_tokens_for_prompt_object( + tokenizer, prompt, truncate_if_too_long)) def get_tokens_for_prompt_object( tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True -) -> [str]: +) -> List[str]: if type(parsed_prompt) is Blend: raise ValueError( "Blend is not supported here - you need to get tokens for each of its .children" @@ -183,7 +190,7 @@ def log_tokenization_for_conjunction( ): display_label_prefix = display_label_prefix or "" for i, p in enumerate(c.prompts): - if len(c.prompts)>1: + if len(c.prompts) > 1: this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})" else: this_display_label_prefix = display_label_prefix @@ -238,7 +245,8 @@ def log_tokenization_for_prompt_object( ) -def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): +def log_tokenization_for_text( + text, tokenizer, display_label=None, truncate_if_too_long=False): """shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index a9576a2fe1..50c901f15f 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -4,18 +4,17 @@ from contextlib import ExitStack from typing import List, Literal, Optional, Union import einops - -from pydantic import BaseModel, Field, validator import torch from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import SchedulerMixin as Scheduler +from pydantic import BaseModel, Field, validator from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.step_callback import stable_diffusion_step_callback -from ..models.image import ImageCategory, ImageField, ResourceOrigin from ...backend.image_util.seamless import configure_model_padding +from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline, @@ -24,7 +23,7 @@ from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \ PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.util.devices import torch_dtype -from ...backend.model_management.lora import ModelPatcher +from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import (BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext) from .compel import ConditioningField @@ -32,14 +31,17 @@ from .controlnet_image_processors import ControlField from .image import ImageOutput from .model import ModelInfo, UNetField, VaeField + class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" - latents_name: Optional[str] = Field(default=None, description="The name of the latents") + latents_name: Optional[str] = Field( + default=None, description="The name of the latents") class Config: schema_extra = {"required": ["latents_name"]} + class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" #fmt: off @@ -53,11 +55,11 @@ class LatentsOutput(BaseInvocationOutput): def build_latents_output(latents_name: str, latents: torch.Tensor): - return LatentsOutput( - latents=LatentsField(latents_name=latents_name), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, - ) + return LatentsOutput( + latents=LatentsField(latents_name=latents_name), + width=latents.size()[3] * 8, + height=latents.size()[2] * 8, + ) SAMPLER_NAME_VALUES = Literal[ @@ -70,16 +72,19 @@ def get_scheduler( scheduler_info: ModelInfo, scheduler_name: str, ) -> Scheduler: - scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim']) - orig_scheduler_info = context.services.model_manager.get_model(**scheduler_info.dict()) + scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( + scheduler_name, SCHEDULER_MAP['ddim']) + orig_scheduler_info = context.services.model_manager.get_model( + **scheduler_info.dict()) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config - + if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] - scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config} + scheduler_config = {**scheduler_config, ** + scheduler_extra_config, "_backup": scheduler_config} scheduler = scheduler_class.from_config(scheduler_config) - + # hack copied over from generate.py if not hasattr(scheduler, 'uses_inpainting_model'): scheduler.uses_inpainting_model = lambda: False @@ -124,18 +129,18 @@ class TextToLatentsInvocation(BaseInvocation): "ui": { "tags": ["latents"], "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number" + "model": "model", + "control": "control", + # "cfg_scale": "float", + "cfg_scale": "number" } }, } # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: + self, context: InvocationContext, source_node_id: str, + intermediate_state: PipelineIntermediateState) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -143,9 +148,12 @@ class TextToLatentsInvocation(BaseInvocation): source_node_id=source_node_id, ) - def get_conditioning_data(self, context: InvocationContext, scheduler) -> ConditioningData: - c, extra_conditioning_info = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + def get_conditioning_data( + self, context: InvocationContext, scheduler) -> ConditioningData: + c, extra_conditioning_info = context.services.latents.get( + self.positive_conditioning.conditioning_name) + uc, _ = context.services.latents.get( + self.negative_conditioning.conditioning_name) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -153,10 +161,10 @@ class TextToLatentsInvocation(BaseInvocation): guidance_scale=self.cfg_scale, extra=extra_conditioning_info, postprocessing_settings=PostprocessingSettings( - threshold=0.0,#threshold, - warmup=0.2,#warmup, - h_symmetry_time_pct=None,#h_symmetry_time_pct, - v_symmetry_time_pct=None#v_symmetry_time_pct, + threshold=0.0, # threshold, + warmup=0.2, # warmup, + h_symmetry_time_pct=None, # h_symmetry_time_pct, + v_symmetry_time_pct=None # v_symmetry_time_pct, ), ) @@ -164,31 +172,32 @@ class TextToLatentsInvocation(BaseInvocation): scheduler, # for ddim scheduler - eta=0.0, #ddim_eta + eta=0.0, # ddim_eta # for ancestral and sde schedulers generator=torch.Generator(device=uc.device).manual_seed(0), ) return conditioning_data - def create_pipeline(self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + def create_pipeline( + self, unet, scheduler) -> StableDiffusionGeneratorPipeline: # TODO: - #configure_model_padding( + # configure_model_padding( # unet, # self.seamless, # self.seamless_axes, - #) + # ) class FakeVae: class FakeVaeConfig: def __init__(self): self.block_out_channels = [0] - + def __init__(self): self.config = FakeVae.FakeVaeConfig() return StableDiffusionGeneratorPipeline( - vae=FakeVae(), # TODO: oh... + vae=FakeVae(), # TODO: oh... text_encoder=None, tokenizer=None, unet=unet, @@ -198,11 +207,12 @@ class TextToLatentsInvocation(BaseInvocation): requires_safety_checker=False, precision="float16" if unet.dtype == torch.float16 else "float32", ) - + def prep_control_data( self, context: InvocationContext, - model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device + # really only need model for dtype and device + model: StableDiffusionGeneratorPipeline, control_input: List[ControlField], latents_shape: List[int], do_classifier_free_guidance: bool = True, @@ -238,15 +248,17 @@ class TextToLatentsInvocation(BaseInvocation): print("Using HF model subfolders") print(" control_name: ", control_name) print(" control_subfolder: ", control_subfolder) - control_model = ControlNetModel.from_pretrained(control_name, - subfolder=control_subfolder, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_name, subfolder=control_subfolder, + torch_dtype=model.unet.dtype).to( + model.device) else: - control_model = ControlNetModel.from_pretrained(control_info.control_model, - torch_dtype=model.unet.dtype).to(model.device) + control_model = ControlNetModel.from_pretrained( + control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_name) + input_image = context.services.images.get_pil_image( + control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -263,41 +275,50 @@ class TextToLatentsInvocation(BaseInvocation): dtype=control_model.dtype, control_mode=control_info.control_mode, ) - control_item = ControlNetData(model=control_model, - image_tensor=control_image, - weight=control_info.control_weight, - begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode, - ) + control_item = ControlNetData( + model=control_model, image_tensor=control_image, + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode,) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data + @torch.inference_mode() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) - with unet_info as unet: + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return + + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, scheduler_info=self.unet.scheduler, scheduler_name=self.scheduler, ) - + pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -305,16 +326,15 @@ class TextToLatentsInvocation(BaseInvocation): do_classifier_free_guidance=True, ) - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - # TODO: Verify the noise is the right size - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback, - ) + # TODO: Verify the noise is the right size + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)), + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -323,14 +343,18 @@ class TextToLatentsInvocation(BaseInvocation): context.services.latents.save(name, result_latents) return build_latents_output(latents_name=name, latents=result_latents) + class LatentsToLatentsInvocation(TextToLatentsInvocation): """Generates latents using latents as base image.""" type: Literal["l2l"] = "l2l" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use") + latents: Optional[LatentsField] = Field( + description="The latents to use as a base image") + strength: float = Field( + default=0.7, ge=0, le=1, + description="The strength of the latents to use") # Schema customisation class Config(InvocationConfig): @@ -345,22 +369,31 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): }, } + @torch.inference_mode() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) latent = context.services.latents.get(self.latents.latents_name) # Get the source node id (we are invoking the prepared node) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) + graph_execution_state = context.services.graph_execution_manager.get( + context.graph_execution_state_id) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): self.dispatch_progress(context, source_node_id, state) - unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict(), - ) + def _lora_loader(): + for lora in self.unet.loras: + lora_info = context.services.model_manager.get_model( + **lora.dict(exclude={"weight"})) + yield (lora_info.context.model, lora.weight) + del lora_info + return - with unet_info as unet: + unet_info = context.services.model_manager.get_model( + **self.unet.unet.dict()) + with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + unet_info as unet: scheduler = get_scheduler( context=context, @@ -370,7 +403,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): pipeline = self.create_pipeline(unet, scheduler) conditioning_data = self.get_conditioning_data(context, scheduler) - + control_data = self.prep_control_data( model=pipeline, context=context, control_input=self.control, latents_shape=noise.shape, @@ -380,8 +413,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype - ) + latent, device=unet.device, dtype=latent.dtype) timesteps, _ = pipeline.get_img2img_timesteps( self.steps, @@ -389,18 +421,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): device=unet.device, ) - loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras] - - with ModelPatcher.apply_lora_unet(pipeline.unet, loras): - result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( - latents=initial_latents, - timesteps=timesteps, - noise=noise, - num_inference_steps=self.steps, - conditioning_data=conditioning_data, - control_data=control_data, # list[ControlNetData] - callback=step_callback - ) + result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( + latents=initial_latents, + timesteps=timesteps, + noise=noise, + num_inference_steps=self.steps, + conditioning_data=conditioning_data, + control_data=control_data, # list[ControlNetData] + callback=step_callback + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -417,9 +446,12 @@ class LatentsToImageInvocation(BaseInvocation): type: Literal["l2i"] = "l2i" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to generate an image from") + latents: Optional[LatentsField] = Field( + description="The latents to generate an image from") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Decode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): @@ -429,7 +461,7 @@ class LatentsToImageInvocation(BaseInvocation): }, } - @torch.no_grad() + @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -450,7 +482,7 @@ class LatentsToImageInvocation(BaseInvocation): # copied from diffusers pipeline latents = latents / vae.config.scaling_factor image = vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) # denormalize + image = (image / 2 + 0.5).clamp(0, 1) # denormalize # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 np_image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -473,9 +505,9 @@ class LatentsToImageInvocation(BaseInvocation): height=image_dto.height, ) -LATENTS_INTERPOLATION_MODE = Literal[ - "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact" -] + +LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", + "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] class ResizeLatentsInvocation(BaseInvocation): @@ -484,21 +516,25 @@ class ResizeLatentsInvocation(BaseInvocation): type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") - height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to resize") + width: int = Field( + ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field( + ge=64, multiple_of=8, description="The height to resize to (px)") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) resized_latents = torch.nn.functional.interpolate( - latents, - size=(self.height // 8, self.width // 8), - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, size=(self.height // 8, self.width // 8), + mode=self.mode, antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -515,21 +551,24 @@ class ScaleLatentsInvocation(BaseInvocation): type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to scale") - scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)") + latents: Optional[LatentsField] = Field( + description="The latents to scale") + scale_factor: float = Field( + gt=0, description="The factor by which to scale the latents") + mode: LATENTS_INTERPOLATION_MODE = Field( + default="bilinear", description="The interpolation mode") + antialias: bool = Field( + default=False, + description="Whether or not to antialias (applied in bilinear and bicubic modes only)") def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) # resizing resized_latents = torch.nn.functional.interpolate( - latents, - scale_factor=self.scale_factor, - mode=self.mode, - antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False, - ) + latents, scale_factor=self.scale_factor, mode=self.mode, + antialias=self.antialias + if self.mode in ["bilinear", "bicubic"] else False,) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -548,7 +587,9 @@ class ImageToLatentsInvocation(BaseInvocation): # Inputs image: Union[ImageField, None] = Field(description="The image to encode") vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") + tiled: bool = Field( + default=False, + description="Encode latents by overlaping tiles(less memory consumption)") # Schema customisation class Config(InvocationConfig): @@ -558,7 +599,7 @@ class ImageToLatentsInvocation(BaseInvocation): }, } - @torch.no_grad() + @torch.inference_mode() def invoke(self, context: InvocationContext) -> LatentsOutput: # image = context.services.images.get( # self.image.image_type, self.image.image_name diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 6cfcb8dd8d..bcd47ff00a 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -1,18 +1,17 @@ from __future__ import annotations import copy -from pathlib import Path from contextlib import contextmanager -from typing import Optional, Dict, Tuple, Any +from pathlib import Path +from typing import Any, Dict, Optional, Tuple import torch +from compel.embeddings_provider import BaseTextualInversionManager +from diffusers.models import UNet2DConditionModel from safetensors.torch import load_file from torch.utils.hooks import RemovableHandle - -from diffusers.models import UNet2DConditionModel from transformers import CLIPTextModel -from compel.embeddings_provider import BaseTextualInversionManager class LoRALayerBase: #rank: Optional[int] @@ -527,7 +526,7 @@ class ModelPatcher: ): original_weights = dict() try: - with torch.no_grad(): + with torch.inference_mode(): for lora, lora_weight in loras: #assert lora.device.type == "cpu" for layer_key, layer in lora.layers.items(): @@ -539,9 +538,10 @@ class ModelPatcher: original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) # enable autocast to calc fp16 loras on cpu - with torch.autocast(device_type="cpu"): - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - layer_weight = layer.get_weight() * lora_weight * layer_scale + #with torch.autocast(device_type="cpu"): + layer.to(dtype=torch.float32) + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + layer_weight = layer.get_weight() * lora_weight * layer_scale if module.weight.shape != layer_weight.shape: # TODO: debug on lycoris @@ -552,7 +552,7 @@ class ModelPatcher: yield # wait for context manager exit finally: - with torch.no_grad(): + with torch.inference_mode(): for module_key, weight in original_weights.items(): model.get_submodule(module_key).weight.copy_(weight) diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts index dd4b713196..9712ef4d5f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -49,8 +49,6 @@ export const addLoRAsToGraph = ( '_' )}`; - console.log(lastLoraNodeId, currentLoraNodeId, currentLoraIndex, loraField); - const loraLoaderNode: LoraLoaderInvocation = { type: 'lora_loader', id: currentLoraNodeId, From 1358c5eb7d5968177acc582a64018cfad5577335 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 4 Jul 2023 22:45:45 +1000 Subject: [PATCH 10/28] fix(ui): fix selector memoization Every `GalleryImage` was rerendering any time the app rerendered bc the selector function itself was not memoized. This resulted in the memoization cache inside the selector constantly being reset. Same for `BatchImage`. Also updated memoization for a few other selectors. --- .../features/batch/components/BatchImage.tsx | 40 +++++++++------- .../gallery/components/GalleryImage.tsx | 47 ++++++++++--------- .../lora/components/ParamLoraList.tsx | 13 +++-- .../Parameters/Core/ParamCFGScale.tsx | 4 +- .../Parameters/Core/ParamHeight.tsx | 4 +- .../Parameters/Core/ParamIterations.tsx | 47 ++++++++++--------- .../components/Parameters/Core/ParamSteps.tsx | 4 +- .../components/Parameters/Core/ParamWidth.tsx | 7 +-- 8 files changed, 92 insertions(+), 74 deletions(-) diff --git a/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx index 822b1cf183..3394946972 100644 --- a/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx +++ b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx @@ -1,28 +1,29 @@ import { Box, Icon, Skeleton } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { FaExclamationCircle } from 'react-icons/fa'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; import { batchImageRangeEndSelected, batchImageSelected, batchImageSelectionToggled, imageRemovedFromBatch, } from 'features/batch/store/batchSlice'; -import IAIDndImage from 'common/components/IAIDndImage'; -import { createSelector } from '@reduxjs/toolkit'; -import { RootState, stateSelector } from 'app/store/store'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; +import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { FaExclamationCircle } from 'react-icons/fa'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; -const isSelectedSelector = createSelector( - [stateSelector, (state: RootState, imageName: string) => imageName], - (state, imageName) => ({ - selection: state.batch.selection, - isSelected: state.batch.selection.includes(imageName), - }), - defaultSelectorOptions -); +const makeSelector = (image_name: string) => + createSelector( + [stateSelector], + (state) => ({ + selection: state.batch.selection, + isSelected: state.batch.selection.includes(image_name), + }), + defaultSelectorOptions + ); type BatchImageProps = { imageName: string; @@ -37,10 +38,13 @@ const BatchImage = (props: BatchImageProps) => { } = useGetImageDTOQuery(props.imageName); const dispatch = useAppDispatch(); - const { isSelected, selection } = useAppSelector((state) => - isSelectedSelector(state, props.imageName) + const selector = useMemo( + () => makeSelector(props.imageName), + [props.imageName] ); + const { isSelected, selection } = useAppSelector(selector); + const handleClickRemove = useCallback(() => { dispatch(imageRemovedFromBatch(props.imageName)); }, [dispatch, props.imageName]); diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx index 30e1c5abf3..7b2e27ddbe 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx @@ -1,34 +1,35 @@ import { Box } from '@chakra-ui/react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { MouseEvent, memo, useCallback, useMemo } from 'react'; -import { FaTrash } from 'react-icons/fa'; -import { useTranslation } from 'react-i18next'; import { createSelector } from '@reduxjs/toolkit'; -import { ImageDTO } from 'services/api/types'; import { TypesafeDraggableData } from 'app/components/ImageDnd/typesafeDnd'; import { stateSelector } from 'app/store/store'; -import ImageContextMenu from './ImageContextMenu'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIDndImage from 'common/components/IAIDndImage'; +import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice'; +import { MouseEvent, memo, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { FaTrash } from 'react-icons/fa'; +import { ImageDTO } from 'services/api/types'; import { imageRangeEndSelected, imageSelected, imageSelectionToggled, } from '../store/gallerySlice'; -import { imageToDeleteSelected } from 'features/imageDeletion/store/imageDeletionSlice'; +import ImageContextMenu from './ImageContextMenu'; -export const selector = createSelector( - [stateSelector, (state, { image_name }: ImageDTO) => image_name], - ({ gallery }, image_name) => { - const isSelected = gallery.selection.includes(image_name); - const selection = gallery.selection; - return { - isSelected, - selection, - }; - }, - defaultSelectorOptions -); +export const makeSelector = (image_name: string) => + createSelector( + [stateSelector], + ({ gallery }) => { + const isSelected = gallery.selection.includes(image_name); + const selection = gallery.selection; + return { + isSelected, + selection, + }; + }, + defaultSelectorOptions + ); interface HoverableImageProps { imageDTO: ImageDTO; @@ -38,13 +39,13 @@ interface HoverableImageProps { * Gallery image component with delete/use all/use seed buttons on hover. */ const GalleryImage = (props: HoverableImageProps) => { - const { isSelected, selection } = useAppSelector((state) => - selector(state, props.imageDTO) - ); - const { imageDTO } = props; const { image_url, thumbnail_url, image_name } = imageDTO; + const localSelector = useMemo(() => makeSelector(image_name), [image_name]); + + const { isSelected, selection } = useAppSelector(localSelector); + const dispatch = useAppDispatch(); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx index 8d6ff98498..89432ac862 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLoraList.tsx @@ -1,14 +1,19 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import ParamLora from './ParamLora'; -const selector = createSelector(stateSelector, ({ lora }) => { - const { loras } = lora; +const selector = createSelector( + stateSelector, + ({ lora }) => { + const { loras } = lora; - return { loras }; -}); + return { loras }; + }, + defaultSelectorOptions +); const ParamLoraList = () => { const { loras } = useAppSelector(selector); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx index 111e3d3ae8..d32ff960d5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamCFGScale.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; @@ -27,7 +28,8 @@ const selector = createSelector( shouldUseSliders, shift, }; - } + }, + defaultSelectorOptions ); const ParamCFGScale = () => { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx index 9501c8b475..6939ede424 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamHeight.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setHeight } from 'features/parameters/store/generationSlice'; @@ -25,7 +26,8 @@ const selector = createSelector( inputMax, step, }; - } + }, + defaultSelectorOptions ); type ParamHeightProps = Omit< diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx index a8cdabc8c9..1e203a1e45 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamIterations.tsx @@ -1,37 +1,38 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; -import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setIterations } from 'features/parameters/store/generationSlice'; -import { configSelector } from 'features/system/store/configSelectors'; -import { hotkeysSelector } from 'features/ui/store/hotkeysSlice'; -import { uiSelector } from 'features/ui/store/uiSelectors'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createSelector([stateSelector], (state) => { - const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = - state.config.sd.iterations; - const { iterations } = state.generation; - const { shouldUseSliders } = state.ui; - const isDisabled = - state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; +const selector = createSelector( + [stateSelector], + (state) => { + const { initial, min, sliderMax, inputMax, fineStep, coarseStep } = + state.config.sd.iterations; + const { iterations } = state.generation; + const { shouldUseSliders } = state.ui; + const isDisabled = + state.dynamicPrompts.isEnabled && state.dynamicPrompts.combinatorial; - const step = state.hotkeys.shift ? fineStep : coarseStep; + const step = state.hotkeys.shift ? fineStep : coarseStep; - return { - iterations, - initial, - min, - sliderMax, - inputMax, - step, - shouldUseSliders, - isDisabled, - }; -}); + return { + iterations, + initial, + min, + sliderMax, + inputMax, + step, + shouldUseSliders, + isDisabled, + }; + }, + defaultSelectorOptions +); const ParamIterations = () => { const { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx index f43cdd425b..d939113c7c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamSteps.tsx @@ -1,5 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAINumberInput from 'common/components/IAINumberInput'; import IAISlider from 'common/components/IAISlider'; @@ -33,7 +34,8 @@ const selector = createSelector( step, shouldUseSliders, }; - } + }, + defaultSelectorOptions ); const ParamSteps = () => { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx index b7d63038d1..b4121184b5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamWidth.tsx @@ -1,7 +1,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAISlider from 'common/components/IAISlider'; -import { IAIFullSliderProps } from 'common/components/IAISlider'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider, { IAIFullSliderProps } from 'common/components/IAISlider'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setWidth } from 'features/parameters/store/generationSlice'; import { configSelector } from 'features/system/store/configSelectors'; @@ -26,7 +26,8 @@ const selector = createSelector( inputMax, step, }; - } + }, + defaultSelectorOptions ); type ParamWidthProps = Omit; From f155887b7de58f63e84d09c384a26070a735e0fd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 10:24:48 +1000 Subject: [PATCH 11/28] fix(ui): change multi image drop to not have selection as payload This caused a lot of re-rendering whenever the selection changed, which caused a huge performance hit. It also made changing the current image lag a bit. Instead of providing an array of image names as a multi-select dnd payload, there is now no multi-select dnd payload at all - instead, the payload types are used by the `imageDropped` listener to pull the selection out of redux. Now, the only big re-renders are when the selectionCount changes. In the future I'll figure out a good way to do image names as payload without incurring re-renders. --- .../app/components/ImageDnd/DragPreview.tsx | 44 ++++++++++++++++++- .../app/components/ImageDnd/typesafeDnd.tsx | 20 ++++++--- .../listeners/imageDropped.ts | 20 ++++----- .../features/batch/components/BatchImage.tsx | 13 +++--- .../components/ControlNetImagePreview.tsx | 25 +++++------ .../components/CurrentImagePreview.tsx | 18 ++++---- .../gallery/components/GalleryImage.tsx | 13 +++--- .../fields/ImageInputFieldComponent.tsx | 9 ++-- 8 files changed, 100 insertions(+), 62 deletions(-) diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx index 5b6142d748..bf66c0ee08 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx @@ -1,4 +1,8 @@ import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { memo } from 'react'; import { TypesafeDraggableData } from './typesafeDnd'; @@ -28,7 +32,24 @@ const STYLES: ChakraProps['sx'] = { }, }; +const selector = createSelector( + stateSelector, + (state) => { + const gallerySelectionCount = state.gallery.selection.length; + const batchSelectionCount = state.batch.selection.length; + + return { + gallerySelectionCount, + batchSelectionCount, + }; + }, + defaultSelectorOptions +); + const DragPreview = (props: OverlayDragImageProps) => { + const { gallerySelectionCount, batchSelectionCount } = + useAppSelector(selector); + if (!props.dragData) { return; } @@ -57,7 +78,7 @@ const DragPreview = (props: OverlayDragImageProps) => { ); } - if (props.dragData.payloadType === 'IMAGE_NAMES') { + if (props.dragData.payloadType === 'BATCH_SELECTION') { return ( { ...STYLES, }} > - {props.dragData.payload.imageNames.length} + {batchSelectionCount} + Images + + ); + } + + if (props.dragData.payloadType === 'GALLERY_SELECTION') { + return ( + + {gallerySelectionCount} Images ); diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx index e744a70750..1478ace748 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx +++ b/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx @@ -77,14 +77,18 @@ export type ImageDraggableData = BaseDragData & { payload: { imageDTO: ImageDTO }; }; -export type ImageNamesDraggableData = BaseDragData & { - payloadType: 'IMAGE_NAMES'; - payload: { imageNames: string[] }; +export type GallerySelectionDraggableData = BaseDragData & { + payloadType: 'GALLERY_SELECTION'; +}; + +export type BatchSelectionDraggableData = BaseDragData & { + payloadType: 'BATCH_SELECTION'; }; export type TypesafeDraggableData = | ImageDraggableData - | ImageNamesDraggableData; + | GallerySelectionDraggableData + | BatchSelectionDraggableData; interface UseDroppableTypesafeArguments extends Omit { @@ -155,11 +159,13 @@ export const isValidDrop = ( case 'SET_NODES_IMAGE': return payloadType === 'IMAGE_DTO'; case 'SET_MULTI_NODES_IMAGE': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; case 'ADD_TO_BATCH': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION'; case 'MOVE_BOARD': - return payloadType === 'IMAGE_DTO' || 'IMAGE_NAMES'; + return ( + payloadType === 'IMAGE_DTO' || 'GALLERY_SELECTION' || 'BATCH_SELECTION' + ); default: return false; } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts index 56f660a653..24a5bffec7 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts @@ -1,24 +1,23 @@ import { createAction } from '@reduxjs/toolkit'; -import { startAppListening } from '../'; -import { log } from 'app/logging/useLogger'; import { TypesafeDraggableData, TypesafeDroppableData, } from 'app/components/ImageDnd/typesafeDnd'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { initialImageChanged } from 'features/parameters/store/generationSlice'; +import { log } from 'app/logging/useLogger'; import { imageAddedToBatch, imagesAddedToBatch, } from 'features/batch/store/batchSlice'; -import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; +import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; import { fieldValueChanged, imageCollectionFieldValueChanged, } from 'features/nodes/store/nodesSlice'; -import { boardsApi } from 'services/api/endpoints/boards'; +import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { boardImagesApi } from 'services/api/endpoints/boardImages'; +import { startAppListening } from '../'; const moduleLog = log.child({ namespace: 'dnd' }); @@ -33,6 +32,7 @@ export const addImageDroppedListener = () => { effect: (action, { dispatch, getState }) => { const { activeData, overData } = action.payload; const { actionType } = overData; + const state = getState(); // set current image if ( @@ -64,9 +64,9 @@ export const addImageDroppedListener = () => { // add multiple images to batch if ( actionType === 'ADD_TO_BATCH' && - activeData.payloadType === 'IMAGE_NAMES' + activeData.payloadType === 'GALLERY_SELECTION' ) { - dispatch(imagesAddedToBatch(activeData.payload.imageNames)); + dispatch(imagesAddedToBatch(state.gallery.selection)); } // set control image @@ -128,14 +128,14 @@ export const addImageDroppedListener = () => { // set multiple nodes images (multiple images handler) if ( actionType === 'SET_MULTI_NODES_IMAGE' && - activeData.payloadType === 'IMAGE_NAMES' + activeData.payloadType === 'GALLERY_SELECTION' ) { const { fieldName, nodeId } = overData.context; dispatch( imageCollectionFieldValueChanged({ nodeId, fieldName, - value: activeData.payload.imageNames.map((image_name) => ({ + value: state.gallery.selection.map((image_name) => ({ image_name, })), }) diff --git a/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx index 3394946972..4a6250f93a 100644 --- a/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx +++ b/invokeai/frontend/web/src/features/batch/components/BatchImage.tsx @@ -19,7 +19,7 @@ const makeSelector = (image_name: string) => createSelector( [stateSelector], (state) => ({ - selection: state.batch.selection, + selectionCount: state.batch.selection.length, isSelected: state.batch.selection.includes(image_name), }), defaultSelectorOptions @@ -43,7 +43,7 @@ const BatchImage = (props: BatchImageProps) => { [props.imageName] ); - const { isSelected, selection } = useAppSelector(selector); + const { isSelected, selectionCount } = useAppSelector(selector); const handleClickRemove = useCallback(() => { dispatch(imageRemovedFromBatch(props.imageName)); @@ -63,13 +63,10 @@ const BatchImage = (props: BatchImageProps) => { ); const draggableData = useMemo(() => { - if (selection.length > 1) { + if (selectionCount > 1) { return { id: 'batch', - payloadType: 'IMAGE_NAMES', - payload: { - imageNames: selection, - }, + payloadType: 'BATCH_SELECTION', }; } @@ -80,7 +77,7 @@ const BatchImage = (props: BatchImageProps) => { payload: { imageDTO }, }; } - }, [imageDTO, selection]); + }, [imageDTO, selectionCount]); if (isError) { return ; diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index df73f1141d..c0c1030b79 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -1,25 +1,22 @@ -import { memo, useCallback, useMemo, useState } from 'react'; -import { ImageDTO } from 'services/api/types'; -import { - ControlNetConfig, - controlNetImageChanged, - controlNetSelector, -} from '../store/controlNetSlice'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { Box, Flex, SystemStyleObject } from '@chakra-ui/react'; -import IAIDndImage from 'common/components/IAIDndImage'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { IAILoadingImageFallback } from 'common/components/IAIImageFallback'; -import IAIIconButton from 'common/components/IAIIconButton'; -import { FaUndo } from 'react-icons/fa'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { skipToken } from '@reduxjs/toolkit/dist/query'; import { TypesafeDraggableData, TypesafeDroppableData, } from 'app/components/ImageDnd/typesafeDnd'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { IAILoadingImageFallback } from 'common/components/IAIImageFallback'; +import { memo, useCallback, useMemo, useState } from 'react'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/thunks/image'; +import { + ControlNetConfig, + controlNetImageChanged, + controlNetSelector, +} from '../store/controlNetSlice'; const selector = createSelector( controlNetSelector, diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx index 112129ffa2..8018beea9a 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImagePreview.tsx @@ -1,19 +1,19 @@ import { Box, Flex, Image } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { useAppSelector } from 'app/store/storeHooks'; -import { isEqual } from 'lodash-es'; -import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; -import NextPrevImageButtons from './NextPrevImageButtons'; -import { memo, useMemo } from 'react'; -import IAIDndImage from 'common/components/IAIDndImage'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { stateSelector } from 'app/store/store'; -import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice'; import { TypesafeDraggableData, TypesafeDroppableData, } from 'app/components/ImageDnd/typesafeDnd'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { selectLastSelectedImage } from 'features/gallery/store/gallerySlice'; +import { isEqual } from 'lodash-es'; +import { memo, useMemo } from 'react'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import ImageMetadataViewer from './ImageMetaDataViewer/ImageMetadataViewer'; +import NextPrevImageButtons from './NextPrevImageButtons'; export const imagesSelector = createSelector( [stateSelector, selectLastSelectedImage], diff --git a/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx index 7b2e27ddbe..ea0b3b0fd8 100644 --- a/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/GalleryImage.tsx @@ -22,10 +22,10 @@ export const makeSelector = (image_name: string) => [stateSelector], ({ gallery }) => { const isSelected = gallery.selection.includes(image_name); - const selection = gallery.selection; + const selectionCount = gallery.selection.length; return { isSelected, - selection, + selectionCount, }; }, defaultSelectorOptions @@ -44,7 +44,7 @@ const GalleryImage = (props: HoverableImageProps) => { const localSelector = useMemo(() => makeSelector(image_name), [image_name]); - const { isSelected, selection } = useAppSelector(localSelector); + const { isSelected, selectionCount } = useAppSelector(localSelector); const dispatch = useAppDispatch(); @@ -75,11 +75,10 @@ const GalleryImage = (props: HoverableImageProps) => { ); const draggableData = useMemo(() => { - if (selection.length > 1) { + if (selectionCount > 1) { return { id: 'gallery-image', - payloadType: 'IMAGE_NAMES', - payload: { imageNames: selection }, + payloadType: 'GALLERY_SELECTION', }; } @@ -90,7 +89,7 @@ const GalleryImage = (props: HoverableImageProps) => { payload: { imageDTO }, }; } - }, [imageDTO, selection]); + }, [imageDTO, selectionCount]); return ( diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx index 499946e3af..bfae89e931 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx @@ -7,18 +7,17 @@ import { } from 'features/nodes/types/types'; import { memo, useCallback, useMemo } from 'react'; -import { FieldComponentProps } from './types'; -import IAIDndImage from 'common/components/IAIDndImage'; -import { ImageDTO } from 'services/api/types'; import { Flex } from '@chakra-ui/react'; -import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { skipToken } from '@reduxjs/toolkit/dist/query'; import { - NodesImageDropData, TypesafeDraggableData, TypesafeDroppableData, } from 'app/components/ImageDnd/typesafeDnd'; +import IAIDndImage from 'common/components/IAIDndImage'; +import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/thunks/image'; +import { ImageDTO } from 'services/api/types'; +import { FieldComponentProps } from './types'; const ImageInputFieldComponent = ( props: FieldComponentProps From 639d88afd6cfafc5039a668f09e974c8a218e1c7 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 5 Jul 2023 16:39:15 +1200 Subject: [PATCH 12/28] revert: inference_mode to no_grad --- invokeai/app/invocations/compel.py | 2 +- invokeai/app/invocations/latent.py | 8 ++++---- invokeai/backend/model_management/lora.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index d77269da20..d4ba7efeda 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -56,7 +56,7 @@ class CompelInvocation(BaseInvocation): }, } - @torch.inference_mode() + @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: tokenizer_info = context.services.model_manager.get_model( **self.clip.tokenizer.dict(), diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 50c901f15f..3e691c934e 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -285,7 +285,7 @@ class TextToLatentsInvocation(BaseInvocation): # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data - @torch.inference_mode() + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) @@ -369,7 +369,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): }, } - @torch.inference_mode() + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: noise = context.services.latents.get(self.noise.latents_name) latent = context.services.latents.get(self.latents.latents_name) @@ -461,7 +461,7 @@ class LatentsToImageInvocation(BaseInvocation): }, } - @torch.inference_mode() + @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -599,7 +599,7 @@ class ImageToLatentsInvocation(BaseInvocation): }, } - @torch.inference_mode() + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: # image = context.services.images.get( # self.image.image_type, self.image.image_name diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index bcd47ff00a..5d27555ab3 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -526,7 +526,7 @@ class ModelPatcher: ): original_weights = dict() try: - with torch.inference_mode(): + with torch.no_grad(): for lora, lora_weight in loras: #assert lora.device.type == "cpu" for layer_key, layer in lora.layers.items(): @@ -552,7 +552,7 @@ class ModelPatcher: yield # wait for context manager exit finally: - with torch.inference_mode(): + with torch.no_grad(): for module_key, weight in original_weights.items(): model.get_submodule(module_key).weight.copy_(weight) From 1a29a3fe39aeb8b4623f69196e3fb952045c95d4 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 5 Jul 2023 16:39:28 +1200 Subject: [PATCH 13/28] feat: Add Lora to Canvas --- .../graphBuilders/buildCanvasInpaintGraph.ts | 3 +++ .../UnifiedCanvas/UnifiedCanvasParameters.tsx | 18 ++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 82912de219..c4f9415067 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -8,6 +8,7 @@ import { RangeOfSizeInvocation, } from 'services/api/types'; import { modelIdToMainModelField } from '../modelIdToMainModelField'; +import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { INPAINT, @@ -194,6 +195,8 @@ export const buildCanvasInpaintGraph = ( ], }; + addLoRAsToGraph(graph, state, INPAINT); + // Add VAE addVAEToGraph(graph, state); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx index 061ebb962e..63ed4cc1cf 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasParameters.tsx @@ -1,14 +1,15 @@ -import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; -import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; -import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamLoraCollapse from 'features/lora/components/ParamLoraCollapse'; import ParamInfillAndScalingCollapse from 'features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse'; import ParamSeamCorrectionCollapse from 'features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse'; -import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; -import { memo } from 'react'; -import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; -import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; import ParamControlNetCollapse from 'features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse'; -import ParamDynamicPromptsCollapse from 'features/dynamicPrompts/components/ParamDynamicPromptsCollapse'; +import ParamNegativeConditioning from 'features/parameters/components/Parameters/Core/ParamNegativeConditioning'; +import ParamPositiveConditioning from 'features/parameters/components/Parameters/Core/ParamPositiveConditioning'; +import ParamSymmetryCollapse from 'features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse'; +import ParamVariationCollapse from 'features/parameters/components/Parameters/Variations/ParamVariationCollapse'; +import ProcessButtons from 'features/parameters/components/ProcessButtons/ProcessButtons'; +import { memo } from 'react'; +import UnifiedCanvasCoreParameters from './UnifiedCanvasCoreParameters'; const UnifiedCanvasParameters = () => { return ( @@ -17,6 +18,7 @@ const UnifiedCanvasParameters = () => { + From 1fb317243d9731e86d645e6d00bb806f54fc7bf4 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 5 Jul 2023 18:07:14 +1200 Subject: [PATCH 14/28] fix: Change Lora weight bounds to -1 to 2 --- .../frontend/web/src/features/lora/components/ParamLora.tsx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx index 277a316fae..23459e9410 100644 --- a/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx +++ b/invokeai/frontend/web/src/features/lora/components/ParamLora.tsx @@ -35,13 +35,14 @@ const ParamLora = (props: Props) => { label={lora.name} value={lora.weight} onChange={handleChange} - min={0} - max={1} + min={-1} + max={2} step={0.01} withInput withReset handleReset={handleReset} withSliderMarks + sliderMarks={[-1, 0, 1, 2]} /> Date: Wed, 5 Jul 2023 09:43:46 +0300 Subject: [PATCH 15/28] Fix model detection --- invokeai/backend/install/model_install_backend.py | 2 +- invokeai/backend/model_management/model_manager.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index a10fa852c0..00646e70e3 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -223,7 +223,7 @@ class ModelInstall(object): try: model_result = None info = info or ModelProbe().heuristic_probe(path,self.prediction_helper) - model_name = path.stem if info.format=='checkpoint' else path.name + model_name = path.stem if path.is_file() else path.name if self.mgr.model_exists(model_name, info.base_type, info.model_type): raise ValueError(f'A model named "{model_name}" is already installed.') attributes = self._make_attributes(path,info) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index a8cbb50474..8002ec9ba4 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -731,12 +731,12 @@ class ModelManager(object): if model_path.is_relative_to(self.app_config.root_path): model_path = model_path.relative_to(self.app_config.root_path) - try: - model_config: ModelConfigBase = model_class.probe_config(str(model_path)) - self.models[model_key] = model_config - new_models_found = True - except NotImplementedError as e: - self.logger.warning(e) + try: + model_config: ModelConfigBase = model_class.probe_config(str(model_path)) + self.models[model_key] = model_config + new_models_found = True + except NotImplementedError as e: + self.logger.warning(e) imported_models = self.autoimport() From 7170e82f73679f441f6b0927dfb88420d815158c Mon Sep 17 00:00:00 2001 From: Eugene Brodsky Date: Tue, 4 Jul 2023 17:05:35 -0400 Subject: [PATCH 16/28] expose max_cache_size in config --- invokeai/app/services/config.py | 19 ++++++----- .../app/services/model_manager_service.py | 34 ++++++++++--------- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/invokeai/app/services/config.py b/invokeai/app/services/config.py index e0f1ceeb25..e7f817fc0a 100644 --- a/invokeai/app/services/config.py +++ b/invokeai/app/services/config.py @@ -228,10 +228,10 @@ class InvokeAISettings(BaseSettings): upcase_environ = dict() for key,value in os.environ.items(): upcase_environ[key.upper()] = value - + fields = cls.__fields__ cls.argparse_groups = {} - + for name, field in fields.items(): if name not in cls._excluded(): current_default = field.default @@ -348,7 +348,7 @@ setting environment variables INVOKEAI_. ''' singleton_config: ClassVar[InvokeAIAppConfig] = None singleton_init: ClassVar[Dict] = None - + #fmt: off type: Literal["InvokeAI"] = "InvokeAI" host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server') @@ -367,7 +367,8 @@ setting environment variables INVOKEAI_. always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", category='Memory/Performance') free_gpu_mem : bool = Field(default=False, description="If true, purge model from GPU after each generation.", category='Memory/Performance') - max_loaded_models : int = Field(default=3, gt=0, description="Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') + max_loaded_models : int = Field(default=3, gt=0, description="(DEPRECATED: use max_cache_size) Maximum number of models to keep in memory for rapid switching", category='Memory/Performance') + max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance') precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='float16',description='Floating point precision', category='Memory/Performance') sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance') xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance') @@ -385,9 +386,9 @@ setting environment variables INVOKEAI_. outdir : Path = Field(default='outputs', description='Default folder for output images', category='Paths') from_file : Path = Field(default=None, description='Take command input from the indicated file (command-line client only)', category='Paths') use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', category='Paths') - + model : str = Field(default='stable-diffusion-1.5', description='Initial model name', category='Models') - + log_handlers : List[str] = Field(default=["console"], description='Log handler. Valid options are "console", "file=", "syslog=path|address:host:port", "http="', category="Logging") # note - would be better to read the log_format values from logging.py, but this creates circular dependencies issues log_format : Literal[tuple(['plain','color','syslog','legacy'])] = Field(default="color", description='Log format. Use "plain" for text-only, "color" for colorized output, "legacy" for 2.3-style logging and "syslog" for syslog-style', category="Logging") @@ -396,7 +397,7 @@ setting environment variables INVOKEAI_. def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False): ''' - Update settings with contents of init file, environment, and + Update settings with contents of init file, environment, and command-line settings. :param conf: alternate Omegaconf dictionary object :param argv: aternate sys.argv list @@ -411,7 +412,7 @@ setting environment variables INVOKEAI_. except: pass InvokeAISettings.initconf = conf - + # parse args again in order to pick up settings in configuration file super().parse_args(argv) @@ -431,7 +432,7 @@ setting environment variables INVOKEAI_. cls.singleton_config = cls(**kwargs) cls.singleton_init = kwargs return cls.singleton_config - + @property def root_path(self)->Path: ''' diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 98b4d81ba8..455d9d021f 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -33,13 +33,13 @@ class ModelManagerServiceBase(ABC): logger: types.ModuleType, ): """ - Initialize with the path to the models.yaml config file. + Initialize with the path to the models.yaml config file. Optional parameters are the torch device type, precision, max_models, and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. """ pass - + @abstractmethod def get_model( self, @@ -50,8 +50,8 @@ class ModelManagerServiceBase(ABC): node: Optional[BaseInvocation] = None, context: Optional[InvocationContext] = None, ) -> ModelInfo: - """Retrieve the indicated model with name and type. - submodel can be used to get a part (such as the vae) + """Retrieve the indicated model with name and type. + submodel can be used to get a part (such as the vae) of a diffusers pipeline.""" pass @@ -115,8 +115,8 @@ class ModelManagerServiceBase(ABC): """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ pass @@ -129,8 +129,8 @@ class ModelManagerServiceBase(ABC): model_type: ModelType, ): """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted + Delete the named model from configuration. If delete_files is true, + then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ pass @@ -176,7 +176,7 @@ class ModelManagerService(ModelManagerServiceBase): logger: types.ModuleType, ): """ - Initialize with the path to the models.yaml config file. + Initialize with the path to the models.yaml config file. Optional parameters are the torch device type, precision, max_models, and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. @@ -206,6 +206,8 @@ class ModelManagerService(ModelManagerServiceBase): if hasattr(config,'max_cache_size') \ else config.max_loaded_models * 2.5 + logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB") + sequential_offload = config.sequential_guidance self.mgr = ModelManager( @@ -261,7 +263,7 @@ class ModelManagerService(ModelManagerServiceBase): submodel=submodel, model_info=model_info ) - + return model_info def model_exists( @@ -314,8 +316,8 @@ class ModelManagerService(ModelManagerServiceBase): """ Update the named model with a dictionary of attributes. Will fail with an assertion error if the name already exists. Pass clobber=True to overwrite. - On a successful update, the config will be changed in memory. Will fail - with an assertion error if provided attributes are incorrect or + On a successful update, the config will be changed in memory. Will fail + with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber) @@ -328,8 +330,8 @@ class ModelManagerService(ModelManagerServiceBase): model_type: ModelType, ): """ - Delete the named model from configuration. If delete_files is true, - then the underlying weight file or diffusers directory will be deleted + Delete the named model from configuration. If delete_files is true, + then the underlying weight file or diffusers directory will be deleted as well. Call commit() to write to disk. """ self.mgr.del_model(model_name, base_model, model_type) @@ -383,7 +385,7 @@ class ModelManagerService(ModelManagerServiceBase): @property def logger(self): return self.mgr.logger - + def heuristic_import(self, items_to_import: Set[str], prediction_type_helper: Callable[[Path],SchedulerPredictionType]=None, @@ -404,4 +406,4 @@ class ModelManagerService(ModelManagerServiceBase): of the set is a dict corresponding to the newly-created OmegaConf stanza for that model. ''' - return self.mgr.heuristic_import(items_to_import, prediction_type_helper) + return self.mgr.heuristic_import(items_to_import, prediction_type_helper) From e41e8606b579b61c1ce2d34c102e6b5da728b79e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 17:33:03 +1000 Subject: [PATCH 17/28] feat(ui): improve accordion ux - Accordions now may be opened or closed regardless of whether or not their contents are enabled or active - Accordions have a short text indicator alerting the user if their contents are enabled, either a simple `Enabled` or, for accordions like LoRA or ControlNet, `X Active` if any are active --- .../web/src/common/components/IAICollapse.tsx | 59 +++++++++++++------ .../web/src/common/components/IAISwitch.tsx | 2 +- .../ParamControlNetFeatureToggle.tsx | 36 +++++++++++ .../controlNet/util/getValidControlNets.ts | 15 +++++ .../ParamDynamicPromptsCollapse.tsx | 26 +++----- .../ParamDynamicPromptsCombinatorial.tsx | 15 ++--- .../components/ParamDynamicPromptsEnabled.tsx | 36 +++++++++++ .../ParamDynamicPromptsMaxPrompts.tsx | 24 +++++--- .../lora/components/ParamLoraCollapse.tsx | 22 ++++++- .../nodes/util/addControlNetToLinearGraph.ts | 10 +--- .../BoundingBox/ParamBoundingBoxCollapse.tsx | 13 ++-- .../ParamInfillAndScalingCollapse.tsx | 11 +--- .../ParamSeamCorrectionCollapse.tsx | 14 ++--- .../ControlNet/ParamControlNetCollapse.tsx | 44 +++++++------- .../Parameters/Hires/ParamHiresCollapse.tsx | 40 +++++++------ .../Parameters/Hires/ParamHiresToggle.tsx | 1 - .../Parameters/Noise/ParamNoiseCollapse.tsx | 41 +++++++------ .../Parameters/Noise/ParamNoiseThreshold.tsx | 21 +++++-- .../Parameters/Noise/ParamNoiseToggle.tsx | 27 +++++++++ .../Parameters/Noise/ParamPerlinNoise.tsx | 19 +++++- .../Seamless/ParamSeamlessCollapse.tsx | 42 +++++++------ .../Symmetry/ParamSymmetryCollapse.tsx | 36 +++++------ .../Symmetry/ParamSymmetryToggle.tsx | 1 + .../Variations/ParamVariationCollapse.tsx | 45 +++++++------- .../Variations/ParamVariationToggle.tsx | 27 +++++++++ .../parameters/store/generationSlice.ts | 10 +--- .../ImageToImageTabCoreParameters.tsx | 18 ++++-- .../TextToImageTabCoreParameters.tsx | 22 ++++--- .../UnifiedCanvasCoreParameters.tsx | 22 ++++--- 29 files changed, 457 insertions(+), 242 deletions(-) create mode 100644 invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx create mode 100644 invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts create mode 100644 invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsEnabled.tsx create mode 100644 invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseToggle.tsx create mode 100644 invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx diff --git a/invokeai/frontend/web/src/common/components/IAICollapse.tsx b/invokeai/frontend/web/src/common/components/IAICollapse.tsx index 5db26f3841..09dc1392e2 100644 --- a/invokeai/frontend/web/src/common/components/IAICollapse.tsx +++ b/invokeai/frontend/web/src/common/components/IAICollapse.tsx @@ -4,22 +4,25 @@ import { Collapse, Flex, Spacer, - Switch, + Text, useColorMode, + useDisclosure, } from '@chakra-ui/react'; +import { AnimatePresence, motion } from 'framer-motion'; import { PropsWithChildren, memo } from 'react'; import { mode } from 'theme/util/mode'; export type IAIToggleCollapseProps = PropsWithChildren & { label: string; - isOpen: boolean; - onToggle: () => void; - withSwitch?: boolean; + activeLabel?: string; + defaultIsOpen?: boolean; }; const IAICollapse = (props: IAIToggleCollapseProps) => { - const { label, isOpen, onToggle, children, withSwitch = false } = props; + const { label, activeLabel, children, defaultIsOpen = false } = props; + const { isOpen, onToggle } = useDisclosure({ defaultIsOpen }); const { colorMode } = useColorMode(); + return ( { alignItems: 'center', p: 2, px: 4, + gap: 2, borderTopRadius: 'base', borderBottomRadius: isOpen ? 0 : 'base', bg: isOpen @@ -48,19 +52,40 @@ const IAICollapse = (props: IAIToggleCollapseProps) => { }} > {label} + + {activeLabel && ( + + + {activeLabel} + + + )} + - {withSwitch && } - {!withSwitch && ( - - )} + { isDisabled={isDisabled} width={width} display="flex" - gap={4} alignItems="center" {...formControlProps} > @@ -47,6 +46,7 @@ const IAISwitch = (props: Props) => { sx={{ cursor: isDisabled ? 'not-allowed' : 'pointer', ...formLabelProps?.sx, + pe: 4, }} {...formLabelProps} > diff --git a/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx new file mode 100644 index 0000000000..3a7eea2fbf --- /dev/null +++ b/invokeai/frontend/web/src/features/controlNet/components/parameters/ParamControlNetFeatureToggle.tsx @@ -0,0 +1,36 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISwitch from 'common/components/IAISwitch'; +import { isControlNetEnabledToggled } from 'features/controlNet/store/controlNetSlice'; +import { useCallback } from 'react'; + +const selector = createSelector( + stateSelector, + (state) => { + const { isEnabled } = state.controlNet; + + return { isEnabled }; + }, + defaultSelectorOptions +); + +const ParamControlNetFeatureToggle = () => { + const { isEnabled } = useAppSelector(selector); + const dispatch = useAppDispatch(); + + const handleChange = useCallback(() => { + dispatch(isControlNetEnabledToggled()); + }, [dispatch]); + + return ( + + ); +}; + +export default ParamControlNetFeatureToggle; diff --git a/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts b/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts new file mode 100644 index 0000000000..4bff39db63 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlNet/util/getValidControlNets.ts @@ -0,0 +1,15 @@ +import { filter } from 'lodash-es'; +import { ControlNetConfig } from '../store/controlNetSlice'; + +export const getValidControlNets = ( + controlNets: Record +) => { + const validControlNets = filter( + controlNets, + (c) => + c.isEnabled && + (Boolean(c.processedControlImage) || + (c.processorType === 'none' && Boolean(c.controlImage))) + ); + return validControlNets; +}; diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx index 1aefecf3e6..0e41fad994 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCollapse.tsx @@ -1,40 +1,30 @@ +import { Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { useCallback } from 'react'; -import { isEnabledToggled } from '../store/slice'; -import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts'; import ParamDynamicPromptsCombinatorial from './ParamDynamicPromptsCombinatorial'; -import { Flex } from '@chakra-ui/react'; +import ParamDynamicPromptsToggle from './ParamDynamicPromptsEnabled'; +import ParamDynamicPromptsMaxPrompts from './ParamDynamicPromptsMaxPrompts'; const selector = createSelector( stateSelector, (state) => { const { isEnabled } = state.dynamicPrompts; - return { isEnabled }; + return { activeLabel: isEnabled ? 'Enabled' : undefined }; }, defaultSelectorOptions ); const ParamDynamicPromptsCollapse = () => { - const dispatch = useAppDispatch(); - const { isEnabled } = useAppSelector(selector); - - const handleToggleIsEnabled = useCallback(() => { - dispatch(isEnabledToggled()); - }, [dispatch]); + const { activeLabel } = useAppSelector(selector); return ( - + + diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx index 30c2240c37..cb930acd3b 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsCombinatorial.tsx @@ -1,23 +1,23 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { combinatorialToggled } from '../store/slice'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { useCallback } from 'react'; import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISwitch from 'common/components/IAISwitch'; +import { useCallback } from 'react'; +import { combinatorialToggled } from '../store/slice'; const selector = createSelector( stateSelector, (state) => { - const { combinatorial } = state.dynamicPrompts; + const { combinatorial, isEnabled } = state.dynamicPrompts; - return { combinatorial }; + return { combinatorial, isDisabled: !isEnabled }; }, defaultSelectorOptions ); const ParamDynamicPromptsCombinatorial = () => { - const { combinatorial } = useAppSelector(selector); + const { combinatorial, isDisabled } = useAppSelector(selector); const dispatch = useAppDispatch(); const handleChange = useCallback(() => { @@ -26,6 +26,7 @@ const ParamDynamicPromptsCombinatorial = () => { return ( { + const { isEnabled } = state.dynamicPrompts; + + return { isEnabled }; + }, + defaultSelectorOptions +); + +const ParamDynamicPromptsToggle = () => { + const dispatch = useAppDispatch(); + const { isEnabled } = useAppSelector(selector); + + const handleToggleIsEnabled = useCallback(() => { + dispatch(isEnabledToggled()); + }, [dispatch]); + + return ( + + ); +}; + +export default ParamDynamicPromptsToggle; diff --git a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx index 19f02ae3e5..172120fd1e 100644 --- a/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx +++ b/invokeai/frontend/web/src/features/dynamicPrompts/components/ParamDynamicPromptsMaxPrompts.tsx @@ -1,25 +1,31 @@ -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import IAISlider from 'common/components/IAISlider'; -import { maxPromptsChanged, maxPromptsReset } from '../store/slice'; import { createSelector } from '@reduxjs/toolkit'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { useCallback } from 'react'; import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAISlider from 'common/components/IAISlider'; +import { useCallback } from 'react'; +import { maxPromptsChanged, maxPromptsReset } from '../store/slice'; const selector = createSelector( stateSelector, (state) => { - const { maxPrompts, combinatorial } = state.dynamicPrompts; + const { maxPrompts, combinatorial, isEnabled } = state.dynamicPrompts; const { min, sliderMax, inputMax } = state.config.sd.dynamicPrompts.maxPrompts; - return { maxPrompts, min, sliderMax, inputMax, combinatorial }; + return { + maxPrompts, + min, + sliderMax, + inputMax, + isDisabled: !isEnabled || !combinatorial, + }; }, defaultSelectorOptions ); const ParamDynamicPromptsMaxPrompts = () => { - const { maxPrompts, min, sliderMax, inputMax, combinatorial } = + const { maxPrompts, min, sliderMax, inputMax, isDisabled } = useAppSelector(selector); const dispatch = useAppDispatch(); @@ -37,7 +43,7 @@ const ParamDynamicPromptsMaxPrompts = () => { return ( { + const loraCount = size(state.lora.loras); + return { + activeLabel: loraCount > 0 ? `${loraCount} Active` : undefined, + }; + }, + defaultSelectorOptions +); + const ParamLoraCollapse = () => { - const { isOpen, onToggle } = useDisclosure(); + const { activeLabel } = useAppSelector(selector); return ( - + diff --git a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts index 11ceb23763..5c4d67ebd3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/addControlNetToLinearGraph.ts @@ -1,5 +1,5 @@ import { RootState } from 'app/store/store'; -import { filter } from 'lodash-es'; +import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { CollectInvocation, ControlNetInvocation } from 'services/api/types'; import { NonNullableGraph } from '../types/types'; import { CONTROL_NET_COLLECT } from './graphBuilders/constants'; @@ -11,13 +11,7 @@ export const addControlNetToLinearGraph = ( ): void => { const { isEnabled: isControlNetEnabled, controlNets } = state.controlNet; - const validControlNets = filter( - controlNets, - (c) => - c.isEnabled && - (Boolean(c.processedControlImage) || - (c.processorType === 'none' && Boolean(c.controlImage))) - ); + const validControlNets = getValidControlNets(controlNets); if (isControlNetEnabled && Boolean(validControlNets.length)) { if (validControlNets.length > 1) { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx index fea0d8330a..b9cc8511aa 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/BoundingBox/ParamBoundingBoxCollapse.tsx @@ -1,20 +1,15 @@ -import { Flex, useDisclosure } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; +import { Flex } from '@chakra-ui/react'; import IAICollapse from 'common/components/IAICollapse'; import { memo } from 'react'; -import ParamBoundingBoxWidth from './ParamBoundingBoxWidth'; +import { useTranslation } from 'react-i18next'; import ParamBoundingBoxHeight from './ParamBoundingBoxHeight'; +import ParamBoundingBoxWidth from './ParamBoundingBoxWidth'; const ParamBoundingBoxCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx index ed01da9876..a531eba57f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/InfillAndScaling/ParamInfillAndScalingCollapse.tsx @@ -1,4 +1,4 @@ -import { Flex, useDisclosure } from '@chakra-ui/react'; +import { Flex } from '@chakra-ui/react'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -6,19 +6,14 @@ import IAICollapse from 'common/components/IAICollapse'; import ParamInfillMethod from './ParamInfillMethod'; import ParamInfillTilesize from './ParamInfillTilesize'; import ParamScaleBeforeProcessing from './ParamScaleBeforeProcessing'; -import ParamScaledWidth from './ParamScaledWidth'; import ParamScaledHeight from './ParamScaledHeight'; +import ParamScaledWidth from './ParamScaledWidth'; const ParamInfillCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx index 992e8b6d02..88d839fa15 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/SeamCorrection/ParamSeamCorrectionCollapse.tsx @@ -1,22 +1,16 @@ +import IAICollapse from 'common/components/IAICollapse'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; import ParamSeamBlur from './ParamSeamBlur'; import ParamSeamSize from './ParamSeamSize'; import ParamSeamSteps from './ParamSeamSteps'; import ParamSeamStrength from './ParamSeamStrength'; -import { useDisclosure } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; const ParamSeamCorrectionCollapse = () => { const { t } = useTranslation(); - const { isOpen, onToggle } = useDisclosure(); return ( - + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx index 06c6108dcb..59bf7542eb 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/ControlNet/ParamControlNetCollapse.tsx @@ -1,41 +1,45 @@ import { Divider, Flex } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import IAICollapse from 'common/components/IAICollapse'; -import { Fragment, memo, useCallback } from 'react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { createSelector } from '@reduxjs/toolkit'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import IAIButton from 'common/components/IAIButton'; +import IAICollapse from 'common/components/IAICollapse'; +import ControlNet from 'features/controlNet/components/ControlNet'; +import ParamControlNetFeatureToggle from 'features/controlNet/components/parameters/ParamControlNetFeatureToggle'; import { controlNetAdded, controlNetSelector, - isControlNetEnabledToggled, } from 'features/controlNet/store/controlNetSlice'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { map } from 'lodash-es'; -import { v4 as uuidv4 } from 'uuid'; +import { getValidControlNets } from 'features/controlNet/util/getValidControlNets'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; -import IAIButton from 'common/components/IAIButton'; -import ControlNet from 'features/controlNet/components/ControlNet'; +import { map } from 'lodash-es'; +import { Fragment, memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { v4 as uuidv4 } from 'uuid'; const selector = createSelector( controlNetSelector, (controlNet) => { const { controlNets, isEnabled } = controlNet; - return { controlNetsArray: map(controlNets), isEnabled }; + const validControlNets = getValidControlNets(controlNets); + + const activeLabel = + isEnabled && validControlNets.length > 0 + ? `${validControlNets.length} Active` + : undefined; + + return { controlNetsArray: map(controlNets), activeLabel }; }, defaultSelectorOptions ); const ParamControlNetCollapse = () => { const { t } = useTranslation(); - const { controlNetsArray, isEnabled } = useAppSelector(selector); + const { controlNetsArray, activeLabel } = useAppSelector(selector); const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const dispatch = useAppDispatch(); - const handleClickControlNetToggle = useCallback(() => { - dispatch(isControlNetEnabledToggled()); - }, [dispatch]); - const handleClickedAddControlNet = useCallback(() => { dispatch(controlNetAdded({ controlNetId: uuidv4() })); }, [dispatch]); @@ -45,13 +49,9 @@ const ParamControlNetCollapse = () => { } return ( - + + {controlNetsArray.map((c, i) => ( {i > 0 && } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx index b4b077ad6c..fa8606d610 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresCollapse.tsx @@ -1,37 +1,39 @@ import { Flex } from '@chakra-ui/react'; -import { useTranslation } from 'react-i18next'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; -import { ParamHiresStrength } from './ParamHiresStrength'; -import { setHiresFix } from 'features/parameters/store/postprocessingSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { ParamHiresStrength } from './ParamHiresStrength'; +import { ParamHiresToggle } from './ParamHiresToggle'; + +const selector = createSelector( + stateSelector, + (state) => { + const activeLabel = state.postprocessing.hiresFix ? 'Enabled' : undefined; + + return { activeLabel }; + }, + defaultSelectorOptions +); const ParamHiresCollapse = () => { const { t } = useTranslation(); - const hiresFix = useAppSelector( - (state: RootState) => state.postprocessing.hiresFix - ); + const { activeLabel } = useAppSelector(selector); const isHiresEnabled = useFeatureStatus('hires').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setHiresFix(!hiresFix)); - if (!isHiresEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx index 0fc600e9e8..f8e6f22aa4 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Hires/ParamHiresToggle.tsx @@ -23,7 +23,6 @@ export const ParamHiresToggle = () => { return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx index adb76d8da0..4dea1dad4f 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseCollapse.tsx @@ -1,27 +1,33 @@ -import { useTranslation } from 'react-i18next'; import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import ParamPerlinNoise from './ParamPerlinNoise'; -import ParamNoiseThreshold from './ParamNoiseThreshold'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { setShouldUseNoiseSettings } from 'features/parameters/store/generationSlice'; -import { memo } from 'react'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import ParamNoiseThreshold from './ParamNoiseThreshold'; +import { ParamNoiseToggle } from './ParamNoiseToggle'; +import ParamPerlinNoise from './ParamPerlinNoise'; + +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings } = state.generation; + return { + activeLabel: shouldUseNoiseSettings ? 'Enabled' : undefined, + }; + }, + defaultSelectorOptions +); const ParamNoiseCollapse = () => { const { t } = useTranslation(); const isNoiseEnabled = useFeatureStatus('noise').isFeatureEnabled; - const shouldUseNoiseSettings = useAppSelector( - (state: RootState) => state.generation.shouldUseNoiseSettings - ); - - const dispatch = useAppDispatch(); - - const handleToggle = () => - dispatch(setShouldUseNoiseSettings(!shouldUseNoiseSettings)); + const { activeLabel } = useAppSelector(selector); if (!isNoiseEnabled) { return null; @@ -30,11 +36,10 @@ const ParamNoiseCollapse = () => { return ( + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx index e339734992..3abb7532b4 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamNoiseThreshold.tsx @@ -1,18 +1,31 @@ -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; import { setThreshold } from 'features/parameters/store/generationSlice'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings, threshold } = state.generation; + return { + isDisabled: !shouldUseNoiseSettings, + threshold, + }; + }, + defaultSelectorOptions +); + export default function ParamNoiseThreshold() { const dispatch = useAppDispatch(); - const threshold = useAppSelector( - (state: RootState) => state.generation.threshold - ); + const { threshold, isDisabled } = useAppSelector(selector); const { t } = useTranslation(); return ( { + const dispatch = useAppDispatch(); + + const shouldUseNoiseSettings = useAppSelector( + (state: RootState) => state.generation.shouldUseNoiseSettings + ); + + const { t } = useTranslation(); + + const handleChange = (e: ChangeEvent) => + dispatch(setShouldUseNoiseSettings(e.target.checked)); + + return ( + + ); +}; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx index ad710eae54..afd676223c 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Noise/ParamPerlinNoise.tsx @@ -1,16 +1,31 @@ -import { RootState } from 'app/store/store'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAISlider from 'common/components/IAISlider'; import { setPerlin } from 'features/parameters/store/generationSlice'; import { useTranslation } from 'react-i18next'; +const selector = createSelector( + stateSelector, + (state) => { + const { shouldUseNoiseSettings, perlin } = state.generation; + return { + isDisabled: !shouldUseNoiseSettings, + perlin, + }; + }, + defaultSelectorOptions +); + export default function ParamPerlinNoise() { const dispatch = useAppDispatch(); - const perlin = useAppSelector((state: RootState) => state.generation.perlin); + const { perlin, isDisabled } = useAppSelector(selector); const { t } = useTranslation(); return ( { + if (seamlessXAxis && seamlessYAxis) { + return 'X & Y'; + } + + if (seamlessXAxis) { + return 'X'; + } + + if (seamlessYAxis) { + return 'Y'; + } +}; const selector = createSelector( generationSelector, (generation) => { - const { shouldUseSeamless, seamlessXAxis, seamlessYAxis } = generation; + const { seamlessXAxis, seamlessYAxis } = generation; - return { shouldUseSeamless, seamlessXAxis, seamlessYAxis }; + const activeLabel = getActiveLabel(seamlessXAxis, seamlessYAxis); + return { activeLabel }; }, defaultSelectorOptions ); const ParamSeamlessCollapse = () => { const { t } = useTranslation(); - const { shouldUseSeamless } = useAppSelector(selector); + const { activeLabel } = useAppSelector(selector); const isSeamlessEnabled = useFeatureStatus('seamless').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setSeamless(!shouldUseSeamless)); - if (!isSeamlessEnabled) { return null; } @@ -38,9 +48,7 @@ const ParamSeamlessCollapse = () => { return ( diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx index 59bdb39be1..f2ddd19768 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryCollapse.tsx @@ -1,39 +1,39 @@ -import { memo } from 'react'; import { Flex } from '@chakra-ui/react'; +import { memo } from 'react'; import ParamSymmetryHorizontal from './ParamSymmetryHorizontal'; import ParamSymmetryVertical from './ParamSymmetryVertical'; -import { useTranslation } from 'react-i18next'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { setShouldUseSymmetry } from 'features/parameters/store/generationSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { useTranslation } from 'react-i18next'; +import ParamSymmetryToggle from './ParamSymmetryToggle'; + +const selector = createSelector( + stateSelector, + (state) => ({ + activeLabel: state.generation.shouldUseSymmetry ? 'Enabled' : undefined, + }), + defaultSelectorOptions +); const ParamSymmetryCollapse = () => { const { t } = useTranslation(); - const shouldUseSymmetry = useAppSelector( - (state: RootState) => state.generation.shouldUseSymmetry - ); + const { activeLabel } = useAppSelector(selector); const isSymmetryEnabled = useFeatureStatus('symmetry').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => dispatch(setShouldUseSymmetry(!shouldUseSymmetry)); - if (!isSymmetryEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx index 7cc17c045e..59386ff526 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Symmetry/ParamSymmetryToggle.tsx @@ -12,6 +12,7 @@ export default function ParamSymmetryToggle() { return ( dispatch(setShouldUseSymmetry(e.target.checked))} /> diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx index 1564bd64e5..3cdfc3a06b 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationCollapse.tsx @@ -1,39 +1,42 @@ -import ParamVariationWeights from './ParamVariationWeights'; -import ParamVariationAmount from './ParamVariationAmount'; -import { useTranslation } from 'react-i18next'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; -import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice'; import { Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAICollapse from 'common/components/IAICollapse'; -import { memo } from 'react'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import ParamVariationAmount from './ParamVariationAmount'; +import { ParamVariationToggle } from './ParamVariationToggle'; +import ParamVariationWeights from './ParamVariationWeights'; + +const selector = createSelector( + stateSelector, + (state) => { + const activeLabel = state.generation.shouldGenerateVariations + ? 'Enabled' + : undefined; + + return { activeLabel }; + }, + defaultSelectorOptions +); const ParamVariationCollapse = () => { const { t } = useTranslation(); - const shouldGenerateVariations = useAppSelector( - (state: RootState) => state.generation.shouldGenerateVariations - ); + const { activeLabel } = useAppSelector(selector); const isVariationEnabled = useFeatureStatus('variation').isFeatureEnabled; - const dispatch = useAppDispatch(); - - const handleToggle = () => - dispatch(setShouldGenerateVariations(!shouldGenerateVariations)); - if (!isVariationEnabled) { return null; } return ( - + + diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx new file mode 100644 index 0000000000..1c05468de0 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Variations/ParamVariationToggle.tsx @@ -0,0 +1,27 @@ +import type { RootState } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAISwitch from 'common/components/IAISwitch'; +import { setShouldGenerateVariations } from 'features/parameters/store/generationSlice'; +import { ChangeEvent } from 'react'; +import { useTranslation } from 'react-i18next'; + +export const ParamVariationToggle = () => { + const dispatch = useAppDispatch(); + + const shouldGenerateVariations = useAppSelector( + (state: RootState) => state.generation.shouldGenerateVariations + ); + + const { t } = useTranslation(); + + const handleChange = (e: ChangeEvent) => + dispatch(setShouldGenerateVariations(e.target.checked)); + + return ( + + ); +}; diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 209cf4b639..960a41bb45 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -49,7 +49,6 @@ export interface GenerationState { verticalSymmetrySteps: number; model: ModelParam; vae: VAEParam; - shouldUseSeamless: boolean; seamlessXAxis: boolean; seamlessYAxis: boolean; } @@ -84,9 +83,8 @@ export const initialGenerationState: GenerationState = { verticalSymmetrySteps: 0, model: '', vae: '', - shouldUseSeamless: false, - seamlessXAxis: true, - seamlessYAxis: true, + seamlessXAxis: false, + seamlessYAxis: false, }; const initialState: GenerationState = initialGenerationState; @@ -144,9 +142,6 @@ export const generationSlice = createSlice({ setImg2imgStrength: (state, action: PayloadAction) => { state.img2imgStrength = action.payload; }, - setSeamless: (state, action: PayloadAction) => { - state.shouldUseSeamless = action.payload; - }, setSeamlessXAxis: (state, action: PayloadAction) => { state.seamlessXAxis = action.payload; }, @@ -268,7 +263,6 @@ export const { modelSelected, vaeSelected, setShouldUseNoiseSettings, - setSeamless, setSeamlessXAxis, setSeamlessYAxis, } = generationSlice.actions; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx index 89286232c6..5f5c7ad46b 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx @@ -1,4 +1,4 @@ -import { Box, Flex, useDisclosure } from '@chakra-ui/react'; +import { Box, Flex } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; @@ -21,19 +21,25 @@ const selector = createSelector( [uiSelector, generationSelector], (ui, generation) => { const { shouldUseSliders } = ui; - const { shouldFitToWidthHeight } = generation; + const { shouldFitToWidthHeight, shouldRandomizeSeed } = generation; - return { shouldUseSliders, shouldFitToWidthHeight }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, shouldFitToWidthHeight, activeLabel }; }, defaultSelectorOptions ); const ImageToImageTabCoreParameters = () => { - const { shouldUseSliders, shouldFitToWidthHeight } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, shouldFitToWidthHeight, activeLabel } = + useAppSelector(selector); return ( - + { + stateSelector, + ({ ui, generation }) => { const { shouldUseSliders } = ui; + const { shouldRandomizeSeed } = generation; - return { shouldUseSliders }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, activeLabel }; }, defaultSelectorOptions ); const TextToImageTabCoreParameters = () => { - const { shouldUseSliders } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, activeLabel } = useAppSelector(selector); return ( - + { + stateSelector, + ({ ui, generation }) => { const { shouldUseSliders } = ui; + const { shouldRandomizeSeed } = generation; - return { shouldUseSliders }; + const activeLabel = !shouldRandomizeSeed ? 'Manual Seed' : undefined; + + return { shouldUseSliders, activeLabel }; }, defaultSelectorOptions ); const UnifiedCanvasCoreParameters = () => { - const { shouldUseSliders } = useAppSelector(selector); - const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true }); + const { shouldUseSliders, activeLabel } = useAppSelector(selector); return ( - + Date: Wed, 5 Jul 2023 18:21:46 +1000 Subject: [PATCH 18/28] fix(ui): deleting image selects first image --- .../listeners/imageDeleted.ts | 26 ++++++++++--------- .../features/gallery/store/gallerySlice.ts | 6 ----- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index ca20170c5d..f083a716a4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -1,21 +1,21 @@ -import { startAppListening } from '..'; -import { imageDeleted } from 'services/api/thunks/image'; import { log } from 'app/logging/useLogger'; -import { clamp } from 'lodash-es'; -import { - imageSelected, - imageRemoved, - selectImagesIds, -} from 'features/gallery/store/gallerySlice'; import { resetCanvas } from 'features/canvas/store/canvasSlice'; import { controlNetReset } from 'features/controlNet/store/controlNetSlice'; -import { clearInitialImage } from 'features/parameters/store/generationSlice'; -import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; -import { api } from 'services/api'; +import { + imageRemoved, + imageSelected, + selectFilteredImages, +} from 'features/gallery/store/gallerySlice'; import { imageDeletionConfirmed, isModalOpenChanged, } from 'features/imageDeletion/store/imageDeletionSlice'; +import { nodeEditorReset } from 'features/nodes/store/nodesSlice'; +import { clearInitialImage } from 'features/parameters/store/generationSlice'; +import { clamp } from 'lodash-es'; +import { api } from 'services/api'; +import { imageDeleted } from 'services/api/thunks/image'; +import { startAppListening } from '..'; const moduleLog = log.child({ namespace: 'image' }); @@ -37,7 +37,9 @@ export const addRequestedImageDeletionListener = () => { state.gallery.selection[state.gallery.selection.length - 1]; if (lastSelectedImage === image_name) { - const ids = selectImagesIds(state); + const filteredImages = selectFilteredImages(state); + + const ids = filteredImages.map((i) => i.image_name); const deletedImageIndex = ids.findIndex( (result) => result.toString() === image_name diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts index f4d2babf38..41a52e3452 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySlice.ts @@ -7,7 +7,6 @@ import { import { RootState } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { dateComparator } from 'common/util/dateComparator'; -import { imageDeletionConfirmed } from 'features/imageDeletion/store/imageDeletionSlice'; import { keyBy, uniq } from 'lodash-es'; import { boardsApi } from 'services/api/endpoints/boards'; import { @@ -174,11 +173,6 @@ export const gallerySlice = createSlice({ state.limit = limit; state.total = total; }); - builder.addCase(imageDeletionConfirmed, (state, action) => { - // Image deleted - const { image_name } = action.payload.imageDTO; - imagesAdapter.removeOne(state, image_name); - }); builder.addCase(imageUrlsReceived.fulfilled, (state, action) => { const { image_name, image_url, thumbnail_url } = action.payload; From 2a7dee17bef16f72437e5b0cf49ebbce7534c299 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 19:06:40 +1000 Subject: [PATCH 19/28] fix(ui): fix dnd on nodes I had broken this earlier today --- .../components/ControlNetImagePreview.tsx | 17 +++++---- .../fields/ImageInputFieldComponent.tsx | 35 +++++-------------- 2 files changed, 16 insertions(+), 36 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index c0c1030b79..dde449a464 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -80,15 +80,14 @@ const ControlNetImagePreview = (props: Props) => { } }, [controlImage, controlNetId]); - const droppableData = useMemo(() => { - if (controlNetId) { - return { - id: controlNetId, - actionType: 'SET_CONTROLNET_IMAGE', - context: { controlNetId }, - }; - } - }, [controlNetId]); + const droppableData = useMemo( + () => ({ + id: controlNetId, + actionType: 'SET_CONTROLNET_IMAGE', + context: { controlNetId }, + }), + [controlNetId] + ); const postUploadAction = useMemo( () => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }), diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx index bfae89e931..34e403f9cc 100644 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/fields/ImageInputFieldComponent.tsx @@ -16,7 +16,6 @@ import { import IAIDndImage from 'common/components/IAIDndImage'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/thunks/image'; -import { ImageDTO } from 'services/api/types'; import { FieldComponentProps } from './types'; const ImageInputFieldComponent = ( @@ -33,23 +32,6 @@ const ImageInputFieldComponent = ( isSuccess, } = useGetImageDTOQuery(field.value?.image_name ?? skipToken); - const handleDrop = useCallback( - ({ image_name }: ImageDTO) => { - if (field.value?.image_name === image_name) { - return; - } - - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: { image_name }, - }) - ); - }, - [dispatch, field.name, field.value, nodeId] - ); - const handleReset = useCallback(() => { dispatch( fieldValueChanged({ @@ -70,15 +52,14 @@ const ImageInputFieldComponent = ( } }, [field.name, imageDTO, nodeId]); - const droppableData = useMemo(() => { - if (imageDTO) { - return { - id: `node-${nodeId}-${field.name}`, - actionType: 'SET_NODES_IMAGE', - context: { nodeId, fieldName: field.name }, - }; - } - }, [field.name, imageDTO, nodeId]); + const droppableData = useMemo( + () => ({ + id: `node-${nodeId}-${field.name}`, + actionType: 'SET_NODES_IMAGE', + context: { nodeId, fieldName: field.name }, + }), + [field.name, nodeId] + ); const postUploadAction = useMemo( () => ({ From 307a01d6049c012c2ba0338ac28c3890017f729a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 1 Jul 2023 21:08:59 -0400 Subject: [PATCH 20/28] when migrating models, changes / to _ in model names to avoid breaking model name keys --- invokeai/backend/install/migrate_to_3.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index c8e024f484..fb3d964c7b 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -76,6 +76,10 @@ class MigrateTo3(object): Create a unique name for a model for use within models.yaml. ''' done = False + + # some model names have slashes in them, which really screws things up + name = name.replace('/','_') + key = ModelManager.create_key(name,info.base_type,info.model_type) unique_name = key counter = 1 From e3fc1b3816cb04873bfed2ffb19a2909a3d29d40 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 5 Jul 2023 13:43:09 +0300 Subject: [PATCH 21/28] Fix clip path in migrate script --- invokeai/backend/install/migrate_to_3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index c8e024f484..e192a4513c 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -219,11 +219,11 @@ class MigrateTo3(object): repo_id = 'openai/clip-vit-large-patch14' self._migrate_pretrained(CLIPTokenizer, repo_id= repo_id, - dest= target_dir / 'clip-vit-large-patch14' / 'tokenizer', + dest= target_dir / 'clip-vit-large-patch14', **kwargs) self._migrate_pretrained(CLIPTextModel, repo_id = repo_id, - dest = target_dir / 'clip-vit-large-patch14' / 'text_encoder', + dest = target_dir / 'clip-vit-large-patch14', **kwargs) # sd-2 From 596c791844d113e228fbda49649a1f5b0dda5a99 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 21:02:31 +1000 Subject: [PATCH 22/28] fix(ui): fix prompt resize & style resizer --- .../web/src/theme/components/textarea.ts | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/theme/components/textarea.ts b/invokeai/frontend/web/src/theme/components/textarea.ts index 85e6e37d3f..b737cf5e57 100644 --- a/invokeai/frontend/web/src/theme/components/textarea.ts +++ b/invokeai/frontend/web/src/theme/components/textarea.ts @@ -1,7 +1,28 @@ import { defineStyle, defineStyleConfig } from '@chakra-ui/react'; import { getInputOutlineStyles } from '../util/getInputOutlineStyles'; -const invokeAI = defineStyle((props) => getInputOutlineStyles(props)); +const invokeAI = defineStyle((props) => ({ + ...getInputOutlineStyles(props), + '::-webkit-scrollbar': { + display: 'initial', + }, + '::-webkit-resizer': { + backgroundImage: `linear-gradient(135deg, + var(--invokeai-colors-base-50) 0%, + var(--invokeai-colors-base-50) 70%, + var(--invokeai-colors-base-200) 70%, + var(--invokeai-colors-base-200) 100%)`, + }, + _dark: { + '::-webkit-resizer': { + backgroundImage: `linear-gradient(135deg, + var(--invokeai-colors-base-900) 0%, + var(--invokeai-colors-base-900) 70%, + var(--invokeai-colors-base-800) 70%, + var(--invokeai-colors-base-800) 100%)`, + }, + }, +})); export const textareaTheme = defineStyleConfig({ variants: { From ee042ab76d75e9292332dd0feea99056ee468767 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 5 Jul 2023 14:18:30 +0300 Subject: [PATCH 23/28] Fix ckpt scanning on conversion --- .../backend/model_management/convert_ckpt_to_diffusers.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 1eeee92fb7..e3e64940de 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -29,7 +29,7 @@ import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from .model_manager import ModelManager -from .model_cache import ModelCache +from picklescan.scanner import scan_file_path from .models import BaseModelType, ModelVariantType try: @@ -1014,7 +1014,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt( checkpoint = load_file(checkpoint_path) else: if scan_needed: - ModelCache.scan_model(checkpoint_path, checkpoint_path) + # scan model + scan_result = scan_file_path(checkpoint_path) + if scan_result.infected_files != 0: + raise "The model {checkpoint_path} is potentially infected by malware. Aborting import." checkpoint = torch.load(checkpoint_path) # sometimes there is a state_dict key and sometimes not From acd3b1a512cd54c8fc09d7417a8abeeb00b0b630 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 5 Jul 2023 19:28:40 +1000 Subject: [PATCH 24/28] build: remove web ui dist from gitignore The web UI should manage its own .gitignore --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index 7f3b1278df..e9918d4fb5 100644 --- a/.gitignore +++ b/.gitignore @@ -201,8 +201,6 @@ checkpoints # If it's a Mac .DS_Store -invokeai/frontend/web/dist/* - # Let the frontend manage its own gitignore !invokeai/frontend/web/* From 0ac9dca926a0690cde46cc69d9ee0fdf17faabb4 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 5 Jul 2023 19:46:00 +0300 Subject: [PATCH 25/28] Fix loading diffusers ti --- invokeai/app/invocations/compel.py | 5 +++-- invokeai/backend/model_management/lora.py | 3 +++ invokeai/backend/model_management/model_manager.py | 6 +++--- invokeai/backend/model_management/models/__init__.py | 2 +- invokeai/backend/model_management/models/base.py | 3 +++ .../model_management/models/textual_inversion.py | 10 +++++++++- 6 files changed, 22 insertions(+), 7 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index d4ba7efeda..4850b9670d 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -9,6 +9,7 @@ from compel.prompt_parser import (Blend, Conjunction, FlattenedPrompt, Fragment) from pydantic import BaseModel, Field +from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management import BaseModelType, ModelType, SubModelType from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent @@ -86,10 +87,10 @@ class CompelInvocation(BaseInvocation): model_type=ModelType.TextualInversion, ).context.model ) - except Exception: + except ModelNotFoundException: # print(e) #import traceback - # print(traceback.format_exc()) + #print(traceback.format_exc()) print(f"Warn: trigger: \"{trigger}\" not found") with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\ diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 5d27555ab3..ae576e39d9 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -655,6 +655,9 @@ class TextualInversionModel: else: result.embedding = next(iter(state_dict.values())) + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + if not isinstance(result.embedding, torch.Tensor): raise ValueError(f"Invalid embeddings file: {file_path.name}") diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 8002ec9ba4..f15dcfac3c 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -249,7 +249,7 @@ from .model_cache import ModelCache, ModelLocker from .models import ( BaseModelType, ModelType, SubModelType, ModelError, SchedulerPredictionType, MODEL_CLASSES, - ModelConfigBase, + ModelConfigBase, ModelNotFoundException, ) # We are only starting to number the config file with release 3. @@ -409,7 +409,7 @@ class ModelManager(object): if model_key not in self.models: self.scan_models_directory(base_model=base_model, model_type=model_type) if model_key not in self.models: - raise Exception(f"Model not found - {model_key}") + raise ModelNotFoundException(f"Model not found - {model_key}") model_config = self.models[model_key] model_path = self.app_config.root_path / model_config.path @@ -421,7 +421,7 @@ class ModelManager(object): else: self.models.pop(model_key, None) - raise Exception(f"Model not found - {model_key}") + raise ModelNotFoundException(f"Model not found - {model_key}") # vae/movq override # TODO: diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 87b0ad3c4e..00630eef62 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -2,7 +2,7 @@ import inspect from enum import Enum from pydantic import BaseModel from typing import Literal, get_origin -from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings +from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings, ModelNotFoundException from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .vae import VaeModel from .lora import LoRAModel diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index afa62b2e4f..57c02bce76 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -15,6 +15,9 @@ from contextlib import suppress from pydantic import BaseModel, Field from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union +class ModelNotFoundException(Exception): + pass + class BaseModelType(str, Enum): StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" diff --git a/invokeai/backend/model_management/models/textual_inversion.py b/invokeai/backend/model_management/models/textual_inversion.py index 9a032218f0..4dcdbb24ba 100644 --- a/invokeai/backend/model_management/models/textual_inversion.py +++ b/invokeai/backend/model_management/models/textual_inversion.py @@ -8,6 +8,7 @@ from .base import ( ModelType, SubModelType, classproperty, + ModelNotFoundException, ) # TODO: naming from ..lora import TextualInversionModel as TextualInversionModelRaw @@ -37,8 +38,15 @@ class TextualInversionModel(ModelBase): if child_type is not None: raise Exception("There is no child models in textual inversion") + checkpoint_path = self.model_path + if os.path.isdir(checkpoint_path): + checkpoint_path = os.path.join(checkpoint_path, "learned_embeds.bin") + + if not os.path.exists(checkpoint_path): + raise ModelNotFoundException() + model = TextualInversionModelRaw.from_checkpoint( - file_path=self.model_path, + file_path=checkpoint_path, dtype=torch_dtype, ) From ea81ce94898597be522df802316ef3afad7b181a Mon Sep 17 00:00:00 2001 From: Mary Hipp Rogers Date: Wed, 5 Jul 2023 13:12:27 -0400 Subject: [PATCH 26/28] close modal when user clicks cancel (#3656) * close modal when user clicks cancel * close modal when delete image context cleared --------- Co-authored-by: Mary Hipp --- .../src/features/imageDeletion/components/DeleteImageModal.tsx | 2 ++ .../web/src/features/imageDeletion/store/imageDeletionSlice.ts | 1 + 2 files changed, 3 insertions(+) diff --git a/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx b/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx index cdc8257488..8306437cc7 100644 --- a/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx +++ b/invokeai/frontend/web/src/features/imageDeletion/components/DeleteImageModal.tsx @@ -23,6 +23,7 @@ import { stateSelector } from 'app/store/store'; import { imageDeletionConfirmed, imageToDeleteCleared, + isModalOpenChanged, selectImageUsage, } from '../store/imageDeletionSlice'; @@ -63,6 +64,7 @@ const DeleteImageModal = () => { const handleClose = useCallback(() => { dispatch(imageToDeleteCleared()); + dispatch(isModalOpenChanged(false)); }, [dispatch]); const handleDelete = useCallback(() => { diff --git a/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts b/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts index 0daffba0d7..49630bcdb4 100644 --- a/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts +++ b/invokeai/frontend/web/src/features/imageDeletion/store/imageDeletionSlice.ts @@ -31,6 +31,7 @@ const imageDeletion = createSlice({ }, imageToDeleteCleared: (state) => { state.imageToDelete = null; + state.isModalOpen = false; }, }, }); From cf173b522bb14af4191414a48d338b344d88bf36 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 5 Jul 2023 13:14:41 -0400 Subject: [PATCH 27/28] allow clip-vit-large-patch14 text encoder to coexist with tokenizer in same directory --- invokeai/backend/install/migrate_to_3.py | 27 +++++++++++++----------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/invokeai/backend/install/migrate_to_3.py b/invokeai/backend/install/migrate_to_3.py index b32890f6b7..6f9cee6246 100644 --- a/invokeai/backend/install/migrate_to_3.py +++ b/invokeai/backend/install/migrate_to_3.py @@ -228,6 +228,7 @@ class MigrateTo3(object): self._migrate_pretrained(CLIPTextModel, repo_id = repo_id, dest = target_dir / 'clip-vit-large-patch14', + force = True, **kwargs) # sd-2 @@ -291,21 +292,21 @@ class MigrateTo3(object): def _model_probe_to_path(self, info: ModelProbeInfo)->Path: return Path(self.dest_models, info.base_type.value, info.model_type.value) - def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, **kwargs): - if dest.exists(): + def _migrate_pretrained(self, model_class, repo_id: str, dest: Path, force:bool=False, **kwargs): + if dest.exists() and not force: logger.info(f'Skipping existing {dest}') return model = model_class.from_pretrained(repo_id, **kwargs) - self._save_pretrained(model, dest) + self._save_pretrained(model, dest, overwrite=force) - def _save_pretrained(self, model, dest: Path): - if dest.exists(): - logger.info(f'Skipping existing {dest}') - return + def _save_pretrained(self, model, dest: Path, overwrite: bool=False): model_name = dest.name - download_path = dest.with_name(f'{model_name}.downloading') - model.save_pretrained(download_path, safe_serialization=True) - download_path.replace(dest) + if overwrite: + model.save_pretrained(dest, safe_serialization=True) + else: + download_path = dest.with_name(f'{model_name}.downloading') + model.save_pretrained(download_path, safe_serialization=True) + download_path.replace(dest) def _download_vae(self, repo_id: str, subfolder:str=None)->Path: vae = AutoencoderKL.from_pretrained(repo_id, cache_dir=self.root_directory / 'models/hub', subfolder=subfolder) @@ -573,8 +574,10 @@ script, which will perform a full upgrade in place.""" dest_directory = args.dest_directory assert dest_directory.is_dir(), f"{dest_directory} is not a valid directory" - assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory" - assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." + + # TODO: revisit + # assert (dest_directory / 'models').is_dir(), f"{dest_directory} does not contain a 'models' subdirectory" + # assert (dest_directory / 'invokeai.yaml').exists(), f"{dest_directory} does not contain an InvokeAI init file." do_migrate(root_directory,dest_directory) From c21245f590ed4eaed8ca1ff624f829411e392ca1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 6 Jul 2023 15:34:50 +1000 Subject: [PATCH 28/28] fix(api): make list models params querys, make path `/`, remove defaults The list models route should just be the base route path, and should use query parameters as opposed to path parameters (which cannot be optional) Removed defaults for update model route - for the purposes of the API, we should always be explicit with this --- invokeai/app/api/routers/models.py | 71 +++--------------------------- 1 file changed, 7 insertions(+), 64 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 25b227e87a..1d070fdee1 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -26,17 +26,13 @@ class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @models_router.get( - "/{base_model}/{model_type}", + "/", operation_id="list_models", responses={200: {"model": ModelsList }}, ) async def list_models( - base_model: Optional[BaseModelType] = Path( - default=None, description="Base model" - ), - model_type: Optional[ModelType] = Path( - default=None, description="The type of model to get" - ), + base_model: Optional[BaseModelType] = Query(default=None, description="Base model"), + model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Gets a list of models""" models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type) @@ -54,10 +50,10 @@ async def list_models( response_model = UpdateModelResponse, ) async def update_model( - base_model: BaseModelType = Path(default='sd-1', description="Base model"), - model_type: ModelType = Path(default='main', description="The type of model"), - model_name: str = Path(default=None, description="model name"), - info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), + info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), ) -> UpdateModelResponse: """ Add Model """ try: @@ -194,56 +190,3 @@ async def convert_model( except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response - - # @socketio.on("mergeDiffusersModels") - # def merge_diffusers_models(model_merge_info: dict): - # try: - # models_to_merge = model_merge_info["models_to_merge"] - # model_ids_or_paths = [ - # self.generate.model_manager.model_name_or_path(x) - # for x in models_to_merge - # ] - # merged_pipe = merge_diffusion_models( - # model_ids_or_paths, - # model_merge_info["alpha"], - # model_merge_info["interp"], - # model_merge_info["force"], - # ) - - # dump_path = global_models_dir() / "merged_models" - # if model_merge_info["model_merge_save_path"] is not None: - # dump_path = Path(model_merge_info["model_merge_save_path"]) - - # os.makedirs(dump_path, exist_ok=True) - # dump_path = dump_path / model_merge_info["merged_model_name"] - # merged_pipe.save_pretrained(dump_path, safe_serialization=1) - - # merged_model_config = dict( - # model_name=model_merge_info["merged_model_name"], - # description=f'Merge of models {", ".join(models_to_merge)}', - # commit_to_conf=opt.conf, - # ) - - # if vae := self.generate.model_manager.config[models_to_merge[0]].get( - # "vae", None - # ): - # print(f">> Using configured VAE assigned to {models_to_merge[0]}") - # merged_model_config.update(vae=vae) - - # self.generate.model_manager.import_diffuser_model( - # dump_path, **merged_model_config - # ) - # new_model_list = self.generate.model_manager.list_models() - - # socketio.emit( - # "modelsMerged", - # { - # "merged_models": models_to_merge, - # "merged_model_name": model_merge_info["merged_model_name"], - # "model_list": new_model_list, - # "update": True, - # }, - # ) - # print(f">> Models Merged: {models_to_merge}") - # print(f">> New Model Added: {model_merge_info['merged_model_name']}") - # except Exception as e: