diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs index b82d47952da..ee3db835eec 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/InvokeRestMethodCommand.Common.cs @@ -85,7 +85,7 @@ public int MaximumFollowRelLink #region Helper Methods - private bool TryProcessFeedStream(BufferingStreamReader responseStream) + private bool TryProcessFeedStream(Stream responseStream) { bool isRssOrFeed = false; @@ -382,95 +382,95 @@ internal override void ProcessResponse(HttpResponseMessage response) { if (response == null) { throw new ArgumentNullException("response"); } - using (BufferingStreamReader responseStream = new BufferingStreamReader(StreamHelper.GetResponseStream(response))) + var baseResponseStream = StreamHelper.GetResponseStream(response); + + if (ShouldWriteToPipeline) { - if (ShouldWriteToPipeline) + using var responseStream = new BufferingStreamReader(baseResponseStream); + + // First see if it is an RSS / ATOM feed, in which case we can + // stream it - unless the user has overridden it with a return type of "XML" + if (TryProcessFeedStream(responseStream)) { - // First see if it is an RSS / ATOM feed, in which case we can - // stream it - unless the user has overridden it with a return type of "XML" - if (TryProcessFeedStream(responseStream)) + // Do nothing, content has been processed. + } + else + { + // determine the response type + RestReturnType returnType = CheckReturnType(response); + + // Try to get the response encoding from the ContentType header. + Encoding encoding = null; + string charSet = response.Content.Headers.ContentType?.CharSet; + if (!string.IsNullOrEmpty(charSet)) { - // Do nothing, content has been processed. + // NOTE: Don't use ContentHelper.GetEncoding; it returns a + // default which bypasses checking for a meta charset value. + StreamHelper.TryGetEncoding(charSet, out encoding); } - else - { - // determine the response type - RestReturnType returnType = CheckReturnType(response); - - // Try to get the response encoding from the ContentType header. - Encoding encoding = null; - string charSet = response.Content.Headers.ContentType?.CharSet; - if (!string.IsNullOrEmpty(charSet)) - { - // NOTE: Don't use ContentHelper.GetEncoding; it returns a - // default which bypasses checking for a meta charset value. - StreamHelper.TryGetEncoding(charSet, out encoding); - } - if (string.IsNullOrEmpty(charSet) && returnType == RestReturnType.Json) - { - encoding = Encoding.UTF8; - } - - object obj = null; - Exception ex = null; + if (string.IsNullOrEmpty(charSet) && returnType == RestReturnType.Json) + { + encoding = Encoding.UTF8; + } - string str = StreamHelper.DecodeStream(responseStream, ref encoding); + object obj = null; + Exception ex = null; - string encodingVerboseName; - try - { - encodingVerboseName = string.IsNullOrEmpty(encoding.HeaderName) ? encoding.EncodingName : encoding.HeaderName; - } - catch (NotSupportedException) - { - encodingVerboseName = encoding.EncodingName; - } - // NOTE: Tests use this verbose output to verify the encoding. - WriteVerbose(string.Format - ( - System.Globalization.CultureInfo.InvariantCulture, - "Content encoding: {0}", - encodingVerboseName) - ); - bool convertSuccess = false; - - if (returnType == RestReturnType.Json) - { - convertSuccess = TryConvertToJson(str, out obj, ref ex) || TryConvertToXml(str, out obj, ref ex); - } - // default to try xml first since it's more common - else - { - convertSuccess = TryConvertToXml(str, out obj, ref ex) || TryConvertToJson(str, out obj, ref ex); - } + string str = StreamHelper.DecodeStream(responseStream, ref encoding); - if (!convertSuccess) - { - // fallback to string - obj = str; - } + string encodingVerboseName; + try + { + encodingVerboseName = string.IsNullOrEmpty(encoding.HeaderName) ? encoding.EncodingName : encoding.HeaderName; + } + catch (NotSupportedException) + { + encodingVerboseName = encoding.EncodingName; + } + // NOTE: Tests use this verbose output to verify the encoding. + WriteVerbose(string.Format + ( + System.Globalization.CultureInfo.InvariantCulture, + "Content encoding: {0}", + encodingVerboseName) + ); + bool convertSuccess = false; + + if (returnType == RestReturnType.Json) + { + convertSuccess = TryConvertToJson(str, out obj, ref ex) || TryConvertToXml(str, out obj, ref ex); + } + // default to try xml first since it's more common + else + { + convertSuccess = TryConvertToXml(str, out obj, ref ex) || TryConvertToJson(str, out obj, ref ex); + } - WriteObject(obj); + if (!convertSuccess) + { + // fallback to string + obj = str; } - } - if (ShouldSaveToOutFile) - { - StreamHelper.SaveStreamToFile(responseStream, QualifiedOutFile, this); + WriteObject(obj); } + } + else if (ShouldSaveToOutFile) + { + StreamHelper.SaveStreamToFile(baseResponseStream, QualifiedOutFile, this, _cancelToken.Token); + } - if (!string.IsNullOrEmpty(StatusCodeVariable)) - { - PSVariableIntrinsics vi = SessionState.PSVariable; - vi.Set(StatusCodeVariable, (int)response.StatusCode); - } + if (!string.IsNullOrEmpty(StatusCodeVariable)) + { + PSVariableIntrinsics vi = SessionState.PSVariable; + vi.Set(StatusCodeVariable, (int)response.StatusCode); + } - if (!string.IsNullOrEmpty(ResponseHeadersVariable)) - { - PSVariableIntrinsics vi = SessionState.PSVariable; - vi.Set(ResponseHeadersVariable, WebResponseHelper.GetHeadersDictionary(response)); - } + if (!string.IsNullOrEmpty(ResponseHeadersVariable)) + { + PSVariableIntrinsics vi = SessionState.PSVariable; + vi.Set(ResponseHeadersVariable, WebResponseHelper.GetHeadersDictionary(response)); } } diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs index ddac58fe57d..1aea7aa82ee 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/Common/WebRequestPSCmdlet.Common.cs @@ -911,7 +911,7 @@ public abstract partial class WebRequestPSCmdlet : PSCmdlet /// /// Cancellation token source. /// - private CancellationTokenSource _cancelToken = null; + internal CancellationTokenSource _cancelToken = null; /// /// Parse Rel Links. diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs index 0a5441faef5..63baf2946fc 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/CoreCLR/InvokeWebRequestCommand.CoreClr.cs @@ -51,7 +51,7 @@ internal override void ProcessResponse(HttpResponseMessage response) if (ShouldSaveToOutFile) { - StreamHelper.SaveStreamToFile(responseStream, QualifiedOutFile, this); + StreamHelper.SaveStreamToFile(responseStream, QualifiedOutFile, this, _cancelToken.Token); } } diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs index 91a9a0e9975..91662ed7878 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/WebCmdlet/StreamHelper.cs @@ -9,6 +9,8 @@ using System.Net.Http; using System.Text; using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; namespace Microsoft.PowerShell.Commands { @@ -99,7 +101,7 @@ public override long Length /// /// /// - public override System.Threading.Tasks.Task CopyToAsync(Stream destination, int bufferSize, System.Threading.CancellationToken cancellationToken) + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) { Initialize(); return base.CopyToAsync(destination, bufferSize, cancellationToken); @@ -124,7 +126,7 @@ public override int Read(byte[] buffer, int offset, int count) /// /// /// - public override System.Threading.Tasks.Task ReadAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { Initialize(); return base.ReadAsync(buffer, offset, count, cancellationToken); @@ -175,7 +177,7 @@ public override void Write(byte[] buffer, int offset, int count) /// /// /// - public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, System.Threading.CancellationToken cancellationToken) + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { Initialize(); return base.WriteAsync(buffer, offset, count, cancellationToken); @@ -273,73 +275,55 @@ internal static class StreamHelper #region Static Methods - internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet) + internal static void WriteToStream(Stream input, Stream output, PSCmdlet cmdlet, CancellationToken cancellationToken) { - byte[] data = new byte[ChunkSize]; + if (cmdlet == null) + { + throw new ArgumentNullException(nameof(cmdlet)); + } - int read = 0; - long totalWritten = 0; - do + Task copyTask = input.CopyToAsync(output, cancellationToken); + + ProgressRecord record = new ProgressRecord( + ActivityId, + WebCmdletStrings.WriteRequestProgressActivity, + WebCmdletStrings.WriteRequestProgressStatus); + try { - if (cmdlet != null) + do { - ProgressRecord record = new ProgressRecord(ActivityId, - WebCmdletStrings.WriteRequestProgressActivity, - StringUtil.Format(WebCmdletStrings.WriteRequestProgressStatus, totalWritten)); + record.StatusDescription = StringUtil.Format(WebCmdletStrings.WriteRequestProgressStatus, output.Position); cmdlet.WriteProgress(record); - } - read = input.Read(data, 0, ChunkSize); + Task.Delay(1000).Wait(cancellationToken); + } + while (!copyTask.IsCompleted && !cancellationToken.IsCancellationRequested); - if (0 < read) + if (copyTask.IsCompleted) { - output.Write(data, 0, read); - totalWritten += read; + record.StatusDescription = StringUtil.Format(WebCmdletStrings.WriteRequestComplete, output.Position); + cmdlet.WriteProgress(record); } - } while (read != 0); - - if (cmdlet != null) + } + catch (OperationCanceledException) { - ProgressRecord record = new ProgressRecord(ActivityId, - WebCmdletStrings.WriteRequestProgressActivity, - StringUtil.Format(WebCmdletStrings.WriteRequestComplete, totalWritten)); - record.RecordType = ProgressRecordType.Completed; - cmdlet.WriteProgress(record); } - - output.Flush(); - } - - internal static void WriteToStream(byte[] input, Stream output) - { - output.Write(input, 0, input.Length); - output.Flush(); } /// /// Saves content from stream into filePath. /// Caller need to ensure position is properly set. /// - /// - /// - /// - internal static void SaveStreamToFile(Stream stream, string filePath, PSCmdlet cmdlet) + /// Input stream. + /// Output file name. + /// Current cmdlet (Invoke-WebRequest or Invoke-RestMethod). + /// CancellationToken to track the cmdlet cancellation. + internal static void SaveStreamToFile(Stream stream, string filePath, PSCmdlet cmdlet, CancellationToken cancellationToken) { // If the web cmdlet should resume, append the file instead of overwriting. - if (cmdlet is WebRequestPSCmdlet webCmdlet && webCmdlet.ShouldResume) - { - using (FileStream output = new FileStream(filePath, FileMode.Append, FileAccess.Write, FileShare.Read)) - { - WriteToStream(stream, output, cmdlet); - } - } - else - { - using (FileStream output = File.Create(filePath)) - { - WriteToStream(stream, output, cmdlet); - } - } + FileMode fileMode = cmdlet is WebRequestPSCmdlet webCmdlet && webCmdlet.ShouldResume ? FileMode.Append : FileMode.Create; + using FileStream output = new FileStream(filePath, fileMode, FileAccess.Write, FileShare.Read); + WriteToStream(stream, output, cmdlet, cancellationToken); } private static string StreamToString(Stream stream, Encoding encoding)