diff --git a/src/transformation/builtins/array.ts b/src/transformation/builtins/array.ts index c9074874f..22afd5b45 100644 --- a/src/transformation/builtins/array.ts +++ b/src/transformation/builtins/array.ts @@ -5,7 +5,7 @@ import { TransformationContext } from "../context"; import { unsupportedProperty } from "../utils/diagnostics"; import { LuaLibFeature, transformLuaLibFunction } from "../utils/lualib"; import { transformArguments, transformCallAndArguments } from "../visitors/call"; -import { findFirstNonOuterParent, typeAlwaysHasSomeOfFlags } from "../utils/typescript"; +import { expressionResultIsUsed, typeAlwaysHasSomeOfFlags } from "../utils/typescript"; import { moveToPrecedingTemp } from "../visitors/expression-list"; import { isUnpackCall, wrapInTable } from "../utils/lua-ast"; @@ -54,8 +54,6 @@ function transformSingleElementArrayPush( caller: lua.Expression, param: lua.Expression ): lua.Expression { - const expressionIsUsed = !ts.isExpressionStatement(findFirstNonOuterParent(node)); - const arrayIdentifier = lua.isIdentifier(caller) ? caller : moveToPrecedingTemp(context, caller); // #array + 1 @@ -65,6 +63,7 @@ function transformSingleElementArrayPush( lua.SyntaxKind.AdditionOperator ); + const expressionIsUsed = expressionResultIsUsed(node); if (expressionIsUsed) { // store length in a temp lengthExpression = moveToPrecedingTemp(context, lengthExpression); diff --git a/src/transformation/utils/typescript/index.ts b/src/transformation/utils/typescript/index.ts index dca15e00b..e5984b08f 100644 --- a/src/transformation/utils/typescript/index.ts +++ b/src/transformation/utils/typescript/index.ts @@ -32,6 +32,10 @@ export function findFirstNonOuterParent(node: ts.Node): ts.Node { return current; } +export function expressionResultIsUsed(node: ts.Expression): boolean { + return !ts.isExpressionStatement(findFirstNonOuterParent(node)); +} + export function getFirstDeclarationInFile(symbol: ts.Symbol, sourceFile: ts.SourceFile): ts.Declaration | undefined { const originalSourceFile = ts.getParseTreeNode(sourceFile) ?? sourceFile; const declarations = (symbol.getDeclarations() ?? []).filter(d => d.getSourceFile() === originalSourceFile); diff --git a/src/transformation/utils/typescript/types.ts b/src/transformation/utils/typescript/types.ts index d13a40b38..1a657d673 100644 --- a/src/transformation/utils/typescript/types.ts +++ b/src/transformation/utils/typescript/types.ts @@ -115,8 +115,6 @@ export function canBeFalsy(context: TransformationContext, type: ts.Type): boole } export function canBeFalsyWhenNotNull(context: TransformationContext, type: ts.Type): boolean { - const strictNullChecks = context.options.strict === true || context.options.strictNullChecks === true; - if (!strictNullChecks && !type.isLiteral()) return true; const falsyFlags = ts.TypeFlags.Boolean | ts.TypeFlags.BooleanLiteral | diff --git a/src/transformation/visitors/expression-statement.ts b/src/transformation/visitors/expression-statement.ts index ebc0b1600..a247fd29b 100644 --- a/src/transformation/visitors/expression-statement.ts +++ b/src/transformation/visitors/expression-statement.ts @@ -1,6 +1,6 @@ import * as ts from "typescript"; import * as lua from "../../LuaAST"; -import { FunctionVisitor, tempSymbolId, TransformationContext } from "../context"; +import { FunctionVisitor, tempSymbolId } from "../context"; import { transformBinaryExpressionStatement } from "./binary-expression"; import { transformUnaryExpressionStatement } from "./unary-expression"; @@ -15,15 +15,10 @@ export const transformExpressionStatement: FunctionVisitor(); @@ -74,12 +77,16 @@ const optionalContinuations = new WeakMap() function createOptionalContinuationIdentifier(text: string, tsOriginal: ts.Expression): ts.Identifier { const identifier = ts.factory.createIdentifier(text); ts.setOriginalNode(identifier, tsOriginal); - optionalContinuations.set(identifier, {}); + optionalContinuations.set(identifier, { + usedIdentifiers: [], + }); return identifier; } + export function isOptionalContinuation(node: ts.Node): boolean { return ts.isIdentifier(node) && optionalContinuations.has(node); } + export function getOptionalContinuationData(identifier: ts.Identifier): OptionalContinuation | undefined { return optionalContinuations.get(identifier); } @@ -90,16 +97,16 @@ export function transformOptionalChain(context: TransformationContext, node: ts. export function transformOptionalChainWithCapture( context: TransformationContext, - node: ts.OptionalChain, + tsNode: ts.OptionalChain, thisValueCapture: lua.Identifier | undefined, isDelete?: ts.DeleteExpression ): ExpressionWithThisValue { - const luaTemp = context.createTempNameForNode(node); + const luaTempName = context.createTempName("opt"); - const { expression: tsLeftExpression, chain } = flattenChain(node); + const { expression: tsLeftExpression, chain } = flattenChain(tsNode); // build temp.b.c.d - const tsTemp = createOptionalContinuationIdentifier(luaTemp.text, tsLeftExpression); + const tsTemp = createOptionalContinuationIdentifier(luaTempName, tsLeftExpression); let tsRightExpression: ts.Expression = tsTemp; for (const link of chain) { if (ts.isPropertyAccessExpression(link)) { @@ -121,18 +128,18 @@ export function transformOptionalChainWithCapture( // transform right expression first to check if thisValue capture is needed // capture and return thisValue if requested from outside let returnThisValue: lua.Expression | undefined; - const [rightPrecedingStatements, rightAssignment] = transformInPrecedingStatementScope(context, () => { - let result: lua.Expression; - if (thisValueCapture) { - ({ expression: result, thisValue: returnThisValue } = transformExpressionWithThisValueCapture( - context, - tsRightExpression, - thisValueCapture - )); - } else { - result = context.transformExpression(tsRightExpression); + const [rightPrecedingStatements, rightExpression] = transformInPrecedingStatementScope(context, () => { + if (!thisValueCapture) { + return context.transformExpression(tsRightExpression); } - return lua.createAssignmentStatement(luaTemp, result); + + const { expression: result, thisValue } = transformExpressionWithThisValueCapture( + context, + tsRightExpression, + thisValueCapture + ); + returnThisValue = thisValue; + return result; }); // transform left expression, handle thisValue if needed by rightExpression @@ -140,7 +147,8 @@ export function transformOptionalChainWithCapture( const leftThisValueTemp = lua.createIdentifier(thisValueCaptureName, undefined, tempSymbolId); let capturedThisValue: lua.Expression | undefined; - const rightContextualCall = getOptionalContinuationData(tsTemp)?.contextualCall; + const optionalContinuationData = getOptionalContinuationData(tsTemp); + const rightContextualCall = optionalContinuationData?.contextualCall; const [leftPrecedingStatements, leftExpression] = transformInPrecedingStatementScope(context, () => { let result: lua.Expression; if (rightContextualCall) { @@ -177,26 +185,78 @@ export function transformOptionalChainWithCapture( } } - // - // local temp = - // if temp ~= nil then - // - // temp = temp.b.c.d - // end - // return temp - - context.addPrecedingStatements([ - ...leftPrecedingStatements, - lua.createVariableDeclarationStatement(luaTemp, leftExpression), - lua.createIfStatement( - lua.createBinaryExpression(luaTemp, lua.createNilLiteral(), lua.SyntaxKind.InequalityOperator), - lua.createBlock([...rightPrecedingStatements, rightAssignment]) - ), - ]); - return { - expression: luaTemp, - thisValue: returnThisValue, - }; + // evaluate optional chain + context.addPrecedingStatements(leftPrecedingStatements); + + // try use existing variable instead of creating new one, if possible + let leftIdentifier: lua.Identifier | undefined; + const usedLuaIdentifiers = optionalContinuationData?.usedIdentifiers; + const reuseLeftIdentifier = + usedLuaIdentifiers && + usedLuaIdentifiers.length > 0 && + lua.isIdentifier(leftExpression) && + (rightPrecedingStatements.length === 0 || !shouldMoveToTemp(context, leftExpression, tsLeftExpression)); + if (reuseLeftIdentifier) { + leftIdentifier = leftExpression; + for (const usedIdentifier of usedLuaIdentifiers) { + usedIdentifier.text = leftIdentifier.text; + } + } else { + leftIdentifier = lua.createIdentifier(luaTempName, undefined, tempSymbolId); + context.addPrecedingStatements(lua.createVariableDeclarationStatement(leftIdentifier, leftExpression)); + } + + if (!expressionResultIsUsed(tsNode) || isDelete) { + // if left ~= nil then + // + // + // end + + const innerExpression = wrapInStatement(rightExpression); + const innerStatements = rightPrecedingStatements; + if (innerExpression) innerStatements.push(innerExpression); + + context.addPrecedingStatements( + lua.createIfStatement( + lua.createBinaryExpression(leftIdentifier, lua.createNilLiteral(), lua.SyntaxKind.InequalityOperator), + lua.createBlock(innerStatements) + ) + ); + return { expression: lua.createNilLiteral(), thisValue: returnThisValue }; + } else if ( + rightPrecedingStatements.length === 0 && + !canBeFalsyWhenNotNull(context, context.checker.getTypeAtLocation(tsLeftExpression)) + ) { + // return a && a.b + return { + expression: lua.createBinaryExpression(leftIdentifier, rightExpression, lua.SyntaxKind.AndOperator, tsNode), + thisValue: returnThisValue, + }; + } else { + let resultIdentifier: lua.Identifier; + if (!reuseLeftIdentifier) { + // reuse temp variable for output + resultIdentifier = leftIdentifier; + } else { + resultIdentifier = lua.createIdentifier(context.createTempName("opt_result"), undefined, tempSymbolId); + context.addPrecedingStatements(lua.createVariableDeclarationStatement(resultIdentifier)); + } + // if left ~= nil then + // + // result = + // end + // return result + context.addPrecedingStatements( + lua.createIfStatement( + lua.createBinaryExpression(leftIdentifier, lua.createNilLiteral(), lua.SyntaxKind.InequalityOperator), + lua.createBlock([ + ...rightPrecedingStatements, + lua.createAssignmentStatement(resultIdentifier, rightExpression), + ]) + ) + ); + return { expression: resultIdentifier, thisValue: returnThisValue }; + } } export function transformOptionalDeleteExpression( diff --git a/src/transformation/visitors/void.ts b/src/transformation/visitors/void.ts index ab5286db5..c942e859c 100644 --- a/src/transformation/visitors/void.ts +++ b/src/transformation/visitors/void.ts @@ -1,13 +1,13 @@ import * as ts from "typescript"; import * as lua from "../../LuaAST"; import { FunctionVisitor } from "../context"; -import { transformExpressionToStatement } from "./expression-statement"; +import { wrapInStatement } from "./expression-statement"; // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/void export const transformVoidExpression: FunctionVisitor = (node, context) => { // If content is a literal it is safe to replace the entire expression with nil if (!ts.isLiteralExpression(node.expression)) { - const statements = transformExpressionToStatement(context, node.expression); + const statements = wrapInStatement(context.transformExpression(node.expression)); if (statements) context.addPrecedingStatements(statements); } diff --git a/test/unit/__snapshots__/optionalChaining.spec.ts.snap b/test/unit/__snapshots__/optionalChaining.spec.ts.snap index 94655ca01..e52fe69b8 100644 --- a/test/unit/__snapshots__/optionalChaining.spec.ts.snap +++ b/test/unit/__snapshots__/optionalChaining.spec.ts.snap @@ -3,18 +3,18 @@ exports[`Unsupported optional chains Builtin global method: code 1`] = ` "local ____lualib = require(\\"lualib_bundle\\") local __TS__Number = ____lualib.__TS__Number -local ____Number_result_0 = Number -if ____Number_result_0 ~= nil then - ____Number_result_0 = __TS__Number(\\"3\\") +local ____opt_0 = Number +if ____opt_0 ~= nil then + __TS__Number(\\"3\\") end" `; exports[`Unsupported optional chains Builtin global method: diagnostics 1`] = `"main.ts(2,17): error TSTL: Optional calls are not supported for builtin or language extension functions."`; exports[`Unsupported optional chains Builtin global property: code 1`] = ` -"local ____console_log_result_0 = console -if ____console_log_result_0 ~= nil then - ____console_log_result_0 = print(\\"3\\") +"local ____opt_0 = console +if ____opt_0 ~= nil then + print(\\"3\\") end" `; @@ -23,9 +23,9 @@ exports[`Unsupported optional chains Builtin global property: diagnostics 1`] = exports[`Unsupported optional chains Builtin prototype method: code 1`] = ` "local ____lualib = require(\\"lualib_bundle\\") local __TS__ArrayForEach = ____lualib.__TS__ArrayForEach -local ____table_forEach_result_0 = ({1, 2, 3, 4}).forEach -if ____table_forEach_result_0 ~= nil then - ____table_forEach_result_0 = __TS__ArrayForEach( +local ____opt_0 = ({1, 2, 3, 4}).forEach +if ____opt_0 ~= nil then + __TS__ArrayForEach( {1, 2, 3, 4}, function() end @@ -50,9 +50,9 @@ function ____exports.__main(self) --- -- @compileMembersOnly local D = \\"D\\" - local ____TestEnum_B_0 = TestEnum - if ____TestEnum_B_0 ~= nil then - ____TestEnum_B_0 = B + local ____opt_0 = TestEnum + if ____opt_0 ~= nil then + local ____ = B end end return ____exports" @@ -61,10 +61,214 @@ return ____exports" exports[`Unsupported optional chains Compile members only: diagnostics 1`] = `"main.ts(10,17): error TSTL: Optional calls are not supported on enums marked with @compileMembersOnly."`; exports[`Unsupported optional chains Language extensions: code 1`] = ` -"local ____table_has_result_0 = ({}).has -if ____table_has_result_0 ~= nil then - ____table_has_result_0 = nil[3] ~= nil +"local ____opt_0 = ({}).has +if ____opt_0 ~= nil then + local ____ = nil[3] ~= nil end" `; exports[`Unsupported optional chains Language extensions: diagnostics 1`] = `"main.ts(2,17): error TSTL: Optional calls are not supported for builtin or language extension functions."`; + +exports[`long optional chain 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local a = {b = {c = {d = {e = {f = \\"hello!\\"}}}}} + local ____opt_2 = a.b + local ____opt_0 = ____opt_2 and ____opt_2.c + return ____opt_0 and ____opt_0.d.e.f +end +return ____exports" +`; + +exports[`optional chaining ("{ foo: \\"foo\\" }") 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local obj = {foo = \\"foo\\"} + return obj and obj.foo +end +return ____exports" +`; + +exports[`optional chaining ("null") 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local obj = nil + return obj and obj.foo +end +return ____exports" +`; + +exports[`optional chaining ("undefined") 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local obj = nil + return obj and obj.foo +end +return ____exports" +`; + +exports[`optional element function calls 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local obj = { + value = \\"foobar\\", + foo = function(v) return v + 10 end + } + local fooKey = \\"foo\\" + local barKey = \\"bar\\" + local ____opt_0 = obj[barKey] + local ____temp_4 = ____opt_0 and ____opt_0(5) + if ____temp_4 == nil then + local ____opt_2 = obj[fooKey] + ____temp_4 = ____opt_2 and ____opt_2(15) + end + return ____temp_4 +end +return ____exports" +`; + +exports[`unused call 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local result + local obj = {foo = function(self) + result = \\"bar\\" + end} + if obj ~= nil then + obj:foo() + end + return result +end +return ____exports" +`; + +exports[`unused expression 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local obj = {foo = \\"bar\\"} + if obj ~= nil then + local ____ = obj.foo + end +end +return ____exports" +`; + +exports[`unused result with preceding statements on right side 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local i = 0 + local obj = nil + if obj ~= nil then + local ____opt_0_foo_2 = obj.foo + local ____i_1 = i + i = ____i_1 + 1 + ____opt_0_foo_2(obj, ____i_1) + end + return i +end +return ____exports" +`; + +exports[`unused result with preceding statements on right side 2`] = ` +"local ____exports = {} +function ____exports.__main(self) + local i = 0 + local obj = {foo = function(self, val) + return val + end} + if obj ~= nil then + local ____opt_0_foo_2 = obj.foo + local ____i_1 = i + i = ____i_1 + 1 + ____opt_0_foo_2(obj, ____i_1) + end + return i +end +return ____exports" +`; + +exports[`with preceding statements on right side 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local i = 0 + local obj = nil + local ____opt_result_4 + if obj ~= nil then + local ____opt_0_foo_2 = obj.foo + local ____i_1 = i + i = ____i_1 + 1 + ____opt_result_4 = ____opt_0_foo_2(obj, ____i_1) + end + return {result = ____opt_result_4, i = i} +end +return ____exports" +`; + +exports[`with preceding statements on right side 2`] = ` +"local ____exports = {} +function ____exports.__main(self) + local i = 0 + local obj = {foo = function(____, v) return v end} + local ____opt_result_4 + if obj ~= nil then + local ____opt_0_foo_2 = obj.foo + local ____i_1 = i + i = ____i_1 + 1 + ____opt_result_4 = ____opt_0_foo_2(obj, ____i_1) + end + return {result = ____opt_result_4, i = i} +end +return ____exports" +`; + +exports[`with preceding statements on right side modifying left 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local i = 0 + local obj = nil + local function bar(self) + if obj then + obj.foo = nil + end + obj = nil + return 1 + end + local ____opt_0 = obj + if ____opt_0 ~= nil then + local ____opt_0_foo_3 = ____opt_0.foo + local ____bar_result_2 = bar(nil) + local ____i_1 = i + i = ____i_1 + 1 + ____opt_0 = ____opt_0_foo_3(____opt_0, ____bar_result_2, ____i_1) + end + return {result = ____opt_0, obj = obj, i = i} +end +return ____exports" +`; + +exports[`with preceding statements on right side modifying left 2`] = ` +"local ____exports = {} +function ____exports.__main(self) + local i = 0 + local obj = {foo = function(self, v) + return v + end} + local function bar(self) + if obj then + obj.foo = nil + end + obj = nil + return 1 + end + local ____opt_0 = obj + if ____opt_0 ~= nil then + local ____opt_0_foo_3 = ____opt_0.foo + local ____bar_result_2 = bar(nil) + local ____i_1 = i + i = ____i_1 + 1 + ____opt_0 = ____opt_0_foo_3(____opt_0, ____bar_result_2, ____i_1) + end + return {result = ____opt_0, obj = obj, i = i} +end +return ____exports" +`; diff --git a/test/unit/optionalChaining.spec.ts b/test/unit/optionalChaining.spec.ts index 675fc59f9..8a30e1647 100644 --- a/test/unit/optionalChaining.spec.ts +++ b/test/unit/optionalChaining.spec.ts @@ -4,16 +4,22 @@ import { ScriptTarget } from "typescript"; test.each(["null", "undefined", '{ foo: "foo" }'])("optional chaining (%p)", value => { util.testFunction` - const obj: any = ${value}; + const obj: {foo: string} | null | undefined = ${value}; return obj?.foo; - `.expectToMatchJsResult(); + ` + .expectToMatchJsResult() + .expectLuaToMatchSnapshot(); + // should use "and" expression }); test("long optional chain", () => { util.testFunction` const a = { b: { c: { d: { e: { f: "hello!"}}}}}; return a.b?.c?.d.e.f; - `.expectToMatchJsResult(); + ` + .expectToMatchJsResult() + .expectLuaToMatchSnapshot(); + // should use "and" expression }); test.each(["undefined", "{}", "{ foo: {} }", "{ foo: {bar: 'baz'}}"])("nested optional chaining (%p)", value => { @@ -69,7 +75,89 @@ test("optional element function calls", () => { const fooKey = "foo"; const barKey = "bar"; return obj[barKey]?.(5) ?? obj[fooKey]?.(15); - `.expectToMatchJsResult(); + ` + .expectToMatchJsResult() + .expectLuaToMatchSnapshot(); + // should still use "and" statement, as functions have no self +}); + +test("unused expression", () => { + util.testFunction` + const obj = { foo: "bar" }; + obj?.foo; + ` + .expectToHaveNoDiagnostics() + .expectNoExecutionError() + .expectLuaToMatchSnapshot(); + // should use if statement, as result is not used +}); + +test("unused call", () => { + util.testFunction` + let result + const obj = { + foo() { + result = "bar" + } + }; + obj?.foo(); + return result; + ` + .expectToMatchJsResult() + .expectLuaToMatchSnapshot(); + // should use if statement, as result is not used +}); + +test.each(["undefined", "{ foo: v=>v }"])("with preceding statements on right side", value => { + util.testFunction` + let i = 0 + const obj: any = ${value}; + return {result: obj?.foo(i++), i}; + ` + .expectToMatchJsResult() + .expectLuaToMatchSnapshot(); + // should use if statement, as there are preceding statements +}); + +// unused, with preceding statements on right side +test.each(["undefined", "{ foo(val) {return val} }"])( + "unused result with preceding statements on right side", + value => { + util.testFunction` + let i = 0 + const obj = ${value}; + obj?.foo(i++); + return i + ` + .expectToHaveNoDiagnostics() + .expectLuaToMatchSnapshot(); + // should use if statement, as there are preceding statements + } +); + +test.each(["undefined", "{ foo(v) { return v} }"])("with preceding statements on right side modifying left", value => { + util.testFunction` + let i = 0 + let obj: any = ${value}; + function bar() { + if(obj) obj.foo = undefined + obj = undefined + return 1 + } + + return {result: obj?.foo(bar(), i++), obj, i} + ` + .expectToMatchJsResult() + .expectLuaToMatchSnapshot(); + // should use if statement, as there are preceding statements +}); + +test("does not suppress error if left side is false", () => { + const result = util.testFunction` + const obj: any = false + return obj?.foo + `.getLuaExecutionResult(); + expect(result).toBeInstanceOf(util.ExecutionError); }); describe("optional access method calls", () => { @@ -85,7 +173,7 @@ describe("optional access method calls", () => { `.expectToMatchJsResult(); }); - test("optional access call", () => { + test("property access call", () => { util.testFunction` const obj: { value: string; foo?(prefix: string): string; bar?(prefix: string): string; } = { value: "foobar",