Make the proxy stream the initial result

This commit is contained in:
Timothy Baldridge 2022-06-08 07:12:44 -06:00
parent 3f12422e01
commit ca74e79348
7 changed files with 146 additions and 13 deletions

View File

@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Linq; using System.Linq;
using System.Net.Http; using System.Net.Http;
using System.Text.Encodings.Web; using System.Text.Encodings.Web;
@ -73,6 +74,17 @@ public class GoogleDriveDownloader : ADownloader<DTOs.DownloadStates.GoogleDrive
public override Priority Priority => Priority.Normal; public override Priority Priority => Priority.Normal;
public async Task<T> DownloadStream<T>(Archive archive, Func<Stream, Task<T>> fn, CancellationToken token)
{
var state = archive.State as DTOs.DownloadStates.GoogleDrive;
var msg = await ToMessage(state, true, token);
using var result = await _client.SendAsync(msg, token);
HttpException.ThrowOnFailure(result);
await using var stream = await result.Content.ReadAsStreamAsync(token);
return await fn(stream);
}
public override async Task<Hash> Download(Archive archive, DTOs.DownloadStates.GoogleDrive state, public override async Task<Hash> Download(Archive archive, DTOs.DownloadStates.GoogleDrive state,
AbsolutePath destination, IJob job, CancellationToken token) AbsolutePath destination, IJob job, CancellationToken token)
{ {
@ -98,7 +110,10 @@ public class GoogleDriveDownloader : ADownloader<DTOs.DownloadStates.GoogleDrive
if (download) if (download)
{ {
var initialUrl = $"https://drive.google.com/uc?id={state.Id}&export=download"; var initialUrl = $"https://drive.google.com/uc?id={state.Id}&export=download";
using var response = await _client.GetAsync(initialUrl, token); var msg = new HttpRequestMessage(HttpMethod.Get, initialUrl);
msg.UseChromeUserAgent();
using var response = await _client.SendAsync(msg, token);
var cookies = response.GetSetCookies(); var cookies = response.GetSetCookies();
var warning = cookies.FirstOrDefault(c => c.Key.StartsWith("download_warning_")); var warning = cookies.FirstOrDefault(c => c.Key.StartsWith("download_warning_"));
@ -106,6 +121,8 @@ public class GoogleDriveDownloader : ADownloader<DTOs.DownloadStates.GoogleDrive
{ {
var doc = new HtmlDocument(); var doc = new HtmlDocument();
var txt = await response.Content.ReadAsStringAsync(token); var txt = await response.Content.ReadAsStringAsync(token);
if (txt.Contains("<title>Google Drive - Quota exceeded</title>"))
throw new Exception("Google Drive - Quota Exceeded");
doc.LoadHtml(txt); doc.LoadHtml(txt);
@ -127,13 +144,19 @@ public class GoogleDriveDownloader : ADownloader<DTOs.DownloadStates.GoogleDrive
var url = $"https://drive.google.com/uc?export=download&confirm={warning.Value}&id={state.Id}"; var url = $"https://drive.google.com/uc?export=download&confirm={warning.Value}&id={state.Id}";
var httpState = new HttpRequestMessage(HttpMethod.Get, url); var httpState = new HttpRequestMessage(HttpMethod.Get, url);
httpState.UseChromeUserAgent();
return httpState; return httpState;
} }
else else
{ {
var url = $"https://drive.google.com/file/d/{state.Id}/edit"; var url = $"https://drive.google.com/file/d/{state.Id}/edit";
using var response = await _client.GetAsync(url, token); var msg = new HttpRequestMessage(HttpMethod.Get, url);
return !response.IsSuccessStatusCode ? null : new HttpRequestMessage(HttpMethod.Get, url); msg.UseChromeUserAgent();
using var response = await _client.SendAsync(msg, token);
msg = new HttpRequestMessage(HttpMethod.Get, url);
msg.UseChromeUserAgent();
return !response.IsSuccessStatusCode ? null : msg;
} }
} }
} }

View File

