Skip to content

Commit a231205

Browse files
jselboJoshua Selbo
andauthored
Fix StackOverflowError with AbstractList after using mockSingleton (#3790)
* Fix StackOverflowError when mocking after singleton mock is closed Replace WeakHashMap with a WeakIdentityMap that uses System.identityHashCode and == for key comparison. This avoids calling instrumented hashCode()/equals() methods on mock instances during map lookups, which caused infinite recursion when mocking classes like AbstractList whose hashCode() invokes instrumented methods. * Update mockSingleton already registered exception message --------- Co-authored-by: Joshua Selbo <jselbo@meta.com>
1 parent f6a91a6 commit a231205

4 files changed

Lines changed: 307 additions & 2 deletions

File tree

mockito-core/src/main/java/org/mockito/internal/creation/bytebuddy/InlineDelegateByteBuddyMockMaker.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.mockito.internal.creation.instance.ConstructorInstantiator;
1919
import org.mockito.internal.framework.DisabledMockHandler;
2020
import org.mockito.internal.util.Platform;
21+
import org.mockito.internal.util.collections.WeakIdentityMap;
2122
import org.mockito.internal.util.concurrent.DetachedThreadLocal;
2223
import org.mockito.internal.util.concurrent.WeakConcurrentMap;
2324
import org.mockito.invocation.MockHandler;
@@ -695,7 +696,7 @@ public <T> SingletonMockControl<T> createSingletonMock(
695696

696697
Map<Object, MockMethodInterceptor> singletons = mockedSingletons.get();
697698
if (singletons == null) {
698-
singletons = new WeakHashMap<>();
699+
singletons = new WeakIdentityMap<>();
699700
mockedSingletons.set(singletons);
700701
}
701702
mockedSingletons.getBackingMap().expungeStaleEntries();
@@ -941,7 +942,7 @@ public void enable() {
941942
join(
942943
"The singleton instance "
943944
+ instance.getClass().getName()
944-
+ " is already registered as a mock",
945+
+ " is already registered as a mock in the current thread",
945946
"",
946947
"To create a new mock, the existing mock registration must be deregistered"));
947948
}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright (c) 2024 Mockito contributors
3+
* This program is made available under the terms of the MIT License.
4+
*/
5+
package org.mockito.internal.util.collections;
6+
7+
import java.lang.ref.ReferenceQueue;
8+
import java.lang.ref.WeakReference;
9+
import java.util.*;
10+
11+
/**
12+
* A weak-key map that uses identity ({@code ==}) rather than {@code hashCode()}/{@code equals()}
13+
* for key comparison. This avoids calling instrumented methods on mock instances during map
14+
* operations, which would cause infinite recursion. Stale entries whose keys have been
15+
* garbage-collected are cleaned up via a {@link ReferenceQueue}.
16+
*
17+
* <p>Internally, entries are bucketed by {@link System#identityHashCode} for O(1) amortized
18+
* lookups. Hash collisions (distinct objects sharing an identity hash code) are resolved by
19+
* linear scan within the bucket using {@code ==}. Lookups do not allocate.
20+
*/
21+
public class WeakIdentityMap<V> extends AbstractMap<Object, V> {
22+
23+
private final Map<Integer, List<WeakEntry<V>>> delegate = new HashMap<>();
24+
private final ReferenceQueue<Object> queue = new ReferenceQueue<>();
25+
private int size;
26+
27+
private void expungeStaleEntries() {
28+
WeakEntry<?> ref;
29+
while ((ref = (WeakEntry<?>) queue.poll()) != null) {
30+
int hash = ref.hashCode;
31+
List<WeakEntry<V>> bucket = delegate.get(hash);
32+
if (bucket != null) {
33+
if (bucket.remove(ref)) {
34+
size--;
35+
if (bucket.isEmpty()) {
36+
delegate.remove(hash);
37+
}
38+
}
39+
}
40+
}
41+
}
42+
43+
@Override
44+
public V get(Object key) {
45+
expungeStaleEntries();
46+
List<WeakEntry<V>> bucket = delegate.get(System.identityHashCode(key));
47+
if (bucket != null) {
48+
for (WeakEntry<V> entry : bucket) {
49+
if (entry.get() == key) {
50+
return entry.value;
51+
}
52+
}
53+
}
54+
return null;
55+
}
56+
57+
@Override
58+
public boolean containsKey(Object key) {
59+
expungeStaleEntries();
60+
List<WeakEntry<V>> bucket = delegate.get(System.identityHashCode(key));
61+
if (bucket != null) {
62+
for (WeakEntry<V> entry : bucket) {
63+
if (entry.get() == key) {
64+
return true;
65+
}
66+
}
67+
}
68+
return false;
69+
}
70+
71+
@Override
72+
public V put(Object key, V value) {
73+
expungeStaleEntries();
74+
int hash = System.identityHashCode(key);
75+
List<WeakEntry<V>> bucket = delegate.get(hash);
76+
if (bucket != null) {
77+
for (WeakEntry<V> entry : bucket) {
78+
if (entry.get() == key) {
79+
V old = entry.value;
80+
entry.value = value;
81+
return old;
82+
}
83+
}
84+
} else {
85+
bucket = new ArrayList<>(2);
86+
delegate.put(hash, bucket);
87+
}
88+
bucket.add(new WeakEntry<>(key, hash, value, queue));
89+
size++;
90+
return null;
91+
}
92+
93+
@Override
94+
public V remove(Object key) {
95+
expungeStaleEntries();
96+
int hash = System.identityHashCode(key);
97+
List<WeakEntry<V>> bucket = delegate.get(hash);
98+
if (bucket != null) {
99+
Iterator<WeakEntry<V>> it = bucket.iterator();
100+
while (it.hasNext()) {
101+
WeakEntry<V> entry = it.next();
102+
if (entry.get() == key) {
103+
it.remove();
104+
size--;
105+
if (bucket.isEmpty()) {
106+
delegate.remove(hash);
107+
}
108+
return entry.value;
109+
}
110+
}
111+
}
112+
return null;
113+
}
114+
115+
@Override
116+
public int size() {
117+
expungeStaleEntries();
118+
return size;
119+
}
120+
121+
@Override
122+
public Set<Entry<Object, V>> entrySet() {
123+
expungeStaleEntries();
124+
Set<Entry<Object, V>> result = new HashSet<>();
125+
for (List<WeakEntry<V>> bucket : delegate.values()) {
126+
for (WeakEntry<V> entry : bucket) {
127+
Object referent = entry.get();
128+
if (referent != null) {
129+
result.add(new SimpleEntry<>(referent, entry.value));
130+
}
131+
}
132+
}
133+
return result;
134+
}
135+
136+
private static class WeakEntry<V> extends WeakReference<Object> {
137+
138+
final int hashCode;
139+
V value;
140+
141+
WeakEntry(Object key, int hashCode, V value, ReferenceQueue<Object> queue) {
142+
super(key, queue);
143+
this.hashCode = hashCode;
144+
this.value = value;
145+
}
146+
}
147+
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright (c) 2024 Mockito contributors
3+
* This program is made available under the terms of the MIT License.
4+
*/
5+
package org.mockito.internal.util.collections;
6+
7+
import static org.junit.Assert.*;
8+
9+
import org.junit.Test;
10+
11+
public class WeakIdentityMapTest {
12+
13+
WeakIdentityMap<String> map = new WeakIdentityMap<>();
14+
15+
@Test
16+
public void putAndGet() {
17+
Object key = new Object();
18+
map.put(key, "value");
19+
20+
assertEquals("value", map.get(key));
21+
assertTrue(map.containsKey(key));
22+
}
23+
24+
@Test
25+
public void getReturnsNullForAbsentKey() {
26+
assertNull(map.get(new Object()));
27+
}
28+
29+
@Test
30+
public void remove() {
31+
Object key = new Object();
32+
map.put(key, "value");
33+
34+
assertEquals("value", map.remove(key));
35+
assertNull(map.get(key));
36+
assertFalse(map.containsKey(key));
37+
}
38+
39+
@Test
40+
public void removeReturnsNullForAbsentKey() {
41+
assertNull(map.remove(new Object()));
42+
}
43+
44+
@Test
45+
public void putOverwritesPreviousValue() {
46+
Object key = new Object();
47+
map.put(key, "first");
48+
map.put(key, "second");
49+
50+
assertEquals("second", map.get(key));
51+
assertEquals(1, map.size());
52+
}
53+
54+
@Test
55+
public void usesIdentityNotEquals() {
56+
// Two objects that are equals() but not ==
57+
String key1 = new String("same");
58+
String key2 = new String("same");
59+
assertEquals(key1, key2);
60+
assertNotSame(key1, key2);
61+
62+
map.put(key1, "value1");
63+
map.put(key2, "value2");
64+
65+
assertEquals("value1", map.get(key1));
66+
assertEquals("value2", map.get(key2));
67+
assertEquals(2, map.size());
68+
}
69+
70+
@Test
71+
public void doesNotCallHashCodeOnKeys() {
72+
Object key =
73+
new Object() {
74+
@Override
75+
public int hashCode() {
76+
throw new RuntimeException("hashCode should not be called");
77+
}
78+
79+
@Override
80+
public boolean equals(Object obj) {
81+
throw new RuntimeException("equals should not be called");
82+
}
83+
};
84+
85+
map.put(key, "value");
86+
assertEquals("value", map.get(key));
87+
assertTrue(map.containsKey(key));
88+
assertEquals("value", map.remove(key));
89+
}
90+
91+
@Test
92+
public void entrySetReflectsContents() {
93+
Object key1 = new Object();
94+
Object key2 = new Object();
95+
map.put(key1, "a");
96+
map.put(key2, "b");
97+
98+
assertEquals(2, map.entrySet().size());
99+
}
100+
101+
@Test
102+
public void sizeReflectsContents() {
103+
assertEquals(0, map.size());
104+
105+
Object key = new Object();
106+
map.put(key, "value");
107+
assertEquals(1, map.size());
108+
109+
map.remove(key);
110+
assertEquals(0, map.size());
111+
}
112+
113+
@Test
114+
public void weakReferenceBehavior() {
115+
Object key = new Object();
116+
map.put(key, "value");
117+
assertEquals(1, map.size());
118+
119+
// Clear the strong reference and trigger GC
120+
key = null;
121+
forceGc();
122+
123+
// After GC, the entry should be expunged on next access
124+
assertEquals(0, map.size());
125+
}
126+
127+
private static void forceGc() {
128+
for (int i = 0; i < 5; i++) {
129+
System.gc();
130+
try {
131+
Thread.sleep(50);
132+
} catch (InterruptedException e) {
133+
Thread.currentThread().interrupt();
134+
}
135+
}
136+
}
137+
}

mockito-integration-tests/inline-mocks-tests/src/test/java/org/mockitoinline/RecursionTest.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,15 @@
55
package org.mockitoinline;
66

77
import org.junit.Test;
8+
import org.mockito.MockedSingleton;
89

10+
import java.util.ArrayList;
11+
import java.util.List;
912
import java.util.concurrent.ConcurrentHashMap;
1013
import java.util.concurrent.ConcurrentMap;
1114

15+
import static org.mockito.Mockito.mock;
16+
import static org.mockito.Mockito.mockSingleton;
1217
import static org.mockito.Mockito.spy;
1318

1419
public class RecursionTest {
@@ -18,4 +23,19 @@ public void testMockConcurrentHashMap() {
1823
ConcurrentMap<String, String> map = spy(new ConcurrentHashMap<String, String>());
1924
map.putIfAbsent("a", "b");
2025
}
26+
27+
enum MyEnum {
28+
A
29+
}
30+
31+
@Test
32+
public void testSingletonMockAndInstrumentingAbstractList() {
33+
// Initializes mockedSingletons map
34+
try (MockedSingleton<MyEnum> ignored = mockSingleton(MyEnum.A)) {}
35+
// instruments AbstractList whose hashCode() implementation invokes the instrumented method
36+
// iterator()
37+
List<?> listMock = mock(ArrayList.class);
38+
// Verify no StackOverflowError when invoking method on instrumented class
39+
listMock.clear();
40+
}
2141
}

0 commit comments

Comments
 (0)