diff --git a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/AddType.cs b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/AddType.cs index a71101f269e..cd5a42413aa 100644 --- a/src/Microsoft.PowerShell.Commands.Utility/commands/utility/AddType.cs +++ b/src/Microsoft.PowerShell.Commands.Utility/commands/utility/AddType.cs @@ -763,6 +763,14 @@ internal string GetUsingSet(Language language) /// protected override void EndProcessing() { + // Prevent code compilation in ConstrainedLanguage mode + if (SessionState.LanguageMode == PSLanguageMode.ConstrainedLanguage) + { + ThrowTerminatingError( + new ErrorRecord( + new PSNotSupportedException(AddTypeStrings.CannotDefineNewType), "CannotDefineNewType", ErrorCategory.PermissionDenied, null)); + } + // Generate an error if they've specified an output // assembly type without an output assembly if (String.IsNullOrEmpty(outputAssembly) && outputTypeSpecified) @@ -779,36 +787,6 @@ protected override void EndProcessing() ThrowTerminatingError(errorRecord); return; } - - PopulateSource(); - } - - internal void PopulateSource() - { - // Prevent code compilation in ConstrainedLanguage mode - if (SessionState.LanguageMode == PSLanguageMode.ConstrainedLanguage) - { - ThrowTerminatingError( - new ErrorRecord( - new PSNotSupportedException(AddTypeStrings.CannotDefineNewType), "CannotDefineNewType", ErrorCategory.PermissionDenied, null)); - } - - // Load the source if they want to load from a file - if (String.Equals(ParameterSetName, "FromPath", StringComparison.OrdinalIgnoreCase) || - String.Equals(ParameterSetName, "FromLiteralPath", StringComparison.OrdinalIgnoreCase) - ) - { - sourceCode = ""; - foreach (string file in paths) - { - sourceCode += System.IO.File.ReadAllText(file) + "\n"; - } - } - - if (String.Equals(ParameterSetName, "FromMember", StringComparison.OrdinalIgnoreCase)) - { - sourceCode = GenerateTypeSource(typeNamespace, Name, sourceCode, language); - } } internal void HandleCompilerErrors(AddTypeCompilerError[] compilerErrors) @@ -853,12 +831,12 @@ private void OutputError(AddTypeCompilerError error, string[] actualSource) { if (!String.IsNullOrEmpty(error.FileName)) { - actualSource = System.IO.File.ReadAllText(error.FileName).Split(Utils.Separators.Newline); + actualSource = System.IO.File.ReadAllLines(error.FileName); } } string errorText = StringUtil.Format(AddTypeStrings.CompilationErrorFormat, - error.FileName, error.Line, error.ErrorText) + "\n"; + error.FileName, error.Line, error.ErrorText) + Environment.NewLine; for (int lineNumber = error.Line - 1; lineNumber < error.Line + 2; lineNumber++) { @@ -876,8 +854,8 @@ private void OutputError(AddTypeCompilerError error, string[] actualSource) lineText += actualSource[lineNumber - 1]; - errorText += "\n" + StringUtil.Format(AddTypeStrings.CompilationErrorFormat, - error.FileName, lineNumber, lineText) + "\n"; + errorText += Environment.NewLine + StringUtil.Format(AddTypeStrings.CompilationErrorFormat, + error.FileName, lineNumber, lineText) + Environment.NewLine; } } @@ -932,13 +910,35 @@ protected override void EndProcessing() { // Load the source if they want to load from a file if (String.Equals(ParameterSetName, "FromPath", StringComparison.OrdinalIgnoreCase) || - String.Equals(ParameterSetName, "FromLiteralPath", StringComparison.OrdinalIgnoreCase)) + String.Equals(ParameterSetName, "FromLiteralPath", StringComparison.OrdinalIgnoreCase) + ) { - this.sourceCode = ""; - foreach (string file in paths) + if (paths.Length == 1) { - this.sourceCode += System.IO.File.ReadAllText(file) + "\n"; + sourceCode = File.ReadAllText(paths[0]); } + else + { + + // We replace 'ReadAllText' with 'StringBuilder' and 'ReadAllLines' + // to avoide temporary LOH allocations. + + StringBuilder sb = new StringBuilder(8192); + + foreach (string file in paths) + { + foreach (string line in File.ReadAllLines(file)) + { + sb.AppendLine(line); + } + } + + sourceCode = sb.ToString(); + } + } + else if (String.Equals(ParameterSetName, "FromMember", StringComparison.OrdinalIgnoreCase)) + { + sourceCode = GenerateTypeSource(typeNamespace, Name, sourceCode, language); } CompileSourceToAssembly(this.sourceCode); diff --git a/test/powershell/Modules/Microsoft.PowerShell.Utility/Add-Type.Tests.ps1 b/test/powershell/Modules/Microsoft.PowerShell.Utility/Add-Type.Tests.ps1 index e4efeb10f32..b36ce180d51 100644 --- a/test/powershell/Modules/Microsoft.PowerShell.Utility/Add-Type.Tests.ps1 +++ b/test/powershell/Modules/Microsoft.PowerShell.Utility/Add-Type.Tests.ps1 @@ -1,6 +1,38 @@ -$guid = [Guid]::NewGuid().ToString().Replace("-","") - Describe "Add-Type" -Tags "CI" { + BeforeAll { + $guid = [Guid]::NewGuid().ToString().Replace("-","") + + $code1 = @" + namespace Test.AddType + { + public class BasicTest1 + { + public static int Add1(int a, int b) + { + return (a + b); + } + } + } +"@ + $code2 = @" + namespace Test.AddType + { + public class BasicTest2 + { + public static int Add2(int a, int b) + { + return (a + b); + } + } + } +"@ + $codeFile1 = Join-Path -Path $TestDrive -ChildPath "codeFile1.cs" + $codeFile2 = Join-Path -Path $TestDrive -ChildPath "codeFile2.cs" + + Set-Content -Path $codeFile1 -Value $code1 -Force + Set-Content -Path $codeFile2 -Value $code2 -Force + } + It "Public 'Language' enumeration contains all members" { [Enum]::GetNames("Microsoft.PowerShell.Commands.Language") -join "," | Should Be "CSharp,CSharpVersion7,CSharpVersion6,CSharpVersion5,CSharpVersion4,CSharpVersion3,CSharpVersion2,CSharpVersion1,VisualBasic,JScript" } @@ -20,4 +52,15 @@ public class AttributeTest$guid {} It "Can load TPA assembly System.Runtime.Serialization.Primitives.dll" { Add-Type -AssemblyName 'System.Runtime.Serialization.Primitives' -PassThru | Should Not Be $null } + + It "Can compile C# files" { + + { [Test.AddType.BasicTest1]::Add1(1, 2) } | Should Throw + { [Test.AddType.BasicTest2]::Add2(3, 4) } | Should Throw + + Add-Type -Path $codeFile1,$codeFile2 + + { [Test.AddType.BasicTest1]::Add1(1, 2) } | Should Not Throw + { [Test.AddType.BasicTest2]::Add2(3, 4) } | Should Not Throw + } }