diff --git a/src/transformation/visitors/function.ts b/src/transformation/visitors/function.ts index f79682223..1c21a0f7c 100644 --- a/src/transformation/visitors/function.ts +++ b/src/transformation/visitors/function.ts @@ -24,11 +24,17 @@ import { transformBindingPattern } from "./variable-declaration"; function transformParameterDefaultValueDeclaration( context: TransformationContext, parameterName: lua.Identifier, - value?: ts.Expression, + value: ts.Expression, tsOriginal?: ts.Node -): lua.Statement { - const parameterValue = value ? context.transformExpression(value) : undefined; - const assignment = lua.createAssignmentStatement(parameterName, parameterValue); +): lua.Statement | undefined { + const { precedingStatements: statements, result: parameterValue } = transformInPrecedingStatementScope( + context, + () => context.transformExpression(value) + ); + if (!lua.isNilLiteral(parameterValue)) { + statements.push(lua.createAssignmentStatement(parameterName, parameterValue)); + } + if (statements.length === 0) return undefined; const nilCondition = lua.createBinaryExpression( parameterName, @@ -36,7 +42,7 @@ function transformParameterDefaultValueDeclaration( lua.SyntaxKind.EqualityOperator ); - const ifBlock = lua.createBlock([assignment]); + const ifBlock = lua.createBlock(statements, tsOriginal); return lua.createIfStatement(nilCondition, ifBlock, undefined, tsOriginal); } @@ -106,7 +112,7 @@ export function transformFunctionBodyHeader( parameters: ts.NodeArray, spreadIdentifier?: lua.Identifier ): lua.Statement[] { - const headerStatements = []; + const headerStatements: lua.Statement[] = []; // Add default parameters and object binding patterns const bindingPatternDeclarations: lua.Statement[] = []; @@ -116,9 +122,12 @@ export function transformFunctionBodyHeader( const identifier = lua.createIdentifier(`____bindingPattern${bindPatternIndex++}`); if (declaration.initializer !== undefined) { // Default binding parameter - headerStatements.push( - transformParameterDefaultValueDeclaration(context, identifier, declaration.initializer) + const initializer = transformParameterDefaultValueDeclaration( + context, + identifier, + declaration.initializer ); + if (initializer) headerStatements.push(initializer); } // Binding pattern @@ -129,13 +138,12 @@ export function transformFunctionBodyHeader( bindingPatternDeclarations.push(...precedingStatements, ...bindings); } else if (declaration.initializer !== undefined) { // Default parameter - headerStatements.push( - transformParameterDefaultValueDeclaration( - context, - transformIdentifier(context, declaration.name), - declaration.initializer - ) + const initializer = transformParameterDefaultValueDeclaration( + context, + transformIdentifier(context, declaration.name), + declaration.initializer ); + if (initializer) headerStatements.push(initializer); } } diff --git a/test/unit/functions/__snapshots__/functions.spec.ts.snap b/test/unit/functions/__snapshots__/functions.spec.ts.snap index 4e07a4367..83345e13b 100644 --- a/test/unit/functions/__snapshots__/functions.spec.ts.snap +++ b/test/unit/functions/__snapshots__/functions.spec.ts.snap @@ -1,5 +1,27 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP +exports[`Function default parameter with value "null" 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local function foo(self, x) + return x + end + return foo(nil) +end +return ____exports" +`; + +exports[`Function default parameter with value "undefined" 1`] = ` +"local ____exports = {} +function ____exports.__main(self) + local function foo(self, x) + return x + end + return foo(nil) +end +return ____exports" +`; + exports[`function.length unsupported ("5.0"): code 1`] = ` "local ____exports = {} function ____exports.__main(self) diff --git a/test/unit/functions/functions.spec.ts b/test/unit/functions/functions.spec.ts index 7bd224fba..7bdf6b040 100644 --- a/test/unit/functions/functions.spec.ts +++ b/test/unit/functions/functions.spec.ts @@ -107,6 +107,48 @@ test("Function default binding parameter maintains order", () => { `.expectToMatchJsResult(); }); +test.each(["undefined", "null"])("Function default parameter with value %p", defaultValue => { + util.testFunction` + function foo(x = ${defaultValue}) { + return x; + } + return foo(); + ` + .expectToMatchJsResult() + .tap(builder => { + const lua = builder.getMainLuaCodeChunk(); + expect(lua).not.toMatch("if x == nil then"); + }) + .expectLuaToMatchSnapshot(); +}); + +test("Function default parameter with preceding statements", () => { + util.testFunction` + let i = 1 + function foo(x = i++) { + return x; + } + return [i, foo(), i]; + `.expectToMatchJsResult(); +}); + +test("Function default parameter with nil value and preceding statements", () => { + util.testFunction` + const a = new LuaTable() + a.set("foo", "bar") + function foo(x: any = a.set("foo", "baz")) { + return x ?? "nil"; + } + return [a.get("foo"), foo(), a.get("foo")]; + ` + .withLanguageExtensions() + .tap(builder => { + const lua = builder.getMainLuaCodeChunk(); + expect(lua).not.toMatch(" x = nil"); + }) + .expectToEqual(["bar", "nil", "baz"]); +}); + test("Class method call", () => { util.testFunction` class TestClass {