diff --git a/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs b/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs index e86e87e137b..df5980abaa8 100644 --- a/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs +++ b/src/System.Management.Automation/engine/CommandCompletion/CompletionCompleters.cs @@ -5524,7 +5524,7 @@ private static void AddUniqueVariable(HashSet hashedResults, List s_varModificationCommands = new(StringComparer.OrdinalIgnoreCase) + internal static readonly HashSet s_varModificationCommands = new(StringComparer.OrdinalIgnoreCase) { "New-Variable", "nv", @@ -5533,13 +5533,13 @@ private static void AddUniqueVariable(HashSet hashedResults, List hashedResults, List s_localScopeCommandNames = new(StringComparer.OrdinalIgnoreCase) + internal static readonly HashSet s_localScopeCommandNames = new(StringComparer.OrdinalIgnoreCase) { "Microsoft.PowerShell.Core\\ForEach-Object", "ForEach-Object", diff --git a/src/System.Management.Automation/engine/parser/TypeInferenceVisitor.cs b/src/System.Management.Automation/engine/parser/TypeInferenceVisitor.cs index a1763755f8f..6996b64eb2a 100644 --- a/src/System.Management.Automation/engine/parser/TypeInferenceVisitor.cs +++ b/src/System.Management.Automation/engine/parser/TypeInferenceVisitor.cs @@ -1272,7 +1272,7 @@ object ICustomAstVisitor.VisitFileRedirection(FileRedirectionAst fileRedirection return TypeInferenceContext.EmptyPSTypeNameArray; } - private void InferTypesFrom(CommandAst commandAst, List inferredTypes) + private void InferTypesFrom(CommandAst commandAst, List inferredTypes, bool forRedirection = false) { if (commandAst.Redirections.Count > 0) { @@ -1282,7 +1282,7 @@ private void InferTypesFrom(CommandAst commandAst, List inferredType { if (streamRedirection is FileRedirectionAst fileRedirection) { - if (fileRedirection.FromStream is RedirectionStream.All or RedirectionStream.Output) + if (!forRedirection && fileRedirection.FromStream is RedirectionStream.All or RedirectionStream.Output) { // command output is redirected so it returns nothing. return; @@ -1952,6 +1952,36 @@ private IEnumerable InferTypesFrom(MemberExpressionAst memberExpress return res; } + private static IEnumerable InferTypeFromRef(InvokeMemberExpressionAst invokeMember, ExpressionAst refArgument) + { + Type expressionClrType = (invokeMember.Expression as TypeExpressionAst)?.TypeName.GetReflectionType(); + string memberName = (invokeMember.Member as StringConstantExpressionAst)?.Value; + int argumentIndex = invokeMember.Arguments.IndexOf(refArgument); + if (expressionClrType is null || string.IsNullOrEmpty(memberName) || argumentIndex == -1) + { + yield break; + } + + foreach (MemberInfo memberInfo in expressionClrType.GetMember(memberName)) + { + if (memberInfo.MemberType == MemberTypes.Method) + { + var methodInfo = memberInfo as MethodInfo; + ParameterInfo[] methodParams = methodInfo.GetParameters(); + if (methodParams.Length < argumentIndex) + { + continue; + } + + ParameterInfo paramCandidate = methodParams[argumentIndex]; + if (paramCandidate.IsOut) + { + yield return new PSTypeName(paramCandidate.ParameterType.GetElementType()); + } + } + } + } + private void GetTypesOfMembers( PSTypeName thisType, string memberName, @@ -2248,7 +2278,7 @@ private void InferTypeFrom(VariableExpressionAst variableExpressionAst, List 0) { if (switchErrorStatement.Conditions[0].Extent.EndOffset < variableExpressionAst.Extent.StartOffset) { - parent = switchErrorStatement.Conditions[0]; + currentAst = switchErrorStatement.Conditions[0]; break; } else { // $_ is inside the condition that is being declared, eg: Get-Process | Sort-Object -Property {switch ($_.Proc - parent = switchErrorStatement.Parent; + currentAst = switchErrorStatement.Parent; continue; } } break; } - else if (parent is ScriptBlockExpressionAst) + else if (currentAst is ScriptBlockExpressionAst) { hasSeenScriptBlock = true; } else if (hasSeenScriptBlock) { - if (parent is InvokeMemberExpressionAst invokeMember) + if (currentAst is InvokeMemberExpressionAst invokeMember) { - parent = invokeMember.Expression; + currentAst = invokeMember.Expression; break; } - else if (parent is CommandAst cmdAst && cmdAst.Parent is PipelineAst pipeline && pipeline.PipelineElements.Count > 1) + else if (currentAst is CommandAst cmdAst && cmdAst.Parent is PipelineAst pipeline && pipeline.PipelineElements.Count > 1) { // We've found a pipeline with multiple commands, now we need to determine what command came before the command with the scriptblock: // eg Get-Partition in this example: Get-Disk | Get-Partition | Where {$_} var indexOfPreviousCommand = pipeline.PipelineElements.IndexOf(cmdAst) - 1; if (indexOfPreviousCommand >= 0) { - parent = pipeline.PipelineElements[indexOfPreviousCommand]; + currentAst = pipeline.PipelineElements[indexOfPreviousCommand]; break; } } } - parent = parent.Parent; + currentAst = currentAst.Parent; } - if (parent is CatchClauseAst catchBlock) + if (currentAst is CatchClauseAst catchBlock) { if (catchBlock.CatchTypes.Count > 0) { @@ -2339,7 +2369,7 @@ private void InferTypeFrom(VariableExpressionAst variableExpressionAst, List)AstSearcher.FindAll( - parent, - ast => + if (!astVariablePath.UnqualifiedPath.EqualsOrdinalIgnoreCase(SpecialVariables.AutomaticVariables[i])) { - if (ast is ParameterAst || ast is AssignmentStatementAst || ast is CommandAst) - { - return variableExpressionAst.AstAssignsToSameVariable(ast) - && ast.Extent.EndOffset < startOffset; - } + continue; + } - if (ast is ForEachStatementAst) - { - return variableExpressionAst.AstAssignsToSameVariable(ast) - && ast.Extent.StartOffset < startOffset; - } + Type type = SpecialVariables.AutomaticVariableTypes[i]; + if (type != typeof(object)) + { + inferredTypes.Add(new PSTypeName(type)); + } - return false; - }, - searchNestedScriptBlocks: true); + return; + } - foreach (var ast in targetAsts) + // This visitor + loop finds the start of the current scope and traverses top to bottom to find the nearest variable assignment. + // Then repeats the process for each parent scope. + var assignmentVisitor = new VariableAssignmentVisitor() + { + ScopeIsLocal = true, + LocalScopeOnly = variableExpressionAst.VariablePath.IsLocal || variableExpressionAst.VariablePath.IsPrivate, + StopSearchOffset = variableExpressionAst.Extent.StartOffset, + VariableTarget = variableExpressionAst + }; + while (currentAst is not null) { - if (ast is ParameterAst parameterAst) + if (currentAst is IParameterMetadataProvider) { - var currentCount = inferredTypes.Count; - inferredTypes.AddRange(InferTypes(parameterAst)); + currentAst.Visit(assignmentVisitor); - if (inferredTypes.Count != currentCount) + if (assignmentVisitor.LocalScopeOnly + || assignmentVisitor.LastConstraint is not null + || ((assignmentVisitor.LastAssignment is not null || assignmentVisitor.LastAssignmentType is not null) + && (currentAst.Parent is not ScriptBlockExpressionAst scriptBlock || !scriptBlock.IsDotsourced()))) { - return; + // We only care about the parent scopes if no assignment has been made in the current scope + // or if it's a dot sourced scriptblock where an earlier defined type constraint could influence the final type + break; } - } - } - - var assignAsts = targetAsts.OfType().ToArray(); - // If any of the assignments lhs use a type constraint, then we use that. - // Otherwise, we use the rhs of the "nearest" assignment - for (int i = assignAsts.Length - 1; i >= 0; i--) - { - if (assignAsts[i].Left is ConvertExpressionAst lhsConvert) - { - inferredTypes.Add(new PSTypeName(lhsConvert.Type.TypeName)); - return; + assignmentVisitor.ScopeIsLocal = false; + assignmentVisitor.StopSearchOffset = currentAst.Extent.StartOffset; } - } - var foreachAst = targetAsts.OfType().FirstOrDefault(); - if (foreachAst != null) - { - inferredTypes.AddRange( - GetInferredEnumeratedTypes(InferTypes(foreachAst.Condition))); - return; + currentAst = currentAst.Parent; } - var commandCompletionAst = targetAsts.OfType().FirstOrDefault(); - if (commandCompletionAst != null) + // The visitor is done finding the last assignment, now we need to infer the type of that assignment. + if (assignmentVisitor.LastConstraint is not null) { - inferredTypes.AddRange(InferTypes(commandCompletionAst)); - return; + inferredTypes.Add(new PSTypeName(assignmentVisitor.LastConstraint)); } - - int smallestDiff = int.MaxValue; - AssignmentStatementAst closestAssignment = null; - foreach (var assignAst in assignAsts) + else if (assignmentVisitor.LastAssignment is not null) { - var endOffset = assignAst.Extent.EndOffset; - if ((startOffset - endOffset) < smallestDiff) + if (assignmentVisitor.EnumerateAssignment) { - smallestDiff = startOffset - endOffset; - closestAssignment = assignAst; + inferredTypes.AddRange(GetInferredEnumeratedTypes(InferTypes(assignmentVisitor.LastAssignment))); + } + else + { + if (assignmentVisitor.LastAssignment is ConvertExpressionAst convertExpression + && convertExpression.IsRef()) + { + if (convertExpression.Parent is InvokeMemberExpressionAst memberInvoke) + { + inferredTypes.AddRange(InferTypeFromRef(memberInvoke, convertExpression)); + } + } + else if (assignmentVisitor.RedirectionAssignment && assignmentVisitor.LastAssignment is CommandAst cmdAst) + { + InferTypesFrom(cmdAst, inferredTypes, forRedirection: true); + } + else + { + inferredTypes.AddRange(InferTypes(assignmentVisitor.LastAssignment)); + } } } - - if (closestAssignment != null) + else if (assignmentVisitor.LastAssignmentType is not null) { - inferredTypes.AddRange(InferTypes(closestAssignment.Right)); + inferredTypes.Add(assignmentVisitor.LastAssignmentType); } if (_context.TryGetRepresentativeTypeNameFromExpressionSafeEval(variableExpressionAst, out var evalTypeName)) @@ -2858,95 +2858,447 @@ private static CommandBaseAst GetPreviousPipelineCommand(CommandAst commandAst) var i = pipe.PipelineElements.IndexOf(commandAst); return i != 0 ? pipe.PipelineElements[i - 1] : null; } - } - internal static class TypeInferenceExtension - { - public static bool EqualsOrdinalIgnoreCase(this string s, string t) + private sealed class VariableAssignmentVisitor : AstVisitor2 { - return string.Equals(s, t, StringComparison.OrdinalIgnoreCase); - } + /// + /// If set, we only look for local/private assignments in the scope of the variable we are inferring. + /// + internal bool LocalScopeOnly; + + /// + /// The current scope is local to the variable that is being inferred. + /// + internal bool ScopeIsLocal; + + /// + /// The variable that we are trying to determine the type of. + /// + internal VariableExpressionAst VariableTarget; - public static IEnumerable GetGetterProperty(this Type type, string propertyName) - { - var res = new List(); - foreach (var m in type.GetMethods(BindingFlags.Public | BindingFlags.Instance)) + /// + /// The last type constraint applied to the variable. This takes priority when determining the type of the variable. + /// + internal ITypeName LastConstraint; + + /// + /// The last ast that assigned a value to the variable. This determines the value of the variable unless a type constraint has been applied. + /// + internal Ast LastAssignment; + + /// + /// The inferred type from the most recent assignment. This is only used for stream redirections to variables, or the special OutVariable common parameters. + /// + internal PSTypeName LastAssignmentType; + + /// + /// Whether or not the types from the last assignment should be enumerated. + /// For assignments made by the PipelineVariable parameter or the foreach statement. + /// + internal bool EnumerateAssignment; + + /// + /// Whether or not the last assignment was via command redirection. + /// + internal bool RedirectionAssignment; + internal int StopSearchOffset; + private int LastAssignmentOffset = -1; + + private void SetLastAssignment(Ast ast, bool enumerate = false, bool redirectionAssignment = false) { - var name = m.Name; - // Equals without string allocation - if (name.Length == propertyName.Length + 4 - && name.StartsWith("get_") - && name.IndexOf(propertyName, 4, StringComparison.Ordinal) == 4) + if (LastAssignmentOffset < ast.Extent.StartOffset) { - res.Add(m); + ClearAssignmentData(); + LastAssignment = ast; + EnumerateAssignment = enumerate; + RedirectionAssignment = redirectionAssignment; + LastAssignmentOffset = ast.Extent.StartOffset; } } - return res; - } + private void SetLastAssignmentType(PSTypeName typeName, int assignmentOffset) + { + if (LastAssignmentOffset < assignmentOffset) + { + ClearAssignmentData(); + LastAssignmentType = typeName; + LastAssignmentOffset = assignmentOffset; + } + } - public static bool AstAssignsToSameVariable(this VariableExpressionAst variableAst, Ast ast) - { - var parameterAst = ast as ParameterAst; - var variableAstVariablePath = variableAst.VariablePath; - if (parameterAst != null) + private void ClearAssignmentData() + { + LastAssignment = null; + LastAssignmentType = null; + EnumerateAssignment = false; + RedirectionAssignment = false; + } + + private bool AssignsToTargetVar(VariableExpressionAst foundVar) + { + if (!foundVar.VariablePath.UnqualifiedPath.EqualsOrdinalIgnoreCase(VariableTarget.VariablePath.UnqualifiedPath)) + { + return false; + } + + int scopeIndex = foundVar.VariablePath.UserPath.IndexOf(':'); + string scopeName = scopeIndex == -1 ? string.Empty : foundVar.VariablePath.UserPath.Remove(scopeIndex); + return AssignsToTargetScope(scopeName); + } + + private bool AssignsToTargetVar(string userPath) + { + if (string.IsNullOrEmpty(userPath)) + { + return false; + } + + string scopeName; + string varName; + int scopeIndex = userPath.IndexOf(':'); + if (scopeIndex == -1) + { + scopeName = string.Empty; + varName = userPath; + } + else + { + scopeName = userPath.Remove(scopeIndex); + varName = userPath.Substring(scopeIndex + 1); + } + + if (!varName.EqualsOrdinalIgnoreCase(VariableTarget.VariablePath.UnqualifiedPath)) + { + return false; + } + + return AssignsToTargetScope(scopeName); + } + + private bool AssignsToTargetScope(string scopeName) + => LocalScopeOnly + ? string.IsNullOrEmpty(scopeName) || scopeName.EqualsOrdinalIgnoreCase("Local") || scopeName.EqualsOrdinalIgnoreCase("Private") + : ScopeIsLocal || !(scopeName.EqualsOrdinalIgnoreCase("Local") || scopeName.EqualsOrdinalIgnoreCase("Private")); + + public override AstVisitAction DefaultVisit(Ast ast) { - return variableAstVariablePath.IsUnscopedVariable && - parameterAst.Name.VariablePath.UnqualifiedPath.Equals(variableAstVariablePath.UnqualifiedPath, StringComparison.OrdinalIgnoreCase) && - parameterAst.Parent.Parent.Extent.EndOffset > variableAst.Extent.StartOffset; + if (ast.Extent.StartOffset >= StopSearchOffset) + { + // When visiting do while/until statements, the condition will be visited before the statement block + // The condition itself may not be interesting if it's after the cursor, but the statement block could be + // Example: + // do + // { + // $Var = gci + // $Var. + // } + // until($false) + return ast is PipelineBaseAst && ast.Parent is DoUntilStatementAst or DoWhileStatementAst + ? AstVisitAction.SkipChildren + : AstVisitAction.StopVisit; + } + + return AstVisitAction.Continue; } - if (ast is ForEachStatementAst foreachAst) + public override AstVisitAction VisitAssignmentStatement(AssignmentStatementAst assignmentStatementAst) { - return variableAstVariablePath.IsUnscopedVariable && - foreachAst.Variable.VariablePath.UnqualifiedPath.Equals(variableAstVariablePath.UnqualifiedPath, StringComparison.OrdinalIgnoreCase); + if (assignmentStatementAst.Extent.StartOffset >= StopSearchOffset) + { + return assignmentStatementAst.Parent is DoUntilStatementAst or DoWhileStatementAst + ? AstVisitAction.SkipChildren + : AstVisitAction.StopVisit; + } + + if (assignmentStatementAst.Left is AttributedExpressionAst attributedExpression) + { + var firstConvertExpression = attributedExpression as ConvertExpressionAst; + ExpressionAst child = attributedExpression.Child; + while (child is AttributedExpressionAst attributeChild) + { + if (firstConvertExpression is null && attributeChild is ConvertExpressionAst convertExpression) + { + // Multiple type constraint can be set on a variable like this: [int] [string] $Var1 = 1 + // But it's the left most type constraint that determines the final type. + firstConvertExpression = convertExpression; + } + + child = attributeChild.Child; + } + + if (child is VariableExpressionAst variableExpression && AssignsToTargetVar(variableExpression)) + { + if (firstConvertExpression is not null) + { + LastConstraint = firstConvertExpression.Type.TypeName; + } + else + { + SetLastAssignment(assignmentStatementAst.Right); + } + } + } + else if (assignmentStatementAst.Left is VariableExpressionAst variableExpression && AssignsToTargetVar(variableExpression)) + { + SetLastAssignment(assignmentStatementAst.Right); + } + + return AstVisitAction.Continue; } - if (ast is CommandAst commandAst) + public override AstVisitAction VisitCommand(CommandAst commandAst) { - string[] variableParameters = { "PV", "PipelineVariable", "OV", "OutVariable" }; - StaticBindingResult bindingResult = StaticParameterBinder.BindCommand(commandAst, false, variableParameters); + if (commandAst.Extent.StartOffset >= StopSearchOffset) + { + return AstVisitAction.StopVisit; + } - if (bindingResult != null) + string commandName = commandAst.GetCommandName(); + if (commandName is not null && CompletionCompleters.s_varModificationCommands.Contains(commandName)) { - foreach (string commandVariableParameter in variableParameters) + StaticBindingResult bindingResult = StaticParameterBinder.BindCommand(commandAst, resolve: false, CompletionCompleters.s_varModificationParameters); + if (bindingResult is not null + && bindingResult.BoundParameters.TryGetValue("Name", out ParameterBindingResult variableName) + && variableName.ConstantValue is string nameValue + && AssignsToTargetVar(nameValue) + && bindingResult.BoundParameters.TryGetValue("Value", out ParameterBindingResult variableValue)) { - if (bindingResult.BoundParameters.TryGetValue(commandVariableParameter, out ParameterBindingResult parameterBindingResult)) + SetLastAssignment(variableValue.Value); + return AstVisitAction.Continue; + } + } + + StaticBindingResult bindResult = StaticParameterBinder.BindCommand(commandAst, resolve: false); + if (bindResult is not null) + { + foreach (string parameterName in CompletionCompleters.s_outVarParameters) + { + if (bindResult.BoundParameters.TryGetValue(parameterName, out ParameterBindingResult outVarBind) + && outVarBind.ConstantValue is string varName + && AssignsToTargetVar(varName)) { - if (string.Equals(variableAstVariablePath.UnqualifiedPath, (string)parameterBindingResult.ConstantValue, StringComparison.OrdinalIgnoreCase)) + // The *Variable parameters actually always results in an ArrayList + // But to make type inference of individual elements better, we say it's a generic list. + switch (parameterName) { - return true; + case "ErrorVariable": + case "ev": + SetLastAssignmentType(new PSTypeName(typeof(List)), commandAst.Extent.StartOffset); + break; + + case "WarningVariable": + case "wv": + SetLastAssignmentType(new PSTypeName(typeof(List)), commandAst.Extent.StartOffset); + break; + + case "InformationVariable": + case "iv": + SetLastAssignmentType(new PSTypeName(typeof(List)), commandAst.Extent.StartOffset); + break; + + case "OutVariable": + case "ov": + SetLastAssignment(commandAst); + break; + + default: + break; + } + + return AstVisitAction.Continue; + } + } + + if (commandAst.Parent is PipelineAst pipeline && pipeline.Extent.EndOffset > VariableTarget.Extent.StartOffset) + { + foreach (string parameterName in CompletionCompleters.s_pipelineVariableParameters) + { + if (bindResult.BoundParameters.TryGetValue(parameterName, out ParameterBindingResult pipeVarBind) + && pipeVarBind.ConstantValue is string varName + && AssignsToTargetVar(varName)) + { + SetLastAssignment(commandAst, enumerate: true); + return AstVisitAction.Continue; } } } } - return false; + foreach (RedirectionAst redirection in commandAst.Redirections) + { + if (redirection is FileRedirectionAst fileRedirection + && fileRedirection.Location is StringConstantExpressionAst redirectTarget + && redirectTarget.Value.StartsWith("variable:", StringComparison.OrdinalIgnoreCase) + && redirectTarget.Value.Length > "variable:".Length) + { + string varName = redirectTarget.Value.Substring("variable:".Length); + if (!AssignsToTargetVar(varName)) + { + continue; + } + + switch (fileRedirection.FromStream) + { + case RedirectionStream.Error: + SetLastAssignmentType(new PSTypeName(typeof(ErrorRecord)), commandAst.Extent.StartOffset); + break; + + case RedirectionStream.Warning: + SetLastAssignmentType(new PSTypeName(typeof(WarningRecord)), commandAst.Extent.StartOffset); + break; + + case RedirectionStream.Verbose: + SetLastAssignmentType(new PSTypeName(typeof(VerboseRecord)), commandAst.Extent.StartOffset); + break; + + case RedirectionStream.Debug: + SetLastAssignmentType(new PSTypeName(typeof(DebugRecord)), commandAst.Extent.StartOffset); + break; + + case RedirectionStream.Information: + SetLastAssignmentType(new PSTypeName(typeof(InformationRecord)), commandAst.Extent.StartOffset); + break; + + default: + SetLastAssignment(commandAst, redirectionAssignment: true); + break; + } + } + } + + return AstVisitAction.Continue; } - var assignmentAst = (AssignmentStatementAst)ast; - var lhs = assignmentAst.Left; - if (lhs is ConvertExpressionAst convertExpr) + public override AstVisitAction VisitParameter(ParameterAst parameterAst) { - lhs = convertExpr.Child; + if (parameterAst.Extent.StartOffset >= StopSearchOffset) + { + return AstVisitAction.StopVisit; + } + + if (AssignsToTargetVar(parameterAst.Name)) + { + foreach (AttributeBaseAst attribute in parameterAst.Attributes) + { + if (attribute is TypeConstraintAst typeConstraint) + { + LastConstraint = typeConstraint.TypeName; + return AstVisitAction.Continue; + } + } + } + + return AstVisitAction.Continue; } - if (lhs is not VariableExpressionAst varExpr) + public override AstVisitAction VisitForEachStatement(ForEachStatementAst forEachStatementAst) { - return false; + if (forEachStatementAst.Extent.StartOffset >= StopSearchOffset) + { + return AstVisitAction.StopVisit; + } + + if (AssignsToTargetVar(forEachStatementAst.Variable) && forEachStatementAst.Condition.Extent.EndOffset < VariableTarget.Extent.StartOffset) + { + SetLastAssignment(forEachStatementAst.Condition, enumerate: true); + } + + return AstVisitAction.Continue; } - var candidateVarPath = varExpr.VariablePath; - if (candidateVarPath.UserPath.Equals(variableAstVariablePath.UserPath, StringComparison.OrdinalIgnoreCase)) + public override AstVisitAction VisitConvertExpression(ConvertExpressionAst convertExpressionAst) { - return true; + if (convertExpressionAst.IsRef() + && convertExpressionAst.Child is VariableExpressionAst varAst + && AssignsToTargetVar(varAst)) + { + SetLastAssignment(convertExpressionAst); + } + + return AstVisitAction.Continue; } - // The following condition is making an assumption that at script scope, we didn't use $script:, but in the local scope, we did - // If we are searching anything other than script scope, this is wrong. - if (variableAstVariablePath.IsScript && variableAstVariablePath.UnqualifiedPath.Equals(candidateVarPath.UnqualifiedPath, StringComparison.OrdinalIgnoreCase)) + public override AstVisitAction VisitAttribute(AttributeAst attributeAst) { - return true; + // Attributes can't assign values to variables so they aren't interesting. + return AstVisitAction.SkipChildren; + } + + public override AstVisitAction VisitScriptBlockExpression(ScriptBlockExpressionAst scriptBlockExpressionAst) + { + return scriptBlockExpressionAst.IsDotsourced() + ? AstVisitAction.Continue + : AstVisitAction.SkipChildren; + } + + public override AstVisitAction VisitDataStatement(DataStatementAst dataStatementAst) + { + if (dataStatementAst.Extent.StartOffset >= StopSearchOffset) + { + return AstVisitAction.StopVisit; + } + + if (AssignsToTargetVar(dataStatementAst.Variable) && dataStatementAst.Extent.EndOffset < VariableTarget.Extent.StartOffset) + { + SetLastAssignment(dataStatementAst.Body); + } + + return AstVisitAction.SkipChildren; + } + + public override AstVisitAction VisitFunctionDefinition(FunctionDefinitionAst functionDefinitionAst) + { + return AstVisitAction.SkipChildren; + } + } + } + + internal static class TypeInferenceExtension + { + public static bool EqualsOrdinalIgnoreCase(this string s, string t) + { + return string.Equals(s, t, StringComparison.OrdinalIgnoreCase); + } + + public static IEnumerable GetGetterProperty(this Type type, string propertyName) + { + var res = new List(); + foreach (var m in type.GetMethods(BindingFlags.Public | BindingFlags.Instance)) + { + var name = m.Name; + // Equals without string allocation + if (name.Length == propertyName.Length + 4 + && name.StartsWith("get_") + && name.IndexOf(propertyName, 4, StringComparison.Ordinal) == 4) + { + res.Add(m); + } + } + + return res; + } + + public static bool IsDotsourced(this ScriptBlockExpressionAst scriptBlockExpressionAst) + { + Ast parent = scriptBlockExpressionAst.Parent; + + // This loop checks if the scriptblock is used as a dot sourced command + // or an argument for a command that uses the local scope eg: ForEach-Object -Process {$Var1 = "Hello"}, {Var2 = $true} + while (parent is not null) + { + if (parent is CommandAst cmdAst) + { + string cmdName = cmdAst.GetCommandName(); + return CompletionCompleters.s_localScopeCommandNames.Contains(cmdName) + || (cmdAst.CommandElements[0] is ScriptBlockExpressionAst && cmdAst.InvocationOperator == TokenKind.Dot); + } + + if (parent is not CommandExpressionAst and not PipelineAst and not StatementBlockAst and not ArrayExpressionAst and not ArrayLiteralAst) + { + break; + } + + parent = parent.Parent; } return false; diff --git a/test/powershell/engine/Api/TypeInference.Tests.ps1 b/test/powershell/engine/Api/TypeInference.Tests.ps1 index 0d0188dc23f..2ab5681dddb 100644 --- a/test/powershell/engine/Api/TypeInference.Tests.ps1 +++ b/test/powershell/engine/Api/TypeInference.Tests.ps1 @@ -1481,6 +1481,30 @@ Describe "Type inference Tests" -tags "CI" { $res.Name | Should -Be 'System.Management.Automation.Internal.Host.InternalHost' } + It 'Infers type of variable assigned inside do while loop' { + $res = [AstTypeInference]::InferTypeOf(({ + do + { + $Test = 1 + $Test + } + while (1) + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.Int32' + } + + It 'Infers type of variable assigned inside do until loop' { + $res = [AstTypeInference]::InferTypeOf(({ + do + { + $Test = 1 + $Test + } + until ($null = gci) + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst] -and $Ast.VariablePath.UserPath -eq 'Test'}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.Int32' + } + It 'Infers type of external applications' { $res = [AstTypeInference]::InferTypeOf( { pwsh }.Ast) $res.Name | Should -Be 'System.String' @@ -1494,6 +1518,178 @@ Describe "Type inference Tests" -tags "CI" { $null = [AstTypeInference]::InferTypeOf($FoundAst) } + It 'Ignores type constraint defined outside of scope' { + $res = [AstTypeInference]::InferTypeOf(({ + function Outer + { + [string] $Test = "Hello" + function Inner + { + $Test = 2 + $Test + } + } + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.Int32' + } + + It 'Considers the type constraint defined outside of scope when dot sourcing' { + $res = [AstTypeInference]::InferTypeOf(({ + [string] $Test = "Hello" + . {$Test = 2; $Test} + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.String' + } + + It 'Infers type of ref assigned variable' { + $res = [AstTypeInference]::InferTypeOf(({ + $MyRefVar = $null + $null = [System.Management.Automation.Language.Parser]::ParseInput("", [ref] $MyRefVar, [ref] $null) + $MyRefVar + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.Management.Automation.Language.Token[]' + } + + It 'Infers type of variable assigned with New/Set-Variable' { + $res = [AstTypeInference]::InferTypeOf( { + New-Variable -Name Var1 -Value $true | Out-Null + New-Variable -Name Var2 -Value "Hello" | Out-Null + $Var1 + $Var2 + }.Ast) + $res[0].Name | Should -Be 'System.Boolean' + $res[1].Name | Should -Be 'System.String' + } + + It 'Infers type of variable assigned with common parameter' -TestCases @( + @{ParameterName = "WarningVariable"; ExpectedType = [List[WarningRecord]]} + @{ParameterName = "wv"; ExpectedType = [List[WarningRecord]]} + @{ParameterName = "ErrorVariable"; ExpectedType = [List[ErrorRecord]]} + @{ParameterName = "ev"; ExpectedType = [List[ErrorRecord]]} + @{ParameterName = "InformationVariable"; ExpectedType = [List[InformationalRecord]]} + @{ParameterName = "iv"; ExpectedType = [List[InformationalRecord]]} + @{ParameterName = "OutVariable"; ExpectedType = [guid]} + @{ParameterName = "ov"; ExpectedType = [guid]} + @{ParameterName = "PipelineVariable"; ExpectedType = [guid]} + @{ParameterName = "pv"; ExpectedType = [guid]} + ) -Test { + param($ParameterName, $ExpectedType) + $Ast = [scriptblock]::Create("New-Guid -$ParameterName MyOutVar | % {`$MyOutVar}").Ast.FindAll({ + param($Ast) + $Ast -is [Language.VariableExpressionAst] + }, $true) | Select-Object -Last 1 + $res = [AstTypeInference]::InferTypeOf($Ast) + $res.Type | Should -Be $ExpectedType + } + + It 'Infers type of variable assigned via Data statement' { + $res = [AstTypeInference]::InferTypeOf(({ + Data MyDataVar {"Hello"} + $MyDataVar + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.String' + } + + It 'Infers type of well known variable with global scope' { + $res = [AstTypeInference]::InferTypeOf({$global:true}.Ast) + $res.Name | Should -Be 'System.Boolean' + } + + It 'Infers parameter type from closest parameter' { + $res = [AstTypeInference]::InferTypeOf( ({ + param([string]$Param1) + function TestFunction {param([bool]$Param1) $Param1} + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.Boolean' + } + + It 'Infers variable type from closest foreach statement' { + $res = [AstTypeInference]::InferTypeOf( ({ + foreach ($X in 1..10) + { + $X + } + foreach ($X in New-Guid) + { + $X + } + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.Guid' + } + + It 'Infers global variable type in child scope' { + $res = [AstTypeInference]::InferTypeOf( ({ + $Global:GlobalTest1 = "Hello" + function TestFunction {$GlobalTest1} + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.String' + } + + It 'Does not infer private variable type in child scope' { + $res = [AstTypeInference]::InferTypeOf( ({ + $Private:PrivateTest1 = "Hello" + function TestFunction {$PrivateTest1} + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Count | Should -Be 0 + } + + It 'Infers variable assigned with an attribute' { + $res = [AstTypeInference]::InferTypeOf( ({ + [ValidateNotNull()]$ValidatedVar1 = New-Guid + $ValidatedVar1 + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.Guid' + } + + It 'Infers variable assigned with multiple type constraints' { + $res = [AstTypeInference]::InferTypeOf( ({ + [int] [string]$MultiConstraintVar1 = "10" + $MultiConstraintVar1 + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $res.Name | Should -Be 'System.Int32' + } + + It 'Infers variable assigned by redirection' { + $res = [AstTypeInference]::InferTypeOf( ({ + New-Guid *>&1 1>variable:RedirVar1; $RedirVar1 + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) | Select-Object -Last 1 )) + $ExpectedTypeNames = @( + [ErrorRecord].FullName + [WarningRecord].FullName + [VerboseRecord].FullName + [DebugRecord].FullName + [InformationRecord].FullName + [guid].FullName + ) -join ';' + $res.Name -join ';' | Should -Be $ExpectedTypeNames + } + + It 'Infers variables assigned by redirection from specific streams' { + $VarAsts = [List[Language.Ast]]{ + [void](New-Guid 1>variable:RedirSuccess 2>variable:RedirError 3>variable:RedirWarning 4>variable:RedirVerbose 5>variable:RedirDebug 6>variable:RedirInfo) + $RedirSuccess + $RedirError + $RedirWarning + $RedirVerbose + $RedirDebug + $RedirInfo + }.Ast.FindAll({param($Ast) $Ast -is [Language.VariableExpressionAst]}, $true) + $ExpectedTypeNames = @( + [guid].FullName + [ErrorRecord].FullName + [WarningRecord].FullName + [VerboseRecord].FullName + [DebugRecord].FullName + [InformationRecord].FullName + ) + + for ($i = 0; $i -lt $VarAsts.Count; $i++) + { + $res = [AstTypeInference]::InferTypeOf($VarAsts[$i]) + $res.Name | Should -Be $ExpectedTypeNames[$i] + } + } + It 'Should infer output from anonymous function' { $res = [AstTypeInference]::InferTypeOf( { & {"Hello"} }.Ast) $res.Name | Should -Be 'System.String'