mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Resolution to DepthAnything
This commit is contained in:
parent
39fedb090b
commit
7cb49e65bd
@ -621,6 +621,7 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
|
||||
default="small", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
||||
offload: bool = InputField(default=False)
|
||||
|
||||
def run_processor(self, image):
|
||||
@ -630,5 +631,5 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
|
||||
processed_image = depth_anything_detector(image=image, offload=self.offload)
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
|
||||
return processed_image
|
||||
|
@ -64,12 +64,15 @@ class DepthAnythingDetector:
|
||||
del self.model
|
||||
self.model_size = model_size
|
||||
|
||||
if self.model_size == "small":
|
||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
if self.model_size == "base":
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
if self.model_size == "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
match self.model_size:
|
||||
case "small":
|
||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
case _:
|
||||
raise TypeError("Not a supported model")
|
||||
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
@ -81,12 +84,11 @@ class DepthAnythingDetector:
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def __call__(self, image, offload=False):
|
||||
def __call__(self, image, resolution=512, offload=False):
|
||||
image = np.array(image, dtype=np.uint8)
|
||||
original_width, original_height = image.shape[:2]
|
||||
image = image[:, :, ::-1] / 255.0
|
||||
|
||||
image_width, image_height = image.shape[:2]
|
||||
image_height, image_width = image.shape[:2]
|
||||
image = transform({"image": image})["image"]
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
|
||||
|
||||
@ -97,7 +99,9 @@ class DepthAnythingDetector:
|
||||
|
||||
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
||||
depth_map = Image.fromarray(depth_map)
|
||||
depth_map = depth_map.resize((original_height, original_width))
|
||||
|
||||
new_height = int(image_height * (resolution / image_width))
|
||||
depth_map = depth_map.resize((resolution, new_height))
|
||||
|
||||
if offload:
|
||||
del self.model
|
||||
|
@ -1,5 +1,11 @@
|
||||
import type { ComboboxOnChange } from '@invoke-ai/ui';
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui';
|
||||
import {
|
||||
Combobox,
|
||||
CompositeNumberInput,
|
||||
CompositeSlider,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
} from '@invoke-ai/ui';
|
||||
import { useProcessorNodeChanged } from 'features/controlAdapters/components/hooks/useProcessorNodeChanged';
|
||||
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
|
||||
import type {
|
||||
@ -23,7 +29,7 @@ type Props = {
|
||||
|
||||
const DepthAnythingProcessor = (props: Props) => {
|
||||
const { controlNetId, processorNode, isEnabled } = props;
|
||||
const { model_size } = processorNode;
|
||||
const { model_size, resolution } = processorNode;
|
||||
const processorChanged = useProcessorNodeChanged();
|
||||
|
||||
const { t } = useTranslation();
|
||||
@ -54,6 +60,17 @@ const DepthAnythingProcessor = (props: Props) => {
|
||||
[options, model_size]
|
||||
);
|
||||
|
||||
const handleResolutionChange = useCallback(
|
||||
(v: number) => {
|
||||
processorChanged(controlNetId, { resolution: v });
|
||||
},
|
||||
[controlNetId, processorChanged]
|
||||
);
|
||||
|
||||
const handleResolutionDefaultChange = useCallback(() => {
|
||||
processorChanged(controlNetId, { resolution: 512 });
|
||||
}, [controlNetId, processorChanged]);
|
||||
|
||||
return (
|
||||
<ProcessorWrapper>
|
||||
<FormControl isDisabled={!isEnabled}>
|
||||
@ -65,6 +82,27 @@ const DepthAnythingProcessor = (props: Props) => {
|
||||
onChange={handleModelSizeChange}
|
||||
/>
|
||||
</FormControl>
|
||||
<FormControl isDisabled={!isEnabled}>
|
||||
<FormLabel>{t('controlnet.imageResolution')}</FormLabel>
|
||||
<CompositeSlider
|
||||
value={resolution}
|
||||
onChange={handleResolutionChange}
|
||||
defaultValue={DEFAULTS.resolution}
|
||||
min={64}
|
||||
max={4096}
|
||||
step={64}
|
||||
marks
|
||||
onReset={handleResolutionDefaultChange}
|
||||
/>
|
||||
<CompositeNumberInput
|
||||
value={resolution}
|
||||
onChange={handleResolutionChange}
|
||||
defaultValue={DEFAULTS.resolution}
|
||||
min={64}
|
||||
max={4096}
|
||||
step={64}
|
||||
/>
|
||||
</FormControl>
|
||||
</ProcessorWrapper>
|
||||
);
|
||||
};
|
||||
|
@ -95,6 +95,7 @@ export const CONTROLNET_PROCESSORS: ControlNetProcessorsDict = {
|
||||
id: 'depth_anything_image_processor',
|
||||
type: 'depth_anything_image_processor',
|
||||
model_size: 'small',
|
||||
resolution: 512,
|
||||
offload: false,
|
||||
},
|
||||
},
|
||||
|
@ -80,7 +80,7 @@ export type RequiredContentShuffleImageProcessorInvocation = O.Required<
|
||||
*/
|
||||
export type RequiredDepthAnythingImageProcessorInvocation = O.Required<
|
||||
DepthAnythingImageProcessorInvocation,
|
||||
'type' | 'model_size' | 'offload'
|
||||
'type' | 'model_size' | 'resolution' | 'offload'
|
||||
>;
|
||||
|
||||
export const zDepthAnythingModelSize = z.enum(['large', 'base', 'small']);
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user