diff --git a/src/transformation/visitors/access.ts b/src/transformation/visitors/access.ts index f0eced595..12e9385ad 100644 --- a/src/transformation/visitors/access.ts +++ b/src/transformation/visitors/access.ts @@ -3,12 +3,18 @@ import * as lua from "../../LuaAST"; import { transformBuiltinPropertyAccessExpression } from "../builtins"; import { FunctionVisitor, TransformationContext } from "../context"; import { AnnotationKind, getTypeAnnotations } from "../utils/annotations"; -import { invalidMultiReturnAccess, unsupportedOptionalCompileMembersOnly } from "../utils/diagnostics"; +import { + invalidCallExtensionUse, + invalidMultiReturnAccess, + unsupportedOptionalCompileMembersOnly, +} from "../utils/diagnostics"; +import { getExtensionKindForNode } from "../utils/language-extensions"; import { addToNumericExpression } from "../utils/lua-ast"; import { LuaLibFeature, transformLuaLibFunction } from "../utils/lualib"; import { isArrayType, isNumberType, isStringType } from "../utils/typescript"; import { tryGetConstEnumValue } from "./enum"; import { transformOrderedExpressions } from "./expression-list"; +import { callExtensions } from "./language-extensions/call-extension"; import { isMultiReturnCall, returnsMultiType } from "./language-extensions/multi"; import { transformOptionalChainWithCapture, @@ -143,6 +149,18 @@ export function transformPropertyAccessExpressionWithCapture( return { expression: builtinResult }; } + if ( + ts.isIdentifier(node.expression) && + node.parent && + (!ts.isCallExpression(node.parent) || node.parent.expression !== node) + ) { + // Check if this is a method call extension that is not used as a call + const extensionType = getExtensionKindForNode(context, node); + if (extensionType && callExtensions.has(extensionType)) { + context.diagnostics.push(invalidCallExtensionUse(node)); + } + } + const table = context.transformExpression(node.expression); if (thisValueCapture) { diff --git a/src/transformation/visitors/identifier.ts b/src/transformation/visitors/identifier.ts index efb5349bc..78f2bb8c2 100644 --- a/src/transformation/visitors/identifier.ts +++ b/src/transformation/visitors/identifier.ts @@ -31,7 +31,10 @@ function transformNonValueIdentifier( if (extensionKind) { if (callExtensions.has(extensionKind)) { - context.diagnostics.push(invalidCallExtensionUse(identifier)); + // Avoid putting duplicate diagnostic on the name of a variable declaration, due to the inferred type + if (!(ts.isVariableDeclaration(identifier.parent) && identifier.parent.name === identifier)) { + context.diagnostics.push(invalidCallExtensionUse(identifier)); + } // fall through } else if (isIdentifierExtensionValue(symbol, extensionKind)) { reportInvalidExtensionValue(context, identifier, extensionKind); diff --git a/test/setup.ts b/test/setup.ts index de78ad24b..e468c5ab8 100644 --- a/test/setup.ts +++ b/test/setup.ts @@ -37,7 +37,9 @@ expect.extend({ const message = this.isNot ? diagnosticMessages : expected - ? `Expected:\n${expected.join("\n")}\nReceived:\n${diagnostics.map(diag => diag.code).join("\n")}\n` + ? `Expected:\n${expected.join("\n")}\nReceived:\n${diagnostics + .map(diag => diag.code) + .join("\n")}\n\n${diagnosticMessages}\n` : `Received: ${this.utils.printReceived([])}\n`; return matcherHint + "\n\n" + message; diff --git a/test/unit/language-extensions/__snapshots__/table.spec.ts.snap b/test/unit/language-extensions/__snapshots__/table.spec.ts.snap index 831d1ec54..9821d9777 100644 --- a/test/unit/language-extensions/__snapshots__/table.spec.ts.snap +++ b/test/unit/language-extensions/__snapshots__/table.spec.ts.snap @@ -85,3 +85,51 @@ exports[`LuaTableHas extension invalid use ("const foo: unknown = tableHas;"): d exports[`LuaTableHas extension invalid use ("declare function foo(tableHas: LuaTableHas<{}, string>): void; foo(tableHas);"): code 1`] = `"foo(_G, tableHas)"`; exports[`LuaTableHas extension invalid use ("declare function foo(tableHas: LuaTableHas<{}, string>): void; foo(tableHas);"): diagnostics 1`] = `"main.ts(3,80): error TSTL: This function must be called directly and cannot be referred to."`; + +exports[`LuaTableHas extension invalid use method assignment ("LuaMap"): code 1`] = ` +"____table = {} +has = ____table.has" +`; + +exports[`LuaTableHas extension invalid use method assignment ("LuaMap"): diagnostics 1`] = `"main.ts(3,29): error TSTL: This function must be called directly and cannot be referred to."`; + +exports[`LuaTableHas extension invalid use method assignment ("LuaSet"): code 1`] = ` +"____table = {} +has = ____table.has" +`; + +exports[`LuaTableHas extension invalid use method assignment ("LuaSet"): diagnostics 1`] = `"main.ts(3,29): error TSTL: This function must be called directly and cannot be referred to."`; + +exports[`LuaTableHas extension invalid use method assignment ("LuaTable"): code 1`] = ` +"____table = {} +has = ____table.has" +`; + +exports[`LuaTableHas extension invalid use method assignment ("LuaTable"): diagnostics 1`] = `"main.ts(3,29): error TSTL: This function must be called directly and cannot be referred to."`; + +exports[`LuaTableHas extension invalid use method expression ("LuaMap"): code 1`] = ` +"local ____lualib = require(\\"lualib_bundle\\") +local __TS__ArrayMap = ____lualib.__TS__ArrayMap +____table = {} +__TS__ArrayMap({\\"a\\", \\"b\\", \\"c\\"}, ____table.has)" +`; + +exports[`LuaTableHas extension invalid use method expression ("LuaMap"): diagnostics 1`] = `"main.ts(3,37): error TSTL: This function must be called directly and cannot be referred to."`; + +exports[`LuaTableHas extension invalid use method expression ("LuaSet"): code 1`] = ` +"local ____lualib = require(\\"lualib_bundle\\") +local __TS__ArrayMap = ____lualib.__TS__ArrayMap +____table = {} +__TS__ArrayMap({\\"a\\", \\"b\\", \\"c\\"}, ____table.has)" +`; + +exports[`LuaTableHas extension invalid use method expression ("LuaSet"): diagnostics 1`] = `"main.ts(3,37): error TSTL: This function must be called directly and cannot be referred to."`; + +exports[`LuaTableHas extension invalid use method expression ("LuaTable"): code 1`] = ` +"local ____lualib = require(\\"lualib_bundle\\") +local __TS__ArrayMap = ____lualib.__TS__ArrayMap +____table = {} +__TS__ArrayMap({\\"a\\", \\"b\\", \\"c\\"}, ____table.has)" +`; + +exports[`LuaTableHas extension invalid use method expression ("LuaTable"): diagnostics 1`] = `"main.ts(3,37): error TSTL: This function must be called directly and cannot be referred to."`; diff --git a/test/unit/language-extensions/table.spec.ts b/test/unit/language-extensions/table.spec.ts index 8f21dc8d1..ed9e8e37c 100644 --- a/test/unit/language-extensions/table.spec.ts +++ b/test/unit/language-extensions/table.spec.ts @@ -170,6 +170,30 @@ describe("LuaTableHas extension", () => { .withLanguageExtensions() .expectDiagnosticsToMatchSnapshot([invalidCallExtensionUse.code]); }); + + test.each(["LuaTable", "LuaMap", "LuaSet"])( + "invalid use method assignment (%p)", + type => { + util.testModule` + const table = new ${type}(); + const has = table.has; + ` + .withLanguageExtensions() + .expectDiagnosticsToMatchSnapshot([invalidCallExtensionUse.code]); + } + ); + + test.each(["LuaTable", "LuaMap", "LuaSet"])( + "invalid use method expression (%p)", + type => { + util.testModule` + const table = new ${type}(); + ["a", "b", "c"].map(table.has); + ` + .withLanguageExtensions() + .expectDiagnosticsToMatchSnapshot([invalidCallExtensionUse.code]); + } + ); }); describe("LuaTableDelete extension", () => {