wabbajack/Wabbajack.Downloaders.WabbajackCDN/WabbajackCDNDownloader.cs
2022-08-21 14:23:11 -06:00

184 lines
6.6 KiB
C#

using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Wabbajack.Common;
using Wabbajack.Downloaders.Interfaces;
using Wabbajack.DTOs;
using Wabbajack.DTOs.CDN;
using Wabbajack.DTOs.DownloadStates;
using Wabbajack.DTOs.JsonConverters;
using Wabbajack.DTOs.Validation;
using Wabbajack.Hashing.xxHash64;
using Wabbajack.Paths;
using Wabbajack.Paths.IO;
using Wabbajack.RateLimiter;
namespace Wabbajack.Downloaders;
public class WabbajackCDNDownloader : ADownloader<WabbajackCDN>, IUrlDownloader, IChunkedSeekableStreamDownloader
{
public static Dictionary<string, string> DomainRemaps = new()
{
{"wabbajack.b-cdn.net", "authored-files.wabbajack.org"},
{"wabbajack-mirror.b-cdn.net", "mirror.wabbajack.org"},
{"wabbajack-patches.b-cdn.net", "patches.wabbajack.org"},
{"wabbajacktest.b-cdn.net", "test-files.wabbajack.org"}
};
private readonly HttpClient _client;
private readonly DTOSerializer _dtos;
private readonly ILogger<WabbajackCDNDownloader> _logger;
private readonly ParallelOptions _parallelOptions;
private readonly IResource<HttpClient> _limiter;
public WabbajackCDNDownloader(ILogger<WabbajackCDNDownloader> logger, HttpClient client, IResource<HttpClient> limiter, DTOSerializer dtos)
{
_client = client;
_logger = logger;
_dtos = dtos;
_limiter = limiter;
}
public override Task<bool> Prepare()
{
return Task.FromResult(true);
}
public override bool IsAllowed(ServerAllowList allowList, IDownloadState state)
{
return true;
}
public override IDownloadState? Resolve(IReadOnlyDictionary<string, string> iniData)
{
if (iniData.ContainsKey("directURL") && Uri.TryCreate(iniData["directURL"].CleanIniString(), UriKind.Absolute, out var uri))
return Parse(uri);
return null;
}
public IDownloadState? Parse(Uri url)
{
if (DomainRemaps.ContainsKey(url.Host) || DomainRemaps.ContainsValue(url.Host))
return new WabbajackCDN {Url = url};
return null;
}
public Uri UnParse(IDownloadState state)
{
return ((WabbajackCDN) state).Url;
}
public override Priority Priority => Priority.Normal;
public override async Task<Hash> Download(Archive archive, WabbajackCDN state, AbsolutePath destination, IJob job,
CancellationToken token)
{
var definition = (await GetDefinition(state, token))!;
await using var fs = destination.Open(FileMode.Create, FileAccess.Write, FileShare.None);
await definition.Parts.PMapAll(async part =>
{
using var partJob = await _limiter.Begin(
$"Downloading {definition.MungedName} ({part.Index}/{definition.Size})",
part.Size, token);
var msg = MakeMessage(new Uri(state.Url + $"/parts/{part.Index}"));
using var response = await _client.SendAsync(msg, HttpCompletionOption.ResponseHeadersRead, token);
if (!response.IsSuccessStatusCode)
throw new InvalidDataException($"Bad response for part request for part {part.Index}");
var length = response.Content.Headers.ContentLength;
if (length != part.Size)
throw new InvalidDataException(
$"Bad part size, expected {part.Size} got {length} for part {part.Index}");
await using var data = await response.Content.ReadAsStreamAsync(token);
var ms = new MemoryStream();
var hash = await data.HashingCopy(ms, token, partJob);
ms.Position = 0;
if (hash != part.Hash)
{
throw new Exception(
$"Invalid part hash {part.Index} got {hash} instead of {part.Hash} for {definition.MungedName}");
}
return (ms, part);
}).Do(async rec =>
{
var (ms, part) = rec;
fs.Position = part.Offset;
await job.Report((int)part.Size, token);
await ms.CopyToAsync(fs, token);
await fs.FlushAsync(token);
});
return definition.Hash;
}
private async Task<FileDefinition?> GetDefinition(WabbajackCDN state, CancellationToken token)
{
var msg = MakeMessage(new Uri(state.Url + "/definition.json.gz"));
using var data = await _client.SendAsync(msg, token);
if (!data.IsSuccessStatusCode) return null;
await using var stream = await data.Content.ReadAsStreamAsync(token);
await using var gz = new GZipStream(stream, CompressionMode.Decompress);
return (await _dtos.DeserializeAsync<FileDefinition>(gz, token))!;
}
private HttpRequestMessage MakeMessage(Uri url)
{
if (DomainRemaps.TryGetValue(url.Host, out var host))
url = new UriBuilder(url) {Host = host}.Uri;
else
host = url.Host;
var msg = new HttpRequestMessage(HttpMethod.Get, url);
msg.Headers.Add("Host", host);
return msg;
}
public override async Task<bool> Verify(Archive archive, WabbajackCDN archiveState, IJob job,
CancellationToken token)
{
return await GetDefinition(archiveState, token) != null;
}
public override IEnumerable<string> MetaIni(Archive a, WabbajackCDN state)
{
return new[] {$"directURL={state.Url}"};
}
public async ValueTask<Stream> GetChunkedSeekableStream(Archive archive, CancellationToken token)
{
var state = archive.State as WabbajackCDN;
var definition = await GetDefinition(state!, token);
if (definition == null)
throw new Exception("Could not get CDN definition");
return new ChunkedSeekableDownloader(state!, definition!, this);
}
public async Task<byte[]> GetPart(WabbajackCDN state, PartDefinition part, CancellationToken token)
{
var msg = MakeMessage(new Uri(state.Url + $"/parts/{part.Index}"));
using var response = await _client.SendAsync(msg, HttpCompletionOption.ResponseHeadersRead, token);
if (!response.IsSuccessStatusCode)
throw new InvalidDataException($"Bad response for part request for part {part.Index}");
var length = response.Content.Headers.ContentLength;
if (length != part.Size)
throw new InvalidDataException(
$"Bad part size, expected {part.Size} got {length} for part {part.Index}");
return await response.Content.ReadAsByteArrayAsync(token);
}
}