mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add checks for malformed URLs and malicious content dispositions
This commit is contained in:
@ -284,8 +284,12 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
if job.destination.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers.get("Content-Disposition", "bug-noname")).group(1)
|
||||
except AttributeError:
|
||||
file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1)
|
||||
self._validate_filename(job.destination, file_name) # will raise a ValueError exception if file_name is suspicious
|
||||
except ValueError:
|
||||
self._logger.warning(f"Invalid filename '{file_name}' returned by source {job.source}, using last component of URL instead")
|
||||
file_name = os.path.basename(job.source)
|
||||
except KeyError:
|
||||
file_name = os.path.basename(job.source)
|
||||
job.destination = job.destination / file_name
|
||||
dest = job.destination
|
||||
@ -336,6 +340,16 @@ class DownloadQueue(DownloadQueueBase):
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _validate_filename(self, directory: str, filename: str):
|
||||
if '/' in filename:
|
||||
raise ValueError
|
||||
if filename.startswith('..'):
|
||||
raise ValueError
|
||||
if len(filename) > os.pathconf(directory, "PC_NAME_MAX"):
|
||||
raise ValueError
|
||||
if len(os.path.join(directory, filename)) > os.pathconf(directory, "PC_PATH_MAX"):
|
||||
raise ValueError
|
||||
|
||||
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
|
||||
"""Optionally change the job status and send an event indicating a change of state."""
|
||||
if new_status:
|
||||
|
@ -43,7 +43,36 @@ for i in ["12345", "9999", "54321"]:
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
session.mount("http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition" : 'filename="missing.txt"',
|
||||
}
|
||||
),
|
||||
)
|
||||
# not found
|
||||
session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
# specifies a content disposition that may overwrite files in the parent directory
|
||||
session.mount("http://www.civitai.com/models/malicious",
|
||||
TestAdapter(
|
||||
b"Malicious URL",
|
||||
headers={
|
||||
"Content-Disposition" : 'filename="../badness.txt"',
|
||||
}
|
||||
),
|
||||
)
|
||||
# Would create a path that is too long
|
||||
session.mount("http://www.civitai.com/models/long",
|
||||
TestAdapter(
|
||||
b"Malicious URL",
|
||||
headers={
|
||||
"Content-Disposition" : f'filename="{"i"*1000}"',
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
# mock HuggingFace URLs
|
||||
hf_sd2_paths = [
|
||||
@ -176,18 +205,42 @@ def test_repo_id_download():
|
||||
assert not Path(repo_root, "text_encoder", "model.fp16.safetensors").exists()
|
||||
|
||||
|
||||
def test_failure_modes():
|
||||
def test_bad_urls():
|
||||
queue = DownloadQueue(
|
||||
requests_session=session,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# do we handle 404 and other HTTP errors?
|
||||
job = queue.create_download_job(source="http://www.civitai.com/models/broken", destdir=tmpdir)
|
||||
queue.join()
|
||||
assert job.status == "error"
|
||||
assert isinstance(job.error, HTTPError)
|
||||
assert str(job.error) == "NOT FOUND"
|
||||
|
||||
# Do we handle missing content length field?
|
||||
job = queue.create_download_job(source="http://www.civitai.com/models/missing", destdir=tmpdir)
|
||||
queue.join()
|
||||
assert job.status == "completed"
|
||||
assert job.total_bytes == 0
|
||||
assert job.bytes > 0
|
||||
assert job.bytes == Path(tmpdir, "missing.txt").stat().st_size
|
||||
|
||||
# Don't let the URL specify a filename with slashes or double dots... (e.g. '../../etc/passwd')
|
||||
job = queue.create_download_job(source="http://www.civitai.com/models/malicious", destdir=tmpdir)
|
||||
queue.join()
|
||||
assert job.status == "completed"
|
||||
assert job.destination == Path(tmpdir, 'malicious')
|
||||
assert Path(tmpdir, 'malicious').exists()
|
||||
|
||||
# Nor a destination that would exceed the maximum filename or path length
|
||||
job = queue.create_download_job(source="http://www.civitai.com/models/long", destdir=tmpdir)
|
||||
queue.join()
|
||||
assert job.status == "completed"
|
||||
assert job.destination == Path(tmpdir, 'long')
|
||||
assert Path(tmpdir, 'long').exists()
|
||||
|
||||
|
||||
# create a foreign job which will be invalid for the queue
|
||||
bad_job = DownloadJobBase(id=999, source="mock", destination="mock")
|
||||
try:
|
||||
@ -288,3 +341,4 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
|
||||
).exists(), "cancelled file should be deleted"
|
||||
|
||||
assert len(queue.list_jobs()) == 0
|
||||
|
||||
|
Reference in New Issue
Block a user