Skip to content

Commit 2446fa6

Browse files
authored
🎨 binarywang#1592 实现简单的redis分布式锁 RedisTemplateSimpleDistributedLock
1 parent 0a2e4d8 commit 2446fa6

File tree

4 files changed

+213
-7
lines changed

4 files changed

+213
-7
lines changed

weixin-java-common/src/main/java/me/chanjar/weixin/common/redis/RedisTemplateWxRedisOps.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package me.chanjar.weixin.common.redis;
22

3+
import lombok.NonNull;
34
import lombok.RequiredArgsConstructor;
5+
import me.chanjar.weixin.common.util.locks.RedisTemplateSimpleDistributedLock;
46
import org.springframework.data.redis.core.StringRedisTemplate;
57

68
import java.util.concurrent.TimeUnit;
79
import java.util.concurrent.locks.Lock;
8-
import java.util.concurrent.locks.ReentrantLock;
910

1011
@RequiredArgsConstructor
1112
public class RedisTemplateWxRedisOps implements WxRedisOps {
@@ -37,7 +38,7 @@ public void expire(String key, int expire, TimeUnit timeUnit) {
3738
}
3839

3940
@Override
40-
public Lock getLock(String key) {
41-
return new ReentrantLock();
41+
public Lock getLock(@NonNull String key) {
42+
return new RedisTemplateSimpleDistributedLock(redisTemplate, key, 60 * 1000);
4243
}
4344
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
package me.chanjar.weixin.common.util.locks;
2+
3+
import lombok.Getter;
4+
import lombok.NonNull;
5+
import org.jetbrains.annotations.NotNull;
6+
import org.springframework.dao.DataAccessException;
7+
import org.springframework.data.redis.connection.RedisConnection;
8+
import org.springframework.data.redis.connection.RedisStringCommands;
9+
import org.springframework.data.redis.core.RedisCallback;
10+
import org.springframework.data.redis.core.StringRedisTemplate;
11+
import org.springframework.data.redis.core.script.DefaultRedisScript;
12+
import org.springframework.data.redis.core.script.RedisScript;
13+
import org.springframework.data.redis.core.types.Expiration;
14+
15+
import java.nio.charset.StandardCharsets;
16+
import java.util.Arrays;
17+
import java.util.List;
18+
import java.util.UUID;
19+
import java.util.concurrent.TimeUnit;
20+
import java.util.concurrent.locks.Condition;
21+
import java.util.concurrent.locks.Lock;
22+
23+
/**
24+
* 实现简单的redis分布式锁, 支持重入, 不是红锁
25+
*
26+
* @see <a href="https://redis.io/topics/distlock">reids distlock</a>
27+
*/
28+
public class RedisTemplateSimpleDistributedLock implements Lock {
29+
30+
@Getter
31+
private final StringRedisTemplate redisTemplate;
32+
@Getter
33+
private final String key;
34+
@Getter
35+
private final int leaseMilliseconds;
36+
37+
private final ThreadLocal<String> valueThreadLocal = new ThreadLocal<>();
38+
39+
public RedisTemplateSimpleDistributedLock(@NonNull StringRedisTemplate redisTemplate, int leaseMilliseconds) {
40+
this(redisTemplate, "lock:" + UUID.randomUUID().toString(), leaseMilliseconds);
41+
}
42+
43+
public RedisTemplateSimpleDistributedLock(@NonNull StringRedisTemplate redisTemplate, @NonNull String key, int leaseMilliseconds) {
44+
if (leaseMilliseconds <= 0) {
45+
throw new IllegalArgumentException("Parameter 'leaseMilliseconds' must grate then 0: " + leaseMilliseconds);
46+
}
47+
this.redisTemplate = redisTemplate;
48+
this.key = key;
49+
this.leaseMilliseconds = leaseMilliseconds;
50+
}
51+
52+
@Override
53+
public void lock() {
54+
while (!tryLock()) {
55+
try {
56+
Thread.sleep(1000);
57+
} catch (InterruptedException e) {
58+
// Ignore
59+
}
60+
}
61+
}
62+
63+
@Override
64+
public void lockInterruptibly() throws InterruptedException {
65+
while (!tryLock()) {
66+
Thread.sleep(1000);
67+
}
68+
}
69+
70+
@Override
71+
public boolean tryLock() {
72+
String value = valueThreadLocal.get();
73+
if (value == null || value.length() == 0) {
74+
value = UUID.randomUUID().toString();
75+
valueThreadLocal.set(value);
76+
}
77+
final byte[] keyBytes = key.getBytes(StandardCharsets.UTF_8);
78+
final byte[] valueBytes = value.getBytes(StandardCharsets.UTF_8);
79+
List<Object> redisResults = redisTemplate.executePipelined(new RedisCallback<String>() {
80+
@Override
81+
public String doInRedis(RedisConnection connection) throws DataAccessException {
82+
connection.set(keyBytes, valueBytes, Expiration.milliseconds(leaseMilliseconds), RedisStringCommands.SetOption.SET_IF_ABSENT);
83+
connection.get(keyBytes);
84+
return null;
85+
}
86+
});
87+
Object currentLockSecret = redisResults.size() > 1 ? redisResults.get(1) : redisResults.get(0);
88+
return currentLockSecret != null && currentLockSecret.toString().equals(value);
89+
}
90+
91+
@Override
92+
public boolean tryLock(long time, @NotNull TimeUnit unit) throws InterruptedException {
93+
long waitMs = unit.toMillis(time);
94+
boolean locked = tryLock();
95+
while (!locked && waitMs > 0) {
96+
long sleep = waitMs < 1000 ? waitMs : 1000;
97+
Thread.sleep(sleep);
98+
waitMs -= sleep;
99+
locked = tryLock();
100+
}
101+
return locked;
102+
}
103+
104+
@Override
105+
public void unlock() {
106+
if (valueThreadLocal.get() != null) {
107+
// 提示: 必须指定returnType, 类型: 此处必须为Long, 不能是Integer
108+
RedisScript<Long> script = new DefaultRedisScript("if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end", Long.class);
109+
redisTemplate.execute(script, Arrays.asList(key), valueThreadLocal.get());
110+
valueThreadLocal.remove();
111+
}
112+
}
113+
114+
@Override
115+
public Condition newCondition() {
116+
throw new UnsupportedOperationException();
117+
}
118+
119+
/**
120+
* 获取当前锁的值
121+
* return 返回null意味着没有加锁, 但是返回非null值并不以为着当前加锁成功(redis中key可能自动过期)
122+
*/
123+
public String getLockSecretValue() {
124+
return valueThreadLocal.get();
125+
}
126+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package me.chanjar.weixin.common.util.locks;
2+
3+
import lombok.SneakyThrows;
4+
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
5+
import org.springframework.data.redis.core.StringRedisTemplate;
6+
import org.testng.annotations.BeforeTest;
7+
import org.testng.annotations.Test;
8+
9+
import java.util.concurrent.CountDownLatch;
10+
import java.util.concurrent.TimeUnit;
11+
import java.util.concurrent.atomic.AtomicInteger;
12+
13+
import static org.testng.Assert.*;
14+
15+
@Test(enabled = false)
16+
public class RedisTemplateSimpleDistributedLockTest {
17+
18+
RedisTemplateSimpleDistributedLock redisLock;
19+
20+
StringRedisTemplate redisTemplate;
21+
22+
AtomicInteger lockCurrentExecuteCounter;
23+
24+
@BeforeTest
25+
public void init() {
26+
JedisConnectionFactory connectionFactory = new JedisConnectionFactory();
27+
connectionFactory.setHostName("127.0.0.1");
28+
connectionFactory.setPort(6379);
29+
connectionFactory.afterPropertiesSet();
30+
StringRedisTemplate redisTemplate = new StringRedisTemplate(connectionFactory);
31+
this.redisTemplate = redisTemplate;
32+
this.redisLock = new RedisTemplateSimpleDistributedLock(redisTemplate, 60000);
33+
this.lockCurrentExecuteCounter = new AtomicInteger(0);
34+
}
35+
36+
@Test(description = "多线程测试锁排他性")
37+
public void testLockExclusive() throws InterruptedException {
38+
int threadSize = 100;
39+
final CountDownLatch startLatch = new CountDownLatch(threadSize);
40+
final CountDownLatch endLatch = new CountDownLatch(threadSize);
41+
42+
for (int i = 0; i < threadSize; i++) {
43+
new Thread(new Runnable() {
44+
@SneakyThrows
45+
@Override
46+
public void run() {
47+
startLatch.await();
48+
49+
redisLock.lock();
50+
assertEquals(lockCurrentExecuteCounter.incrementAndGet(), 1, "临界区同时只能有一个线程执行");
51+
lockCurrentExecuteCounter.decrementAndGet();
52+
redisLock.unlock();
53+
54+
endLatch.countDown();
55+
}
56+
}).start();
57+
startLatch.countDown();
58+
}
59+
endLatch.await();
60+
}
61+
62+
@Test
63+
public void testTryLock() throws InterruptedException {
64+
assertTrue(redisLock.tryLock(3, TimeUnit.SECONDS), "第一次加锁应该成功");
65+
assertNotNull(redisLock.getLockSecretValue());
66+
String redisValue = this.redisTemplate.opsForValue().get(redisLock.getKey());
67+
assertEquals(redisValue, redisLock.getLockSecretValue());
68+
69+
redisLock.unlock();
70+
assertNull(redisLock.getLockSecretValue());
71+
redisValue = this.redisTemplate.opsForValue().get(redisLock.getKey());
72+
assertNull(redisValue, "释放锁后key会被删除");
73+
74+
redisLock.unlock();
75+
}
76+
77+
78+
}
79+

weixin-java-pay/src/test/java/com/github/binarywang/wxpay/bean/result/WxPaySendRedpackResultTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package com.github.binarywang.wxpay.bean.result;
22

3-
import org.testng.*;
4-
import org.testng.annotations.*;
5-
63
import com.thoughtworks.xstream.XStream;
74
import me.chanjar.weixin.common.util.xml.XStreamInitializer;
5+
import org.testng.Assert;
6+
import org.testng.annotations.BeforeTest;
7+
import org.testng.annotations.Test;
88

99
/**
1010
* The type Wx pay send redpack result test.
@@ -68,6 +68,6 @@ public void loadFailureResult() {
6868
Assert.assertEquals("FAIL", wxMpRedpackResult.getReturnCode());
6969
Assert.assertEquals("FAIL", wxMpRedpackResult.getResultCode());
7070
Assert.assertEquals("onqOjjmM1tad-3ROpncN-yUfa6uI", wxMpRedpackResult.getReOpenid());
71-
Assert.assertEquals(1, wxMpRedpackResult.getTotalAmount());
71+
Assert.assertEquals(Integer.valueOf(1), wxMpRedpackResult.getTotalAmount());
7272
}
7373
}

0 commit comments

Comments
 (0)