mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Resolution to DepthAnything
This commit is contained in:
parent
39fedb090b
commit
7cb49e65bd
@ -621,6 +621,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||||
default="small", description="The size of the depth model to use"
|
default="small", description="The size of the depth model to use"
|
||||||
)
|
)
|
||||||
|
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
||||||
offload: bool = InputField(default=False)
|
offload: bool = InputField(default=False)
|
||||||
|
|
||||||
def run_processor(self, image):
|
def run_processor(self, image):
|
||||||
@ -630,5 +631,5 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
if image.mode == "RGBA":
|
if image.mode == "RGBA":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
processed_image = depth_anything_detector(image=image, offload=self.offload)
|
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
@ -64,12 +64,15 @@ class DepthAnythingDetector:
|
|||||||
del self.model
|
del self.model
|
||||||
self.model_size = model_size
|
self.model_size = model_size
|
||||||
|
|
||||||
if self.model_size == "small":
|
match self.model_size:
|
||||||
|
case "small":
|
||||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||||
if self.model_size == "base":
|
case "base":
|
||||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||||
if self.model_size == "large":
|
case "large":
|
||||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||||
|
case _:
|
||||||
|
raise TypeError("Not a supported model")
|
||||||
|
|
||||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
@ -81,12 +84,11 @@ class DepthAnythingDetector:
|
|||||||
self.model.to(device)
|
self.model.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __call__(self, image, offload=False):
|
def __call__(self, image, resolution=512, offload=False):
|
||||||
image = np.array(image, dtype=np.uint8)
|
image = np.array(image, dtype=np.uint8)
|
||||||
original_width, original_height = image.shape[:2]
|
|
||||||
image = image[:, :, ::-1] / 255.0
|
image = image[:, :, ::-1] / 255.0
|
||||||
|
|
||||||
image_width, image_height = image.shape[:2]
|
image_height, image_width = image.shape[:2]
|
||||||
image = transform({"image": image})["image"]
|
image = transform({"image": image})["image"]
|
||||||
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
|
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
|
||||||
|
|
||||||
@ -97,7 +99,9 @@ class DepthAnythingDetector:
|
|||||||
|
|
||||||
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
||||||
depth_map = Image.fromarray(depth_map)
|
depth_map = Image.fromarray(depth_map)
|
||||||
depth_map = depth_map.resize((original_height, original_width))
|
|
||||||
|
new_height = int(image_height * (resolution / image_width))
|
||||||
|
depth_map = depth_map.resize((resolution, new_height))
|
||||||
|
|
||||||
if offload:
|
if offload:
|
||||||
del self.model
|
del self.model
|
||||||
|
@ -1,5 +1,11 @@
|
|||||||
import type { ComboboxOnChange } from '@invoke-ai/ui';
|
import type { ComboboxOnChange } from '@invoke-ai/ui';
|
||||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui';
|
import {
|
||||||
|
Combobox,
|
||||||
|
CompositeNumberInput,
|
||||||
|
CompositeSlider,
|
||||||
|
FormControl,
|
||||||
|
FormLabel,
|
||||||
|
} from '@invoke-ai/ui';
|
||||||
import { useProcessorNodeChanged } from 'features/controlAdapters/components/hooks/useProcessorNodeChanged';
|
import { useProcessorNodeChanged } from 'features/controlAdapters/components/hooks/useProcessorNodeChanged';
|
||||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||||
import type {
|
import type {
|
||||||
@ -23,7 +29,7 @@ type Props = {
|
|||||||
|
|
||||||
const DepthAnythingProcessor = (props: Props) => {
|
const DepthAnythingProcessor = (props: Props) => {
|
||||||
const { controlNetId, processorNode, isEnabled } = props;
|
const { controlNetId, processorNode, isEnabled } = props;
|
||||||
const { model_size } = processorNode;
|
const { model_size, resolution } = processorNode;
|
||||||
const processorChanged = useProcessorNodeChanged();
|
const processorChanged = useProcessorNodeChanged();
|
||||||
|
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
@ -54,6 +60,17 @@ const DepthAnythingProcessor = (props: Props) => {
|
|||||||
[options, model_size]
|
[options, model_size]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const handleResolutionChange = useCallback(
|
||||||
|
(v: number) => {
|
||||||
|
processorChanged(controlNetId, { resolution: v });
|
||||||
|
},
|
||||||
|
[controlNetId, processorChanged]
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleResolutionDefaultChange = useCallback(() => {
|
||||||
|
processorChanged(controlNetId, { resolution: 512 });
|
||||||
|
}, [controlNetId, processorChanged]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ProcessorWrapper>
|
<ProcessorWrapper>
|
||||||
<FormControl isDisabled={!isEnabled}>
|
<FormControl isDisabled={!isEnabled}>
|
||||||
@ -65,6 +82,27 @@ const DepthAnythingProcessor = (props: Props) => {
|
|||||||
onChange={handleModelSizeChange}
|
onChange={handleModelSizeChange}
|
||||||
/>
|
/>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
|
<FormControl isDisabled={!isEnabled}>
|
||||||
|
<FormLabel>{t('controlnet.imageResolution')}</FormLabel>
|
||||||
|
<CompositeSlider
|
||||||
|
value={resolution}
|
||||||
|
onChange={handleResolutionChange}
|
||||||
|
defaultValue={DEFAULTS.resolution}
|
||||||
|
min={64}
|
||||||
|
max={4096}
|
||||||
|
step={64}
|
||||||
|
marks
|
||||||
|
onReset={handleResolutionDefaultChange}
|
||||||
|
/>
|
||||||
|
<CompositeNumberInput
|
||||||
|
value={resolution}
|
||||||
|
onChange={handleResolutionChange}
|
||||||
|
defaultValue={DEFAULTS.resolution}
|
||||||
|
min={64}
|
||||||
|
max={4096}
|
||||||
|
step={64}
|
||||||
|
/>
|
||||||
|
</FormControl>
|
||||||
</ProcessorWrapper>
|
</ProcessorWrapper>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -95,6 +95,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
|||||||
id: 'depth_anything_image_processor',
|
id: 'depth_anything_image_processor',
|
||||||
type: 'depth_anything_image_processor',
|
type: 'depth_anything_image_processor',
|
||||||
model_size: 'small',
|
model_size: 'small',
|
||||||
|
resolution: 512,
|
||||||
offload: false,
|
offload: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
@ -80,7 +80,7 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required<
|
|||||||
*/
|
*/
|
||||||
export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
|
export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
|
||||||
DepthAnythingImageProcessorInvocation,
|
DepthAnythingImageProcessorInvocation,
|
||||||
'type' | 'model_size' | 'offload'
|
'type' | 'model_size' | 'resolution' | 'offload'
|
||||||
>;
|
>;
|
||||||
|
|
||||||
export const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
|
export const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user