diff --git a/Compression.BSA/MemoryStreamFactory.cs b/Compression.BSA/MemoryStreamFactory.cs index c428ed9e..09821140 100644 --- a/Compression.BSA/MemoryStreamFactory.cs +++ b/Compression.BSA/MemoryStreamFactory.cs @@ -21,4 +21,23 @@ namespace Compression.BSA public DateTime LastModifiedUtc => DateTime.UtcNow; public IPath Name => (RelativePath)"BSA Memory Stream"; } + + public class MemoryBufferFactory : IStreamFactory + { + private readonly byte[] _data; + private int _size; + + public MemoryBufferFactory(byte[] data, int size) + { + _data = data; + _size = size; + } + public async ValueTask GetStream() + { + return new MemoryStream(_data, 0, _size); + } + + public DateTime LastModifiedUtc => DateTime.UtcNow; + public IPath Name => (RelativePath)"BSA Memory Stream"; + } } diff --git a/Wabbajack.Lib/AInstaller.cs b/Wabbajack.Lib/AInstaller.cs index c76db728..ad291fed 100644 --- a/Wabbajack.Lib/AInstaller.cs +++ b/Wabbajack.Lib/AInstaller.cs @@ -150,9 +150,11 @@ namespace Wabbajack.Lib { var patchData = await LoadBytesFromPath(pfa.PatchID); var toFile = file.To.RelativeTo(OutputFolder); - await using var os = await toFile.Create(); - Utils.ApplyPatch(s, () => new MemoryStream(patchData), os); - + { + await using var os = await toFile.Create(); + Utils.ApplyPatch(s, () => new MemoryStream(patchData), os); + } + if (await VirusScanner.ShouldScan(toFile) && await ClientAPI.GetVirusScanResult(toFile) == VirusScanner.Result.Malware) { diff --git a/Wabbajack.VirtualFileSystem/FileExtractor2/FileExtractor.cs b/Wabbajack.VirtualFileSystem/FileExtractor2/FileExtractor.cs index ab66907f..27261a26 100644 --- a/Wabbajack.VirtualFileSystem/FileExtractor2/FileExtractor.cs +++ b/Wabbajack.VirtualFileSystem/FileExtractor2/FileExtractor.cs @@ -27,6 +27,10 @@ namespace Wabbajack.VirtualFileSystem public static async Task> GatheringExtract(IStreamFactory sFn, Predicate shouldExtract, Func> mapfn) { + if (sFn is NativeFileStreamFactory) + { + Utils.Log($"Extracting {sFn.Name}"); + } await using var archive = await sFn.GetStream(); var sig = await ArchiveSigs.MatchesAsync(archive); archive.Position = 0; diff --git a/Wabbajack.VirtualFileSystem/FileExtractor2/GatheringExtractor.cs b/Wabbajack.VirtualFileSystem/FileExtractor2/GatheringExtractor.cs index 154ea053..c19c9158 100644 --- a/Wabbajack.VirtualFileSystem/FileExtractor2/GatheringExtractor.cs +++ b/Wabbajack.VirtualFileSystem/FileExtractor2/GatheringExtractor.cs @@ -21,6 +21,7 @@ namespace Wabbajack.VirtualFileSystem private Dictionary _indexes; private Stream _stream; private Definitions.FileType _sig; + private Exception _killException; public GatheringExtractor(Stream stream, Definitions.FileType sig, Predicate shouldExtract, Func> mapfn) { @@ -51,7 +52,14 @@ namespace Wabbajack.VirtualFileSystem _archive._archive.Extract(null, 0xFFFFFFFF, 0, this); _archive.Dispose(); - source.SetResult(true); + if (_killException != null) + { + source.SetException(_killException); + } + else + { + source.SetResult(true); + } } catch (Exception ex) { @@ -105,51 +113,97 @@ namespace Wabbajack.VirtualFileSystem private uint _index; private bool _written; private ulong _totalSize; - private MemoryStream _tmpStream; + private Stream _tmpStream; + private TempFile _tmpFile; + private IStreamFactory _factory; + private bool _diskCached; public GatheringExtractorStream(GatheringExtractor extractor, uint index) { _extractor = extractor; _index = index; - _written = false; _totalSize = extractor._indexes[index].Item2; - _tmpStream = new MemoryStream(); + _diskCached = _totalSize >= 500_000_000; } - public int Write(IntPtr data, uint size, IntPtr processedSize) + public int Write(byte[] data, uint size, IntPtr processedSize) { - unsafe + try { - var ums = new UnmanagedMemoryStream((byte*)data, size); - ums.CopyTo(_tmpStream); - if ((ulong)_tmpStream.Length >= _totalSize) - { - _tmpStream.Position = 0; - var result = _extractor._mapFn(_extractor._indexes[_index].Item1, new MemoryStreamFactory(_tmpStream)).AsTask().Result; + if (size == _totalSize) + WriteSingleCall(data, size); + else if (_diskCached) + WriteDiskCached(data, size); + else + WriteMemoryCached(data, size); - _extractor._results[_extractor._indexes[_index].Item1] = result; - } - if (processedSize != IntPtr.Zero) { - Marshal.WriteInt32(processedSize, (int) size); + Marshal.WriteInt32(processedSize, (int)size); } - } - - return 0; - - if (_written) throw new Exception("TODO"); - unsafe - { - - - - - _written = true; - - return 0; + return 0; } + catch (Exception ex) + { + Utils.Log($"Error during extraction {ex}"); + _extractor.Kill(ex); + + return 1; + } + } + + private void WriteSingleCall(byte[] data, in uint size) + { + var result = _extractor._mapFn(_extractor._indexes[_index].Item1, new MemoryBufferFactory(data, (int)size)).Result; + AddResult(result); + Cleanup(); + } + + private void Cleanup() + { + _tmpStream?.Dispose(); + _tmpFile?.DisposeAsync().AsTask().Wait(); + } + + private void AddResult(T result) + { + _extractor._results.Add(_extractor._indexes[_index].Item1, result); + } + + private void WriteMemoryCached(byte[] data, in uint size) + { + if (_tmpStream == null) + _tmpStream = new MemoryStream(); + _tmpStream.Write(data, 0, (int)size); + + if (_tmpStream.Length != (long)_totalSize) return; + + _tmpStream.Flush(); + _tmpStream.Position = 0; + var result = _extractor._mapFn(_extractor._indexes[_index].Item1, new MemoryStreamFactory((MemoryStream)_tmpStream)).Result; + AddResult(result); + Cleanup(); + } + + private void WriteDiskCached(byte[] data, in uint size) + { + if (_tmpFile == null) + { + _tmpFile = new TempFile(); + _tmpStream = _tmpFile.Path.Create().Result; + } + + _tmpStream.Write(data, 0, (int)size); + + if (_tmpStream.Length != (long)_totalSize) return; + + _tmpStream.Flush(); + _tmpStream.Close(); + + var result = _extractor._mapFn(_extractor._indexes[_index].Item1, new NativeFileStreamFactory(_tmpFile.Path)).Result; + AddResult(result); + Cleanup(); } public void Seek(long offset, uint seekOrigin, IntPtr newPosition) @@ -162,5 +216,10 @@ namespace Wabbajack.VirtualFileSystem return 0; } } + + private void Kill(Exception ex) + { + _killException = ex; + } } } diff --git a/Wabbajack.VirtualFileSystem/SevenZipExtractor/SevenZipInterface.cs b/Wabbajack.VirtualFileSystem/SevenZipExtractor/SevenZipInterface.cs index cfc0d371..f9b6811c 100644 --- a/Wabbajack.VirtualFileSystem/SevenZipExtractor/SevenZipInterface.cs +++ b/Wabbajack.VirtualFileSystem/SevenZipExtractor/SevenZipInterface.cs @@ -207,7 +207,7 @@ namespace Wabbajack.VirtualFileSystem.SevenZipExtractor { [PreserveSig] int Write( - IntPtr data, + [In, MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)]byte[] data, uint size, IntPtr processedSize); // ref uint processedSize /* @@ -246,7 +246,7 @@ namespace Wabbajack.VirtualFileSystem.SevenZipExtractor { [PreserveSig] int Write( - IntPtr data, + [Out, MarshalAs(UnmanagedType.LPArray, SizeParamIndex = 1)] byte[] data, uint size, IntPtr processedSize); // ref uint processedSize @@ -444,7 +444,7 @@ namespace Wabbajack.VirtualFileSystem.SevenZipExtractor return 0; } - public int Write(IntPtr data, uint size, IntPtr processedSize) + public int Write(byte[] data, uint size, IntPtr processedSize) { throw new NotImplementedException(); /* diff --git a/Wabbajack.VirtualFileSystem/VirtualFile.cs b/Wabbajack.VirtualFileSystem/VirtualFile.cs index 9719ccbb..63bdf4b2 100644 --- a/Wabbajack.VirtualFileSystem/VirtualFile.cs +++ b/Wabbajack.VirtualFileSystem/VirtualFile.cs @@ -209,12 +209,16 @@ namespace Wabbajack.VirtualFileSystem try { - var list = await FileExtractor2.GatheringExtract(extractedFile, - _ => true, + var list = await FileExtractor2.GatheringExtract(extractedFile, + _ => true, async (path, sfactory) => await Analyze(context, self, sfactory, path, depth + 1)); - + self.Children = list.Values.ToImmutableList(); } + catch (EndOfStreamException ex) + { + return self; + } catch (Exception ex) { Utils.Log($"Error while examining the contents of {relPath.FileName}");