diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index c37dcda998..7eff62a8a5 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -9,6 +9,7 @@ from typing import Literal, Optional, Union, List, Dict from PIL import Image from pydantic import BaseModel, Field, validator +from ...backend.model_management import BaseModelType, ModelType from ..models.image import ImageField, ImageCategory, ResourceOrigin from .baseinvocation import ( BaseInvocation, @@ -105,9 +106,15 @@ CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control # CONTROLNET_RESIZE_VALUES = Literal[tuple(["just_resize", "crop_resize", "fill_resize"])] +class ControlNetModelField(BaseModel): + """ControlNet model field""" + + model_name: str = Field(description="Name of the ControlNet model") + base_model: BaseModelType = Field(description="Base model") + class ControlField(BaseModel): image: ImageField = Field(default=None, description="The control image") - control_model: Optional[str] = Field(default=None, description="The ControlNet model to use") + control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field(default=0, ge=0, le=1, @@ -118,15 +125,15 @@ class ControlField(BaseModel): # resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") @validator("control_weight") - def abs_le_one(cls, v): - """validate that all abs(values) are <=1""" + def validate_control_weight(cls, v): + """Validate that all control weights in the valid range""" if isinstance(v, list): for i in v: - if abs(i) > 1: - raise ValueError('all abs(control_weight) must be <= 1') + if i < -1 or i > 2: + raise ValueError('Control weights must be within -1 to 2 range') else: - if abs(v) > 1: - raise ValueError('abs(control_weight) must be <= 1') + if v < -1 or v > 2: + raise ValueError('Control weights must be within -1 to 2 range') return v class Config: schema_extra = { @@ -134,6 +141,7 @@ class ControlField(BaseModel): "ui": { "type_hints": { "control_weight": "float", + "control_model": "controlnet_model", # "control_weight": "number", } } @@ -154,10 +162,10 @@ class ControlNetInvocation(BaseInvocation): type: Literal["controlnet"] = "controlnet" # Inputs image: ImageField = Field(default=None, description="The control image") - control_model: CONTROLNET_NAME_VALUES = Field(default="lllyasviel/sd-controlnet-canny", + control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny", description="control model used") control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") - begin_step_percent: float = Field(default=0, ge=0, le=1, + begin_step_percent: float = Field(default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)") end_step_percent: float = Field(default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)") diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index b3f95f3658..baf78c7c23 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -1,5 +1,6 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +from contextlib import ExitStack from typing import List, Literal, Optional, Union import einops @@ -11,6 +12,7 @@ from pydantic import BaseModel, Field, validator from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend.model_management.models.base import ModelType from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState @@ -71,16 +73,21 @@ def get_scheduler( scheduler_name: str, ) -> Scheduler: scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get( - scheduler_name, SCHEDULER_MAP['ddim']) + scheduler_name, SCHEDULER_MAP['ddim'] + ) orig_scheduler_info = context.services.model_manager.get_model( - **scheduler_info.dict()) + **scheduler_info.dict() + ) with orig_scheduler_info as orig_scheduler: scheduler_config = orig_scheduler.config if "_backup" in scheduler_config: scheduler_config = scheduler_config["_backup"] - scheduler_config = {**scheduler_config, ** - scheduler_extra_config, "_backup": scheduler_config} + scheduler_config = { + **scheduler_config, + **scheduler_extra_config, + "_backup": scheduler_config, + } scheduler = scheduler_class.from_config(scheduler_config) # hack copied over from generate.py @@ -137,8 +144,11 @@ class TextToLatentsInvocation(BaseInvocation): # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( - self, context: InvocationContext, source_node_id: str, - intermediate_state: PipelineIntermediateState) -> None: + self, + context: InvocationContext, + source_node_id: str, + intermediate_state: PipelineIntermediateState, + ) -> None: stable_diffusion_step_callback( context=context, intermediate_state=intermediate_state, @@ -147,11 +157,16 @@ class TextToLatentsInvocation(BaseInvocation): ) def get_conditioning_data( - self, context: InvocationContext, scheduler) -> ConditioningData: + self, + context: InvocationContext, + scheduler, + ) -> ConditioningData: c, extra_conditioning_info = context.services.latents.get( - self.positive_conditioning.conditioning_name) + self.positive_conditioning.conditioning_name + ) uc, _ = context.services.latents.get( - self.negative_conditioning.conditioning_name) + self.negative_conditioning.conditioning_name + ) conditioning_data = ConditioningData( unconditioned_embeddings=uc, @@ -178,7 +193,10 @@ class TextToLatentsInvocation(BaseInvocation): return conditioning_data def create_pipeline( - self, unet, scheduler) -> StableDiffusionGeneratorPipeline: + self, + unet, + scheduler, + ) -> StableDiffusionGeneratorPipeline: # TODO: # configure_model_padding( # unet, @@ -213,6 +231,7 @@ class TextToLatentsInvocation(BaseInvocation): model: StableDiffusionGeneratorPipeline, control_input: List[ControlField], latents_shape: List[int], + exit_stack: ExitStack, do_classifier_free_guidance: bool = True, ) -> List[ControlNetData]: @@ -238,25 +257,19 @@ class TextToLatentsInvocation(BaseInvocation): control_data = [] control_models = [] for control_info in control_list: - # handle control models - if ("," in control_info.control_model): - control_model_split = control_info.control_model.split(",") - control_name = control_model_split[0] - control_subfolder = control_model_split[1] - print("Using HF model subfolders") - print(" control_name: ", control_name) - print(" control_subfolder: ", control_subfolder) - control_model = ControlNetModel.from_pretrained( - control_name, subfolder=control_subfolder, - torch_dtype=model.unet.dtype).to( - model.device) - else: - control_model = ControlNetModel.from_pretrained( - control_info.control_model, torch_dtype=model.unet.dtype).to(model.device) + control_model = exit_stack.enter_context( + context.services.model_manager.get_model( + model_name=control_info.control_model.model_name, + model_type=ModelType.ControlNet, + base_model=control_info.control_model.base_model, + ) + ) + control_models.append(control_model) control_image_field = control_info.image input_image = context.services.images.get_pil_image( - control_image_field.image_name) + control_image_field.image_name + ) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -278,7 +291,8 @@ class TextToLatentsInvocation(BaseInvocation): weight=control_info.control_weight, begin_step_percent=control_info.begin_step_percent, end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode,) + control_mode=control_info.control_mode, + ) control_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] return control_data @@ -289,7 +303,8 @@ class TextToLatentsInvocation(BaseInvocation): # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id) + context.graph_execution_state_id + ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -298,14 +313,17 @@ class TextToLatentsInvocation(BaseInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}) + ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict()) - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + **self.unet.unet.dict() + ) + with ExitStack() as exit_stack,\ + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: scheduler = get_scheduler( @@ -322,6 +340,7 @@ class TextToLatentsInvocation(BaseInvocation): latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, + exit_stack=exit_stack, ) # TODO: Verify the noise is the right size @@ -374,7 +393,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): # Get the source node id (we are invoking the prepared node) graph_execution_state = context.services.graph_execution_manager.get( - context.graph_execution_state_id) + context.graph_execution_state_id + ) source_node_id = graph_execution_state.prepared_source_mapping[self.id] def step_callback(state: PipelineIntermediateState): @@ -383,14 +403,17 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): def _lora_loader(): for lora in self.unet.loras: lora_info = context.services.model_manager.get_model( - **lora.dict(exclude={"weight"})) + **lora.dict(exclude={"weight"}) + ) yield (lora_info.context.model, lora.weight) del lora_info return unet_info = context.services.model_manager.get_model( - **self.unet.unet.dict()) - with ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ + **self.unet.unet.dict() + ) + with ExitStack() as exit_stack,\ + ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\ unet_info as unet: scheduler = get_scheduler( @@ -407,11 +430,13 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation): latents_shape=noise.shape, # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) do_classifier_free_guidance=True, + exit_stack=exit_stack, ) # TODO: Verify the noise is the right size initial_latents = latent if self.strength < 1.0 else torch.zeros_like( - latent, device=unet.device, dtype=latent.dtype) + latent, device=unet.device, dtype=latent.dtype + ) timesteps, _ = pipeline.get_img2img_timesteps( self.steps, @@ -535,7 +560,8 @@ class ResizeLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents, size=(self.height // 8, self.width // 8), mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False,) + if self.mode in ["bilinear", "bicubic"] else False, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() @@ -569,7 +595,8 @@ class ScaleLatentsInvocation(BaseInvocation): resized_latents = torch.nn.functional.interpolate( latents, scale_factor=self.scale_factor, mode=self.mode, antialias=self.antialias - if self.mode in ["bilinear", "bicubic"] else False,) + if self.mode in ["bilinear", "bicubic"] else False, + ) # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 torch.cuda.empty_cache() diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts index dd2fb6f469..a923bd0b60 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetAutoProcess.ts @@ -13,7 +13,11 @@ import { RootState } from 'app/store/store'; const moduleLog = log.child({ namespace: 'controlNet' }); -const predicate: AnyListenerPredicate = (action, state) => { +const predicate: AnyListenerPredicate = ( + action, + state, + prevState +) => { const isActionMatched = controlNetProcessorParamsChanged.match(action) || controlNetModelChanged.match(action) || @@ -25,6 +29,16 @@ const predicate: AnyListenerPredicate = (action, state) => { return false; } + if (controlNetAutoConfigToggled.match(action)) { + // do not process if the user just disabled auto-config + if ( + prevState.controlNet.controlNets[action.payload.controlNetId] + .shouldAutoConfig === true + ) { + return false; + } + } + const { controlImage, processorType, shouldAutoConfig } = state.controlNet.controlNets[action.payload.controlNetId]; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index ee879a8915..05076960fb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -10,6 +10,7 @@ import { zMainModel } from 'features/parameters/types/parameterSchemas'; import { addToast } from 'features/system/store/systemSlice'; import { forEach } from 'lodash-es'; import { startAppListening } from '..'; +import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice'; const moduleLog = log.child({ module: 'models' }); @@ -51,7 +52,14 @@ export const addModelSelectedListener = () => { modelsCleared += 1; } - // TODO: handle incompatible controlnet; pending model manager support + const { controlNets } = state.controlNet; + forEach(controlNets, (controlNet, controlNetId) => { + if (controlNet.model?.base_model !== base_model) { + dispatch(controlNetRemoved({ controlNetId })); + modelsCleared += 1; + } + }); + if (modelsCleared > 0) { dispatch( addToast( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index f8abcfa758..5e3caa7c99 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -11,6 +11,7 @@ import { import { forEach, some } from 'lodash-es'; import { modelsApi } from 'services/api/endpoints/models'; import { startAppListening } from '..'; +import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice'; const moduleLog = log.child({ module: 'models' }); @@ -127,7 +128,22 @@ export const addModelsLoadedListener = () => { matcher: modelsApi.endpoints.getControlNetModels.matchFulfilled, effect: async (action, { getState, dispatch }) => { // ControlNet models loaded - need to remove missing ControlNets from state - // TODO: pending model manager controlnet support + const controlNets = getState().controlNet.controlNets; + + forEach(controlNets, (controlNet, controlNetId) => { + const isControlNetAvailable = some( + action.payload.entities, + (m) => + m?.model_name === controlNet?.model?.model_name && + m?.base_model === controlNet?.model?.base_model + ); + + if (isControlNetAvailable) { + return; + } + + dispatch(controlNetRemoved({ controlNetId })); + }); }, }); }; diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 40b8c1c73a..be642a6435 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -1,5 +1,5 @@ import { - CONTROLNET_MODELS, + // CONTROLNET_MODELS, CONTROLNET_PROCESSORS, } from 'features/controlNet/store/constants'; import { InvokeTabName } from 'features/ui/store/tabMap'; @@ -128,7 +128,7 @@ export type AppConfig = { canRestoreDeletedImagesFromBin: boolean; sd: { defaultModel?: string; - disabledControlNetModels: (keyof typeof CONTROLNET_MODELS)[]; + disabledControlNetModels: string[]; disabledControlNetProcessors: (keyof typeof CONTROLNET_PROCESSORS)[]; iterations: { initial: number; diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx index 59a1d281fe..991398f5a0 100644 --- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx @@ -170,12 +170,14 @@ const IAIDndImage = (props: IAIDndImageProps) => { )} {!imageDTO && isUploadDisabled && noContentFallback} - - {imageDTO && ( + {!isDropDisabled && ( + + )} + {imageDTO && !isDragDisabled && ( & { tooltip?: string; inputRef?: RefObject; + label?: string; }; const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { - const { searchable = true, tooltip, inputRef, ...rest } = props; + const { + searchable = true, + tooltip, + inputRef, + label, + disabled, + ...rest + } = props; const dispatch = useAppDispatch(); const handleKeyDown = useCallback( @@ -37,7 +45,15 @@ const IAIMantineMultiSelect = (props: IAIMultiSelectProps) => { return ( + {label} + + ) : undefined + } ref={inputRef} + disabled={disabled} onKeyDown={handleKeyDown} onKeyUp={handleKeyUp} searchable={searchable} diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx index edf1665bb4..2c3f5434ad 100644 --- a/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx +++ b/invokeai/frontend/web/src/common/components/IAIMantineSearchableSelect.tsx @@ -1,4 +1,4 @@ -import { Tooltip } from '@chakra-ui/react'; +import { FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { Select, SelectProps } from '@mantine/core'; import { useAppDispatch } from 'app/store/storeHooks'; import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; @@ -11,13 +11,22 @@ export type IAISelectDataType = { tooltip?: string; }; -type IAISelectProps = SelectProps & { +type IAISelectProps = Omit & { tooltip?: string; + label?: string; inputRef?: RefObject; }; const IAIMantineSearchableSelect = (props: IAISelectProps) => { - const { searchable = true, tooltip, inputRef, onChange, ...rest } = props; + const { + searchable = true, + tooltip, + inputRef, + onChange, + label, + disabled, + ...rest + } = props; const dispatch = useAppDispatch(); const [searchValue, setSearchValue] = useState(''); @@ -61,6 +70,14 @@ const IAIMantineSearchableSelect = (props: IAISelectProps) => { +