Files
InvokeAI/tests/AC_model_manager/test_model_download.py
2023-10-09 00:28:21 -04:00

362 lines
14 KiB
Python

"""Test the queued download facility"""
import tempfile
import time
from pathlib import Path
import requests
from requests import HTTPError
from requests_testadapter import TestAdapter
import invokeai.backend.model_manager.download.model_queue as download_queue
from invokeai.backend.model_manager.download import (
DownloadJobBase,
DownloadJobStatus,
ModelDownloadQueue,
UnknownJobIDException,
)
# Allow for at least one chunk to be fetched during the pause/unpause test.
# Otherwise pause test doesn't work because whole file contents are read
# before pause is received.
download_queue.DOWNLOAD_CHUNK_SIZE = 16500
# Prevent pytest deprecation warnings
TestAdapter.__test__ = False
# Disable some tests that require the internet.
INTERNET_AVAILABLE = requests.get("http://www.google.com/").status_code == 200
########################################################################################
# Lots of dummy content here to test model download without using lots of bandwidth
# The repo_id tests are not self-contained because they still need to use the HF API
# to retrieve metainformation about the files to retrieve. However, the big weights files
# are not downloaded.
# If the internet is not available, then the repo_id tests are skipped, but the single
# URL tests are still run.
session = requests.Session()
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
session.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
session.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
# prevent us from going to civitai to get metadata
session.mount("https://civitai.com/api/download/models/", TestAdapter(b"Not found", status=404))
session.mount("https://civitai.com/api/v1/models/", TestAdapter(b"Not found", status=404))
session.mount("https://civitai.com/api/v1/model-versions/", TestAdapter(b"Not found", status=404))
# specifies a content disposition that may overwrite files in the parent directory
session.mount(
"http://www.civitai.com/models/malicious",
TestAdapter(
b"Malicious URL",
headers={
"Content-Disposition": 'filename="../badness.txt"',
},
),
)
# Would create a path that is too long
session.mount(
"http://www.civitai.com/models/long",
TestAdapter(
b"Malicious URL",
headers={
"Content-Disposition": f'filename="{"i"*1000}"',
},
),
)
# mock HuggingFace URLs
hf_sd2_paths = [
"feature_extractor/preprocessor_config.json",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.fp16.safetensors",
"text_encoder/model.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/pytorch_model.bin",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.fp16.safetensors",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.safetensors",
]
for path in hf_sd2_paths:
url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}"
path = Path(path)
filename = path.name
content = b"This is the content for path " + bytearray(path.as_posix(), "utf-8")
session.mount(
url,
TestAdapter(
content,
status=200,
headers={"Content-Length": len(content), "Content-Disposition": f'filename="{filename}"'},
),
)
# This is the content of `model_index.json` for stable-diffusion-2-1
model_index_content = b'{"_class_name": "StableDiffusionPipeline", "_diffusers_version": "0.8.0", "feature_extractor": ["transformers", "CLIPImageProcessor"], "requires_safety_checker": false, "safety_checker": [null, null], "scheduler": ["diffusers", "DDIMScheduler"], "text_encoder": ["transformers", "CLIPTextModel"], "tokenizer": ["transformers", "CLIPTokenizer"], "unet": ["diffusers", "UNet2DConditionModel"], "vae": ["diffusers", "AutoencoderKL"]}'
session.mount(
"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/model_index.json",
TestAdapter(
model_index_content,
status=200,
headers={"Content-Length": len(model_index_content), "Content-Disposition": 'filename="model_index.json"'},
),
)
# ================================================================================================================== #
def test_basic_queue_download():
events = list()
def event_handler(job: DownloadJobBase):
events.append(job.status)
queue = ModelDownloadQueue(
requests_session=session,
event_handlers=[event_handler],
)
with tempfile.TemporaryDirectory() as tmpdir:
job = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
assert isinstance(job, DownloadJobBase), "expected the job to be of type DownloadJobBase"
assert isinstance(job.id, int), "expected the job id to be numeric"
assert job.status == "idle", "expected job status to be idle"
assert job.status == DownloadJobStatus.IDLE
queue.start_job(job)
queue.join()
assert events[0] == DownloadJobStatus.ENQUEUED
assert events[-1] == DownloadJobStatus.COMPLETED
assert DownloadJobStatus.RUNNING in events
assert Path(tmpdir, "mock12345.safetensors").exists(), f"expected {tmpdir}/mock12345.safetensors to exist"
def test_queue_priority():
queue = ModelDownloadQueue(
requests_session=session,
)
with tempfile.TemporaryDirectory() as tmpdir:
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
queue.change_priority(job1, -10) # make id1 run first
assert job1 < job2
queue.start_all_jobs()
queue.join()
assert job1.job_sequence < job2.job_sequence
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
queue.change_priority(job2, -10) # make id2 run first
assert job2 < job1
queue.start_all_jobs()
queue.join()
assert job2.job_sequence < job1.job_sequence
assert Path(tmpdir, "mock12345.safetensors").exists(), f"expected {tmpdir}/mock12345.safetensors to exist"
assert Path(tmpdir, "mock9999.safetensors").exists(), f"expected {tmpdir}/mock9999.safetensors to exist"
def test_repo_id_download():
if not INTERNET_AVAILABLE:
return
repo_id = "stabilityai/stable-diffusion-2-1"
queue = ModelDownloadQueue(
requests_session=session,
)
# first with fp16 variant
with tempfile.TemporaryDirectory() as tmpdir:
queue.create_download_job(source=repo_id, destdir=tmpdir, variant="fp16", start=True)
queue.join()
repo_root = Path(tmpdir, "stable-diffusion-2-1")
assert repo_root.exists()
assert Path(repo_root, "model_index.json").exists()
assert Path(repo_root, "text_encoder", "config.json").exists()
assert Path(repo_root, "text_encoder", "model.fp16.safetensors").exists()
# then without fp16
with tempfile.TemporaryDirectory() as tmpdir:
queue.create_download_job(source=repo_id, destdir=tmpdir, start=True)
queue.join()
repo_root = Path(tmpdir, "stable-diffusion-2-1")
assert Path(repo_root, "text_encoder", "model.safetensors").exists()
assert not Path(repo_root, "text_encoder", "model.fp16.safetensors").exists()
def test_bad_urls():
queue = ModelDownloadQueue(
requests_session=session,
)
with tempfile.TemporaryDirectory() as tmpdir:
# do we handle 404 and other HTTP errors?
job = queue.create_download_job(source="http://www.civitai.com/models/broken", destdir=tmpdir)
queue.join()
assert job.status == "error"
assert isinstance(job.error, HTTPError)
assert str(job.error) == "NOT FOUND"
# Do we handle missing content length field?
job = queue.create_download_job(source="http://www.civitai.com/models/missing", destdir=tmpdir)
queue.join()
assert job.status == "completed"
assert job.total_bytes == 0
assert job.bytes > 0
assert job.bytes == Path(tmpdir, "missing.txt").stat().st_size
# Don't let the URL specify a filename with slashes or double dots... (e.g. '../../etc/passwd')
job = queue.create_download_job(source="http://www.civitai.com/models/malicious", destdir=tmpdir)
queue.join()
assert job.status == "completed"
assert job.destination == Path(tmpdir, "malicious")
assert Path(tmpdir, "malicious").exists()
# Nor a destination that would exceed the maximum filename or path length
job = queue.create_download_job(source="http://www.civitai.com/models/long", destdir=tmpdir)
queue.join()
assert job.status == "completed"
assert job.destination == Path(tmpdir, "long")
assert Path(tmpdir, "long").exists()
# create a foreign job which will be invalid for the queue
bad_job = DownloadJobBase(id=999, source="mock", destination="mock")
try:
queue.start_job(bad_job) # this should fail
succeeded = True
except UnknownJobIDException:
succeeded = False
assert not succeeded
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
def event_handler(job: DownloadJobBase):
time.sleep(0.5) # slow down the thread so that we can recover the paused state
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
with tempfile.TemporaryDirectory() as tmpdir:
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
job3 = queue.create_download_job(source="http://www.civitai.com/models/54321", destdir=tmpdir, start=False)
assert job1.status == "idle"
queue.start_job(job1)
queue.start_job(job3)
time.sleep(0.1) # wait for enqueueing
assert job1.status in ["enqueued", "running"]
# check pause and restart
queue.pause_job(job1)
time.sleep(0.1) # wait to be paused
assert job1.status == "paused"
queue.start_job(job1)
time.sleep(0.1)
assert job1.status == "running"
# check cancel
queue.start_job(job2)
time.sleep(0.1)
assert job2.status == "running"
queue.cancel_job(job2)
time.sleep(0.1)
assert job2.status == "cancelled"
queue.join()
assert job1.status == "completed"
assert job2.status == "cancelled"
assert job3.status == "completed"
assert Path(tmpdir, "mock12345.safetensors").exists()
assert Path(tmpdir, "mock9999.safetensors").exists() is False, "cancelled file should be deleted"
assert Path(tmpdir, "mock54321.safetensors").exists()
assert len(queue.list_jobs()) == 3
queue.prune_jobs()
assert len(queue.list_jobs()) == 0
def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions
def event_handler(job: DownloadJobBase):
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
if not INTERNET_AVAILABLE:
return
repo_id = "stabilityai/stable-diffusion-2-1"
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False)
job2 = queue.create_download_job(source=repo_id, destdir=tmpdir2, variant="fp16", start=False)
assert job1.status == "idle"
queue.start_job(job1)
time.sleep(0.1) # wait for enqueueing
assert job1.status in ["enqueued", "running"]
# check pause and restart
queue.pause_job(job1)
time.sleep(0.1) # wait to be paused
assert job1.status == "paused"
queue.start_job(job1)
time.sleep(0.1)
assert job1.status == "running"
# check cancel
queue.start_job(job2)
time.sleep(0.1)
assert job2.status == "running"
queue.cancel_job(job2)
time.sleep(0.1)
assert job2.status == "cancelled"
queue.join()
assert job1.status == "completed"
assert job2.status == "cancelled"
assert Path(tmpdir1, "stable-diffusion-2-1", "model_index.json").exists()
assert not Path(
tmpdir2, "stable-diffusion-2-1", "model_index.json"
).exists(), "cancelled file should be deleted"
assert len(queue.list_jobs()) == 0