Skip to content

Commit d608941

Browse files
committed
Implement type inference in conditional types
1 parent 8e337b5 commit d608941

10 files changed

Lines changed: 158 additions & 13 deletions

File tree

src/compiler/binder.ts

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ namespace ts {
101101
HasLocals = 1 << 5,
102102
IsInterface = 1 << 6,
103103
IsObjectLiteralOrClassExpressionMethod = 1 << 7,
104+
IsInferenceContainer = 1 << 8,
104105
}
105106

106107
const binder = createBinder();
@@ -119,6 +120,7 @@ namespace ts {
119120
let parent: Node;
120121
let container: Node;
121122
let blockScopeContainer: Node;
123+
let inferenceContainer: Node;
122124
let lastContainer: Node;
123125
let seenThisKeyword: boolean;
124126

@@ -186,6 +188,7 @@ namespace ts {
186188
parent = undefined;
187189
container = undefined;
188190
blockScopeContainer = undefined;
191+
inferenceContainer = undefined;
189192
lastContainer = undefined;
190193
seenThisKeyword = false;
191194
currentFlow = undefined;
@@ -561,6 +564,13 @@ namespace ts {
561564
bindChildren(node);
562565
node.flags = seenThisKeyword ? node.flags | NodeFlags.ContainsThis : node.flags & ~NodeFlags.ContainsThis;
563566
}
567+
else if (containerFlags & ContainerFlags.IsInferenceContainer) {
568+
const saveInferenceContainer = inferenceContainer;
569+
inferenceContainer = node;
570+
node.locals = undefined;
571+
bindChildren(node);
572+
inferenceContainer = saveInferenceContainer;
573+
}
564574
else {
565575
bindChildren(node);
566576
}
@@ -1417,6 +1427,9 @@ namespace ts {
14171427
case SyntaxKind.MappedType:
14181428
return ContainerFlags.IsContainer | ContainerFlags.HasLocals;
14191429

1430+
case SyntaxKind.ConditionalType:
1431+
return ContainerFlags.IsInferenceContainer;
1432+
14201433
case SyntaxKind.SourceFile:
14211434
return ContainerFlags.IsContainer | ContainerFlags.IsControlFlowContainer | ContainerFlags.HasLocals;
14221435

@@ -2059,7 +2072,7 @@ namespace ts {
20592072
case SyntaxKind.TypePredicate:
20602073
return checkTypePredicate(node as TypePredicateNode);
20612074
case SyntaxKind.TypeParameter:
2062-
return declareSymbolAndAddToSymbolTable(<Declaration>node, SymbolFlags.TypeParameter, SymbolFlags.TypeParameterExcludes);
2075+
return bindTypeParameter(node as TypeParameterDeclaration);
20632076
case SyntaxKind.Parameter:
20642077
return bindParameter(<ParameterDeclaration>node);
20652078
case SyntaxKind.VariableDeclaration:
@@ -2576,6 +2589,23 @@ namespace ts {
25762589
: declareSymbolAndAddToSymbolTable(node, symbolFlags, symbolExcludes);
25772590
}
25782591

2592+
function bindTypeParameter(node: TypeParameterDeclaration) {
2593+
if (node.parent.kind === SyntaxKind.InferType) {
2594+
if (inferenceContainer) {
2595+
if (!inferenceContainer.locals) {
2596+
inferenceContainer.locals = createSymbolTable();
2597+
}
2598+
declareSymbol(inferenceContainer.locals, /*parent*/ undefined, node, SymbolFlags.TypeParameter, SymbolFlags.TypeParameterExcludes);
2599+
}
2600+
else {
2601+
bindAnonymousDeclaration(node, SymbolFlags.TypeParameter, getDeclarationName(node));
2602+
}
2603+
}
2604+
else {
2605+
declareSymbolAndAddToSymbolTable(node, SymbolFlags.TypeParameter, SymbolFlags.TypeParameterExcludes);
2606+
}
2607+
}
2608+
25792609
// reachability checks
25802610

25812611
function shouldReportErrorOnModuleDeclaration(node: ModuleDeclaration): boolean {
@@ -3441,6 +3471,7 @@ namespace ts {
34413471
case SyntaxKind.UnionType:
34423472
case SyntaxKind.IntersectionType:
34433473
case SyntaxKind.ConditionalType:
3474+
case SyntaxKind.InferType:
34443475
case SyntaxKind.ParenthesizedType:
34453476
case SyntaxKind.InterfaceDeclaration:
34463477
case SyntaxKind.TypeAliasDeclaration:

src/compiler/checker.ts

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,11 @@ namespace ts {
11691169
);
11701170
}
11711171
}
1172+
else if (location.kind === SyntaxKind.ConditionalType) {
1173+
// A type parameter declared using 'infer T' in a conditional type is visible only in
1174+
// the true branch of the conditional type.
1175+
useResult = lastLocation === (<ConditionalTypeNode>location).trueType;
1176+
}
11721177

11731178
if (useResult) {
11741179
break loop;
@@ -4628,10 +4633,14 @@ namespace ts {
46284633
case SyntaxKind.TypeAliasDeclaration:
46294634
case SyntaxKind.JSDocTemplateTag:
46304635
case SyntaxKind.MappedType:
4636+
case SyntaxKind.ConditionalType:
46314637
const outerTypeParameters = getOuterTypeParameters(node, includeThisTypes);
46324638
if (node.kind === SyntaxKind.MappedType) {
46334639
return append(outerTypeParameters, getDeclaredTypeOfTypeParameter(getSymbolOfNode((<MappedTypeNode>node).typeParameter)));
46344640
}
4641+
else if (node.kind === SyntaxKind.ConditionalType) {
4642+
return concatenate(outerTypeParameters, getInferTypeParameters(<ConditionalTypeNode>node));
4643+
}
46354644
const outerAndOwnTypeParameters = appendTypeParameters(outerTypeParameters, getEffectiveTypeParameterDeclarations(<DeclarationWithTypeParameters>node) || emptyArray);
46364645
const thisType = includeThisTypes &&
46374646
(node.kind === SyntaxKind.ClassDeclaration || node.kind === SyntaxKind.ClassExpression || node.kind === SyntaxKind.InterfaceDeclaration) &&
@@ -8078,27 +8087,43 @@ namespace ts {
80788087
return type.flags & TypeFlags.Substitution ? (<SubstitutionType>type).typeParameter : type;
80798088
}
80808089

8081-
function createConditionalType(checkType: Type, extendsType: Type, trueType: Type, falseType: Type, target: ConditionalType, mapper: TypeMapper, aliasSymbol: Symbol, aliasTypeArguments: Type[]) {
8090+
function createConditionalType(checkType: Type, extendsType: Type, trueType: Type, falseType: Type, inferTypeParameters: TypeParameter[], target: ConditionalType, mapper: TypeMapper, aliasSymbol: Symbol, aliasTypeArguments: Type[]) {
80828091
const type = <ConditionalType>createType(TypeFlags.Conditional);
80838092
type.checkType = checkType;
80848093
type.extendsType = extendsType;
80858094
type.trueType = trueType;
80868095
type.falseType = falseType;
8096+
type.inferTypeParameters = inferTypeParameters;
80878097
type.target = target;
80888098
type.mapper = mapper;
80898099
type.aliasSymbol = aliasSymbol;
80908100
type.aliasTypeArguments = aliasTypeArguments;
80918101
return type;
80928102
}
80938103

8094-
function getConditionalType(checkType: Type, extendsType: Type, baseTrueType: Type, baseFalseType: Type, target: ConditionalType, mapper: TypeMapper, aliasSymbol?: Symbol, baseAliasTypeArguments?: Type[]): Type {
8104+
function getConditionalType(checkType: Type, baseExtendsType: Type, baseTrueType: Type, baseFalseType: Type, inferTypeParameters: TypeParameter[], target: ConditionalType, mapper: TypeMapper, aliasSymbol?: Symbol, baseAliasTypeArguments?: Type[]): Type {
8105+
// Instantiate extends type without instantiating any 'infer T' type parameters
8106+
const extendsType = instantiateType(baseExtendsType, mapper);
8107+
let combinedMapper: TypeMapper;
8108+
if (inferTypeParameters) {
8109+
const inferences = map(inferTypeParameters, createInferenceInfo);
8110+
// We don't want inferences from constraints as they may cause us to eagerly resolve the
8111+
// conditional type instead of deferring resolution.
8112+
inferTypes(inferences, checkType, extendsType, InferencePriority.NoConstraints);
8113+
// We infer 'never' when there are no candidates for a type parameter
8114+
const inferredTypes = map(inferences, inference => getTypeFromInference(inference) || neverType);
8115+
const inferenceMapper = createTypeMapper(inferTypeParameters, inferredTypes);
8116+
combinedMapper = mapper ? combineTypeMappers(mapper, inferenceMapper) : inferenceMapper;
8117+
}
80958118
// Return union of trueType and falseType for any and never since they match anything
80968119
if (checkType.flags & (TypeFlags.Any | TypeFlags.Never)) {
8097-
return getUnionType([instantiateType(baseTrueType, mapper), instantiateType(baseFalseType, mapper)]);
8120+
return getUnionType([instantiateType(baseTrueType, combinedMapper || mapper), instantiateType(baseFalseType, mapper)]);
80988121
}
8122+
// Instantiate the extends type including inferences for 'infer T' type parameters
8123+
const inferredExtendsType = combinedMapper ? instantiateType(baseExtendsType, combinedMapper) : extendsType;
80998124
// Return trueType for a definitely true extends check
8100-
if (isTypeAssignableTo(checkType, extendsType)) {
8101-
return instantiateType(baseTrueType, mapper);
8125+
if (isTypeAssignableTo(checkType, inferredExtendsType)) {
8126+
return instantiateType(baseTrueType, combinedMapper || mapper);
81028127
}
81038128
// Return falseType for a definitely false extends check
81048129
if (!isTypeAssignableTo(instantiateType(checkType, anyMapper), instantiateType(extendsType, constraintMapper))) {
@@ -8114,25 +8139,45 @@ namespace ts {
81148139
return cached;
81158140
}
81168141
const result = createConditionalType(erasedCheckType, extendsType, trueType, falseType,
8117-
target, mapper, aliasSymbol, instantiateTypes(baseAliasTypeArguments, mapper));
8142+
inferTypeParameters, target, mapper, aliasSymbol, instantiateTypes(baseAliasTypeArguments, mapper));
81188143
if (id) {
81198144
conditionalTypes.set(id, result);
81208145
}
81218146
return result;
81228147
}
81238148

8149+
function getInferTypeParameters(node: ConditionalTypeNode): TypeParameter[] {
8150+
let result: TypeParameter[];
8151+
if (node.locals) {
8152+
node.locals.forEach(symbol => {
8153+
if (symbol.flags & SymbolFlags.TypeParameter) {
8154+
result = append(result, getDeclaredTypeOfSymbol(symbol));
8155+
}
8156+
});
8157+
}
8158+
return result;
8159+
}
8160+
81248161
function getTypeFromConditionalTypeNode(node: ConditionalTypeNode): Type {
81258162
const links = getNodeLinks(node);
81268163
if (!links.resolvedType) {
81278164
links.resolvedType = getConditionalType(
81288165
getTypeFromTypeNode(node.checkType), getTypeFromTypeNode(node.extendsType),
81298166
getTypeFromTypeNode(node.trueType), getTypeFromTypeNode(node.falseType),
8130-
/*target*/ undefined, /*mapper*/ undefined,
8167+
getInferTypeParameters(node), /*target*/ undefined, /*mapper*/ undefined,
81318168
getAliasSymbolForTypeNode(node), getAliasTypeArgumentsForTypeNode(node));
81328169
}
81338170
return links.resolvedType;
81348171
}
81358172

8173+
function getTypeFromInferTypeNode(node: InferTypeNode): Type {
8174+
const links = getNodeLinks(node);
8175+
if (!links.resolvedType) {
8176+
links.resolvedType = getDeclaredTypeOfTypeParameter(getSymbolOfNode(node.typeParameter));
8177+
}
8178+
return links.resolvedType;
8179+
}
8180+
81368181
function getTypeFromTypeLiteralOrFunctionOrConstructorTypeNode(node: TypeNode): Type {
81378182
const links = getNodeLinks(node);
81388183
if (!links.resolvedType) {
@@ -8423,6 +8468,8 @@ namespace ts {
84238468
return getTypeFromMappedTypeNode(<MappedTypeNode>node);
84248469
case SyntaxKind.ConditionalType:
84258470
return getTypeFromConditionalTypeNode(<ConditionalTypeNode>node);
8471+
case SyntaxKind.InferType:
8472+
return getTypeFromInferTypeNode(<InferTypeNode>node);
84268473
// This function assumes that an identifier or qualified name is a type expression
84278474
// Callers should first ensure this by calling isTypeNode
84288475
case SyntaxKind.Identifier:
@@ -8714,8 +8761,8 @@ namespace ts {
87148761
}
87158762

87168763
function instantiateConditionalType(type: ConditionalType, mapper: TypeMapper): Type {
8717-
return getConditionalType(instantiateType(type.checkType, mapper), instantiateType(type.extendsType, mapper),
8718-
type.trueType, type.falseType, type, mapper, type.aliasSymbol, type.aliasTypeArguments);
8764+
return getConditionalType(instantiateType(type.checkType, mapper), type.extendsType, type.trueType, type.falseType,
8765+
type.inferTypeParameters, type, mapper, type.aliasSymbol, type.aliasTypeArguments);
87198766
}
87208767

87218768
function instantiateType(type: Type, mapper: TypeMapper): Type {
@@ -11206,7 +11253,7 @@ namespace ts {
1120611253
const templateType = getTemplateTypeFromMappedType(target);
1120711254
const inference = createInferenceInfo(typeParameter);
1120811255
inferTypes([inference], sourceType, templateType);
11209-
return inference.candidates ? getUnionType(inference.candidates, UnionReduction.Subtype) : emptyObjectType;
11256+
return getTypeFromInference(inference) || emptyObjectType;
1121011257
}
1121111258

1121211259
function getUnmatchedProperty(source: Type, target: Type, requireOptionalProperties: boolean) {
@@ -11222,6 +11269,12 @@ namespace ts {
1122211269
return undefined;
1122311270
}
1122411271

11272+
function getTypeFromInference(inference: InferenceInfo) {
11273+
return inference.candidates ? getUnionType(inference.candidates, UnionReduction.Subtype) :
11274+
inference.contraCandidates ? getCommonSubtype(inference.contraCandidates) :
11275+
undefined;
11276+
}
11277+
1122511278
function inferTypes(inferences: InferenceInfo[], originalSource: Type, originalTarget: Type, priority: InferencePriority = 0) {
1122611279
let symbolStack: Symbol[];
1122711280
let visited: Map<boolean>;
@@ -11381,7 +11434,9 @@ namespace ts {
1138111434
}
1138211435
}
1138311436
else {
11384-
source = getApparentType(source);
11437+
if (!(priority && InferencePriority.NoConstraints && source.flags & (TypeFlags.Intersection | TypeFlags.Instantiable))) {
11438+
source = getApparentType(source);
11439+
}
1138511440
if (source.flags & (TypeFlags.Object | TypeFlags.Intersection)) {
1138611441
const key = source.id + "," + target.id;
1138711442
if (visited && visited.get(key)) {

src/compiler/declarationEmitter.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,8 @@ namespace ts {
452452
return emitIntersectionType(<IntersectionTypeNode>type);
453453
case SyntaxKind.ConditionalType:
454454
return emitConditionalType(<ConditionalTypeNode>type);
455+
case SyntaxKind.InferType:
456+
return emitInferType(<InferTypeNode>type);
455457
case SyntaxKind.ParenthesizedType:
456458
return emitParenType(<ParenthesizedTypeNode>type);
457459
case SyntaxKind.TypeOperator:
@@ -557,6 +559,11 @@ namespace ts {
557559
emitType(node.falseType);
558560
}
559561

562+
function emitInferType(node: InferTypeNode) {
563+
write("infer ");
564+
writeTextOfNode(currentText, node.typeParameter.name);
565+
}
566+
560567
function emitParenType(type: ParenthesizedTypeNode) {
561568
write("(");
562569
emitType(type.type);

src/compiler/emitter.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,8 @@ namespace ts {
602602
return emitIntersectionType(<IntersectionTypeNode>node);
603603
case SyntaxKind.ConditionalType:
604604
return emitConditionalType(<ConditionalTypeNode>node);
605+
case SyntaxKind.InferType:
606+
return emitInferType(<InferTypeNode>node);
605607
case SyntaxKind.ParenthesizedType:
606608
return emitParenthesizedType(<ParenthesizedTypeNode>node);
607609
case SyntaxKind.ExpressionWithTypeArguments:
@@ -1202,6 +1204,11 @@ namespace ts {
12021204
emit(node.falseType);
12031205
}
12041206

1207+
function emitInferType(node: InferTypeNode) {
1208+
write("infer ");
1209+
emit(node.typeParameter);
1210+
}
1211+
12051212
function emitParenthesizedType(node: ParenthesizedTypeNode) {
12061213
writePunctuation("(");
12071214
emit(node.type);

src/compiler/factory.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,18 @@ namespace ts {
747747
: node;
748748
}
749749

750+
export function createInferTypeNode(typeParameter: TypeParameterDeclaration) {
751+
const node = <InferTypeNode>createSynthesizedNode(SyntaxKind.InferType);
752+
node.typeParameter = typeParameter;
753+
return node;
754+
}
755+
756+
export function updateInferTypeNode(node: InferTypeNode, typeParameter: TypeParameterDeclaration) {
757+
return node.typeParameter !== typeParameter
758+
? updateNode(createInferTypeNode(typeParameter), node)
759+
: node;
760+
}
761+
750762
export function createParenthesizedType(type: TypeNode) {
751763
const node = <ParenthesizedTypeNode>createSynthesizedNode(SyntaxKind.ParenthesizedType);
752764
node.type = type;

src/compiler/parser.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ namespace ts {
180180
visitNode(cbNode, (<ConditionalTypeNode>node).extendsType) ||
181181
visitNode(cbNode, (<ConditionalTypeNode>node).trueType) ||
182182
visitNode(cbNode, (<ConditionalTypeNode>node).falseType);
183+
case SyntaxKind.InferType:
184+
return visitNode(cbNode, (<InferTypeNode>node).typeParameter);
183185
case SyntaxKind.ParenthesizedType:
184186
case SyntaxKind.TypeOperator:
185187
return visitNode(cbNode, (<ParenthesizedTypeNode | TypeOperatorNode>node).type);
@@ -2647,6 +2649,15 @@ namespace ts {
26472649
return finishNode(node);
26482650
}
26492651

2652+
function parseInferType(): InferTypeNode {
2653+
const node = <InferTypeNode>createNode(SyntaxKind.InferType);
2654+
parseExpected(SyntaxKind.InferKeyword);
2655+
const typeParameter = <TypeParameterDeclaration>createNode(SyntaxKind.TypeParameter);
2656+
typeParameter.name = parseIdentifier();
2657+
node.typeParameter = finishNode(typeParameter);
2658+
return finishNode(node);
2659+
}
2660+
26502661
function parseFunctionOrConstructorType(kind: SyntaxKind): FunctionOrConstructorTypeNode {
26512662
const node = <FunctionOrConstructorTypeNode>createNodeWithJSDoc(kind);
26522663
if (kind === SyntaxKind.ConstructorType) {
@@ -2733,6 +2744,8 @@ namespace ts {
27332744
return parseTupleType();
27342745
case SyntaxKind.OpenParenToken:
27352746
return parseParenthesizedType();
2747+
case SyntaxKind.InferKeyword:
2748+
return parseInferType();
27362749
default:
27372750
return parseTypeReference();
27382751
}
@@ -2767,6 +2780,7 @@ namespace ts {
27672780
case SyntaxKind.QuestionToken:
27682781
case SyntaxKind.ExclamationToken:
27692782
case SyntaxKind.DotDotDotToken:
2783+
case SyntaxKind.InferKeyword:
27702784
return true;
27712785
case SyntaxKind.MinusToken:
27722786
return !inStartOfParameter && lookAhead(nextTokenIsNumericLiteral);

src/compiler/scanner.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ namespace ts {
9292
"implements": SyntaxKind.ImplementsKeyword,
9393
"import": SyntaxKind.ImportKeyword,
9494
"in": SyntaxKind.InKeyword,
95+
"infer": SyntaxKind.InferKeyword,
9596
"instanceof": SyntaxKind.InstanceOfKeyword,
9697
"interface": SyntaxKind.InterfaceKeyword,
9798
"is": SyntaxKind.IsKeyword,

0 commit comments

Comments
 (0)