Skip to content

Commit 2d14503

Browse files
committed
feat: Add Rust trait inheritance and impl block extraction with method receiver type support
Addresses Rust's impl block syntax where trait implementations (`impl Trait for Type`) and trait supertraits (`trait Sub: Super`) create inheritance relationships. Adds getReceiverType to extract method receiver types from impl blocks, enabling proper method-to-struct relationships and qualified name resolution. Verified against Deno codebase and moved from "Needs Verification" to completed language support.
1 parent ce7b768 commit 2d14503

4 files changed

Lines changed: 243 additions & 7 deletions

File tree

__tests__/extraction.test.ts

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,78 @@ pub trait Repository {
650650
expect(traitNode).toBeDefined();
651651
expect(traitNode?.name).toBe('Repository');
652652
});
653+
654+
it('should extract impl Trait for Type as implements edges', () => {
655+
const code = `
656+
pub struct MyCache {}
657+
658+
pub trait Cache {
659+
fn get(&self, key: &str) -> Option<String>;
660+
}
661+
662+
impl Cache for MyCache {
663+
fn get(&self, key: &str) -> Option<String> {
664+
None
665+
}
666+
}
667+
`;
668+
const result = extractFromSource('cache.rs', code);
669+
670+
// Should have an unresolved reference for implements
671+
const implRef = result.unresolvedReferences.find(
672+
(r) => r.referenceKind === 'implements' && r.referenceName === 'Cache'
673+
);
674+
expect(implRef).toBeDefined();
675+
676+
// The struct MyCache should be the source
677+
const myCacheNode = result.nodes.find((n) => n.name === 'MyCache' && n.kind === 'struct');
678+
expect(myCacheNode).toBeDefined();
679+
expect(implRef?.fromNodeId).toBe(myCacheNode?.id);
680+
});
681+
682+
it('should extract trait supertraits as extends references', () => {
683+
const code = `
684+
pub trait Display {}
685+
686+
pub trait Error: Display {
687+
fn description(&self) -> &str;
688+
}
689+
`;
690+
const result = extractFromSource('error.rs', code);
691+
692+
const extendsRef = result.unresolvedReferences.find(
693+
(r) => r.referenceKind === 'extends' && r.referenceName === 'Display'
694+
);
695+
expect(extendsRef).toBeDefined();
696+
697+
const errorTrait = result.nodes.find((n) => n.name === 'Error' && n.kind === 'trait');
698+
expect(errorTrait).toBeDefined();
699+
expect(extendsRef?.fromNodeId).toBe(errorTrait?.id);
700+
});
701+
702+
it('should not create implements edges for plain impl blocks', () => {
703+
const code = `
704+
pub struct Counter {
705+
count: u32,
706+
}
707+
708+
impl Counter {
709+
pub fn new() -> Counter {
710+
Counter { count: 0 }
711+
}
712+
pub fn increment(&mut self) {
713+
self.count += 1;
714+
}
715+
}
716+
`;
717+
const result = extractFromSource('counter.rs', code);
718+
719+
// Should have no implements references (no trait involved)
720+
const implRefs = result.unresolvedReferences.filter(
721+
(r) => r.referenceKind === 'implements'
722+
);
723+
expect(implRefs).toHaveLength(0);
724+
});
653725
});
654726

