diff --git a/go.mod b/go.mod index 9e3eb8d8..d8f28af0 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a github.com/sagernet/nftables v0.3.0-beta.4 - github.com/sagernet/sing v0.7.0-beta.1 + github.com/sagernet/sing v0.7.6 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/net v0.26.0 diff --git a/go.sum b/go.sum index fb428267..5a1be49b 100644 --- a/go.sum +++ b/go.sum @@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= -github.com/sagernet/sing v0.7.0-beta.1 h1:2D44KzgeDZwD/R4Ts8jwSUHTRR238a1FpXDrl7l4tVw= -github.com/sagernet/sing v0.7.0-beta.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.7.6 h1:6LBfDH+aI/26J3r9UHlaxTNjJeMhBpU/wrk0JKDZYI4= +github.com/sagernet/sing v0.7.6/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= diff --git a/internal/gtcpip/header/ipv4.go b/internal/gtcpip/header/ipv4.go index d76db68e..72066ff1 100644 --- a/internal/gtcpip/header/ipv4.go +++ b/internal/gtcpip/header/ipv4.go @@ -458,7 +458,10 @@ func (b IPv4) SetDestinationAddress(addr tcpip.Address) { // CalculateChecksum calculates the checksum of the IPv4 header. func (b IPv4) CalculateChecksum() uint16 { - return checksum.Checksum(b[:b.HeaderLength()], 0) + // return checksum.Checksum(b[:b.HeaderLength()], 0) + xsum0 := checksum.Checksum(b[:xsum], 0) + xsum0 = checksum.Checksum(b[xsum+2:b.HeaderLength()], xsum0) + return xsum0 } // Encode encodes all the fields of the IPv4 header. @@ -550,7 +553,8 @@ func (b IPv4) IsChecksumValid() bool { // same set of octets, including the checksum field. If the result // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. - return b.CalculateChecksum() == 0xffff + //return b.CalculateChecksum() == 0xffff + return checksum.Checksum(b[:b.HeaderLength()], 0) == 0xffff } // IsV4MulticastAddress determines if the provided address is an IPv4 multicast diff --git a/internal/gtcpip/header/tcp.go b/internal/gtcpip/header/tcp.go index 58552538..1b58df86 100644 --- a/internal/gtcpip/header/tcp.go +++ b/internal/gtcpip/header/tcp.go @@ -351,14 +351,18 @@ func (b TCP) SetUrgentPointer(urgentPointer uint16) { // and the checksum of the segment data. func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 { // Calculate the rest of the checksum. - return checksum.Checksum(b[:b.DataOffset()], partialChecksum) + // return checksum.Checksum(b[:b.DataOffset()], partialChecksum) + xsum := checksum.Checksum(b[:TCPChecksumOffset], partialChecksum) + xsum = checksum.Checksum(b[TCPChecksumOffset+2:b.DataOffset()], xsum) + return xsum } // IsChecksumValid returns true iff the TCP header's checksum is valid. func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool { xsum := PseudoHeaderChecksum(TCPProtocolNumber, src.AsSlice(), dst.AsSlice(), uint16(b.DataOffset())+payloadLength) xsum = checksum.Combine(xsum, payloadChecksum) - return b.CalculateChecksum(xsum) == 0xffff + // return b.CalculateChecksum(xsum) == 0xffff + return checksum.Checksum(b[:b.DataOffset()], xsum) == 0xffff } // Options returns a slice that holds the unparsed TCP options in the segment. diff --git a/internal/gtcpip/header/udp.go b/internal/gtcpip/header/udp.go index 080a97fd..a995a172 100644 --- a/internal/gtcpip/header/udp.go +++ b/internal/gtcpip/header/udp.go @@ -113,15 +113,18 @@ func (b UDP) SetLength(length uint16) { // CalculateChecksum calculates the checksum of the UDP packet, given the // checksum of the network-layer pseudo-header and the checksum of the payload. func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { - // Calculate the rest of the checksum. - return checksum.Checksum(b[:UDPMinimumSize], partialChecksum) + // Calculate the rest of the checksum.\ + // return checksum.Checksum(b[:UDPMinimumSize], partialChecksum) + xsum := checksum.Checksum(b[:udpChecksum], partialChecksum) + xsum = checksum.Checksum(b[udpChecksum+2:UDPMinimumSize], xsum) + return xsum } // IsChecksumValid returns true iff the UDP header's checksum is valid. func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool { xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst.AsSlice(), src.AsSlice(), b.Length()) xsum = checksum.Combine(xsum, payloadChecksum) - return b.CalculateChecksum(xsum) == 0xffff + return checksum.Checksum(b[:UDPMinimumSize], xsum) == 0xffff } // Encode encodes all the fields of the UDP header. diff --git a/internal/wintun/wintun_windows.go b/internal/wintun/wintun_windows.go index a817e6c5..288d364b 100644 --- a/internal/wintun/wintun_windows.go +++ b/internal/wintun/wintun_windows.go @@ -39,6 +39,10 @@ func closeAdapter(wintun *Adapter) { // deterministically. If it is set to nil, the GUID is chosen by the system at random, // and hence a new NLA entry is created for each new adapter. func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) { + err = procWintunCloseAdapter.Find() + if err != nil { + return + } var name16 *uint16 name16, err = windows.UTF16PtrFromString(name) if err != nil { diff --git a/monitor_shared.go b/monitor_shared.go index 12e3e21b..3595d856 100644 --- a/monitor_shared.go +++ b/monitor_shared.go @@ -5,9 +5,9 @@ package tun import ( "errors" "sync" + "sync/atomic" "time" - "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/control" "github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/x/list" diff --git a/redirect_linux.go b/redirect_linux.go index 5441bc10..2eb0a455 100644 --- a/redirect_linux.go +++ b/redirect_linux.go @@ -20,27 +20,29 @@ import ( ) type autoRedirect struct { - tunOptions *Options - ctx context.Context - handler N.TCPConnectionHandlerEx - logger logger.Logger - tableName string - networkMonitor NetworkUpdateMonitor - networkListener *list.Element[NetworkUpdateCallback] - interfaceFinder control.InterfaceFinder - localAddresses []netip.Prefix - customRedirectPortFunc func() int - customRedirectPort int - redirectServer *redirectServer - enableIPv4 bool - enableIPv6 bool - iptablesPath string - ip6tablesPath string - useNFTables bool - androidSu bool - suPath string - routeAddressSet *[]*netipx.IPSet - routeExcludeAddressSet *[]*netipx.IPSet + tunOptions *Options + ctx context.Context + handler N.TCPConnectionHandlerEx + logger logger.Logger + tableName string + networkMonitor NetworkUpdateMonitor + networkListener *list.Element[NetworkUpdateCallback] + interfaceFinder control.InterfaceFinder + localAddresses []netip.Prefix + customRedirectPortFunc func() int + customRedirectPort int + redirectServer *redirectServer + enableIPv4 bool + enableIPv6 bool + iptablesPath string + ip6tablesPath string + useNFTables bool + androidSu bool + suPath string + routeAddressSet *[]*netipx.IPSet + routeExcludeAddressSet *[]*netipx.IPSet + redirectRouteTableIndex int + redirectInterfaces []control.Interface } func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { @@ -69,6 +71,7 @@ func (r *autoRedirect) Start() error { r.androidSu = true for _, suPath := range []string{ "su", + "/product/bin/su", "/system/bin/su", } { r.suPath, err = exec.LookPath(suPath) @@ -133,6 +136,12 @@ func (r *autoRedirect) Start() error { if r.useNFTables { r.cleanupNFTables() err = r.setupNFTables() + if err == nil && r.tunOptions.AutoRedirectMarkMode { + err = r.setupRedirectRoutes() + if err != nil { + r.cleanupNFTables() + } + } } else { r.cleanupIPTables() err = r.setupIPTables() @@ -142,6 +151,7 @@ func (r *autoRedirect) Start() error { func (r *autoRedirect) Close() error { if r.useNFTables { + r.cleanupRedirectRoutes() r.cleanupNFTables() } else { r.cleanupIPTables() diff --git a/redirect_nftables.go b/redirect_nftables.go index ec4bf5d8..abf6e449 100644 --- a/redirect_nftables.go +++ b/redirect_nftables.go @@ -4,6 +4,7 @@ package tun import ( "net/netip" + "strings" "github.com/sagernet/nftables" "github.com/sagernet/nftables/binaryutil" @@ -143,12 +144,26 @@ func (r *autoRedirect) setupNFTables() error { } } chainPreRoutingUDP := nft.AddChain(&nftables.Chain{ - Name: "prerouting_udp", + Name: "prerouting_udp_icmp", Table: table, Hooknum: nftables.ChainHookPrerouting, Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 2), Type: nftables.ChainTypeFilter, }) + ipProto := &nftables.Set{ + Table: table, + Anonymous: true, + Constant: true, + KeyType: nftables.TypeInetProto, + } + err = nft.AddSet(ipProto, []nftables.SetElement{ + {Key: []byte{unix.IPPROTO_UDP}}, + {Key: []byte{unix.IPPROTO_ICMP}}, + {Key: []byte{unix.IPPROTO_ICMPV6}}, + }) + if err != nil { + return err + } nft.AddRule(&nftables.Rule{ Table: table, Chain: chainPreRoutingUDP, @@ -157,11 +172,31 @@ func (r *autoRedirect) setupNFTables() error { Key: expr.MetaKeyL4PROTO, Register: 1, }, + &expr.Lookup{ + SourceRegister: 1, + SetID: ipProto.ID, + SetName: ipProto.Name, + Invert: true, + }, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chainPreRoutingUDP, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, &expr.Cmp{ - Op: expr.CmpOpNeq, + Op: expr.CmpOpEq, Register: 1, - Data: []byte{unix.IPPROTO_UDP}, + Data: nftablesIfname(r.tunOptions.Name), }, + &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictReturn, }, @@ -248,12 +283,22 @@ func (r *autoRedirect) setupNFTables() error { if err != nil { r.logger.Error("update local address set: ", err) } + if r.tunOptions.AutoRedirectMarkMode { + err = r.updateRedirectRoutes() + if err != nil { + r.logger.Error("update redirect routes: ", err) + } + } }) return nil } // TODO; test is this works func (r *autoRedirect) nftablesUpdateLocalAddressSet() error { + err := r.interfaceFinder.Update() + if err != nil { + return err + } newLocalAddresses := common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix { return common.Filter(it.Addresses, func(prefix netip.Prefix) bool { return it.Name == "lo" || prefix.Addr().IsGlobalUnicast() @@ -262,6 +307,11 @@ func (r *autoRedirect) nftablesUpdateLocalAddressSet() error { if slices.Equal(newLocalAddresses, r.localAddresses) { return nil } + if r.logger != nil { + r.logger.Debug("updating local address set to [", strings.Join(common.Map(newLocalAddresses, func(it netip.Prefix) string { + return it.String() + }), ", ")+"]") + } nft, err := nftables.New() if err != nil { return err diff --git a/redirect_nftables_rules.go b/redirect_nftables_rules.go index ba4ee872..ff0a2994 100644 --- a/redirect_nftables_rules.go +++ b/redirect_nftables_rules.go @@ -74,12 +74,11 @@ func (r *autoRedirect) nftablesCreateLocalAddressSets( localAddresses4 := common.Filter(localAddresses, func(it netip.Prefix) bool { return it.Addr().Is4() }) - updateAddresses4 := common.Filter(localAddresses, func(it netip.Prefix) bool { - return it.Addr().Is4() - }) var update bool if len(lastAddresses) != 0 { - if !slices.Equal(localAddresses4, updateAddresses4) { + if !slices.Equal(localAddresses4, common.Filter(lastAddresses, func(it netip.Prefix) bool { + return it.Addr().Is4() + })) { update = true } } @@ -94,19 +93,14 @@ func (r *autoRedirect) nftablesCreateLocalAddressSets( localAddresses6 := common.Filter(localAddresses, func(it netip.Prefix) bool { return it.Addr().Is6() }) - updateAddresses6 := common.Filter(localAddresses, func(it netip.Prefix) bool { - return it.Addr().Is6() - }) var update bool if len(lastAddresses) != 0 { - if !slices.Equal(localAddresses6, updateAddresses6) { + if !slices.Equal(localAddresses6, common.Filter(lastAddresses, func(it netip.Prefix) bool { + return it.Addr().Is6() + })) { update = true } } - localAddresses6 = common.Filter(localAddresses6, func(it netip.Prefix) bool { - address := it.Addr() - return address.IsLoopback() || address.IsGlobalUnicast() && !address.IsPrivate() - }) if len(lastAddresses) == 0 || update { _, err := nftablesCreateIPSet(nft, table, 6, "inet6_local_address_set", nftables.TableFamilyIPv6, nil, localAddresses6, false, update) if err != nil { @@ -388,7 +382,7 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft &expr.Cmp{ Op: expr.CmpOpNeq, Register: 1, - Data: binaryutil.BigEndian.PutUint32(r.tunOptions.IncludeUID[0].Start), + Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.IncludeUID[0].Start), }, &expr.Counter{}, &expr.Verdict{ diff --git a/redirect_route_linux.go b/redirect_route_linux.go new file mode 100644 index 00000000..6b86da12 --- /dev/null +++ b/redirect_route_linux.go @@ -0,0 +1,179 @@ +//go:build linux + +package tun + +import ( + "math/rand" + "net" + "net/netip" + + "github.com/sagernet/netlink" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/control" + + "golang.org/x/sys/unix" +) + +const redirectRouteRulePriority = 1 + +func (r *autoRedirect) setupRedirectRoutes() error { + for { + r.redirectRouteTableIndex = int(rand.Uint32()) + if r.redirectRouteTableIndex == r.tunOptions.IPRoute2TableIndex { + continue + } + routeList, fErr := netlink.RouteListFiltered(netlink.FAMILY_ALL, + &netlink.Route{Table: r.redirectRouteTableIndex}, + netlink.RT_FILTER_TABLE) + if len(routeList) == 0 || fErr != nil { + break + } + } + err := r.interfaceFinder.Update() + if err != nil { + return err + } + tunName := r.tunOptions.Name + r.redirectInterfaces = common.Filter(r.interfaceFinder.Interfaces(), func(it control.Interface) bool { + return it.Name != "lo" && it.Name != tunName && it.Flags&net.FlagUp != 0 + }) + r.cleanupRedirectRoutes() + for _, iface := range r.redirectInterfaces { + err = r.addRedirectRoutes(iface) + if err != nil { + return err + } + } + if r.enableIPv4 { + rule := netlink.NewRule() + rule.Priority = redirectRouteRulePriority + rule.Table = r.redirectRouteTableIndex + rule.Family = unix.AF_INET + err = netlink.RuleAdd(rule) + if err != nil { + return err + } + } + if r.enableIPv6 { + rule := netlink.NewRule() + rule.Priority = redirectRouteRulePriority + rule.Table = r.redirectRouteTableIndex + rule.Family = unix.AF_INET6 + err = netlink.RuleAdd(rule) + if err != nil { + return err + } + } + return nil +} + +func (r *autoRedirect) addRedirectRoutes(iface control.Interface) error { + if r.enableIPv4 && common.Any(iface.Addresses, func(it netip.Prefix) bool { + return it.Addr().Is4() + }) { + err := netlink.RouteAppend(&netlink.Route{ + LinkIndex: iface.Index, + Dst: &net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(32, 32)}, + Table: r.redirectRouteTableIndex, + Type: unix.RTN_LOCAL, + Scope: netlink.SCOPE_HOST, + }) + if err != nil { + return err + } + } + if r.enableIPv6 && common.Any(iface.Addresses, func(it netip.Prefix) bool { + return it.Addr().Is6() && !it.Addr().Is4In6() + }) { + err := netlink.RouteAppend(&netlink.Route{ + LinkIndex: iface.Index, + Dst: &net.IPNet{IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, + Table: r.redirectRouteTableIndex, + Type: unix.RTN_LOCAL, + Scope: netlink.SCOPE_HOST, + }) + if err != nil { + return err + } + } + return nil +} + +func (r *autoRedirect) removeRedirectRoutes(linkIndex int) { + if r.enableIPv4 { + _ = netlink.RouteDel(&netlink.Route{ + LinkIndex: linkIndex, + Dst: &net.IPNet{IP: net.IPv4(127, 0, 0, 1), Mask: net.CIDRMask(32, 32)}, + Table: r.redirectRouteTableIndex, + Type: unix.RTN_LOCAL, + }) + } + if r.enableIPv6 { + _ = netlink.RouteDel(&netlink.Route{ + LinkIndex: linkIndex, + Dst: &net.IPNet{IP: net.IPv6loopback, Mask: net.CIDRMask(128, 128)}, + Table: r.redirectRouteTableIndex, + Type: unix.RTN_LOCAL, + }) + } +} + +func (r *autoRedirect) updateRedirectRoutes() error { + err := r.interfaceFinder.Update() + if err != nil { + return err + } + tunName := r.tunOptions.Name + newInterfaces := common.Filter(r.interfaceFinder.Interfaces(), func(it control.Interface) bool { + return it.Name != "lo" && it.Name != tunName && it.Flags&net.FlagUp != 0 + }) + oldMap := make(map[int]bool, len(r.redirectInterfaces)) + for _, iface := range r.redirectInterfaces { + oldMap[iface.Index] = true + } + newMap := make(map[int]bool, len(newInterfaces)) + for _, iface := range newInterfaces { + newMap[iface.Index] = true + } + for _, iface := range newInterfaces { + if !oldMap[iface.Index] { + err = r.addRedirectRoutes(iface) + if err != nil { + return err + } + } + } + for _, iface := range r.redirectInterfaces { + if !newMap[iface.Index] { + r.removeRedirectRoutes(iface.Index) + } + } + r.redirectInterfaces = newInterfaces + return nil +} + +func (r *autoRedirect) cleanupRedirectRoutes() { + if r.redirectRouteTableIndex == 0 { + return + } + routes, _ := netlink.RouteListFiltered(netlink.FAMILY_ALL, + &netlink.Route{Table: r.redirectRouteTableIndex}, + netlink.RT_FILTER_TABLE) + for _, route := range routes { + _ = netlink.RouteDel(&route) + } + if r.enableIPv4 { + rule := netlink.NewRule() + rule.Priority = redirectRouteRulePriority + rule.Table = r.redirectRouteTableIndex + rule.Family = unix.AF_INET + _ = netlink.RuleDel(rule) + } + if r.enableIPv6 { + rule := netlink.NewRule() + rule.Priority = redirectRouteRulePriority + rule.Table = r.redirectRouteTableIndex + rule.Family = unix.AF_INET6 + _ = netlink.RuleDel(rule) + } +} diff --git a/redirect_server.go b/redirect_server.go index 86abfd8c..7590b35d 100644 --- a/redirect_server.go +++ b/redirect_server.go @@ -7,9 +7,9 @@ import ( "errors" "net" "net/netip" + "sync/atomic" "time" - "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go index aad97cf6..167cafba 100644 --- a/stack_gvisor_tcp.go +++ b/stack_gvisor_tcp.go @@ -52,7 +52,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress) tcpHdr := header.TCP(pkt.TransportHeader().Slice()) tcpHdr.SetChecksum(0) - tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( + tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()), ))) f.tun.WritePacket(pkt) @@ -66,7 +66,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac ipHdr.SetSourceAddress(inet6LoopbackAddress) tcpHdr := header.TCP(pkt.TransportHeader().Slice()) tcpHdr.SetChecksum(0) - tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( + tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()), ))) f.tun.WritePacket(pkt) diff --git a/stack_mixed.go b/stack_mixed.go index a48639d4..bc1f08e1 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -3,6 +3,9 @@ package tun import ( + "errors" + "syscall" + "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" gHdr "github.com/sagernet/gvisor/pkg/tcpip/header" @@ -169,7 +172,7 @@ func (m *Mixed) batchLoopDarwin(darwinTUN DarwinTUN) { for { buffers, err := darwinTUN.BatchRead() if err != nil { - if E.IsClosed(err) { + if E.IsClosed(err) || errors.Is(err, syscall.EBADF) { return } m.logger.Error(E.Cause(err, "batch read packet")) diff --git a/stack_system.go b/stack_system.go index 825c5f26..b2f67877 100644 --- a/stack_system.go +++ b/stack_system.go @@ -267,7 +267,7 @@ func (s *System) batchLoopDarwin(darwinTUN DarwinTUN) { for { buffers, err := darwinTUN.BatchRead() if err != nil { - if E.IsClosed(err) { + if E.IsClosed(err) || errors.Is(err, syscall.EBADF) { return } s.logger.Error(E.Cause(err, "batch read packet")) @@ -422,14 +422,12 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err } } if !s.txChecksumOffload { - tcpHdr.SetChecksum(0) tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), ))) } else { tcpHdr.SetChecksum(0) } - ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) return true, nil } @@ -470,7 +468,6 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro if !s.txChecksumOffload { tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize))) } - ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version) @@ -520,7 +517,6 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err } } if !s.txChecksumOffload { - tcpHdr.SetChecksum(0) tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), ))) @@ -651,8 +647,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error sourceAddress := ipHdr.SourceAddr() ipHdr.SetSourceAddr(ipHdr.DestinationAddr()) ipHdr.SetDestinationAddr(sourceAddress) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0))) - ipHdr.SetChecksum(0) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) return nil } @@ -686,7 +681,7 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e icmpHdr := header.ICMPv4(newIPHdr.Payload()) icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(code) - icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0))) + icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0)) copy(icmpHdr.Payload(), payload) if PacketOffset > 0 { newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET @@ -779,14 +774,12 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize)) if !w.txChecksumOffload { - udpHdr.SetChecksum(0) udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), ))) } else { udpHdr.SetChecksum(0) } - ipHdr.SetChecksum(0) ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) if PacketOffset > 0 { PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version) @@ -820,7 +813,6 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S udpHdr.SetSourcePort(destination.Port) udpHdr.SetLength(udpLen) if !w.txChecksumOffload { - udpHdr.SetChecksum(0) udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum( header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()), ))) diff --git a/stack_system_nat.go b/stack_system_nat.go index 1d0216ed..66240a61 100644 --- a/stack_system_nat.go +++ b/stack_system_nat.go @@ -11,6 +11,7 @@ import ( ) type TCPNat struct { + timeout time.Duration portIndex uint16 portAccess sync.RWMutex addrAccess sync.RWMutex @@ -19,6 +20,7 @@ type TCPNat struct { } type TCPSession struct { + sync.Mutex Source netip.AddrPort Destination netip.AddrPort LastActive time.Time @@ -26,38 +28,41 @@ type TCPSession struct { func NewNat(ctx context.Context, timeout time.Duration) *TCPNat { natMap := &TCPNat{ + timeout: timeout, portIndex: 10000, addrMap: make(map[netip.AddrPort]uint16), portMap: make(map[uint16]*TCPSession), } - go natMap.loopCheckTimeout(ctx, timeout) + go natMap.loopCheckTimeout(ctx) return natMap } -func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) { - ticker := time.NewTicker(timeout) +func (n *TCPNat) loopCheckTimeout(ctx context.Context) { + ticker := time.NewTicker(n.timeout) defer ticker.Stop() for { select { case <-ticker.C: - n.checkTimeout(timeout) + n.checkTimeout() case <-ctx.Done(): return } } } -func (n *TCPNat) checkTimeout(timeout time.Duration) { +func (n *TCPNat) checkTimeout() { now := time.Now() n.portAccess.Lock() defer n.portAccess.Unlock() n.addrAccess.Lock() defer n.addrAccess.Unlock() for natPort, session := range n.portMap { - if now.Sub(session.LastActive) > timeout { + session.Lock() + if now.Sub(session.LastActive) > n.timeout { delete(n.addrMap, session.Source) delete(n.portMap, natPort) } + session.Unlock() } } @@ -66,7 +71,11 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession { session := n.portMap[port] n.portAccess.RUnlock() if session != nil { - session.LastActive = time.Now() + session.Lock() + if time.Since(session.LastActive) > time.Second { + session.LastActive = time.Now() + } + session.Unlock() } return session } diff --git a/tun.go b/tun.go index 92eab64a..7fb3e8db 100644 --- a/tun.go +++ b/tun.go @@ -52,44 +52,46 @@ type DarwinTUN interface { } const ( - DefaultIPRoute2TableIndex = 2022 - DefaultIPRoute2RuleIndex = 9000 + DefaultIPRoute2TableIndex = 2022 + DefaultIPRoute2RuleIndex = 9000 + DefaultIPRoute2AutoRedirectFallbackRuleIndex = 32768 ) type Options struct { - Name string - Inet4Address []netip.Prefix - Inet6Address []netip.Prefix - MTU uint32 - GSO bool - AutoRoute bool - InterfaceScope bool - Inet4Gateway netip.Addr - Inet6Gateway netip.Addr - DNSServers []netip.Addr - IPRoute2TableIndex int - IPRoute2RuleIndex int - AutoRedirectMarkMode bool - AutoRedirectInputMark uint32 - AutoRedirectOutputMark uint32 - Inet4LoopbackAddress []netip.Addr - Inet6LoopbackAddress []netip.Addr - StrictRoute bool - Inet4RouteAddress []netip.Prefix - Inet6RouteAddress []netip.Prefix - Inet4RouteExcludeAddress []netip.Prefix - Inet6RouteExcludeAddress []netip.Prefix - IncludeInterface []string - ExcludeInterface []string - IncludeUID []ranges.Range[uint32] - ExcludeUID []ranges.Range[uint32] - IncludeAndroidUser []int - IncludePackage []string - ExcludePackage []string - InterfaceFinder control.InterfaceFinder - InterfaceMonitor DefaultInterfaceMonitor - FileDescriptor int - Logger logger.Logger + Name string + Inet4Address []netip.Prefix + Inet6Address []netip.Prefix + MTU uint32 + GSO bool + AutoRoute bool + InterfaceScope bool + Inet4Gateway netip.Addr + Inet6Gateway netip.Addr + DNSServers []netip.Addr + IPRoute2TableIndex int + IPRoute2RuleIndex int + IPRoute2AutoRedirectFallbackRuleIndex int + AutoRedirectMarkMode bool + AutoRedirectInputMark uint32 + AutoRedirectOutputMark uint32 + Inet4LoopbackAddress []netip.Addr + Inet6LoopbackAddress []netip.Addr + StrictRoute bool + Inet4RouteAddress []netip.Prefix + Inet6RouteAddress []netip.Prefix + Inet4RouteExcludeAddress []netip.Prefix + Inet6RouteExcludeAddress []netip.Prefix + IncludeInterface []string + ExcludeInterface []string + IncludeUID []ranges.Range[uint32] + ExcludeUID []ranges.Range[uint32] + IncludeAndroidUser []int + IncludePackage []string + ExcludePackage []string + InterfaceFinder control.InterfaceFinder + InterfaceMonitor DefaultInterfaceMonitor + FileDescriptor int + Logger logger.Logger // No work for TCP, do not use. _TXChecksumOffload bool diff --git a/tun_darwin.go b/tun_darwin.go index 45efa291..dd499767 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -152,7 +152,10 @@ func (t *NativeTun) Start() error { func (t *NativeTun) Close() error { defer flushDNSCache() - return E.Errors(t.unsetRoutes(), t.tunFile.Close()) + t.stopFd.Stop() + err := E.Errors(t.unsetRoutes(), t.tunFile.Close()) + t.stopFd.Close() + return err } func (t *NativeTun) Read(p []byte) (n int, err error) { @@ -347,6 +350,9 @@ func (t *NativeTun) BatchRead() ([]*buf.Buffer, error) { t.buffers = t.buffers[:0] return nil, errno } + if n < 0 { + return nil, os.ErrClosed + } if n < 1 { return nil, nil } diff --git a/tun_linux.go b/tun_linux.go index 6d7dfed9..cf1d4eda 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -17,7 +17,6 @@ import ( "github.com/sagernet/sing-tun/internal/gtcpip/checksum" "github.com/sagernet/sing-tun/internal/gtcpip/header" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/rw" @@ -40,6 +39,7 @@ type NativeTun struct { writeAccess sync.Mutex vnetHdr bool writeBuffer []byte + vnetHdrWriteBuf []byte gsoToWrite []int tcpGROTable *tcpGROTable udpGroAccess sync.Mutex @@ -130,7 +130,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { for _, address := range t.options.Inet4Address { addr4, _ := netlink.ParseAddr(address.String()) err = netlink.AddrAdd(tunLink, addr4) - if err != nil { + if err != nil && !errors.Is(err, unix.EEXIST) { return err } } @@ -139,7 +139,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { for _, address := range t.options.Inet6Address { addr6, _ := netlink.ParseAddr(address.String()) err = netlink.AddrAdd(tunLink, addr6) - if err != nil { + if err != nil && !errors.Is(err, unix.EEXIST) { return err } } @@ -148,7 +148,9 @@ func (t *NativeTun) configure(tunLink netlink.Link) error { if t.options.GSO { err = t.enableGSO() if err != nil { - t.options.Logger.Warn(err) + if t.options.Logger != nil { + t.options.Logger.Warn(err) + } } } @@ -273,7 +275,9 @@ func (t *NativeTun) Start() error { if err != nil { t.gro.disableTCPGRO() t.gro.disableUDPGRO() - t.options.Logger.Warn(E.Cause(err, "disabled TUN TCP & UDP GRO due to GRO probe error")) + if t.options.Logger != nil { + t.options.Logger.Warn(E.Cause(err, "disabled TUN TCP & UDP GRO due to GRO probe error")) + } } } @@ -315,6 +319,8 @@ func (t *NativeTun) Close() error { if t.interfaceCallback != nil { t.options.InterfaceMonitor.UnregisterCallback(t.interfaceCallback) } + t.unsetSearchDomainForSystemdResolved() + t.unsetAddresses() return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile))) } @@ -382,10 +388,7 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e func (t *NativeTun) Write(p []byte) (n int, err error) { if t.vnetHdr { - buffer := buf.Get(virtioNetHdrLen + len(p)) - copy(buffer[virtioNetHdrLen:], p) - _, err = t.BatchWrite([][]byte{buffer}, virtioNetHdrLen) - buf.Put(buffer) + _, err = t.BatchWrite([][]byte{p}, virtioNetHdrLen) if err != nil { return } @@ -616,6 +619,22 @@ func (t *NativeTun) rules() []*netlink.Rule { it.Family = unix.AF_INET6 rules = append(rules, it) } + // Fallback rules after system default rules (32766: main, 32767: default) + // Only reached when main and default tables have no route + if p4 { + it = netlink.NewRule() + it.Priority = t.options.IPRoute2AutoRedirectFallbackRuleIndex + it.Table = t.options.IPRoute2TableIndex + it.Family = unix.AF_INET + rules = append(rules, it) + } + if p6 { + it = netlink.NewRule() + it.Priority = t.options.IPRoute2AutoRedirectFallbackRuleIndex + it.Table = t.options.IPRoute2TableIndex + it.Family = unix.AF_INET6 + rules = append(rules, it) + } return rules } @@ -816,14 +835,6 @@ func (t *NativeTun) rules() []*netlink.Rule { it.Family = unix.AF_INET rules = append(rules, it) } - if p4 && !t.options.StrictRoute { - it = netlink.NewRule() - it.Priority = priority - it.IPProto = syscall.IPPROTO_ICMP - it.Goto = nopPriority - it.Family = unix.AF_INET - rules = append(rules, it) - } if p6 { it = netlink.NewRule() it.Priority = priority6 @@ -834,16 +845,6 @@ func (t *NativeTun) rules() []*netlink.Rule { it.Family = unix.AF_INET6 rules = append(rules, it) } - - if p6 && !t.options.StrictRoute { - it = netlink.NewRule() - it.Priority = priority6 - it.IPProto = syscall.IPPROTO_ICMPV6 - it.Goto = nopPriority - it.Family = unix.AF_INET6 - rules = append(rules, it) - priority6++ - } } if p4 { it = netlink.NewRule() @@ -1007,7 +1008,7 @@ func (t *NativeTun) unsetRules() error { for _, rule := range ruleList { ruleStart := t.options.IPRoute2RuleIndex ruleEnd := ruleStart + 10 - if rule.Priority >= ruleStart && rule.Priority <= ruleEnd { + if rule.Priority >= ruleStart && rule.Priority <= ruleEnd || (t.options.AutoRedirectMarkMode && rule.Priority == t.options.IPRoute2AutoRedirectFallbackRuleIndex) { ruleToDel := netlink.NewRule() ruleToDel.Family = rule.Family ruleToDel.Priority = rule.Priority @@ -1021,6 +1022,24 @@ func (t *NativeTun) unsetRules() error { return nil } +func (t *NativeTun) unsetAddresses() { + if t.options.FileDescriptor > 0 { + return + } + tunLink, err := netlink.LinkByName(t.options.Name) + if err != nil { + return + } + for _, address := range t.options.Inet4Address { + addr, _ := netlink.ParseAddr(address.String()) + _ = netlink.AddrDel(tunLink, addr) + } + for _, address := range t.options.Inet6Address { + addr, _ := netlink.ParseAddr(address.String()) + _ = netlink.AddrDel(tunLink, addr) + } +} + func (t *NativeTun) resetRules() error { t.unsetRules() return t.setRules() @@ -1064,3 +1083,14 @@ func (t *NativeTun) setSearchDomainForSystemdResolved() { _ = shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsServer, netip.Addr.String)...)...).Run() }() } + +func (t *NativeTun) unsetSearchDomainForSystemdResolved() { + if t.options.EXP_DisableDNSHijack { + return + } + ctlPath, err := exec.LookPath("resolvectl") + if err != nil { + return + } + _ = shell.Exec(ctlPath, "revert", t.options.Name).Run() +} diff --git a/tun_linux_gvisor.go b/tun_linux_gvisor.go index cb0561b6..8adf4c5d 100644 --- a/tun_linux_gvisor.go +++ b/tun_linux_gvisor.go @@ -3,6 +3,8 @@ package tun import ( + "fmt" + "github.com/sagernet/gvisor/pkg/rawfile" "github.com/sagernet/gvisor/pkg/tcpip/link/fdbased" "github.com/sagernet/gvisor/pkg/tcpip/stack" @@ -18,6 +20,37 @@ var _ GVisorTun = (*NativeTun)(nil) func (t *NativeTun) WritePacket(pkt *stack.PacketBuffer) (int, error) { iovecs := t.iovecsOutputDefault + if t.vnetHdr { + if t.vnetHdrWriteBuf == nil { + t.vnetHdrWriteBuf = make([]byte, virtioNetHdrLen) + } + vnetHdr := virtioNetHdr{} + if pkt.GSOOptions.Type != stack.GSONone { + vnetHdr.hdrLen = uint16(pkt.HeaderSize()) + if pkt.GSOOptions.NeedsCsum { + vnetHdr.flags = unix.VIRTIO_NET_HDR_F_NEEDS_CSUM + vnetHdr.csumStart = pkt.GSOOptions.L3HdrLen + vnetHdr.csumOffset = pkt.GSOOptions.CsumOffset + } + if uint16(pkt.Data().Size()) > pkt.GSOOptions.MSS { + switch pkt.GSOOptions.Type { + case stack.GSOTCPv4: + vnetHdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + case stack.GSOTCPv6: + vnetHdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + default: + panic(fmt.Sprintf("Unknown gso type: %v", pkt.GSOOptions.Type)) + } + vnetHdr.gsoSize = pkt.GSOOptions.MSS + } + } + if err := vnetHdr.encode(t.vnetHdrWriteBuf); err != nil { + return 0, err + } + iovec := unix.Iovec{Base: &t.vnetHdrWriteBuf[0]} + iovec.SetLen(virtioNetHdrLen) + iovecs = append(iovecs, iovec) + } var dataLen int for _, packetSlice := range pkt.AsSlices() { dataLen += len(packetSlice) diff --git a/tun_windows.go b/tun_windows.go index 66fb13d4..26be089f 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -9,6 +9,7 @@ import ( "net/netip" "os" "sync" + "sync/atomic" "time" "unsafe" @@ -16,7 +17,6 @@ import ( "github.com/sagernet/sing-tun/internal/winsys" "github.com/sagernet/sing-tun/internal/wintun" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/atomic" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/windnsapi" @@ -181,6 +181,13 @@ func (t *NativeTun) Start() error { return err } if t.options.StrictRoute { + major, _, _ := windows.RtlGetNtVersionNumbers() + if major < 10 { + if t.options.Logger != nil { + t.options.Logger.Warn("strict routing is not supported on Windows versions below 10") + } + return nil + } var engine uintptr session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC} err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine)) @@ -395,15 +402,16 @@ retry: func (t *NativeTun) ReadPacket() ([]byte, func(), error) { t.running.Add(1) - defer t.running.Done() retry: if t.close.Load() == 1 { + t.running.Done() return nil, nil, os.ErrClosed } start := nanotime() shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { if t.close.Load() == 1 { + t.running.Done() return nil, nil, os.ErrClosed } packet, err := t.session.ReceivePacket() @@ -411,7 +419,10 @@ retry: case nil: packetSize := len(packet) t.rate.update(uint64(packetSize)) - return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil + return packet, func() { + t.session.ReleaseReceivePacket(packet) + t.running.Done() + }, nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { windows.WaitForSingleObject(t.readWait, windows.INFINITE) @@ -420,10 +431,13 @@ retry: procyield(1) continue case windows.ERROR_HANDLE_EOF: + t.running.Done() return nil, nil, os.ErrClosed case windows.ERROR_INVALID_DATA: + t.running.Done() return nil, nil, errors.New("send ring corrupt") } + t.running.Done() return nil, nil, fmt.Errorf("read failed: %w", err) } }