diff --git a/src/transformation/utils/diagnostics.ts b/src/transformation/utils/diagnostics.ts index 94f7781de..fd443e596 100644 --- a/src/transformation/utils/diagnostics.ts +++ b/src/transformation/utils/diagnostics.ts @@ -155,3 +155,11 @@ export const unsupportedOptionalCompileMembersOnly = createErrorDiagnosticFactor export const undefinedInArrayLiteral = createErrorDiagnosticFactory( "Array literals may not contain undefined or null." ); + +export const invalidMethodCallExtensionUse = createErrorDiagnosticFactory( + "This language extension must be called as a method." +); + +export const invalidSpreadInCallExtension = createErrorDiagnosticFactory( + "Spread elements are not supported in call extensions." +); diff --git a/src/transformation/utils/language-extensions.ts b/src/transformation/utils/language-extensions.ts index 881da105b..108c6eef1 100644 --- a/src/transformation/utils/language-extensions.ts +++ b/src/transformation/utils/language-extensions.ts @@ -1,5 +1,6 @@ import * as ts from "typescript"; import { TransformationContext } from "../context"; +import { invalidMethodCallExtensionUse, invalidSpreadInCallExtension } from "./diagnostics"; export enum ExtensionKind { MultiFunction = "MultiFunction", @@ -53,6 +54,7 @@ export enum ExtensionKind { TableAddKeyType = "TableAddKey", TableAddKeyMethodType = "TableAddKeyMethod", } + const extensionValues: Set = new Set(Object.values(ExtensionKind)); export function getExtensionKindForType(context: TransformationContext, type: ts.Type): ExtensionKind | undefined { @@ -119,3 +121,78 @@ export function getIterableExtensionKindForNode( const type = context.checker.getTypeAtLocation(node); return getIterableExtensionTypeForType(context, type); } + +export const methodExtensionKinds: ReadonlySet = new Set([ + ExtensionKind.AdditionOperatorMethodType, + ExtensionKind.SubtractionOperatorMethodType, + ExtensionKind.MultiplicationOperatorMethodType, + ExtensionKind.DivisionOperatorMethodType, + ExtensionKind.ModuloOperatorMethodType, + ExtensionKind.PowerOperatorMethodType, + ExtensionKind.FloorDivisionOperatorMethodType, + ExtensionKind.BitwiseAndOperatorMethodType, + ExtensionKind.BitwiseOrOperatorMethodType, + ExtensionKind.BitwiseExclusiveOrOperatorMethodType, + ExtensionKind.BitwiseLeftShiftOperatorMethodType, + ExtensionKind.BitwiseRightShiftOperatorMethodType, + ExtensionKind.ConcatOperatorMethodType, + ExtensionKind.LessThanOperatorMethodType, + ExtensionKind.GreaterThanOperatorMethodType, + ExtensionKind.NegationOperatorMethodType, + ExtensionKind.BitwiseNotOperatorMethodType, + ExtensionKind.LengthOperatorMethodType, + ExtensionKind.TableDeleteMethodType, + ExtensionKind.TableGetMethodType, + ExtensionKind.TableHasMethodType, + ExtensionKind.TableSetMethodType, + ExtensionKind.TableAddKeyMethodType, +]); + +export function getNaryCallExtensionArgs( + context: TransformationContext, + node: ts.CallExpression, + kind: ExtensionKind, + numArgs: number +): readonly ts.Expression[] | undefined { + let expressions: readonly ts.Expression[]; + if (node.arguments.some(ts.isSpreadElement)) { + context.diagnostics.push(invalidSpreadInCallExtension(node)); + return undefined; + } + if (methodExtensionKinds.has(kind)) { + if (!(ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression))) { + context.diagnostics.push(invalidMethodCallExtensionUse(node)); + return undefined; + } + if (node.arguments.length < numArgs - 1) { + // assumed to be TS error + return undefined; + } + expressions = [node.expression.expression, ...node.arguments]; + } else { + if (node.arguments.length < numArgs) { + // assumed to be TS error + return undefined; + } + expressions = node.arguments; + } + return expressions; +} + +export function getUnaryCallExtensionArg( + context: TransformationContext, + node: ts.CallExpression, + kind: ExtensionKind +): ts.Expression | undefined { + return getNaryCallExtensionArgs(context, node, kind, 1)?.[0]; +} + +export function getBinaryCallExtensionArgs( + context: TransformationContext, + node: ts.CallExpression, + kind: ExtensionKind +): readonly [ts.Expression, ts.Expression] | undefined { + const expressions = getNaryCallExtensionArgs(context, node, kind, 2); + if (expressions === undefined) return undefined; + return [expressions[0], expressions[1]]; +} diff --git a/src/transformation/visitors/language-extensions/operators.ts b/src/transformation/visitors/language-extensions/operators.ts index 657c98b92..4cb08abc2 100644 --- a/src/transformation/visitors/language-extensions/operators.ts +++ b/src/transformation/visitors/language-extensions/operators.ts @@ -4,8 +4,9 @@ import { TransformationContext } from "../../context"; import { assert } from "../../../utils"; import { LuaTarget } from "../../../CompilerOptions"; import { unsupportedForTarget } from "../../utils/diagnostics"; -import { ExtensionKind } from "../../utils/language-extensions"; +import { ExtensionKind, getBinaryCallExtensionArgs, getUnaryCallExtensionArg } from "../../utils/language-extensions"; import { LanguageExtensionCallTransformerMap } from "./call-extension"; +import { transformOrderedExpressions } from "../expression-list"; const binaryOperatorMappings = new Map([ [ExtensionKind.AdditionOperatorType, lua.SyntaxKind.AdditionOperator], @@ -81,35 +82,21 @@ for (const kind of unaryOperatorMappings.keys()) { function transformBinaryOperator(context: TransformationContext, node: ts.CallExpression, kind: ExtensionKind) { if (requiresLua53.has(kind)) checkHasLua53(context, node, kind); - let args: readonly ts.Expression[] = node.arguments; - if ( - args.length === 1 && - (ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression)) - ) { - args = [node.expression.expression, ...args]; - } + const args = getBinaryCallExtensionArgs(context, node, kind); + if (!args) return lua.createNilLiteral(); + + const [left, right] = transformOrderedExpressions(context, args); const luaOperator = binaryOperatorMappings.get(kind); assert(luaOperator); - return lua.createBinaryExpression( - context.transformExpression(args[0]), - context.transformExpression(args[1]), - luaOperator - ); + return lua.createBinaryExpression(left, right, luaOperator); } function transformUnaryOperator(context: TransformationContext, node: ts.CallExpression, kind: ExtensionKind) { if (requiresLua53.has(kind)) checkHasLua53(context, node, kind); - let arg: ts.Expression; - if ( - node.arguments.length === 0 && - (ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression)) - ) { - arg = node.expression.expression; - } else { - arg = node.arguments[0]; - } + const arg = getUnaryCallExtensionArg(context, node, kind); + if (!arg) return lua.createNilLiteral(); const luaOperator = unaryOperatorMappings.get(kind); assert(luaOperator); diff --git a/src/transformation/visitors/language-extensions/table.ts b/src/transformation/visitors/language-extensions/table.ts index 2d22adf15..d3aaedc5d 100644 --- a/src/transformation/visitors/language-extensions/table.ts +++ b/src/transformation/visitors/language-extensions/table.ts @@ -1,16 +1,22 @@ import * as ts from "typescript"; import * as lua from "../../../LuaAST"; import { TransformationContext } from "../../context"; -import { ExtensionKind, getExtensionKindForNode } from "../../utils/language-extensions"; -import { transformExpressionList } from "../expression-list"; -import { LanguageExtensionCallTransformer } from "./call-extension"; +import { + ExtensionKind, + getBinaryCallExtensionArgs, + getExtensionKindForNode, + getNaryCallExtensionArgs, +} from "../../utils/language-extensions"; +import { transformOrderedExpressions } from "../expression-list"; +import { LanguageExtensionCallTransformerMap } from "./call-extension"; export function isTableNewCall(context: TransformationContext, node: ts.NewExpression) { return getExtensionKindForNode(context, node.expression) === ExtensionKind.TableNewType; } + export const tableNewExtensions = [ExtensionKind.TableNewType]; -export const tableExtensionTransformers: { [P in ExtensionKind]?: LanguageExtensionCallTransformer } = { +export const tableExtensionTransformers: LanguageExtensionCallTransformerMap = { [ExtensionKind.TableDeleteType]: transformTableDeleteExpression, [ExtensionKind.TableDeleteMethodType]: transformTableDeleteExpression, [ExtensionKind.TableGetType]: transformTableGetExpression, @@ -19,8 +25,8 @@ export const tableExtensionTransformers: { [P in ExtensionKind]?: LanguageExtens [ExtensionKind.TableHasMethodType]: transformTableHasExpression, [ExtensionKind.TableSetType]: transformTableSetExpression, [ExtensionKind.TableSetMethodType]: transformTableSetExpression, - [ExtensionKind.TableAddKeyType]: transformTableAddExpression, - [ExtensionKind.TableAddKeyMethodType]: transformTableAddExpression, + [ExtensionKind.TableAddKeyType]: transformTableAddKeyExpression, + [ExtensionKind.TableAddKeyMethodType]: transformTableAddKeyExpression, }; function transformTableDeleteExpression( @@ -28,48 +34,32 @@ function transformTableDeleteExpression( node: ts.CallExpression, extensionKind: ExtensionKind ): lua.Expression { - const args = node.arguments.slice(); - if ( - extensionKind === ExtensionKind.TableDeleteMethodType && - (ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression)) - ) { - // In case of method (no table argument), push method owner to front of args list - args.unshift(node.expression.expression); + const args = getBinaryCallExtensionArgs(context, node, extensionKind); + if (!args) { + return lua.createNilLiteral(); } - const [table, accessExpression] = transformExpressionList(context, args); + const [table, key] = transformOrderedExpressions(context, args); // arg0[arg1] = nil context.addPrecedingStatements( - lua.createAssignmentStatement( - lua.createTableIndexExpression(table, accessExpression), - lua.createNilLiteral(), - node - ) + lua.createAssignmentStatement(lua.createTableIndexExpression(table, key), lua.createNilLiteral(), node) ); return lua.createBooleanLiteral(true); } -function transformWithTableArgument(context: TransformationContext, node: ts.CallExpression): lua.Expression[] { - if (ts.isPropertyAccessExpression(node.expression) || ts.isElementAccessExpression(node.expression)) { - return transformExpressionList(context, [node.expression.expression, ...node.arguments]); - } - // todo: report diagnostic? - return [lua.createNilLiteral(), ...transformExpressionList(context, node.arguments)]; -} - function transformTableGetExpression( context: TransformationContext, node: ts.CallExpression, extensionKind: ExtensionKind ): lua.Expression { - const args = - extensionKind === ExtensionKind.TableGetMethodType - ? transformWithTableArgument(context, node) - : transformExpressionList(context, node.arguments); + const args = getBinaryCallExtensionArgs(context, node, extensionKind); + if (!args) { + return lua.createNilLiteral(); + } - const [table, accessExpression] = args; + const [table, key] = transformOrderedExpressions(context, args); // arg0[arg1] - return lua.createTableIndexExpression(table, accessExpression, node); + return lua.createTableIndexExpression(table, key, node); } function transformTableHasExpression( @@ -77,14 +67,14 @@ function transformTableHasExpression( node: ts.CallExpression, extensionKind: ExtensionKind ): lua.Expression { - const args = - extensionKind === ExtensionKind.TableHasMethodType - ? transformWithTableArgument(context, node) - : transformExpressionList(context, node.arguments); + const args = getBinaryCallExtensionArgs(context, node, extensionKind); + if (!args) { + return lua.createNilLiteral(); + } - const [table, accessExpression] = args; + const [table, key] = transformOrderedExpressions(context, args); // arg0[arg1] - const tableIndexExpression = lua.createTableIndexExpression(table, accessExpression); + const tableIndexExpression = lua.createTableIndexExpression(table, key); // arg0[arg1] ~= nil return lua.createBinaryExpression( @@ -100,37 +90,33 @@ function transformTableSetExpression( node: ts.CallExpression, extensionKind: ExtensionKind ): lua.Expression { - const args = - extensionKind === ExtensionKind.TableSetMethodType - ? transformWithTableArgument(context, node) - : transformExpressionList(context, node.arguments); + const args = getNaryCallExtensionArgs(context, node, extensionKind, 3); + if (!args) { + return lua.createNilLiteral(); + } - const [table, accessExpression, value] = args; + const [table, key, value] = transformOrderedExpressions(context, args); // arg0[arg1] = arg2 context.addPrecedingStatements( - lua.createAssignmentStatement(lua.createTableIndexExpression(table, accessExpression), value, node) + lua.createAssignmentStatement(lua.createTableIndexExpression(table, key), value, node) ); return lua.createNilLiteral(); } -function transformTableAddExpression( +function transformTableAddKeyExpression( context: TransformationContext, node: ts.CallExpression, extensionKind: ExtensionKind ): lua.Expression { - const args = - extensionKind === ExtensionKind.TableAddKeyMethodType - ? transformWithTableArgument(context, node) - : transformExpressionList(context, node.arguments); + const args = getNaryCallExtensionArgs(context, node, extensionKind, 2); + if (!args) { + return lua.createNilLiteral(); + } - const [table, value] = args; + const [table, key] = transformOrderedExpressions(context, args); // arg0[arg1] = true context.addPrecedingStatements( - lua.createAssignmentStatement( - lua.createTableIndexExpression(table, value), - lua.createBooleanLiteral(true), - node - ) + lua.createAssignmentStatement(lua.createTableIndexExpression(table, key), lua.createBooleanLiteral(true), node) ); return lua.createNilLiteral(); } diff --git a/test/unit/__snapshots__/optionalChaining.spec.ts.snap b/test/unit/__snapshots__/optionalChaining.spec.ts.snap index e52fe69b8..707ecf68b 100644 --- a/test/unit/__snapshots__/optionalChaining.spec.ts.snap +++ b/test/unit/__snapshots__/optionalChaining.spec.ts.snap @@ -63,11 +63,13 @@ exports[`Unsupported optional chains Compile members only: diagnostics 1`] = `"m exports[`Unsupported optional chains Language extensions: code 1`] = ` "local ____opt_0 = ({}).has if ____opt_0 ~= nil then - local ____ = nil[3] ~= nil end" `; -exports[`Unsupported optional chains Language extensions: diagnostics 1`] = `"main.ts(2,17): error TSTL: Optional calls are not supported for builtin or language extension functions."`; +exports[`Unsupported optional chains Language extensions: diagnostics 1`] = ` +"main.ts(2,17): error TSTL: Optional calls are not supported for builtin or language extension functions. +main.ts(2,17): error TSTL: This language extension must be called as a method." +`; exports[`long optional chain 1`] = ` "local ____exports = {} diff --git a/test/unit/language-extensions/__snapshots__/operators.spec.ts.snap b/test/unit/language-extensions/__snapshots__/operators.spec.ts.snap index dda0e61e2..a4554a8a1 100644 --- a/test/unit/language-extensions/__snapshots__/operators.spec.ts.snap +++ b/test/unit/language-extensions/__snapshots__/operators.spec.ts.snap @@ -1,5 +1,17 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP +exports[`does not crash on invalid operator use global function: code 1`] = `""`; + +exports[`does not crash on invalid operator use global function: diagnostics 1`] = `"main.ts(3,13): error TS2554: Expected 2 arguments, but got 1."`; + +exports[`does not crash on invalid operator use method: code 1`] = `"left = {}"`; + +exports[`does not crash on invalid operator use method: diagnostics 1`] = `"main.ts(5,18): error TS2554: Expected 1 arguments, but got 0."`; + +exports[`does not crash on invalid operator use unary operator: code 1`] = `"op(_G)"`; + +exports[`does not crash on invalid operator use unary operator: diagnostics 1`] = `"main.ts(2,31): error TS2304: Cannot find name 'LuaUnaryMinus'."`; + exports[`operator mapping - invalid use (const foo = (op as any)(1, 2);): code 1`] = `"foo = op(_G, 1, 2)"`; exports[`operator mapping - invalid use (const foo = (op as any)(1, 2);): diagnostics 1`] = `"main.ts(3,22): error TSTL: This function must be called directly and cannot be referred to."`; diff --git a/test/unit/language-extensions/__snapshots__/table.spec.ts.snap b/test/unit/language-extensions/__snapshots__/table.spec.ts.snap index 9821d9777..5bf4ef2c8 100644 --- a/test/unit/language-extensions/__snapshots__/table.spec.ts.snap +++ b/test/unit/language-extensions/__snapshots__/table.spec.ts.snap @@ -133,3 +133,11 @@ __TS__ArrayMap({\\"a\\", \\"b\\", \\"c\\"}, ____table.has)" `; exports[`LuaTableHas extension invalid use method expression ("LuaTable"): diagnostics 1`] = `"main.ts(3,37): error TSTL: This function must be called directly and cannot be referred to."`; + +exports[`does not crash on invalid extension use global function: code 1`] = `""`; + +exports[`does not crash on invalid extension use global function: diagnostics 1`] = `"main.ts(3,9): error TS2554: Expected 2 arguments, but got 1."`; + +exports[`does not crash on invalid extension use method: code 1`] = `"left = {}"`; + +exports[`does not crash on invalid extension use method: diagnostics 1`] = `"main.ts(5,14): error TS2554: Expected 2 arguments, but got 0."`; diff --git a/test/unit/language-extensions/operators.spec.ts b/test/unit/language-extensions/operators.spec.ts index 043b1cebb..303a5d0be 100644 --- a/test/unit/language-extensions/operators.spec.ts +++ b/test/unit/language-extensions/operators.spec.ts @@ -2,7 +2,11 @@ import * as path from "path"; import * as util from "../../util"; import * as tstl from "../../../src"; import { LuaTarget } from "../../../src"; -import { unsupportedForTarget, invalidCallExtensionUse } from "../../../src/transformation/utils/diagnostics"; +import { + unsupportedForTarget, + invalidCallExtensionUse, + invalidSpreadInCallExtension, +} from "../../../src/transformation/utils/diagnostics"; const operatorsProjectOptions: tstl.CompilerOptions = { luaTarget: LuaTarget.Lua54, @@ -389,3 +393,41 @@ test.each([ .setOptions(operatorsProjectOptions) .expectDiagnosticsToMatchSnapshot([invalidCallExtensionUse.code]); }); + +describe("does not crash on invalid operator use", () => { + test("global function", () => { + util.testModule` + declare const op: LuaAddition; + op(1) + ` + .setOptions(operatorsProjectOptions) + .expectDiagnosticsToMatchSnapshot(); + }); + test("unary operator", () => { + util.testModule` + declare const op: LuaUnaryMinus; + op() + ` + .setOptions(operatorsProjectOptions) + .expectDiagnosticsToMatchSnapshot(); + }); + test("method", () => { + util.testModule` + const left = {} as { + op: LuaAdditionMethod; + } + left.op() + ` + .setOptions(operatorsProjectOptions) + .expectDiagnosticsToMatchSnapshot(); + }); +}); + +test("does not allow spread", () => { + util.testModule` + declare const op: LuaAddition; + op(...[1, 2] as const); + ` + .setOptions(operatorsProjectOptions) + .expectToHaveDiagnostics([invalidSpreadInCallExtension.code]); +}); diff --git a/test/unit/language-extensions/table.spec.ts b/test/unit/language-extensions/table.spec.ts index ed9e8e37c..b833de22b 100644 --- a/test/unit/language-extensions/table.spec.ts +++ b/test/unit/language-extensions/table.spec.ts @@ -467,3 +467,25 @@ test.each([ .withLanguageExtensions() .expectToEqual(expected); }); + +describe("does not crash on invalid extension use", () => { + test("global function", () => { + util.testModule` + declare const op: LuaTableGet<{}, string, any> + op({}) + ` + .withLanguageExtensions() + .expectDiagnosticsToMatchSnapshot(); + }); + + test("method", () => { + util.testModule` + const left = {} as { + op: LuaTableGet<{}, string, any> + } + left.op() + ` + .withLanguageExtensions() + .expectDiagnosticsToMatchSnapshot(); + }); +});