655727
describe('Java Extraction', () => {

docs/SEARCH_QUALITY_LOOP.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,12 +523,12 @@ if (receiverType) {
523523
- [x] **Swift** — NOT needed. Tree-sitter nests methods inside class/extension bodies
524524
- [x] **Java** — NOT needed. Methods nested in class body. Verified against Guava
525525
- [x] **Python** — NOT needed. Methods nested in class body. Verified against Flask
526+
- [x] **Rust**`getReceiverType` walks up to parent `impl_item` to extract type name. Also adds `contains` edges from struct to impl methods. Verified against Deno
526527

527528
### Needs Verification
528529

529530
Check these — may need `getReceiverType` if methods are top-level in the AST:
530531

531-
- [ ] Rust — methods in `impl Type { }` blocks
532532
- [ ] C++ — out-of-class method definitions `Type::method()`
533533
- [ ] Kotlin — extension functions `fun Type.method()`
534534

src/extraction/languages/rust.ts

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,43 @@ export const rustExtractor: LanguageExtractor = {
4545
}
4646
return 'private'; // Rust defaults to private
4747
},
48+
getReceiverType: (node, source) => {
49+
// Walk up the tree-sitter AST to find a parent impl_item
50+
let parent = node.parent;
51+
while (parent) {
52+
if (parent.type === 'impl_item') {
53+
// For `impl Type { ... }` — the type is a direct type_identifier child
54+
// For `impl Trait for Type { ... }` — the type is the LAST type_identifier
55+
// (the first is part of the trait path)
56+
const children = parent.namedChildren;
57+
// Find all direct type_identifier children (not nested in scoped paths)
58+
const typeIdents = children.filter(
59+
(c: SyntaxNode) => c.type === 'type_identifier'
60+
);
61+
if (typeIdents.length > 0) {
62+
// Last type_identifier is always the implementing type
63+
const typeNode = typeIdents[typeIdents.length - 1]!;
64+
return source.substring(typeNode.startIndex, typeNode.endIndex);
65+
}
66+
// Handle generic types: impl<T> MyStruct<T> { ... }
67+
const genericType = children.find(
68+
(c: SyntaxNode) => c.type === 'generic_type'
69+
);
70+
if (genericType) {
71+
const innerType = genericType.namedChildren.find(
72+
(c: SyntaxNode) => c.type === 'type_identifier'
73+
);
74+
if (innerType) {
75+
return source.substring(innerType.startIndex, innerType.endIndex);
76+
}
77+
}
78+
return undefined;
79+
}
80+
parent = parent.parent;
81+
}
82+
return undefined;
83+
},
84+
4885
extractImport: (node, source) => {
4986
const importText = source.substring(node.startIndex, node.endIndex).trim();
5087

src/extraction/tree-sitter.ts

Lines changed: 133 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,10 @@ export class TreeSitterExtractor {
303303
else if (this.extractor.callTypes.includes(nodeType)) {
304304
this.extractCall(node);
305305
}
306+
// Rust: `impl Trait for Type { ... }` — creates implements edge from Type to Trait
307+
else if (nodeType === 'impl_item') {
308+
this.extractRustImplItem(node);
309+
}
306310

307311
// Visit children (unless the extract method already visited them)
308312
if (!skipChildren) {
@@ -406,6 +410,13 @@ export class TreeSitterExtractor {
406410
private extractFunction(node: SyntaxNode): void {
407411
if (!this.extractor) return;
408412

413+
// If the language provides getReceiverType and this function has a receiver
414+
// (e.g., Rust function_item inside an impl block), extract as method instead
415+
if (this.extractor.getReceiverType?.(node, this.source)) {
416+
this.extractMethod(node);
417+
return;
418+
}
419+
409420
let name = extractName(node, this.source, this.extractor);
410421
// For arrow functions and function expressions assigned to variables,
411422
// resolve the name from the parent variable_declarator.
@@ -498,10 +509,15 @@ export class TreeSitterExtractor {
498509
private extractMethod(node: SyntaxNode): void {
499510
if (!this.extractor) return;
500511

512+
// For languages with receiver types (Go, Rust), include receiver in qualified name
513+
// so FTS can match "scrapeLoop.run" → qualified_name "...::scrapeLoop::run"
514+
const receiverType = this.extractor.getReceiverType?.(node, this.source);
515+
501516
// For most languages, only extract as method if inside a class-like node
502517
// Languages with methodsAreTopLevel (e.g. Go) always treat them as methods
503-
if (!this.isInsideClassLikeNode() && !this.extractor.methodsAreTopLevel) {
504-
// Not inside a class-like node and not Go, treat as function
518+
// Languages with getReceiverType (e.g. Rust) extract as method when receiver is found
519+
if (!this.isInsideClassLikeNode() && !this.extractor.methodsAreTopLevel && !receiverType) {
520+
// Not inside a class-like node and no receiver type, treat as function
505521
this.extractFunction(node);
506522
return;
507523
}
@@ -512,10 +528,6 @@ export class TreeSitterExtractor {
512528
const visibility = this.extractor.getVisibility?.(node);
513529
const isAsync = this.extractor.isAsync?.(node);
514530
const isStatic = this.extractor.isStatic?.(node);
515-
516-
// For languages with receiver types (Go), include receiver in qualified name
517-
// so FTS can match "scrapeLoop.run" → qualified_name "...::scrapeLoop::run"
518-
const receiverType = this.extractor.getReceiverType?.(node, this.source);
519531
const extraProps: Partial<Node> = {
520532
docstring,
521533
signature,
@@ -530,6 +542,24 @@ export class TreeSitterExtractor {
530542
const methodNode = this.createNode('method', name, node, extraProps);
531543
if (!methodNode) return;
532544

545+
// For methods with a receiver type but no class-like parent on the stack
546+
// (e.g., Rust impl blocks), add a contains edge from the owning struct/trait
547+
if (receiverType && !this.isInsideClassLikeNode()) {
548+
const ownerNode = this.nodes.find(
549+
(n) =>
550+
n.name === receiverType &&
551+
n.filePath === this.filePath &&
552+
(n.kind === 'struct' || n.kind === 'class' || n.kind === 'enum' || n.kind === 'trait')
553+
);
554+
if (ownerNode) {
555+
this.edges.push({
556+
source: ownerNode.id,
557+
target: methodNode.id,
558+
kind: 'contains',
559+
});
560+
}
561+
}
562+
533563
// Extract type annotations (parameter types and return type)
534564
this.extractTypeAnnotations(node, methodNode.id);
535565

@@ -1311,6 +1341,40 @@ export class TreeSitterExtractor {
13111341
}
13121342
}
13131343

1344+
// Rust trait supertraits: `trait SubTrait: SuperTrait + Display { ... }`
1345+
// trait_bounds contains type_identifier, generic_type, or higher_ranked_trait_bound children
1346+
if (child.type === 'trait_bounds') {
1347+
for (const bound of child.namedChildren) {
1348+
let typeName: string | undefined;
1349+
let posNode: SyntaxNode | undefined;
1350+
1351+
if (bound.type === 'type_identifier') {
1352+
typeName = getNodeText(bound, this.source);
1353+
posNode = bound;
1354+
} else if (bound.type === 'generic_type') {
1355+
// e.g. `Deserialize<'de>`
1356+
const inner = bound.namedChildren.find((c: SyntaxNode) => c.type === 'type_identifier');
1357+
if (inner) { typeName = getNodeText(inner, this.source); posNode = inner; }
1358+
} else if (bound.type === 'higher_ranked_trait_bound') {
1359+
// e.g. `for<'de> Deserialize<'de>`
1360+
const generic = bound.namedChildren.find((c: SyntaxNode) => c.type === 'generic_type');
1361+
const typeId = generic?.namedChildren.find((c: SyntaxNode) => c.type === 'type_identifier')
1362+
?? bound.namedChildren.find((c: SyntaxNode) => c.type === 'type_identifier');
1363+
if (typeId) { typeName = getNodeText(typeId, this.source); posNode = typeId; }
1364+
}
1365+
1366+
if (typeName && posNode) {
1367+
this.unresolvedReferences.push({
1368+
fromNodeId: classId,
1369+
referenceName: typeName,
1370+
referenceKind: 'extends',
1371+
line: posNode.startPosition.row + 1,
1372+
column: posNode.startPosition.column,
1373+
});
1374+
}
1375+
}
1376+
}
1377+
13141378
// Swift: inheritance_specifier > user_type > type_identifier
13151379
// Used for class inheritance, protocol conformance, and protocol inheritance
13161380
if (child.type === 'inheritance_specifier') {
@@ -1336,6 +1400,69 @@ export class TreeSitterExtractor {
13361400
}
13371401
}
13381402

1403+
/**
1404+
* Rust `impl Trait for Type` — creates an implements edge from Type to Trait.
1405+
* For plain `impl Type { ... }` (no trait), no inheritance edge is needed.
1406+
*/
1407+
private extractRustImplItem(node: SyntaxNode): void {
1408+
// Check if this is `impl Trait for Type` by looking for a `for` keyword
1409+
const hasFor = node.children.some(
1410+
(c: SyntaxNode) => c.type === 'for' && !c.isNamed
1411+
);
1412+
if (!hasFor) return;
1413+
1414+
// In `impl Trait for Type`, the type_identifiers are:
1415+
// first = Trait name, last = implementing Type name
1416+
// Also handle generic types like `impl<T> Trait for MyStruct<T>`
1417+
const typeIdents = node.namedChildren.filter(
1418+
(c: SyntaxNode) => c.type === 'type_identifier' || c.type === 'generic_type' || c.type === 'scoped_type_identifier'
1419+
);
1420+
if (typeIdents.length < 2) return;
1421+
1422+
const traitNode = typeIdents[0]!;
1423+
const typeNode = typeIdents[typeIdents.length - 1]!;
1424+
1425+
// Get the trait name (handle scoped paths like std::fmt::Display)
1426+
const traitName = traitNode.type === 'scoped_type_identifier'
1427+
? this.source.substring(traitNode.startIndex, traitNode.endIndex)
1428+
: getNodeText(traitNode, this.source);
1429+
1430+
// Get the implementing type name (extract inner type_identifier for generics)
1431+
let typeName: string;
1432+
if (typeNode.type === 'generic_type') {
1433+
const inner = typeNode.namedChildren.find(
1434+
(c: SyntaxNode) => c.type === 'type_identifier'
1435+
);
1436+
typeName = inner ? getNodeText(inner, this.source) : getNodeText(typeNode, this.source);
1437+
} else {
1438+
typeName = getNodeText(typeNode, this.source);
1439+
}
1440+
1441+
// Find the struct/type node for the implementing type
1442+
const typeNodeId = this.findNodeByName(typeName);
1443+
if (typeNodeId) {
1444+
this.unresolvedReferences.push({
1445+
fromNodeId: typeNodeId,
1446+
referenceName: traitName,
1447+
referenceKind: 'implements',
1448+
line: traitNode.startPosition.row + 1,
1449+
column: traitNode.startPosition.column,
1450+
});
1451+
}
1452+
}
1453+
1454+
/**
1455+
* Find a previously-extracted node by name (used for back-references like impl blocks)
1456+
*/
1457+
private findNodeByName(name: string): string | undefined {
1458+
for (const node of this.nodes) {
1459+
if (node.name === name && (node.kind === 'struct' || node.kind === 'enum' || node.kind === 'class')) {
1460+
return node.id;
1461+
}
1462+
}
1463+
return undefined;
1464+
}
1465+
13391466
/**
13401467
* Languages that support type annotations (TypeScript, etc.)
13411468
*/

0 commit comments

Comments
 (0)