From fa48badb1edcbfb27fae61909b43545c0478d0b2 Mon Sep 17 00:00:00 2001 From: liergou <736540362@qq.com> Date: Sat, 4 Apr 2020 02:27:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0socket=20hook=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=20=20=E5=AE=9E=E7=8E=B0socket=E5=B1=82=E6=8B=A6?= =?UTF-8?q?=E6=88=AASSRF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../org/joychou/security/SSRFChecker.java | 15 +- .../org/joychou/security/SecurityUtil.java | 24 ++ .../java/org/joychou/security/SocketHook.java | 25 ++ .../joychou/security/SocketHookFactory.java | 77 +++++ .../org/joychou/security/SocketHookImpl.java | 294 ++++++++++++++++++ .../org/joychou/security/SocketHookUtils.java | 29 ++ 6 files changed, 463 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/joychou/security/SocketHook.java create mode 100644 src/main/java/org/joychou/security/SocketHookFactory.java create mode 100644 src/main/java/org/joychou/security/SocketHookImpl.java create mode 100644 src/main/java/org/joychou/security/SocketHookUtils.java diff --git a/src/main/java/org/joychou/security/SSRFChecker.java b/src/main/java/org/joychou/security/SSRFChecker.java index b678d598..ef393074 100644 --- a/src/main/java/org/joychou/security/SSRFChecker.java +++ b/src/main/java/org/joychou/security/SSRFChecker.java @@ -5,6 +5,8 @@ import java.net.URI; import java.net.URL; import java.util.ArrayList; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.apache.commons.net.util.SubnetUtils; import org.joychou.config.WebConfig; @@ -14,6 +16,7 @@ class SSRFChecker { private static Logger logger = LoggerFactory.getLogger(SSRFChecker.class); + private final static Pattern IP_PATTERN = Pattern.compile("((25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(25[0-5]|2[0-4]\\d|[01]?\\d\\d?)"); static boolean checkURLFckSSRF(String url) { if (null == url){ @@ -122,7 +125,7 @@ static boolean isInnerIPByUrl(String url) { * @param strIP ip字符串 * @return 如果是内网ip,返回true,否则返回false。 */ - private static boolean isInnerIp(String strIP){ + static boolean isInnerIp(String strIP){ ArrayList blackSubnets= WebConfig.getSsrfBlockIps(); @@ -176,4 +179,14 @@ private static String url2host(String url) { } } + + /** + * 匹配ip + * @return + */ + static String getIpFromStr(String ipStr){ + Matcher matcher = IP_PATTERN.matcher(ipStr); + System.out.println(matcher.find()); + return matcher.group(); + } } diff --git a/src/main/java/org/joychou/security/SecurityUtil.java b/src/main/java/org/joychou/security/SecurityUtil.java index 14cc397f..2cbe52ab 100644 --- a/src/main/java/org/joychou/security/SecurityUtil.java +++ b/src/main/java/org/joychou/security/SecurityUtil.java @@ -4,6 +4,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; import java.io.UnsupportedEncodingException; import java.net.URI; import java.net.URLDecoder; @@ -118,6 +119,29 @@ public static boolean checkSSRFWithoutRedirect(String url) { return !SSRFChecker.isInnerIPByUrl(url); } + /** + * @Author liergou + * @Description 基于Socket hook 进行SSRF检测拦截 + * @Date 2:15 2020/4/4 + * @Param [] + * @return void + **/ + public static void startSSRFHook() throws NoSuchFieldException, IOException { + SocketHook.startHook(); + } + + /** + * @Author liergou + * @Description 关闭Socket hook + * @Date 2:15 2020/4/4 + * @Param [] + * @return void + **/ + public static void stopSSRFHook(){ + SocketHook.stopHook(); + } + + /** * Filter file path to prevent path traversal vulns. diff --git a/src/main/java/org/joychou/security/SocketHook.java b/src/main/java/org/joychou/security/SocketHook.java new file mode 100644 index 00000000..359e33f3 --- /dev/null +++ b/src/main/java/org/joychou/security/SocketHook.java @@ -0,0 +1,25 @@ +package org.joychou.security; + +import java.io.IOException; +import java.net.Socket; +import java.net.SocketException; + +/** + * @Author liergou + * @Description Socket hook开关自如 + * @Date 2:12 2020/4/4 + **/ +class SocketHook { + static void startHook() throws NoSuchFieldException, IOException { + SocketHookFactory.initSocket(); + SocketHookFactory.setHook(true); + try{ + Socket.setSocketImplFactory(new SocketHookFactory()); + }catch (SocketException ignored){ + } + } + + static void stopHook(){ + SocketHookFactory.setHook(false); + } +} \ No newline at end of file diff --git a/src/main/java/org/joychou/security/SocketHookFactory.java b/src/main/java/org/joychou/security/SocketHookFactory.java new file mode 100644 index 00000000..2dbe738f --- /dev/null +++ b/src/main/java/org/joychou/security/SocketHookFactory.java @@ -0,0 +1,77 @@ +package org.joychou.security; + + +import java.io.IOException; +import java.lang.reflect.Field; +import java.net.Socket; +import java.net.SocketImpl; +import java.net.SocketImplFactory; +import java.util.logging.Level; +import java.util.logging.Logger; + + +/** + * @Author liergou + * @Description socket factory impl + * @Date 23:41 2020/4/3 + * @Param + * @return + **/ +public class SocketHookFactory implements SocketImplFactory + { + private static SocketImpl clazz; + private static Boolean isHook = false; + + /** + * @Author liergou + * @Description switch hook + * @Date 23:42 2020/4/2 + * @Param [set] + * @return void + **/ + public static void setHook(Boolean set){ + isHook = set; + } + + /** + * @Author liergou + * @Description 初始化 + * @Date 23:42 2020/4/2 + * @Param [] + * @return void + **/ + public static synchronized void initSocket() throws NoSuchFieldException { + if ( clazz != null ) { return; } + + Socket socket = new Socket(); + try{ + Field implField = Socket.class.getDeclaredField("impl"); + implField.setAccessible( true ); + clazz = (SocketImpl) implField.get(socket); + }catch (NoSuchFieldException | IllegalAccessException e){ + throw new RuntimeException("SocketHookFactory init failed!"); + } + + try { + socket.close(); + } + catch ( IOException ignored) + { + + } + } + + public SocketImpl createSocketImpl() { + + if(isHook) { + try { + return new SocketHookImpl(clazz); + } catch (Exception e) { + Logger.getLogger(SocketHookFactory.class.getName()).log(Level.WARNING, "hook 失败 请检查" ); + return clazz; + } + }else{ + return clazz; + } + } + } diff --git a/src/main/java/org/joychou/security/SocketHookImpl.java b/src/main/java/org/joychou/security/SocketHookImpl.java new file mode 100644 index 00000000..ec41ebac --- /dev/null +++ b/src/main/java/org/joychou/security/SocketHookImpl.java @@ -0,0 +1,294 @@ +package org.joychou.security; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.*; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * @Author liergou + * @Description socket impl + * @Date 23:39 2020/4/2 + * @Param + * @return + **/ +public class SocketHookImpl extends SocketImpl implements SocketOptions +{ + + private SocketImpl socketImpl = null; + private Method createImpl; + private Method connectHostImpl; + private Method connectInetAddressImpl; + private Method connectSocketAddressIMPL; + private Method bindImpl; + private Method listenImpl; + private Method acceptImpl; + private Method getInputStreamImpl; + private Method getOutputStreamImpl; + private Method availableImpl; + private Method closeImpl; + private Method shutdownInputImpl; + private Method shutdownOutputImpl; + private Method sendUrgentDataImpl; + + + /** + * @Author liergou + * @Description 初始化反射方法 + * @Date 23:40 2020/4/2 + * @Param [initSocketImpl] + * @return + **/ + public SocketHookImpl(SocketImpl initSocketImpl) { + + if ( initSocketImpl == null){ + throw new RuntimeException(""); + //TODO close hook + } + + this.socketImpl = initSocketImpl; + final Class clazz = this.socketImpl.getClass(); + Method[] allMethod = clazz.getDeclaredMethods(); + createImpl = SocketHookUtils.findMethod( clazz,"create", new Class[]{ boolean.class } ); + connectHostImpl = SocketHookUtils.findMethod( clazz, "connect", new Class[]{ String.class, int.class } ); + connectInetAddressImpl = SocketHookUtils.findMethod( clazz, "connect", new Class[]{ InetAddress.class, int.class } ); + connectSocketAddressIMPL = SocketHookUtils.findMethod( clazz, "connect", new Class[]{ SocketAddress.class, int.class } ); + bindImpl = SocketHookUtils.findMethod( clazz, "bind", new Class[]{ InetAddress.class, int.class } ); + listenImpl = SocketHookUtils.findMethod( clazz, "listen", new Class[]{ int.class } ); + acceptImpl = SocketHookUtils.findMethod( clazz, "accept", new Class[]{ SocketImpl.class } ); + getInputStreamImpl = SocketHookUtils.findMethod( clazz, "getInputStream", new Class[]{ } ); + getOutputStreamImpl = SocketHookUtils.findMethod( clazz, "getOutputStream", new Class[]{ } ); + availableImpl = SocketHookUtils.findMethod( clazz, "available", new Class[]{ } ); + closeImpl = SocketHookUtils.findMethod( clazz, "close", new Class[]{ } ); + shutdownInputImpl = SocketHookUtils.findMethod( clazz, "shutdownInput", new Class[]{ } ); + shutdownOutputImpl = SocketHookUtils.findMethod( clazz, "shutdownOutput", new Class[]{ } ); + sendUrgentDataImpl = SocketHookUtils.findMethod( clazz, "sendUrgantData", new Class[]{ int.class } ); + } + + + /** + * socket base method impl + */ + @Override + protected void create(boolean stream) throws IOException { + try + { + this.createImpl.invoke( this.socketImpl, stream); + } + catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + + } + + @Override + protected void connect(String host, int port) throws IOException { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.INFO, "host=" + host + ",port=" + port ); + + try + { + this.connectHostImpl.invoke( this.socketImpl, host, port); + } + catch (IllegalAccessException | InvocationTargetException | IllegalArgumentException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + + } + + + @Override + protected void connect(InetAddress address, int port) throws IOException { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.INFO, "InetAddress=" + address.toString()); + + //start check SSRF + if(SSRFChecker.isInnerIp(SSRFChecker.getIpFromStr(address.toString()))){ + throw new RuntimeException("Socket SSRF check failed. InetAddress:"+address.toString()); + } + try + { + this.connectInetAddressImpl.invoke( this.socketImpl, address, port); + } + catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + } + + @Override + protected void connect(SocketAddress address, int timeout) throws IOException { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.INFO, "SocketAddress=" + address.toString()); + //start check SSRF + if(SSRFChecker.isInnerIp(SSRFChecker.getIpFromStr(address.toString()))){ + throw new RuntimeException("Socket SSRF check failed. SocketAddress:"+address.toString()); + } + + try + { + this.connectSocketAddressIMPL.invoke( this.socketImpl, address, timeout); + } + catch (IllegalAccessException | InvocationTargetException | IllegalArgumentException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + } + + @Override + protected void bind(InetAddress host, int port) throws IOException { + try + { + this.bindImpl.invoke( this.socketImpl, host, port); + } + catch (IllegalAccessException | InvocationTargetException | IllegalArgumentException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + } + + @Override + protected void listen(int backlog) throws IOException { + + try + { + this.listenImpl.invoke( this.socketImpl, backlog); + } + catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + } + + @Override + protected void accept(SocketImpl s) throws IOException { + + try + { + this.acceptImpl.invoke( this.socketImpl, s); + } + catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + } + + @Override + protected InputStream getInputStream() throws IOException { + InputStream inStream = null; + + try + { + inStream = (InputStream)this.getInputStreamImpl.invoke( this.socketImpl); + } + catch ( ClassCastException | InvocationTargetException | IllegalArgumentException | IllegalAccessException ex ) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + + return inStream; + } + + @Override + protected OutputStream getOutputStream() throws IOException { + OutputStream outStream = null; + + try + { + outStream = (OutputStream)this.getOutputStreamImpl.invoke( this.socketImpl); + } + catch ( ClassCastException | IllegalArgumentException | IllegalAccessException | InvocationTargetException ex ) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + + return outStream; + } + + @Override + protected int available() throws IOException { + int result = -1; + + try + { + result = (Integer)this.availableImpl.invoke( this.socketImpl); + } + catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + + return result; + } + + @Override + protected void close() throws IOException { + try + { + this.closeImpl.invoke( this.socketImpl); + } + catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + } + + @Override + protected void shutdownInput() throws IOException { + try + { + this.shutdownInputImpl.invoke( this.socketImpl); + } + catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + + } + + @Override + protected void shutdownOutput() throws IOException { + try + { + this.shutdownOutputImpl.invoke( this.socketImpl); + } + catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + + } + + @Override + protected void sendUrgentData(int data) throws IOException { + try + { + this.sendUrgentDataImpl.invoke( this.socketImpl, data); + } + catch (IllegalAccessException | InvocationTargetException | IllegalArgumentException ex) + { + Logger.getLogger(SocketHookImpl.class.getName()).log(Level.SEVERE, null, ex); + } + } + + public void setOption(int optID, Object value) throws SocketException { + if ( null != this.socketImpl ) + { + this.socketImpl.setOption( optID, value ); + } + } + + public Object getOption(int optID) throws SocketException { + return this.socketImpl.getOption( optID ); + } + + /** + * dont impl other child method now + * dont sure where will use it + **/ + + +} diff --git a/src/main/java/org/joychou/security/SocketHookUtils.java b/src/main/java/org/joychou/security/SocketHookUtils.java new file mode 100644 index 00000000..75d48707 --- /dev/null +++ b/src/main/java/org/joychou/security/SocketHookUtils.java @@ -0,0 +1,29 @@ +package org.joychou.security; + +import java.lang.reflect.Method; + +public class SocketHookUtils { + + /** + * @Author liergou + * @Description 轮询父类查找反射方法 + * @Date 1:43 2020/4/4 + * @Param [inputClazz, findName, args] + * @return java.lang.reflect.Method + **/ + public static Method findMethod(Class inputClazz, String findName ,Class[] args){ + Class temp=inputClazz; + Method tmpMethod = null; + while(temp!=null){ + try{ + tmpMethod = temp.getDeclaredMethod(findName,args); + tmpMethod.setAccessible(true); + return tmpMethod; + }catch (NoSuchMethodException e){ + temp=temp.getSuperclass(); + } + } + return null; + } + +}