Merge remote-tracking branch 'upstream/development' into fix-prompts

This commit is contained in:
Damian at mba 2022-10-21 11:59:44 +02:00
commit 8142b72bcd
24 changed files with 196 additions and 37 deletions

View File

@ -2,7 +2,7 @@
# InvokeAI: A Stable Diffusion Toolkit # InvokeAI: A Stable Diffusion Toolkit
_Formally known as lstein/stable-diffusion_ _Formerly known as lstein/stable-diffusion_
![project logo](docs/assets/logo.png) ![project logo](docs/assets/logo.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 519 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 519 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 439 KiB

View File

@ -503,6 +503,16 @@ invoke> !search surreal
This clears the search history from memory and disk. Be advised that This clears the search history from memory and disk. Be advised that
this operation is irreversible and does not issue any warnings! this operation is irreversible and does not issue any warnings!
Other ! Commands
### !mask
This command takes an image, a text prompt, and uses the `clipseg`
algorithm to automatically generate a mask of the area that matches
the text prompt. It is useful for debugging the text masking process
prior to inpainting with the `--text_mask` argument. See
[INPAINTING.md] for details.
## Command-line editing and completion ## Command-line editing and completion
The command-line offers convenient history tracking, editing, and The command-line offers convenient history tracking, editing, and

View File

@ -74,6 +74,60 @@ up at all!
invoke> a baseball -I /path/to/breakfast.png -tm orange 0.6 invoke> a baseball -I /path/to/breakfast.png -tm orange 0.6
~~~ ~~~
The `!mask` command may be useful for debugging problems with the
text2mask feature. The syntax is `!mask /path/to/image.png -tm <text>
<threshold>`
It will generate three files:
- The image with the selected area highlighted.
- The image with the un-selected area highlighted.
- The image with the selected area converted into a black and white
image according to the threshold level.
Note that none of these images are intended to be used as the mask
passed to invoke via `-M` and may give unexpected results if you try
to use them this way. Instead, use `!mask` for testing that you are
selecting the right mask area, and then do inpainting using the
best selection term and threshold.
Here is an example of how `!mask` works:
```
invoke> !mask ./test-pictures/curly.png -tm hair 0.5
>> generating masks from ./test-pictures/curly.png
>> Initializing clipseg model for text to mask inference
Outputs:
[941.1] outputs/img-samples/000019.curly.hair.deselected.png: !mask ./test-pictures/curly.png -tm hair 0.5
[941.2] outputs/img-samples/000019.curly.hair.selected.png: !mask ./test-pictures/curly.png -tm hair 0.5
[941.3] outputs/img-samples/000019.curly.hair.masked.png: !mask ./test-pictures/curly.png -tm hair 0.5
```
**Original image "curly.png"**
<img src="../assets/outpainting/curly.png">
**000019.curly.hair.selected.png**
<img src="../assets/inpainting/000019.curly.hair.selected.png">
**000019.curly.hair.deselected.png**
<img src="../assets/inpainting/000019.curly.hair.deselected.png">
**000019.curly.hair.masked.png**
<img src="../assets/inpainting/000019.curly.hair.masked.png">
It looks like we selected the hair pretty well at the 0.5 threshold
(which is the default, so we didn't actually have to specify it), so
let's have some fun:
```
invoke> medusa with cobras -I ./test-pictures/curly.png -tm hair 0.5 -C20
>> loaded input image of size 512x512 from ./test-pictures/curly.png
...
Outputs:
[946] outputs/img-samples/000024.801380492.png: "medusa with cobras" -s 50 -S 801380492 -W 512 -H 512 -C 20.0 -I ./test-pictures/curly.png -A k_lms -f 0.75
```
<img src="../assets/inpainting/000024.801380492.png">
### Inpainting is not changing the masked region enough! ### Inpainting is not changing the masked region enough!

View File

@ -12,7 +12,7 @@ title: Home
--> -->
<div align="center" markdown> <div align="center" markdown>
# ^^**InvokeAI: A Stable Diffusion Toolkit**^^ :tools: <br> <small>Formally known as lstein/stable-diffusion</small> # ^^**InvokeAI: A Stable Diffusion Toolkit**^^ :tools: <br> <small>Formerly known as lstein/stable-diffusion</small>
![project logo](assets/logo.png) ![project logo](assets/logo.png)

View File

@ -57,7 +57,7 @@ dependencies:
- -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan - -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
- -e git+https://github.com/invoke-ai/clipseg.git#egg=clipseg - -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
- -e . - -e .
variables: variables:
PYTORCH_ENABLE_MPS_FALLBACK: 1 PYTORCH_ENABLE_MPS_FALLBACK: 1

View File

@ -37,5 +37,5 @@ dependencies:
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
- -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan - -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
- -e git+https://github.com/invoke-ai/clipseg.git#egg=clipseg - -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
- -e . - -e .

View File

@ -1 +0,0 @@

File diff suppressed because one or more lines are too long

View File

@ -6,7 +6,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>InvokeAI - A Stable Diffusion Toolkit</title> <title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="/assets/favicon.0d253ced.ico" /> <link rel="shortcut icon" type="icon" href="/assets/favicon.0d253ced.ico" />
<script type="module" crossorigin src="/assets/index.dcc1d08e.js"></script> <script type="module" crossorigin src="/assets/index.b06af007.js"></script>
<link rel="stylesheet" href="/assets/index.58175ea1.css"> <link rel="stylesheet" href="/assets/index.58175ea1.css">
</head> </head>
@ -15,4 +15,4 @@
</body> </body>
</html> </html>

View File

@ -72,7 +72,13 @@ export const gallerySlice = createSlice({
}, },
addImage: (state, action: PayloadAction<InvokeAI.Image>) => { addImage: (state, action: PayloadAction<InvokeAI.Image>) => {
const newImage = action.payload; const newImage = action.payload;
const { uuid, mtime } = newImage; const { uuid, url, mtime } = newImage;
// Do not add duplicate images
if (state.images.find((i) => i.url === url && i.mtime === mtime)) {
return;
}
state.images.unshift(newImage); state.images.unshift(newImage);
state.currentImageUuid = uuid; state.currentImageUuid = uuid;
state.intermediateImage = undefined; state.intermediateImage = undefined;
@ -120,8 +126,15 @@ export const gallerySlice = createSlice({
) => { ) => {
const { images, areMoreImagesAvailable } = action.payload; const { images, areMoreImagesAvailable } = action.payload;
if (images.length > 0) { if (images.length > 0) {
// Filter images that already exist in the gallery
const newImages = images.filter(
(newImage) =>
!state.images.find(
(i) => i.url === newImage.url && i.mtime === newImage.mtime
)
);
state.images = state.images state.images = state.images
.concat(images) .concat(newImages)
.sort((a, b) => b.mtime - a.mtime); .sort((a, b) => b.mtime - a.mtime);
if (!state.currentImage) { if (!state.currentImage) {

View File

@ -729,7 +729,7 @@ class Generate:
seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
if self.embedding_path is not None: if self.embedding_path is not None:
model.embedding_manager.load( self.model.embedding_manager.load(
self.embedding_path, self.precision == 'float32' or self.precision == 'autocast' self.embedding_path, self.precision == 'float32' or self.precision == 'autocast'
) )
@ -806,6 +806,23 @@ class Generate:
else: else:
r[0] = image r[0] = image
def apply_textmask(self, image_path:str, prompt:str, callback, threshold:float=0.5):
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **'
basename,_ = os.path.splitext(os.path.basename(image_path))
if self.txt2mask is None:
self.txt2mask = Txt2Mask(device = self.device)
segmented = self.txt2mask.segment(image_path,prompt)
trans = segmented.to_transparent()
inverse = segmented.to_transparent(invert=True)
mask = segmented.to_mask(threshold)
path_filter = re.compile(r'[<>:"/\\|?*]')
safe_prompt = path_filter.sub('_', prompt)[:50].rstrip(' .')
callback(trans,f'{safe_prompt}.deselected',use_prefix=basename)
callback(inverse,f'{safe_prompt}.selected',use_prefix=basename)
callback(mask,f'{safe_prompt}.masked',use_prefix=basename)
# to help WebGUI - front end to generator util function # to help WebGUI - front end to generator util function
def sample_to_image(self, samples): def sample_to_image(self, samples):
return self._make_base().sample_to_image(samples) return self._make_base().sample_to_image(samples)

View File

@ -56,6 +56,7 @@ COMMANDS = (
'--png_compression','-z', '--png_compression','-z',
'--text_mask','-tm', '--text_mask','-tm',
'!fix','!fetch','!history','!search','!clear', '!fix','!fetch','!history','!search','!clear',
'!mask',
'!models','!switch','!import_model','!edit_model' '!models','!switch','!import_model','!edit_model'
) )
MODEL_COMMANDS = ( MODEL_COMMANDS = (
@ -71,6 +72,7 @@ IMG_PATH_COMMANDS = (
IMG_FILE_COMMANDS=( IMG_FILE_COMMANDS=(
'!fix', '!fix',
'!fetch', '!fetch',
'!mask',
'--init_img[=\s]','-I', '--init_img[=\s]','-I',
'--init_mask[=\s]','-M', '--init_mask[=\s]','-M',
'--init_color[=\s]', '--init_color[=\s]',

View File

@ -41,10 +41,12 @@ class CodeFormerRestoration():
cf.eval() cf.eval()
image = image.convert('RGB') image = image.convert('RGB')
# Codeformer expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
face_helper = FaceRestoreHelper(upscale_factor=1, use_parse=True, device=device) face_helper = FaceRestoreHelper(upscale_factor=1, use_parse=True, device=device)
face_helper.clean_all() face_helper.clean_all()
face_helper.read_image(np.array(image, dtype=np.uint8)) face_helper.read_image(bgr_image_array)
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5) face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
face_helper.align_warp_face() face_helper.align_warp_face()
@ -71,7 +73,8 @@ class CodeFormerRestoration():
restored_img = face_helper.paste_faces_to_input_image() restored_img = face_helper.paste_faces_to_input_image()
res = Image.fromarray(restored_img) # Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1])
if strength < 1.0: if strength < 1.0:
# Resize the image to the new image if the sizes have changed # Resize the image to the new image if the sizes have changed

View File

@ -55,13 +55,18 @@ class GFPGAN():
image = image.convert('RGB') image = image.convert('RGB')
# GFPGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
_, _, restored_img = self.gfpgan.enhance( _, _, restored_img = self.gfpgan.enhance(
np.array(image, dtype=np.uint8), bgr_image_array,
has_aligned=False, has_aligned=False,
only_center_face=False, only_center_face=False,
paste_back=True, paste_back=True,
) )
res = Image.fromarray(restored_img)
# Flip the channels back to RGB
res = Image.fromarray(restored_img[...,::-1])
if strength < 1.0: if strength < 1.0:
# Resize the image to the new image if the sizes have changed # Resize the image to the new image if the sizes have changed

View File

@ -60,14 +60,18 @@ class ESRGAN():
print( print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x' f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
) )
# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]
output, _ = upsampler.enhance( output, _ = upsampler.enhance(
np.array(image, dtype=np.uint8), bgr_image_array,
outscale=upsampler_scale, outscale=upsampler_scale,
alpha_upsampler='realesrgan', alpha_upsampler='realesrgan',
) )
res = Image.fromarray(output) # Flip the channels back to RGB
res = Image.fromarray(output[...,::-1])
if strength < 1.0: if strength < 1.0:
# Resize the image to the new image if the sizes have changed # Resize the image to the new image if the sizes have changed

View File

@ -29,9 +29,9 @@ work fine.
import torch import torch
import numpy as np import numpy as np
from models.clipseg import CLIPDensePredT from clipseg_models.clipseg import CLIPDensePredT
from einops import rearrange, repeat from einops import rearrange, repeat
from PIL import Image from PIL import Image, ImageOps
from torchvision import transforms from torchvision import transforms
CLIP_VERSION = 'ViT-B/16' CLIP_VERSION = 'ViT-B/16'
@ -50,9 +50,14 @@ class SegmentedGrayscale(object):
discrete_heatmap = self.heatmap.lt(threshold).int() discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')) return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
def to_transparent(self)->Image: def to_transparent(self,invert:bool=False)->Image:
transparent_image = self.image.copy() transparent_image = self.image.copy()
transparent_image.putalpha(self.to_grayscale()) gs = self.to_grayscale()
# The following line looks like a bug, but isn't.
# For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite.
gs = ImageOps.invert(gs) if not invert else gs
transparent_image.putalpha(gs)
return transparent_image return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again # unscales and uncrops the 352x352 heatmap so that it matches the image again
@ -79,7 +84,7 @@ class Txt2Mask(object):
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad() @torch.no_grad()
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale: def segment(self, image, prompt:str) -> SegmentedGrayscale:
''' '''
Given a prompt string such as "a bagel", tries to identify the object in the Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter provided image and returns a SegmentedGrayscale object in which the brighter
@ -94,6 +99,10 @@ class Txt2Mask(object):
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64... transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
]) ])
if type(image) is str:
image = Image.open(image).convert('RGB')
image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image) img = self._scale_and_crop(image)
img = transform(img).unsqueeze(0) img = transform(img).unsqueeze(0)

View File

@ -22,5 +22,5 @@ transformers==4.21.3
-e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
-e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
-e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan -e git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
-3 git+https://github.com/invoke-ai/clipseg.git#egg=clipseg -e git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg
-e . -e .

View File

@ -35,4 +35,4 @@ realesrgan
git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
git+https://github.com/invoke-ai/clipseg.git#egg=clipseg git+https://github.com/invoke-ai/clipseg.git@models-rename#egg=clipseg

View File

@ -225,9 +225,13 @@ def main_loop(gen, opt, infile):
os.makedirs(opt.outdir) os.makedirs(opt.outdir)
current_outdir = opt.outdir current_outdir = opt.outdir
# write out the history at this point # Write out the history at this point.
# TODO: Fix the parsing of command-line parameters
# so that !operations don't need to be stripped and readded
if operation == 'postprocess': if operation == 'postprocess':
completer.add_history(f'!fix {command}') completer.add_history(f'!fix {command}')
elif operation == 'mask':
completer.add_history(f'!mask {command}')
else: else:
completer.add_history(command) completer.add_history(command)
@ -247,13 +251,28 @@ def main_loop(gen, opt, infile):
# when the -v switch is used to generate variations # when the -v switch is used to generate variations
nonlocal prior_variations nonlocal prior_variations
nonlocal prefix nonlocal prefix
if use_prefix is not None:
prefix = use_prefix
path = None path = None
if opt.grid: if opt.grid:
grid_images[seed] = image grid_images[seed] = image
elif operation == 'mask':
filename = f'{prefix}.{use_prefix}.{seed}.png'
tm = opt.text_mask[0]
th = opt.text_mask[1] if len(opt.text_mask)>1 else 0.5
formatted_dream_prompt = f'!mask {opt.prompt} -tm {tm} {th}'
path = file_writer.save_image_and_prompt_to_png(
image = image,
dream_prompt = formatted_dream_prompt,
metadata = {},
name = filename,
compress_level = opt.png_compression,
)
results.append([path, formatted_dream_prompt])
else: else:
if use_prefix is not None:
prefix = use_prefix
postprocessed = upscaled if upscaled else operation=='postprocess' postprocessed = upscaled if upscaled else operation=='postprocess'
filename, formatted_dream_prompt = prepare_image_metadata( filename, formatted_dream_prompt = prepare_image_metadata(
opt, opt,
@ -292,7 +311,7 @@ def main_loop(gen, opt, infile):
results.append([path, formatted_dream_prompt]) results.append([path, formatted_dream_prompt])
# so that the seed autocompletes (on linux|mac when -S or --seed specified # so that the seed autocompletes (on linux|mac when -S or --seed specified
if completer: if completer and operation == 'generate':
completer.add_seed(seed) completer.add_seed(seed)
completer.add_seed(first_seed) completer.add_seed(first_seed)
last_results.append([path, seed]) last_results.append([path, seed])
@ -310,6 +329,10 @@ def main_loop(gen, opt, infile):
print(f'>> fixing {opt.prompt}') print(f'>> fixing {opt.prompt}')
opt.last_operation = do_postprocess(gen,opt,image_writer) opt.last_operation = do_postprocess(gen,opt,image_writer)
elif operation == 'mask':
print(f'>> generating masks from {opt.prompt}')
do_textmask(gen, opt, image_writer)
if opt.grid and len(grid_images) > 0: if opt.grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values())) grid_img = make_grid(list(grid_images.values()))
grid_seeds = list(grid_images.keys()) grid_seeds = list(grid_images.keys())
@ -355,6 +378,10 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
command = command.replace('!fix ','',1) command = command.replace('!fix ','',1)
operation = 'postprocess' operation = 'postprocess'
elif command.startswith('!mask'):
command = command.replace('!mask ','',1)
operation = 'mask'
elif command.startswith('!switch'): elif command.startswith('!switch'):
model_name = command.replace('!switch ','',1) model_name = command.replace('!switch ','',1)
gen.set_model(model_name) gen.set_model(model_name)
@ -363,6 +390,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
elif command.startswith('!models'): elif command.startswith('!models'):
gen.model_cache.print_models() gen.model_cache.print_models()
completer.add_history(command)
operation = None operation = None
elif command.startswith('!import'): elif command.startswith('!import'):
@ -494,6 +522,19 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
os.rename(tmpfile,conf_path) os.rename(tmpfile,conf_path)
return True return True
def do_textmask(gen, opt, callback):
image_path = opt.prompt
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **'
assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **'
tm = opt.text_mask[0]
threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5
gen.apply_textmask(
image_path = image_path,
prompt = tm,
threshold = threshold,
callback = callback,
)
def do_postprocess (gen, opt, callback): def do_postprocess (gen, opt, callback):
file_path = opt.prompt # treat the prompt as the file pathname file_path = opt.prompt # treat the prompt as the file pathname
if os.path.dirname(file_path) == '': #basename given if os.path.dirname(file_path) == '': #basename given
@ -670,7 +711,7 @@ def load_face_restoration(opt):
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules') print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return gfpgan,codeformer,esrgan return gfpgan,codeformer,esrgan
def make_step_callback(gen, opt, prefix): def make_step_callback(gen, opt, prefix):
destination = os.path.join(opt.outdir,'intermediates',prefix) destination = os.path.join(opt.outdir,'intermediates',prefix)
os.makedirs(destination,exist_ok=True) os.makedirs(destination,exist_ok=True)

View File

@ -107,25 +107,27 @@ except Exception:
print(traceback.format_exc()) print(traceback.format_exc())
print('...success') print('...success')
print('Loading clipseq model for text-based masking...',end='') print('Loading clipseg model for text-based masking...',end='')
try: try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download' model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip' model_dest = 'src/clipseg/clipseg_weights.zip'
if not os.path.exists(model_dest): weights_dir = 'src/clipseg/weights'
if not os.path.exists(weights_dir):
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest) urllib.request.urlretrieve(model_url,model_dest)
with zipfile.ZipFile(model_dest,'r') as zip: with zipfile.ZipFile(model_dest,'r') as zip:
zip.extractall('src/clipseg') zip.extractall('src/clipseg')
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights') os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
from models.clipseg import CLIPDensePredT os.remove(model_dest)
from clipseg_models.clipseg import CLIPDensePredT
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, ) model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval() model.eval()
model.load_state_dict( model.load_state_dict(
torch.load('src/clipseg/weights/rd64-uni-refined.pth'), torch.load(
model.load_state_dict(torch.load('src/clipseg/weights/rd64-uni-refined.pth'), 'src/clipseg/weights/rd64-uni-refined.pth',
map_location=torch.device('cpu'), map_location=torch.device('cpu')
strict=False, ),
) strict=False,
) )
except Exception: except Exception:
print('Error installing clipseg model:') print('Error installing clipseg model:')