feat: Add Resolution to DepthAnything

This commit is contained in:
blessedcoolant 2024-01-23 10:13:03 +05:30 committed by Kent Keirsey
parent 39fedb090b
commit 7cb49e65bd
6 changed files with 87 additions and 37 deletions

View File

@ -621,6 +621,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
default="small", description="The size of the depth model to use"
)
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
offload: bool = InputField(default=False)
def run_processor(self, image):
@ -630,5 +631,5 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
if image.mode == "RGBA":
image = image.convert("RGB")
processed_image = depth_anything_detector(image=image, offload=self.offload)
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
return processed_image

View File

@ -64,12 +64,15 @@ class DepthAnythingDetector:
del self.model
self.model_size = model_size
if self.model_size == "small":
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
if self.model_size == "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
if self.model_size == "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
match self.model_size:
case "small":
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
case _:
raise TypeError("Not a supported model")
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()
@ -81,12 +84,11 @@ class DepthAnythingDetector:
self.model.to(device)
return self
def __call__(self, image, offload=False):
def __call__(self, image, resolution=512, offload=False):
image = np.array(image, dtype=np.uint8)
original_width, original_height = image.shape[:2]
image = image[:, :, ::-1] / 255.0
image_width, image_height = image.shape[:2]
image_height, image_width = image.shape[:2]
image = transform({"image": image})["image"]
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
@ -97,7 +99,9 @@ class DepthAnythingDetector:
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
depth_map = Image.fromarray(depth_map)
depth_map = depth_map.resize((original_height, original_width))
new_height = int(image_height * (resolution / image_width))
depth_map = depth_map.resize((resolution, new_height))
if offload:
del self.model

View File

@ -1,5 +1,11 @@
import type { ComboboxOnChange } from '@invoke-ai/ui';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui';
import {
Combobox,
CompositeNumberInput,
CompositeSlider,
FormControl,
FormLabel,
} from '@invoke-ai/ui';
import { useProcessorNodeChanged } from 'features/controlAdapters/components/hooks/useProcessorNodeChanged';
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import type {
@ -23,7 +29,7 @@ type Props = {
const DepthAnythingProcessor = (props: Props) => {
const { controlNetId, processorNode, isEnabled } = props;
const { model_size } = processorNode;
const { model_size, resolution } = processorNode;
const processorChanged = useProcessorNodeChanged();
const { t } = useTranslation();
@ -54,6 +60,17 @@ const DepthAnythingProcessor = (props: Props) => {
[options, model_size]
);
const handleResolutionChange = useCallback(
(v: number) => {
processorChanged(controlNetId, { resolution: v });
},
[controlNetId, processorChanged]
);
const handleResolutionDefaultChange = useCallback(() => {
processorChanged(controlNetId, { resolution: 512 });
}, [controlNetId, processorChanged]);
return (
<ProcessorWrapper>
<FormControl isDisabled={!isEnabled}>
@ -65,6 +82,27 @@ const DepthAnythingProcessor = (props: Props) => {
onChange={handleModelSizeChange}
/>
</FormControl>
<FormControl isDisabled={!isEnabled}>
<FormLabel>{t('controlnet.imageResolution')}</FormLabel>
<CompositeSlider
value={resolution}
onChange={handleResolutionChange}
defaultValue={DEFAULTS.resolution}
min={64}
max={4096}
step={64}
marks
onReset={handleResolutionDefaultChange}
/>
<CompositeNumberInput
value={resolution}
onChange={handleResolutionChange}
defaultValue={DEFAULTS.resolution}
min={64}
max={4096}
step={64}
/>
</FormControl>
</ProcessorWrapper>
);
};

View File

@ -95,6 +95,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
id: 'depth_anything_image_processor',
type: 'depth_anything_image_processor',
model_size: 'small',
resolution: 512,
offload: false,
},
},

View File

@ -80,7 +80,7 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required<
*/
export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
DepthAnythingImageProcessorInvocation,
'type' | 'model_size' | 'offload'
'type' | 'model_size' | 'resolution' | 'offload'
>;
export const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);

File diff suppressed because one or more lines are too long