diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index ae10cce140..11453d97f1 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -70,27 +70,25 @@ async def upload_image( raise HTTPException(status_code=500, detail="Failed to create image") -@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image") +@images_router.delete("/{image_name}", operation_id="delete_image") async def delete_image( - image_origin: ResourceOrigin = Path(description="The origin of image to delete"), image_name: str = Path(description="The name of the image to delete"), ) -> None: """Deletes an image""" try: - ApiDependencies.invoker.services.images.delete(image_origin, image_name) + ApiDependencies.invoker.services.images.delete(image_name) except Exception as e: # TODO: Does this need any exception handling at all? pass @images_router.patch( - "/{image_origin}/{image_name}", + "/{image_name}", operation_id="update_image", response_model=ImageDTO, ) async def update_image( - image_origin: ResourceOrigin = Path(description="The origin of image to update"), image_name: str = Path(description="The name of the image to update"), image_changes: ImageRecordChanges = Body( description="The changes to apply to the image" @@ -99,32 +97,29 @@ async def update_image( """Updates an image""" try: - return ApiDependencies.invoker.services.images.update( - image_origin, image_name, image_changes - ) + return ApiDependencies.invoker.services.images.update(image_name, image_changes) except Exception as e: raise HTTPException(status_code=400, detail="Failed to update image") @images_router.get( - "/{image_origin}/{image_name}/metadata", + "/{image_name}/metadata", operation_id="get_image_metadata", response_model=ImageDTO, ) async def get_image_metadata( - image_origin: ResourceOrigin = Path(description="The origin of image to get"), image_name: str = Path(description="The name of image to get"), ) -> ImageDTO: """Gets an image's metadata""" try: - return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name) + return ApiDependencies.invoker.services.images.get_dto(image_name) except Exception as e: raise HTTPException(status_code=404) @images_router.get( - "/{image_origin}/{image_name}", + "/{image_name}", operation_id="get_image_full", response_class=Response, responses={ @@ -136,15 +131,12 @@ async def get_image_metadata( }, ) async def get_image_full( - image_origin: ResourceOrigin = Path( - description="The type of full-resolution image file to get" - ), image_name: str = Path(description="The name of full-resolution image file to get"), ) -> FileResponse: """Gets a full-resolution image file""" try: - path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name) + path = ApiDependencies.invoker.services.images.get_path(image_name) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) @@ -160,7 +152,7 @@ async def get_image_full( @images_router.get( - "/{image_origin}/{image_name}/thumbnail", + "/{image_name}/thumbnail", operation_id="get_image_thumbnail", response_class=Response, responses={ @@ -172,14 +164,13 @@ async def get_image_full( }, ) async def get_image_thumbnail( - image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"), ) -> FileResponse: """Gets a thumbnail image file""" try: path = ApiDependencies.invoker.services.images.get_path( - image_origin, image_name, thumbnail=True + image_name, thumbnail=True ) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) @@ -192,25 +183,21 @@ async def get_image_thumbnail( @images_router.get( - "/{image_origin}/{image_name}/urls", + "/{image_name}/urls", operation_id="get_image_urls", response_model=ImageUrlsDTO, ) async def get_image_urls( - image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"), image_name: str = Path(description="The name of the image whose URL to get"), ) -> ImageUrlsDTO: """Gets an image and thumbnail URL""" try: - image_url = ApiDependencies.invoker.services.images.get_url( - image_origin, image_name - ) + image_url = ApiDependencies.invoker.services.images.get_url(image_name) thumbnail_url = ApiDependencies.invoker.services.images.get_url( - image_origin, image_name, thumbnail=True + image_name, thumbnail=True ) return ImageUrlsDTO( - image_origin=image_origin, image_name=image_name, image_url=image_url, thumbnail_url=thumbnail_url, diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index b32afe4941..f40954ebb6 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -193,9 +193,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): return image def invoke(self, context: InvocationContext) -> ImageOutput: - raw_image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + raw_image = context.services.images.get_pil_image(self.image.image_name) # image type should be PIL.PngImagePlugin.PngImageFile ? processed_image = self.run_processor(raw_image) @@ -216,10 +214,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): ) """Builds an ImageOutput and its ImageField""" - processed_image_field = ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ) + processed_image_field = ImageField(image_name=image_dto.image_name) return ImageOutput( image=processed_image_field, # width=processed_image.width, diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 5275116a2a..dd0ab4d027 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -36,12 +36,8 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) - mask = context.services.images.get_pil_image( - self.mask.image_origin, self.mask.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) + mask = context.services.images.get_pil_image(self.mask.image_name) # Convert to cv image/mask # TODO: consider making these utility functions @@ -65,10 +61,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 53d4d16330..21574c7323 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -83,9 +83,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): # loading controlnet image (currently requires pre-processed image) control_image = ( None if self.control_image is None - else context.services.images.get_pil_image( - self.control_image.image_origin, self.control_image.image_name - ) + else context.services.images.get_pil_image(self.control_image.image_name) ) # loading controlnet model if (self.control_model is None or self.control_model==''): @@ -125,10 +123,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -166,9 +161,7 @@ class ImageToImageInvocation(TextToImageInvocation): image = ( None if self.image is None - else context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + else context.services.images.get_pil_image(self.image.image_name) ) if self.fit: @@ -206,10 +199,7 @@ class ImageToImageInvocation(TextToImageInvocation): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -279,14 +269,12 @@ class InpaintInvocation(ImageToImageInvocation): image = ( None if self.image is None - else context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + else context.services.images.get_pil_image(self.image.image_name) ) mask = ( None if self.mask is None - else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name) + else context.services.images.get_pil_image(self.mask.image_name) ) # Handle invalid model parameter @@ -322,10 +310,7 @@ class InpaintInvocation(ImageToImageInvocation): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index d048410468..f85669eab1 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -72,13 +72,10 @@ class LoadImageInvocation(BaseInvocation): ) # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name) + image = context.services.images.get_pil_image(self.image.image_name) return ImageOutput( - image=ImageField( - image_name=self.image.image_name, - image_origin=self.image.image_origin, - ), + image=ImageField(image_name=self.image.image_name), width=image.width, height=image.height, ) @@ -95,19 +92,14 @@ class ShowImageInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) if image: image.show() # TODO: how to handle failure? return ImageOutput( - image=ImageField( - image_name=self.image.image_name, - image_origin=self.image.image_origin, - ), + image=ImageField(image_name=self.image.image_name), width=image.width, height=image.height, ) @@ -128,9 +120,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) image_crop = Image.new( mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0) @@ -147,10 +137,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -171,19 +158,13 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - base_image = context.services.images.get_pil_image( - self.base_image.image_origin, self.base_image.image_name - ) - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + base_image = context.services.images.get_pil_image(self.base_image.image_name) + image = context.services.images.get_pil_image(self.image.image_name) mask = ( None if self.mask is None else ImageOps.invert( - context.services.images.get_pil_image( - self.mask.image_origin, self.mask.image_name - ) + context.services.images.get_pil_image(self.mask.image_name) ) ) # TODO: probably shouldn't invert mask here... should user be required to do it? @@ -209,10 +190,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -230,9 +208,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> MaskOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) image_mask = image.split()[-1] if self.invert: @@ -248,9 +224,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): ) return MaskOutput( - mask=ImageField( - image_origin=image_dto.image_origin, image_name=image_dto.image_name - ), + mask=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -268,12 +242,8 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image1 = context.services.images.get_pil_image( - self.image1.image_origin, self.image1.image_name - ) - image2 = context.services.images.get_pil_image( - self.image2.image_origin, self.image2.image_name - ) + image1 = context.services.images.get_pil_image(self.image1.image_name) + image2 = context.services.images.get_pil_image(self.image2.image_name) multiply_image = ImageChops.multiply(image1, image2) @@ -287,9 +257,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_origin=image_dto.image_origin, image_name=image_dto.image_name - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -310,9 +278,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) channel_image = image.getchannel(self.channel) @@ -326,9 +292,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_origin=image_dto.image_origin, image_name=image_dto.image_name - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -349,9 +313,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) converted_image = image.convert(self.mode) @@ -365,9 +327,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_origin=image_dto.image_origin, image_name=image_dto.image_name - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -386,9 +346,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) blur = ( ImageFilter.GaussianBlur(self.radius) @@ -407,10 +365,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -450,9 +405,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] @@ -471,10 +424,7 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -493,9 +443,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] width = int(image.width * self.scale_factor) @@ -516,10 +464,7 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -538,9 +483,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = image_arr * (self.max - self.min) + self.max @@ -557,10 +500,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -579,9 +519,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = ( @@ -603,10 +541,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index a06780c1f5..ad67594c29 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -134,9 +134,7 @@ class InfillColorInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) @@ -153,10 +151,7 @@ class InfillColorInvocation(BaseInvocation): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -179,9 +174,7 @@ class InfillTileInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) infilled = tile_fill_missing( image.copy(), seed=self.seed, tile_size=self.tile_size @@ -198,10 +191,7 @@ class InfillTileInvocation(BaseInvocation): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -217,9 +207,7 @@ class InfillPatchMatchInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) if PatchMatch.patchmatch_available(): infilled = infill_patchmatch(image.copy()) @@ -236,10 +224,7 @@ class InfillPatchMatchInvocation(BaseInvocation): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 7da697a1ff..cf216e6c54 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -316,8 +316,7 @@ class TextToLatentsInvocation(BaseInvocation): torch_dtype=model.unet.dtype).to(model.device) control_models.append(control_model) control_image_field = control_info.image - input_image = context.services.images.get_pil_image(control_image_field.image_origin, - control_image_field.image_name) + input_image = context.services.images.get_pil_image(control_image_field.image_name) # self.image.image_type, self.image.image_name # FIXME: still need to test with different widths, heights, devices, dtypes # and add in batch_size, num_images_per_prompt? @@ -537,10 +536,7 @@ class LatentsToImageInvocation(BaseInvocation): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_type=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) @@ -635,9 +631,7 @@ class ImageToLatentsInvocation(BaseInvocation): # image = context.services.images.get( # self.image.image_type, self.image.image_name # ) - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) #vae_info = context.services.model_manager.get_model(**self.vae.vae.dict()) vae_info = context.services.model_manager.get_model( diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index 5313411400..4185de3fd3 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -28,9 +28,7 @@ class RestoreFaceInvocation(BaseInvocation): } def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], upscale=None, @@ -51,10 +49,7 @@ class RestoreFaceInvocation(BaseInvocation): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 80e1567047..42f85fd18d 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -30,9 +30,7 @@ class UpscaleInvocation(BaseInvocation): } def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image( - self.image.image_origin, self.image.image_name - ) + image = context.services.images.get_pil_image(self.image.image_name) results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], upscale=(self.level, self.strength), @@ -53,10 +51,7 @@ class UpscaleInvocation(BaseInvocation): ) return ImageOutput( - image=ImageField( - image_name=image_dto.image_name, - image_origin=image_dto.image_origin, - ), + image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, ) diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 6d48f2dbb1..988a3e1447 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -66,13 +66,10 @@ class InvalidImageCategoryException(ValueError): class ImageField(BaseModel): """An image field used for passing image objects between invocations""" - image_origin: ResourceOrigin = Field( - default=ResourceOrigin.INTERNAL, description="The type of the image" - ) image_name: Optional[str] = Field(default=None, description="The name of the image") class Config: - schema_extra = {"required": ["image_origin", "image_name"]} + schema_extra = {"required": ["image_name"]} class ColorField(BaseModel): diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index 68a994ea75..aeacfe3f1c 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -40,14 +40,12 @@ class ImageFileStorageBase(ABC): """Low-level service responsible for storing and retrieving image files.""" @abstractmethod - def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: + def get(self, image_name: str) -> PILImageType: """Retrieves an image as PIL Image.""" pass @abstractmethod - def get_path( - self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False - ) -> str: + def get_path(self, image_name: str, thumbnail: bool = False) -> str: """Gets the internal path to an image or thumbnail.""" pass @@ -62,7 +60,6 @@ class ImageFileStorageBase(ABC): def save( self, image: PILImageType, - image_origin: ResourceOrigin, image_name: str, metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, @@ -71,7 +68,7 @@ class ImageFileStorageBase(ABC): pass @abstractmethod - def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: + def delete(self, image_name: str) -> None: """Deletes an image and its thumbnail (if one exists).""" pass @@ -93,17 +90,14 @@ class DiskImageFileStorage(ImageFileStorageBase): Path(output_folder).mkdir(parents=True, exist_ok=True) # TODO: don't hard-code. get/save/delete should maybe take subpath? - for image_origin in ResourceOrigin: - Path(os.path.join(output_folder, image_origin)).mkdir( - parents=True, exist_ok=True - ) - Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir( - parents=True, exist_ok=True - ) + Path(os.path.join(output_folder)).mkdir(parents=True, exist_ok=True) + Path(os.path.join(output_folder, "thumbnails")).mkdir( + parents=True, exist_ok=True + ) - def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: + def get(self, image_name: str) -> PILImageType: try: - image_path = self.get_path(image_origin, image_name) + image_path = self.get_path(image_name) cache_item = self.__get_cache(image_path) if cache_item: return cache_item @@ -117,13 +111,12 @@ class DiskImageFileStorage(ImageFileStorageBase): def save( self, image: PILImageType, - image_origin: ResourceOrigin, image_name: str, metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, ) -> None: try: - image_path = self.get_path(image_origin, image_name) + image_path = self.get_path(image_name) if metadata is not None: pnginfo = PngImagePlugin.PngInfo() @@ -133,7 +126,7 @@ class DiskImageFileStorage(ImageFileStorageBase): image.save(image_path, "PNG") thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True) + thumbnail_path = self.get_path(thumbnail_name, thumbnail=True) thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image.save(thumbnail_path) @@ -142,10 +135,10 @@ class DiskImageFileStorage(ImageFileStorageBase): except Exception as e: raise ImageFileSaveException from e - def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: + def delete(self, image_name: str) -> None: try: basename = os.path.basename(image_name) - image_path = self.get_path(image_origin, basename) + image_path = self.get_path(basename) if os.path.exists(image_path): send2trash(image_path) @@ -153,7 +146,7 @@ class DiskImageFileStorage(ImageFileStorageBase): del self.__cache[image_path] thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_origin, thumbnail_name, True) + thumbnail_path = self.get_path(thumbnail_name, True) if os.path.exists(thumbnail_path): send2trash(thumbnail_path) @@ -163,19 +156,19 @@ class DiskImageFileStorage(ImageFileStorageBase): raise ImageFileDeleteException from e # TODO: make this a bit more flexible for e.g. cloud storage - def get_path( - self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False - ) -> str: + def get_path(self, image_name: str, thumbnail: bool = False) -> str: # strip out any relative path shenanigans basename = os.path.basename(image_name) if thumbnail: thumbnail_name = get_thumbnail_name(basename) path = os.path.join( - self.__output_folder, image_origin, "thumbnails", thumbnail_name + self.__output_folder, + "thumbnails", + thumbnail_name, ) else: - path = os.path.join(self.__output_folder, image_origin, basename) + path = os.path.join(self.__output_folder, basename) abspath = os.path.abspath(path) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 6907ac3952..30b379ed8b 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -21,6 +21,7 @@ from invokeai.app.services.models.image_record import ( T = TypeVar("T", bound=BaseModel) + class OffsetPaginatedResults(GenericModel, Generic[T]): """Offset-paginated results""" @@ -60,7 +61,7 @@ class ImageRecordStorageBase(ABC): # TODO: Implement an `update()` method @abstractmethod - def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: + def get(self, image_name: str) -> ImageRecord: """Gets an image record.""" pass @@ -68,7 +69,6 @@ class ImageRecordStorageBase(ABC): def update( self, image_name: str, - image_origin: ResourceOrigin, changes: ImageRecordChanges, ) -> None: """Updates an image record.""" @@ -89,7 +89,7 @@ class ImageRecordStorageBase(ABC): # TODO: The database has a nullable `deleted_at` column, currently unused. # Should we implement soft deletes? Would need coordination with ImageFileStorage. @abstractmethod - def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: + def delete(self, image_name: str) -> None: """Deletes an image record.""" pass @@ -196,9 +196,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) - def get( - self, image_origin: ResourceOrigin, image_name: str - ) -> Union[ImageRecord, None]: + def get(self, image_name: str) -> Union[ImageRecord, None]: try: self._lock.acquire() @@ -225,7 +223,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def update( self, image_name: str, - image_origin: ResourceOrigin, changes: ImageRecordChanges, ) -> None: try: @@ -294,9 +291,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): if categories is not None: ## Convert the enum values to unique list of strings - category_strings = list( - map(lambda c: c.value, set(categories)) - ) + category_strings = list(map(lambda c: c.value, set(categories))) # Create the correct length of placeholders placeholders = ",".join("?" * len(category_strings)) query_conditions += f"AND image_category IN ( {placeholders} )\n" @@ -337,7 +332,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): items=images, offset=offset, limit=limit, total=count ) - def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: + def delete(self, image_name: str) -> None: try: self._lock.acquire() self._cursor.execute( diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 2618a9763e..9f7188f607 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -57,7 +57,6 @@ class ImageServiceABC(ABC): @abstractmethod def update( self, - image_origin: ResourceOrigin, image_name: str, changes: ImageRecordChanges, ) -> ImageDTO: @@ -65,22 +64,22 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: + def get_pil_image(self, image_name: str) -> PILImageType: """Gets an image as a PIL image.""" pass @abstractmethod - def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: + def get_record(self, image_name: str) -> ImageRecord: """Gets an image record.""" pass @abstractmethod - def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: + def get_dto(self, image_name: str) -> ImageDTO: """Gets an image DTO.""" pass @abstractmethod - def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str: + def get_path(self, image_name: str) -> str: """Gets an image's path.""" pass @@ -90,9 +89,7 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_url( - self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False - ) -> str: + def get_url(self, image_name: str, thumbnail: bool = False) -> str: """Gets an image's or thumbnail's URL.""" pass @@ -109,7 +106,7 @@ class ImageServiceABC(ABC): pass @abstractmethod - def delete(self, image_origin: ResourceOrigin, image_name: str): + def delete(self, image_name: str): """Deletes an image.""" pass @@ -206,16 +203,13 @@ class ImageService(ImageServiceABC): ) self._services.files.save( - image_origin=image_origin, image_name=image_name, image=image, metadata=metadata, ) - image_url = self._services.urls.get_image_url(image_origin, image_name) - thumbnail_url = self._services.urls.get_image_url( - image_origin, image_name, True - ) + image_url = self._services.urls.get_image_url(image_name) + thumbnail_url = self._services.urls.get_image_url(image_name, True) return ImageDTO( # Non-nullable fields @@ -249,13 +243,12 @@ class ImageService(ImageServiceABC): def update( self, - image_origin: ResourceOrigin, image_name: str, changes: ImageRecordChanges, ) -> ImageDTO: try: - self._services.records.update(image_name, image_origin, changes) - return self.get_dto(image_origin, image_name) + self._services.records.update(image_name, changes) + return self.get_dto(image_name) except ImageRecordSaveException: self._services.logger.error("Failed to update image record") raise @@ -263,9 +256,9 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem updating image record") raise e - def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: + def get_pil_image(self, image_name: str) -> PILImageType: try: - return self._services.files.get(image_origin, image_name) + return self._services.files.get(image_name) except ImageFileNotFoundException: self._services.logger.error("Failed to get image file") raise @@ -273,9 +266,9 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image file") raise e - def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: + def get_record(self, image_name: str) -> ImageRecord: try: - return self._services.records.get(image_origin, image_name) + return self._services.records.get(image_name) except ImageRecordNotFoundException: self._services.logger.error("Image record not found") raise @@ -283,14 +276,14 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image record") raise e - def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: + def get_dto(self, image_name: str) -> ImageDTO: try: - image_record = self._services.records.get(image_origin, image_name) + image_record = self._services.records.get(image_name) image_dto = image_record_to_dto( image_record, - self._services.urls.get_image_url(image_origin, image_name), - self._services.urls.get_image_url(image_origin, image_name, True), + self._services.urls.get_image_url(image_name), + self._services.urls.get_image_url(image_name, True), ) return image_dto @@ -301,11 +294,9 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image DTO") raise e - def get_path( - self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False - ) -> str: + def get_path(self, image_name: str, thumbnail: bool = False) -> str: try: - return self._services.files.get_path(image_origin, image_name, thumbnail) + return self._services.files.get_path(image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e @@ -317,11 +308,9 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem validating image path") raise e - def get_url( - self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False - ) -> str: + def get_url(self, image_name: str, thumbnail: bool = False) -> str: try: - return self._services.urls.get_image_url(image_origin, image_name, thumbnail) + return self._services.urls.get_image_url(image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e @@ -347,10 +336,8 @@ class ImageService(ImageServiceABC): map( lambda r: image_record_to_dto( r, - self._services.urls.get_image_url(r.image_origin, r.image_name), - self._services.urls.get_image_url( - r.image_origin, r.image_name, True - ), + self._services.urls.get_image_url(r.image_name), + self._services.urls.get_image_url(r.image_name, True), ), results.items, ) @@ -366,10 +353,10 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting paginated image DTOs") raise e - def delete(self, image_origin: ResourceOrigin, image_name: str): + def delete(self, image_name: str): try: - self._services.files.delete(image_origin, image_name) - self._services.records.delete(image_origin, image_name) + self._services.files.delete(image_name) + self._services.records.delete(image_name) except ImageRecordDeleteException: self._services.logger.error(f"Failed to delete image record") raise diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index 051236b12b..d971d65916 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -79,8 +79,6 @@ class ImageUrlsDTO(BaseModel): image_name: str = Field(description="The unique name of the image.") """The unique name of the image.""" - image_origin: ResourceOrigin = Field(description="The type of the image.") - """The origin of the image.""" image_url: str = Field(description="The URL of the image.") """The URL of the image.""" thumbnail_url: str = Field(description="The URL of the image's thumbnail.") diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py index 4c8354c899..5920e9e6c1 100644 --- a/invokeai/app/services/urls.py +++ b/invokeai/app/services/urls.py @@ -1,17 +1,12 @@ import os from abc import ABC, abstractmethod -from invokeai.app.models.image import ResourceOrigin -from invokeai.app.util.thumbnails import get_thumbnail_name - class UrlServiceBase(ABC): """Responsible for building URLs for resources.""" @abstractmethod - def get_image_url( - self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False - ) -> str: + def get_image_url(self, image_name: str, thumbnail: bool = False) -> str: """Gets the URL for an image or thumbnail.""" pass @@ -20,15 +15,11 @@ class LocalUrlService(UrlServiceBase): def __init__(self, base_url: str = "api/v1"): self._base_url = base_url - def get_image_url( - self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False - ) -> str: + def get_image_url(self, image_name: str, thumbnail: bool = False) -> str: image_basename = os.path.basename(image_name) # These paths are determined by the routes in invokeai/app/api/routers/images.py if thumbnail: - return ( - f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail" - ) + return f"{self._base_url}/images/{image_basename}/thumbnail" - return f"{self._base_url}/images/{image_origin.value}/{image_basename}" + return f"{self._base_url}/images/{image_basename}" diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 64b9a828cd..c072a9d95c 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -62,10 +62,12 @@ "@dagrejs/graphlib": "^2.1.12", "@dnd-kit/core": "^6.0.8", "@dnd-kit/modifiers": "^6.0.1", - "@emotion/react": "^11.10.6", + "@emotion/react": "^11.11.1", "@emotion/styled": "^11.10.6", "@floating-ui/react-dom": "^2.0.0", "@fontsource/inter": "^4.5.15", + "@mantine/core": "^6.0.13", + "@mantine/hooks": "^6.0.13", "@reduxjs/toolkit": "^1.9.5", "@roarr/browser-log-writer": "^1.1.5", "chakra-ui-contextmenu": "^1.0.5", diff --git a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx index 6aa38fc15b..82065d83e3 100644 --- a/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx +++ b/invokeai/frontend/web/src/app/components/ThemeLocaleProvider.tsx @@ -3,11 +3,11 @@ import { createLocalStorageManager, extendTheme, } from '@chakra-ui/react'; +import { RootState } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; import { ReactNode, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; import { theme as invokeAITheme } from 'theme/theme'; -import { RootState } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; import { greenTeaThemeColors } from 'theme/colors/greenTea'; import { invokeAIThemeColors } from 'theme/colors/invokeAI'; @@ -15,6 +15,8 @@ import { lightThemeColors } from 'theme/colors/lightTheme'; import { oceanBlueColors } from 'theme/colors/oceanBlue'; import '@fontsource/inter/variable.css'; +import { MantineProvider } from '@mantine/core'; +import { mantineTheme } from 'mantine-theme/theme'; import 'overlayscrollbars/overlayscrollbars.css'; import 'theme/css/overlayscrollbars.css'; @@ -51,9 +53,11 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) { }, [direction]); return ( - - {children} - + + + {children} + + ); } diff --git a/invokeai/frontend/web/src/app/constants.ts b/invokeai/frontend/web/src/app/constants.ts index 6700a732b3..c2e525ad7d 100644 --- a/invokeai/frontend/web/src/app/constants.ts +++ b/invokeai/frontend/web/src/app/constants.ts @@ -22,9 +22,9 @@ export const SCHEDULERS = [ export type Scheduler = (typeof SCHEDULERS)[number]; // Valid upscaling levels -export const UPSCALING_LEVELS: Array<{ key: string; value: number }> = [ - { key: '2x', value: 2 }, - { key: '4x', value: 4 }, +export const UPSCALING_LEVELS: Array<{ label: string; value: string }> = [ + { label: '2x', value: '2' }, + { label: '4x', value: '4' }, ]; export const NUMPY_RAND_MIN = 0; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index 717417792c..ce1b515b84 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -34,10 +34,7 @@ export const addControlNetImageProcessedListener = () => { [controlNet.processorNode.id]: { ...controlNet.processorNode, is_intermediate: true, - image: pick(controlNet.controlImage, [ - 'image_name', - 'image_origin', - ]), + image: pick(controlNet.controlImage, ['image_name']), }, }, }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index f4376a4959..4c0c057242 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -25,7 +25,7 @@ export const addRequestedImageDeletionListener = () => { effect: (action, { dispatch, getState }) => { const { image, imageUsage } = action.payload; - const { image_name, image_origin } = image; + const { image_name } = image; const state = getState(); const selectedImage = state.gallery.selectedImage; @@ -79,9 +79,7 @@ export const addRequestedImageDeletionListener = () => { dispatch(imageRemoved(image_name)); // Delete from server - dispatch( - imageDeleted({ imageName: image_name, imageOrigin: image_origin }) - ); + dispatch(imageDeleted({ imageName: image_name })); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts index 016e3ec8a8..ed308f08a8 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageMetadataReceived.ts @@ -20,7 +20,6 @@ export const addImageMetadataReceivedFulfilledListener = () => { dispatch( imageUpdated({ imageName: image.image_name, - imageOrigin: image.image_origin, requestBody: { is_intermediate: false }, }) ); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index 0b47f7a1be..c9ab894ddb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -36,13 +36,12 @@ export const addInvocationCompleteEventListener = () => { // This complete event has an associated image output if (isImageOutput(result) && !nodeDenylist.includes(node.type)) { - const { image_name, image_origin } = result.image; + const { image_name } = result.image; // Get its metadata dispatch( imageMetadataReceived({ imageName: image_name, - imageOrigin: image_origin, }) ); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts index 9bd3cd6dd2..3e211f73bb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved.ts @@ -11,12 +11,11 @@ export const addStagingAreaImageSavedListener = () => { startAppListening({ actionCreator: stagingAreaImageSaved, effect: async (action, { dispatch, getState, take }) => { - const { image_name, image_origin } = action.payload; + const { image_name } = action.payload; dispatch( imageUpdated({ imageName: image_name, - imageOrigin: image_origin, requestBody: { is_intermediate: false, }, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts index d02ffbe931..7cb8012848 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateImageUrlsOnConnect.ts @@ -80,11 +80,10 @@ export const addUpdateImageUrlsOnConnectListener = () => { `Fetching new image URLs for ${allUsedImages.length} images` ); - allUsedImages.forEach(({ image_name, image_origin }) => { + allUsedImages.forEach(({ image_name }) => { dispatch( imageUrlsReceived({ imageName: image_name, - imageOrigin: image_origin, }) ); }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts index 0ee3016bdb..4d8177d7f3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedCanvas.ts @@ -116,7 +116,6 @@ export const addUserInvokedCanvasListener = () => { // Update the base node with the image name and type baseNode.image = { image_name: baseImageDTO.image_name, - image_origin: baseImageDTO.image_origin, }; } @@ -143,7 +142,6 @@ export const addUserInvokedCanvasListener = () => { // Update the base node with the image name and type baseNode.mask = { image_name: maskImageDTO.image_name, - image_origin: maskImageDTO.image_origin, }; } @@ -160,7 +158,6 @@ export const addUserInvokedCanvasListener = () => { dispatch( imageUpdated({ imageName: baseNode.image.image_name, - imageOrigin: baseNode.image.image_origin, requestBody: { session_id: sessionId }, }) ); @@ -171,7 +168,6 @@ export const addUserInvokedCanvasListener = () => { dispatch( imageUpdated({ imageName: baseNode.mask.image_name, - imageOrigin: baseNode.mask.image_origin, requestBody: { session_id: sessionId }, }) ); diff --git a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx b/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx deleted file mode 100644 index 1d9ae763b1..0000000000 --- a/invokeai/frontend/web/src/common/components/IAICustomSelect.tsx +++ /dev/null @@ -1,256 +0,0 @@ -import { CheckIcon, ChevronUpIcon } from '@chakra-ui/icons'; -import { - Box, - Flex, - FormControl, - FormControlProps, - FormLabel, - Grid, - GridItem, - List, - ListItem, - Text, - Tooltip, - TooltipProps, -} from '@chakra-ui/react'; -import { autoUpdate, offset, shift, useFloating } from '@floating-ui/react-dom'; -import { useSelect } from 'downshift'; -import { isString } from 'lodash-es'; -import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; - -import { memo, useLayoutEffect, useMemo } from 'react'; -import { getInputOutlineStyles } from 'theme/util/getInputOutlineStyles'; - -export type ItemTooltips = { [key: string]: string }; - -export type IAICustomSelectOption = { - value: string; - label: string; - tooltip?: string; -}; - -type IAICustomSelectProps = { - label?: string; - value: string; - data: IAICustomSelectOption[] | string[]; - onChange: (v: string) => void; - withCheckIcon?: boolean; - formControlProps?: FormControlProps; - tooltip?: string; - tooltipProps?: Omit; - ellipsisPosition?: 'start' | 'end'; - isDisabled?: boolean; -}; - -const IAICustomSelect = (props: IAICustomSelectProps) => { - const { - label, - withCheckIcon, - formControlProps, - tooltip, - tooltipProps, - ellipsisPosition = 'end', - data, - value, - onChange, - isDisabled = false, - } = props; - - const values = useMemo(() => { - return data.map((v) => { - if (isString(v)) { - return { value: v, label: v }; - } - return v; - }); - }, [data]); - - const stringValues = useMemo(() => { - return values.map((v) => v.value); - }, [values]); - - const valueData = useMemo(() => { - return values.find((v) => v.value === value); - }, [values, value]); - - const { - isOpen, - getToggleButtonProps, - getLabelProps, - getMenuProps, - highlightedIndex, - getItemProps, - } = useSelect({ - items: stringValues, - selectedItem: value, - onSelectedItemChange: ({ selectedItem: newSelectedItem }) => { - newSelectedItem && onChange(newSelectedItem); - }, - }); - - const { refs, floatingStyles, update } = useFloating({ - // whileElementsMounted: autoUpdate, - middleware: [offset(4), shift({ crossAxis: true, padding: 8 })], - }); - - useLayoutEffect(() => { - if (isOpen && refs.reference.current && refs.floating.current) { - return autoUpdate(refs.reference.current, refs.floating.current, update); - } - }, [isOpen, update, refs.floating, refs.reference]); - - const labelTextDirection = useMemo(() => { - if (ellipsisPosition === 'start') { - return document.dir === 'rtl' ? 'ltr' : 'rtl'; - } - - return document.dir; - }, [ellipsisPosition]); - - return ( - - {label && ( - { - refs.floating.current && refs.floating.current.focus(); - }} - > - {label} - - )} - - - - {valueData?.label} - - - - - - {isOpen && ( - - - {values.map((v, index) => { - const isSelected = value === v.value; - const isHighlighted = highlightedIndex === index; - const fontWeight = isSelected ? 700 : 500; - const bg = isHighlighted - ? 'base.700' - : isSelected - ? 'base.750' - : undefined; - return ( - - - {withCheckIcon ? ( - - - {isSelected && } - - - - {v.label} - - - - ) : ( - - {v.label} - - )} - - - ); - })} - - - )} - - - ); -}; - -export default memo(IAICustomSelect); diff --git a/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx new file mode 100644 index 0000000000..30517d0f41 --- /dev/null +++ b/invokeai/frontend/web/src/common/components/IAIMantineSelect.tsx @@ -0,0 +1,76 @@ +import { Tooltip } from '@chakra-ui/react'; +import { Select, SelectProps } from '@mantine/core'; +import { memo } from 'react'; + +export type IAISelectDataType = { + value: string; + label: string; + tooltip?: string; +}; + +type IAISelectProps = SelectProps & { + tooltip?: string; +}; + +const IAIMantineSelect = (props: IAISelectProps) => { + const { searchable = true, tooltip, ...rest } = props; + return ( + +