diff --git a/src/transformation/visitors/conditional.ts b/src/transformation/visitors/conditional.ts index 518f658e5..566dc9f75 100644 --- a/src/transformation/visitors/conditional.ts +++ b/src/transformation/visitors/conditional.ts @@ -6,28 +6,30 @@ import { performHoisting, popScope, pushScope, ScopeType } from "../utils/scope" import { transformBlockOrStatement } from "./block"; import { canBeFalsy } from "../utils/typescript"; +type EvaluatedExpression = [precedingStatemens: lua.Statement[], value: lua.Expression]; + function transformProtectedConditionalExpression( context: TransformationContext, - expression: ts.ConditionalExpression + expression: ts.ConditionalExpression, + condition: EvaluatedExpression, + whenTrue: EvaluatedExpression, + whenFalse: EvaluatedExpression ): lua.Expression { const tempVar = context.createTempNameForNode(expression.condition); - const condition = context.transformExpression(expression.condition); - - const [trueStatements, val1] = transformInPrecedingStatementScope(context, () => - context.transformExpression(expression.whenTrue) + const trueStatements = whenTrue[0].concat( + lua.createAssignmentStatement(lua.cloneIdentifier(tempVar), whenTrue[1], expression.whenTrue) ); - trueStatements.push(lua.createAssignmentStatement(lua.cloneIdentifier(tempVar), val1, expression.whenTrue)); - const [falseStatements, val2] = transformInPrecedingStatementScope(context, () => - context.transformExpression(expression.whenFalse) + const falseStatements = whenFalse[0].concat( + lua.createAssignmentStatement(lua.cloneIdentifier(tempVar), whenFalse[1], expression.whenFalse) ); - falseStatements.push(lua.createAssignmentStatement(lua.cloneIdentifier(tempVar), val2, expression.whenFalse)); context.addPrecedingStatements([ lua.createVariableDeclarationStatement(tempVar, undefined, expression.condition), + ...condition[0], lua.createIfStatement( - condition, + condition[1], lua.createBlock(trueStatements, expression.whenTrue), lua.createBlock(falseStatements, expression.whenFalse), expression @@ -37,17 +39,27 @@ function transformProtectedConditionalExpression( } export const transformConditionalExpression: FunctionVisitor = (expression, context) => { - if (canBeFalsy(context, context.checker.getTypeAtLocation(expression.whenTrue))) { - return transformProtectedConditionalExpression(context, expression); + const condition = transformInPrecedingStatementScope(context, () => + context.transformExpression(expression.condition) + ); + const whenTrue = transformInPrecedingStatementScope(context, () => + context.transformExpression(expression.whenTrue) + ); + const whenFalse = transformInPrecedingStatementScope(context, () => + context.transformExpression(expression.whenFalse) + ); + if ( + whenTrue[0].length > 0 || + whenFalse[0].length > 0 || + canBeFalsy(context, context.checker.getTypeAtLocation(expression.whenTrue)) + ) { + return transformProtectedConditionalExpression(context, expression, condition, whenTrue, whenFalse); } - const condition = context.transformExpression(expression.condition); - const val1 = context.transformExpression(expression.whenTrue); - const val2 = context.transformExpression(expression.whenFalse); - // condition and v1 or v2 - const conditionAnd = lua.createBinaryExpression(condition, val1, lua.SyntaxKind.AndOperator); - return lua.createBinaryExpression(conditionAnd, val2, lua.SyntaxKind.OrOperator, expression); + context.addPrecedingStatements(condition[0]); + const conditionAnd = lua.createBinaryExpression(condition[1], whenTrue[1], lua.SyntaxKind.AndOperator); + return lua.createBinaryExpression(conditionAnd, whenFalse[1], lua.SyntaxKind.OrOperator, expression); }; export function transformIfStatement(statement: ts.IfStatement, context: TransformationContext): lua.IfStatement { diff --git a/test/unit/conditionals.spec.ts b/test/unit/conditionals.spec.ts index 738ecacc1..a1a234cf6 100644 --- a/test/unit/conditionals.spec.ts +++ b/test/unit/conditionals.spec.ts @@ -112,3 +112,29 @@ test.each([false, true, null])("Ternary conditional with generic whenTrue branch }) .expectToMatchJsResult(); }); + +test.each([false, true])("Ternary conditional with preceding statements in true branch (%p)", trueVal => { + // language=TypeScript + util.testFunction` + let i = 0; + const result = ${trueVal} ? i += 1 : i; + return { result, i }; + ` + .setOptions({ + strictNullChecks: true, + }) + .expectToMatchJsResult(); +}); + +test.each([false, true])("Ternary conditional with preceding statements in false branch (%p)", trueVal => { + // language=TypeScript + util.testFunction` + let i = 0; + const result = ${trueVal} ? i : i += 2; + return { result, i }; + ` + .setOptions({ + strictNullChecks: true, + }) + .expectToMatchJsResult(); +});