@ -1,6 +1,12 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Wabbajack.DTOs;
namespace Wabbajack.Downloaders.Interfaces; namespace Wabbajack.Downloaders.Interfaces;
public interface IProxyable : IUrlDownloader public interface IProxyable : IUrlDownloader
{ {
public Task<T> DownloadStream<T>(Archive archive, Func<Stream, Task<T>> fn, CancellationToken token);
} }

View File

@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO;
using System.Linq; using System.Linq;
using System.Net.Http; using System.Net.Http;
using System.Threading; using System.Threading;
@ -71,6 +72,16 @@ public class MediaFireDownloader : ADownloader<DTOs.DownloadStates.MediaFire>, I
return ((DTOs.DownloadStates.MediaFire) state).Url; return ((DTOs.DownloadStates.MediaFire) state).Url;
} }
public async Task<T> DownloadStream<T>(Archive archive, Func<Stream, Task<T>> fn, CancellationToken token)
{
var state = archive.State as DTOs.DownloadStates.MediaFire;
var url = await Resolve(state!);
var msg = new HttpRequestMessage(HttpMethod.Get, url!);
using var result = await _httpClient.SendAsync(msg, token);
await using var stream = await result.Content.ReadAsStreamAsync(token);
return await fn(stream);
}
public override async Task<Hash> Download(Archive archive, DTOs.DownloadStates.MediaFire state, public override async Task<Hash> Download(Archive archive, DTOs.DownloadStates.MediaFire state,
AbsolutePath destination, IJob job, CancellationToken token) AbsolutePath destination, IJob job, CancellationToken token)
{ {
@ -85,7 +96,7 @@ public class MediaFireDownloader : ADownloader<DTOs.DownloadStates.MediaFire>, I
return await Resolve(archiveState, job, token) != null; return await Resolve(archiveState, job, token) != null;
} }
private async Task<Uri?> Resolve(DTOs.DownloadStates.MediaFire state, IJob job, CancellationToken? token = null) private async Task<Uri?> Resolve(DTOs.DownloadStates.MediaFire state, IJob? job = null, CancellationToken? token = null)
{ {
token ??= CancellationToken.None; token ??= CancellationToken.None;
using var result = await _httpClient.GetAsync(state.Url, HttpCompletionOption.ResponseHeadersRead, using var result = await _httpClient.GetAsync(state.Url, HttpCompletionOption.ResponseHeadersRead,
@ -93,6 +104,7 @@ public class MediaFireDownloader : ADownloader<DTOs.DownloadStates.MediaFire>, I
if (!result.IsSuccessStatusCode) if (!result.IsSuccessStatusCode)
return null; return null;
if (job != null)
job.Size = result.Content.Headers.ContentLength ?? 0; job.Size = result.Content.Headers.ContentLength ?? 0;
if (result.Content.Headers.ContentType!.MediaType!.StartsWith("text/html", if (result.Content.Headers.ContentType!.MediaType!.StartsWith("text/html",

View File

@ -60,6 +60,16 @@ public class MegaDownloader : ADownloader<Mega>, IUrlDownloader, IProxyable
return ((Mega) state).Url; return ((Mega) state).Url;
} }
public async Task<T> DownloadStream<T>(Archive archive, Func<Stream, Task<T>> fn, CancellationToken token)
{
var state = archive.State as Mega;
if (!_apiClient.IsLoggedIn)
await _apiClient.LoginAsync();
await using var ins = await _apiClient.DownloadAsync(state.Url, cancellationToken: token);
return await fn(ins);
}
public override async Task<Hash> Download(Archive archive, Mega state, AbsolutePath destination, IJob job, public override async Task<Hash> Download(Archive archive, Mega state, AbsolutePath destination, IJob job,
CancellationToken token) CancellationToken token)
{ {

View File

@ -1,3 +1,4 @@
using System;
using System.Buffers; using System.Buffers;
using System.IO; using System.IO;
using System.Threading; using System.Threading;
@ -72,4 +73,57 @@ public static class StreamExtensions
return new Hash(finalHash); return new Hash(finalHash);
} }
public static async Task<Hash> HashingCopy(this Stream inputStream, Func<Memory<byte>, Task> fn,
CancellationToken token)
{
using var rented = MemoryPool<byte>.Shared.Rent(1024 * 1024);
var buffer = rented.Memory;
var hasher = new xxHashAlgorithm(0);
var running = true;
ulong finalHash = 0;
while (running && !token.IsCancellationRequested)
{
var totalRead = 0;
while (totalRead != buffer.Length)
{
var read = await inputStream.ReadAsync(buffer.Slice(totalRead, buffer.Length - totalRead),
token);
if (read == 0)
{
running = false;
break;
}
totalRead += read;
}
var pendingWrite = fn(buffer[..totalRead]);
if (running)
{
hasher.TransformByteGroupsInternal(buffer.Span);
await pendingWrite;
}
else
{
var preSize = (totalRead >> 5) << 5;
if (preSize > 0)
{
hasher.TransformByteGroupsInternal(buffer[..preSize].Span);
finalHash = hasher.FinalizeHashValueInternal(buffer[preSize..totalRead].Span);
await pendingWrite;
break;
}
finalHash = hasher.FinalizeHashValueInternal(buffer[..totalRead].Span);
await pendingWrite;
break;
}
}
return new Hash(finalHash);
}
} }

View File

@ -20,4 +20,10 @@ public class HttpException : Exception
public string Reason { get; set; } public string Reason { get; set; }
public int Code { get; set; } public int Code { get; set; }
public static void ThrowOnFailure(HttpResponseMessage result)
{
if (result.IsSuccessStatusCode) return;
throw new HttpException(result);
}
} }

View File

@ -3,6 +3,7 @@ using FluentFTP.Helpers;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Net.Http.Headers;
using Wabbajack.BuildServer; using Wabbajack.BuildServer;
using Wabbajack.Downloaders; using Wabbajack.Downloaders;
using Wabbajack.Downloaders.Interfaces; using Wabbajack.Downloaders.Interfaces;
@ -107,18 +108,39 @@ public class Proxy : ControllerBase
var tempFile = _tempFileManager.CreateFile(deleteOnDispose:false); var tempFile = _tempFileManager.CreateFile(deleteOnDispose:false);
var result = await _dispatcher.Download(archive, tempFile.Path, token); var proxyDownloader = _dispatcher.Downloader(archive) as IProxyable;
await using (var of = tempFile.Path.Open(FileMode.Create, FileAccess.Write, FileShare.None))
{
Response.StatusCode = 200;
if (name != null)
{
Response.Headers.Add(HeaderNames.ContentDisposition, $"attachment; filename=\"{name}\"");
}
Response.Headers.Add( HeaderNames.ContentType, "application/octet-stream" );
var result = await proxyDownloader.DownloadStream(archive, async s => {
return await s.HashingCopy(async m =>
{
var strmA = of.WriteAsync(m, token);
await Response.Body.WriteAsync(m, token);
await Response.Body.FlushAsync(token);
await strmA;
}, token); },
token);
if (hash != default && result != shouldMatch) if (hash != default && result != shouldMatch)
{ {
if (tempFile.Path.FileExists()) if (tempFile.Path.FileExists())
tempFile.Path.Delete(); tempFile.Path.Delete();
return BadRequest(new {Type = "Unmatching Hashes", Expected = shouldMatch.ToHex(), Found = result.ToHex()});
} }
}
await tempFile.Path.MoveToAsync(cacheFile, true, token); await tempFile.Path.MoveToAsync(cacheFile, true, token);
_logger.LogInformation("Returning proxy request for {Uri} {Size}", uri, cacheFile.Size().FileSizeToString()); _logger.LogInformation("Returning proxy request for {Uri} {Size}", uri, cacheFile.Size().FileSizeToString());
return new PhysicalFileResult(cacheFile.ToString(), "application/binary"); return new EmptyResult();
} }
} }