Start adding types

This commit is contained in:
Ivan Habunek 2022-08-20 11:35:07 +02:00
parent 280a284fb2
commit a4b2434735
No known key found for this signature in database
GPG Key ID: CDBD63C43A30BB95
6 changed files with 34 additions and 25 deletions

View File

@ -9,7 +9,7 @@ import tempfile
from os import path from os import path
from pathlib import Path from pathlib import Path
from typing import OrderedDict from typing import List, Optional, OrderedDict
from urllib.parse import urlparse, urlencode from urllib.parse import urlparse, urlencode
from twitchdl import twitch, utils from twitchdl import twitch, utils
@ -137,7 +137,7 @@ def _clip_target_filename(clip, args):
raise ConsoleError("Invalid key {} used in --output. Supported keys are: {}".format(e, supported)) raise ConsoleError("Invalid key {} used in --output. Supported keys are: {}".format(e, supported))
def _get_vod_paths(playlist, start, end): def _get_vod_paths(playlist, start: Optional[int], end: Optional[int]) -> List[str]:
"""Extract unique VOD paths for download from playlist.""" """Extract unique VOD paths for download from playlist."""
files = [] files = []
vod_start = 0 vod_start = 0
@ -157,7 +157,7 @@ def _get_vod_paths(playlist, start, end):
return files return files
def _crete_temp_dir(base_uri): def _crete_temp_dir(base_uri: str) -> str:
"""Create a temp dir to store downloads if it doesn't exist.""" """Create a temp dir to store downloads if it doesn't exist."""
path = urlparse(base_uri).path.lstrip("/") path = urlparse(base_uri).path.lstrip("/")
temp_dir = Path(tempfile.gettempdir(), "twitch-dl", path) temp_dir = Path(tempfile.gettempdir(), "twitch-dl", path)
@ -166,11 +166,11 @@ def _crete_temp_dir(base_uri):
def download(args): def download(args):
for video in args.videos: for video_id in args.videos:
download_one(video, args) download_one(video_id, args)
def download_one(video, args): def download_one(video: str, args):
video_id = utils.parse_video_identifier(video) video_id = utils.parse_video_identifier(video)
if video_id: if video_id:
return _download_video(video_id, args) return _download_video(video_id, args)
@ -227,7 +227,7 @@ def get_clip_authenticated_url(slug, quality):
return "{}?{}".format(url, query) return "{}?{}".format(url, query)
def _download_clip(slug, args): def _download_clip(slug: str, args) -> None:
print_out("<dim>Looking up clip...</dim>") print_out("<dim>Looking up clip...</dim>")
clip = twitch.get_clip(slug) clip = twitch.get_clip(slug)
game = clip["game"]["name"] if clip["game"] else "Unknown" game = clip["game"]["name"] if clip["game"] else "Unknown"
@ -260,7 +260,7 @@ def _download_clip(slug, args):
print_out("Downloaded: <blue>{}</blue>".format(target)) print_out("Downloaded: <blue>{}</blue>".format(target))
def _download_video(video_id, args): def _download_video(video_id, args) -> None:
if args.start and args.end and args.end <= args.start: if args.start and args.end and args.end <= args.start:
raise ConsoleError("End time must be greater than start time") raise ConsoleError("End time must be greater than start time")

View File

@ -5,7 +5,7 @@ import sys
import re import re
from argparse import ArgumentParser, ArgumentTypeError from argparse import ArgumentParser, ArgumentTypeError
from collections import namedtuple from typing import NamedTuple, List, Tuple, Any, Dict
from twitchdl.exceptions import ConsoleError from twitchdl.exceptions import ConsoleError
from twitchdl.output import print_err from twitchdl.output import print_err
@ -13,12 +13,19 @@ from twitchdl.twitch import GQLError
from . import commands, __version__ from . import commands, __version__
Command = namedtuple("Command", ["name", "description", "arguments"]) Argument = Tuple[List[str], Dict[str, Any]]
class Command(NamedTuple):
name: str
description: str
arguments: List[Argument]
CLIENT_WEBSITE = 'https://github.com/ihabunek/twitch-dl' CLIENT_WEBSITE = 'https://github.com/ihabunek/twitch-dl'
def time(value): def time(value: str) -> int:
"""Parse a time string (hh:mm or hh:mm:ss) to number of seconds.""" """Parse a time string (hh:mm or hh:mm:ss) to number of seconds."""
parts = [int(p) for p in value.split(":")] parts = [int(p) for p in value.split(":")]
@ -35,19 +42,19 @@ def time(value):
return hours * 3600 + minutes * 60 + seconds return hours * 3600 + minutes * 60 + seconds
def pos_integer(value): def pos_integer(value: str) -> int:
try: try:
value = int(value) parsed = int(value)
except ValueError: except ValueError:
raise ArgumentTypeError("must be an integer") raise ArgumentTypeError("must be an integer")
if value < 1: if parsed < 1:
raise ArgumentTypeError("must be positive") raise ArgumentTypeError("must be positive")
return value return parsed
def rate(value): def rate(value: str) -> int:
match = re.search(r"^([0-9]+)(k|m|)$", value, flags=re.IGNORECASE) match = re.search(r"^([0-9]+)(k|m|)$", value, flags=re.IGNORECASE)
if not match: if not match:

