Merge branch 'main' into lstein/migrate-fix

This commit is contained in:
Kent Keirsey 2023-07-15 10:37:56 -04:00 committed by GitHub
commit 77b0129b4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 613 additions and 6 deletions

View File

@ -1,8 +1,7 @@
import os
import torch
from enum import Enum
from pathlib import Path
from typing import Optional, Union, Literal
from typing import Optional
from .base import (
ModelBase,
ModelConfigBase,
@ -14,6 +13,7 @@ from .base import (
calc_model_size_by_data,
classproperty,
InvalidModelException,
ModelNotFoundException,
)
class ControlNetModelFormat(str, Enum):
@ -60,10 +60,20 @@ class ControlNetModel(ModelBase):
if child_type is not None:
raise Exception("There is no child models in controlnet model")
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
)
model = None
for variant in ['fp16',None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except:
pass
if not model:
raise ModelNotFoundException()
# calc more accurate size
self.model_size = calc_model_size_by_data(model)
return model

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long