using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Reactive.Linq; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Compression.BSA; using Wabbajack.Common; using Wabbajack.Common.FileSignatures; using Wabbajack.Common.StatusFeed.Errors; using Wabbajack.VirtualFileSystem.SevenZipExtractor; namespace Wabbajack.VirtualFileSystem { public class GatheringExtractor : IArchiveExtractCallback { private ArchiveFile _archive; private Predicate _shouldExtract; private Func> _mapFn; private Dictionary _results; private Definitions.FileType _sig; private Exception _killException; private uint _itemsCount; private IStreamFactory _streamFactory; public GatheringExtractor(IStreamFactory sF, Definitions.FileType sig, Predicate shouldExtract, Func> mapfn) { _shouldExtract = shouldExtract; _mapFn = mapfn; _results = new Dictionary(); _streamFactory = sF; _sig = sig; } public async Task> Extract() { var source = new TaskCompletionSource(); var th = new Thread(() => { try { using var stream = _streamFactory.GetStream().Result; _archive = ArchiveFile.Open(stream, _sig).Result; ulong checkPos = (ulong)stream.Length; var oresult = _archive._archive.Open(_archive._archiveStream, ref checkPos, new ArchiveCallback()); // Can't read this with the COM interface for some reason if (oresult != 0) { var _ = ExtractSlow(source, _streamFactory); return; } _itemsCount = _archive._archive.GetNumberOfItems(); var result = _archive._archive.Extract(null, 0xFFFFFFFF, 0, this); _archive.Dispose(); if (_killException != null) { source.SetException(_killException); } else { source.SetResult(true); } } catch (Exception ex) { source.SetException(ex); } }) {Priority = ThreadPriority.BelowNormal, Name = "7Zip Extraction Worker Thread"}; th.Start(); await source.Task; return _results; } private async Task ExtractSlow(TaskCompletionSource tcs, IStreamFactory streamFactory) { try { TempFile tempFile = null; AbsolutePath source; if (streamFactory is NativeFileStreamFactory nsf) { source = (AbsolutePath)nsf.Name; } else { await using var stream = await streamFactory.GetStream(); tempFile = new TempFile(); await tempFile.Path.WriteAllAsync(stream); } var dest = await TempFolder.Create(); Utils.Log( $"The contents of {(string)source.FileName} are being extracted to {(string)source.FileName} using 7zip.exe"); var process = new ProcessHelper {Path = @"Extractors\7z.exe".RelativeTo(AbsolutePath.EntryPoint),}; process.Arguments = new object[] {"x", "-bsp1", "-y", $"-o\"{dest.Dir}\"", source, "-mmt=off"}; var _ = process.Output.Where(d => d.Type == ProcessHelper.StreamType.Output) .ForEachAsync(p => { var (_, line) = p; if (line == null) return; if (line.Length <= 4 || line[3] != '%') return; int.TryParse(line.Substring(0, 3), out var percentInt); Utils.Status($"Extracting {(string)source.FileName} - {line.Trim()}", Percent.FactoryPutInRange(percentInt / 100d)); }); var exitCode = await process.Start(); if (exitCode != 0) { Utils.ErrorThrow(new _7zipReturnError(exitCode, source, dest.Dir, "")); } else { Utils.Status($"Extracting {source.FileName} - done", Percent.One, alsoLog: true); } if (tempFile != null) { await tempFile.DisposeAsync(); } foreach (var file in dest.Dir.EnumerateFiles()) { var relPath = file.RelativeTo(dest.Dir); if (!_shouldExtract(relPath)) continue; var result = await _mapFn(relPath, new NativeFileStreamFactory(file)); _results[relPath] = result; await file.DeleteAsync(); } tcs.SetResult(true); } catch (Exception ex) { tcs.SetException(ex); } } public void SetTotal(ulong total) { } public void SetCompleted(ref ulong completeValue) { } public int GetStream(uint index, out ISequentialOutStream outStream, AskMode askExtractMode) { var entry = _archive.GetEntry(index); var path = (RelativePath)entry.FileName; if (entry.IsFolder || !_shouldExtract(path)) { outStream = null; return 0; } Utils.Status($"Extracting {path}", Percent.FactoryPutInRange(_results.Count, _itemsCount)); // Empty files are never extracted via a write call, so we have to fake that now if (entry.Size == 0) { var result = _mapFn(path, new MemoryStreamFactory(new MemoryStream(), path)).Result; _results.Add(path, result); } outStream = new GatheringExtractorStream(this, entry, path); return 0; } public void PrepareOperation(AskMode askExtractMode) { } public void SetOperationResult(OperationResult resultEOperationResult) { } private class GatheringExtractorStream : ISequentialOutStream, IOutStream { private GatheringExtractor _extractor; private ulong _totalSize; private Stream _tmpStream; private TempFile _tmpFile; private bool _diskCached; private RelativePath _path; public GatheringExtractorStream(GatheringExtractor extractor, Entry entry, RelativePath path) { _path = path; _extractor = extractor; _totalSize = entry.Size; _diskCached = _totalSize >= int.MaxValue - 1024; } private IPath GetPath() { return _path; } public int Write(byte[] data, uint size, IntPtr processedSize) { try { if (size == _totalSize) WriteSingleCall(data, size); else if (_diskCached) WriteDiskCached(data, size); else WriteMemoryCached(data, size); if (processedSize != IntPtr.Zero) { Marshal.WriteInt32(processedSize, (int)size); } return 0; } catch (Exception ex) { Utils.Log($"Error during extraction {ex}"); _extractor.Kill(ex); return 1; } } private void WriteSingleCall(byte[] data, in uint size) { var result = _extractor._mapFn(_path, new MemoryBufferFactory(data, (int)size, GetPath())).Result; AddResult(result); Cleanup(); } private void Cleanup() { _tmpStream?.Dispose(); _tmpFile?.DisposeAsync().AsTask().Wait(); } private void AddResult(T result) { _extractor._results.Add(_path, result); } private void WriteMemoryCached(byte[] data, in uint size) { if (_tmpStream == null) _tmpStream = new MemoryStream(); _tmpStream.Write(data, 0, (int)size); if (_tmpStream.Length != (long)_totalSize) return; _tmpStream.Flush(); _tmpStream.Position = 0; var result = _extractor._mapFn(_path, new MemoryStreamFactory((MemoryStream)_tmpStream, GetPath())).Result; AddResult(result); Cleanup(); } private void WriteDiskCached(byte[] data, in uint size) { if (_tmpFile == null) { _tmpFile = new TempFile(); _tmpStream = _tmpFile.Path.Create().Result; } _tmpStream.Write(data, 0, (int)size); if (_tmpStream.Length != (long)_totalSize) return; _tmpStream.Flush(); _tmpStream.Close(); var result = _extractor._mapFn(_path, new NativeFileStreamFactory(_tmpFile.Path, GetPath())).Result; AddResult(result); Cleanup(); } public void Seek(long offset, uint seekOrigin, IntPtr newPosition) { } public int SetSize(long newSize) { return 0; } } private void Kill(Exception ex) { _killException = ex; } class ArchiveCallback : IArchiveOpenCallback { public void SetTotal(IntPtr files, IntPtr bytes) { } public void SetCompleted(IntPtr files, IntPtr bytes) { } } } }