Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -514,12 +514,23 @@ object ICustomAstVisitor.VisitInvokeMemberExpression(InvokeMemberExpressionAst i

object ICustomAstVisitor.VisitArrayExpression(ArrayExpressionAst arrayExpressionAst)
{
return new[] { new PSTypeName(typeof(object[])) };
if (arrayExpressionAst.SubExpression.Statements.Count == 0)
{
return new[] { new PSTypeName(typeof(object[])) };
}

return new[] { GetArrayType(InferTypes(arrayExpressionAst.SubExpression)) };
}

object ICustomAstVisitor.VisitArrayLiteral(ArrayLiteralAst arrayLiteralAst)
{
return new[] { new PSTypeName(typeof(object[])) };
var inferredElementTypes = new List<PSTypeName>();
foreach (ExpressionAst expression in arrayLiteralAst.Elements)
{
inferredElementTypes.AddRange(InferTypes(expression));
}

return new[] { GetArrayType(inferredElementTypes) };
}

object ICustomAstVisitor.VisitHashtable(HashtableAst hashtableAst)
Expand Down Expand Up @@ -1930,9 +1941,22 @@ private void InferTypeFrom(VariableExpressionAst variableExpressionAst, List<PST
int startOffset = variableExpressionAst.Extent.StartOffset;
var targetAsts = (List<Ast>)AstSearcher.FindAll(
parent,
ast => (ast is ParameterAst || ast is AssignmentStatementAst || ast is ForEachStatementAst || ast is CommandAst)
&& variableExpressionAst.AstAssignsToSameVariable(ast)
&& ast.Extent.EndOffset < startOffset,
ast =>
{
if (ast is ParameterAst || ast is AssignmentStatementAst || ast is CommandAst)
{
return variableExpressionAst.AstAssignsToSameVariable(ast)
&& ast.Extent.EndOffset < startOffset;
}

if (ast is ForEachStatementAst)
{
return variableExpressionAst.AstAssignsToSameVariable(ast)
&& ast.Extent.StartOffset < startOffset;
}

return false;
},
searchNestedScriptBlocks: true);

foreach (var ast in targetAsts)
Expand Down Expand Up @@ -1965,7 +1989,8 @@ private void InferTypeFrom(VariableExpressionAst variableExpressionAst, List<PST
var foreachAst = targetAsts.OfType<ForEachStatementAst>().FirstOrDefault();
if (foreachAst != null)
{
inferredTypes.AddRange(InferTypes(foreachAst.Condition));
inferredTypes.AddRange(
GetInferredEnumeratedTypes(InferTypes(foreachAst.Condition)));
return;
}

Expand Down Expand Up @@ -1999,6 +2024,177 @@ private void InferTypeFrom(VariableExpressionAst variableExpressionAst, List<PST
}
}

/// <summary>
/// Gets the most specific array type possible from a group of inferred types.
/// </summary>
/// <param name="inferredTypes">The inferred types all the items in the array.</param>
/// <returns>The inferred strongly typed array type.</returns>
private PSTypeName GetArrayType(IEnumerable<PSTypeName> inferredTypes)
Comment thread
SeeminglyScience marked this conversation as resolved.
{
PSTypeName foundType = null;
foreach (PSTypeName inferredType in inferredTypes)
{
if (inferredType.Type == null)
{
return new PSTypeName(typeof(object[]));
}

// IEnumerable<>.GetEnumerator and IDictionary.GetEnumerator will always be
// inferred as multiple types due to explicit implementations, so if we find
// one then assume the rest are also enumerators.
if (typeof(IEnumerator).IsAssignableFrom(inferredType.Type))
{
foundType = inferredType;
break;
}

if (foundType == null)
{
foundType = inferredType;
continue;
}

// If there are mixed types then fall back to object[].
if (foundType.Type != inferredType.Type)
{
return new PSTypeName(typeof(object[]));
}
}

if (foundType == null)
{
return new PSTypeName(typeof(object[]));
}

if (foundType.Type.IsArray)
{
return foundType;
}

Type enumeratedItemType = GetMostSpecificEnumeratedItemType(foundType.Type);
if (enumeratedItemType != null)
{
return new PSTypeName(enumeratedItemType.MakeArrayType());
}

return new PSTypeName(foundType.Type.MakeArrayType());
}

/// <summary>
/// Gets the most specific type item type from a type that is potentially enumerable.
/// </summary>
/// <param name="enumerableType">The type to infer enumerated item type from.</param>
/// <returns>The inferred enumerated item type.</returns>
private Type GetMostSpecificEnumeratedItemType(Type enumerableType)
Comment thread
SeeminglyScience marked this conversation as resolved.
{
if (enumerableType.IsArray)
{
return enumerableType.GetElementType();
}

// These types implement IEnumerable, but we intentionally do not enumerate them.
if (enumerableType == typeof(string) ||
typeof(IDictionary).IsAssignableFrom(enumerableType) ||
typeof(Xml.XmlNode).IsAssignableFrom(enumerableType))
{
return enumerableType;
}

if (enumerableType == typeof(Data.DataTable))
{
return typeof(Data.DataRow);
}

bool hasSeenNonGeneric = false;
bool hasSeenDictionaryEnumerator = false;
Type collectionInterface = GetGenericCollectionLikeInterface(
enumerableType,
ref hasSeenNonGeneric,
ref hasSeenDictionaryEnumerator);

if (collectionInterface != null)
{
return collectionInterface.GetGenericArguments()[0];
}

foreach (Type interfaceType in enumerableType.GetInterfaces())
{
collectionInterface = GetGenericCollectionLikeInterface(
interfaceType,
ref hasSeenNonGeneric,
ref hasSeenDictionaryEnumerator);

if (collectionInterface != null)
{
return collectionInterface.GetGenericArguments()[0];
}
}

if (hasSeenDictionaryEnumerator)
{
return typeof(DictionaryEntry);
}

if (hasSeenNonGeneric)
{
return typeof(object);
}

return null;
}

/// <summary>
/// Determines if the interface can be used to infer a specific enumerated type.
/// </summary>
/// <param name="interfaceType">The interface to test.</param>
/// <param name="hasSeenNonGeneric">
/// A reference to a value indicating whether a non-generic enumerable type has been
/// seen. If <see paramref="interfaceType" /> is a non-generic enumerable type this
/// value will be set to <see langword="true" />.
/// </param>
/// <param name="hasSeenDictionaryEnumerator">
/// A reference to a value indicating whether <see cref="IDictionaryEnumerator" /> has been
/// seen. If <paramref name="interfaceType" /> is a <see cref="IDictionaryEnumerator" /> this
/// value will be set to <see langword="true" />.
/// </param>
/// <returns>
/// The value of <paramref name="interfaceType" /> if it can be used to infer a specific
/// enumerated type, otherwise <see langword="null" />.
/// </returns>
private Type GetGenericCollectionLikeInterface(
Comment thread
SeeminglyScience marked this conversation as resolved.
Type interfaceType,
ref bool hasSeenNonGeneric,
ref bool hasSeenDictionaryEnumerator)
{
if (!interfaceType.IsInterface)
{
return null;
}

if (interfaceType.IsConstructedGenericType)
{
Type openGeneric = interfaceType.GetGenericTypeDefinition();
if (openGeneric == typeof(IEnumerator<>) ||
openGeneric == typeof(IEnumerable<>))
{
return interfaceType;
}
}

if (interfaceType == typeof(IDictionaryEnumerator))
{
hasSeenDictionaryEnumerator = true;
}

if (interfaceType == typeof(IEnumerator) ||
interfaceType == typeof(IEnumerable))
{
hasSeenNonGeneric = true;
}

return null;
}

private IEnumerable<PSTypeName> InferTypeFrom(IndexExpressionAst indexExpressionAst)
{
var targetTypes = InferTypes(indexExpressionAst.Target);
Expand Down Expand Up @@ -2061,6 +2257,32 @@ private IEnumerable<PSTypeName> InferTypeFrom(IndexExpressionAst indexExpression
}
}

/// <summary>
/// Infers the types as if they were enumerated. For example, a <see cref="List{T}" />
/// of type <see cref="string" /> would be returned as <see cref="string" />.
/// </summary>
/// <param name="enumerableTypes">
/// The potentially enumerable types to infer enumerated type from.
/// </param>
/// <returns>The enumerated item types.</returns>
private IEnumerable<PSTypeName> GetInferredEnumeratedTypes(IEnumerable<PSTypeName> enumerableTypes)
Comment thread
SeeminglyScience marked this conversation as resolved.
{
foreach (PSTypeName maybeEnumerableType in enumerableTypes)
{
Type type = maybeEnumerableType.Type;
if (type == null)
{
yield return maybeEnumerableType;
continue;
}

Type enumeratedItemType = GetMostSpecificEnumeratedItemType(type);
yield return enumeratedItemType == null
? maybeEnumerableType
: new PSTypeName(enumeratedItemType);
}
}

private void GetInferredTypeFromScriptBlockParameter(AstParameterArgumentPair argument, List<PSTypeName> inferredTypes)
{
var argumentPair = argument as AstPair;
Expand Down
Loading