View File

@ -10,7 +10,7 @@ class DownloadFailed(Exception):
pass pass
def _download(url, path): def _download(url: str, path: str):
tmp_path = path + ".tmp" tmp_path = path + ".tmp"
size = 0 size = 0
with httpx.stream("GET", url, timeout=CONNECT_TIMEOUT) as response: with httpx.stream("GET", url, timeout=CONNECT_TIMEOUT) as response:
@ -23,7 +23,7 @@ def _download(url, path):
return size return size
def download_file(url, path, retries=RETRY_COUNT): def download_file(url: str, path: str, retries: int = RETRY_COUNT):
if os.path.exists(path): if os.path.exists(path):
from_disk = True from_disk = True
return (os.path.getsize(path), from_disk) return (os.path.getsize(path), from_disk)

View File

@ -55,7 +55,7 @@ class TokenBucket:
class EndlessTokenBucket: class EndlessTokenBucket:
"""Used when download speed is not limited.""" """Used when download speed is not limited."""
def advance(self, size): def advance(self, size: int):
pass pass

View File

@ -6,6 +6,7 @@ import re
from itertools import islice from itertools import islice
from twitchdl import utils from twitchdl import utils
from typing import Any, Match
START_CODES = { START_CODES = {
@ -29,26 +30,26 @@ END_PATTERN = "</(" + "|".join(START_CODES.keys()) + ")>"
USE_ANSI_COLOR = "--no-color" not in sys.argv USE_ANSI_COLOR = "--no-color" not in sys.argv
def start_code(match): def start_code(match: Match[str]) -> str:
name = match.group(1) name = match.group(1)
return START_CODES[name] return START_CODES[name]
def colorize(text): def colorize(text: str) -> str:
text = re.sub(START_PATTERN, start_code, text) text = re.sub(START_PATTERN, start_code, text)
text = re.sub(END_PATTERN, END_CODE, text) text = re.sub(END_PATTERN, END_CODE, text)
return text return text
def strip_tags(text): def strip_tags(text: str) -> str:
text = re.sub(START_PATTERN, '', text) text = re.sub(START_PATTERN, '', text)
text = re.sub(END_PATTERN, '', text) text = re.sub(END_PATTERN, '', text)
return text return text
def truncate(string, length): def truncate(string: str, length: int) -> str:
if len(string) > length: if len(string) > length:
return string[:length - 1] + "" return string[:length - 1] + ""
@ -60,7 +61,7 @@ def print_out(*args, **kwargs):
print(*args, **kwargs) print(*args, **kwargs)
def print_json(data): def print_json(data: Any):
print(json.dumps(data)) print(json.dumps(data))

View File

@ -4,6 +4,7 @@ Twitch API access.
import httpx import httpx
from typing import Dict
from twitchdl import CLIENT_ID from twitchdl import CLIENT_ID
from twitchdl.exceptions import ConsoleError from twitchdl.exceptions import ConsoleError
@ -37,7 +38,7 @@ def gql_post(query):
return response return response
def gql_query(query, headers={}): def gql_query(query: str, headers: Dict[str, str] = {}):
url = "https://gql.twitch.tv/gql" url = "https://gql.twitch.tv/gql"
response = authenticated_post(url, json={"query": query}, headers=headers).json() response = authenticated_post(url, json={"query": query}, headers=headers).json()