diff --git a/internal/gtcpip/README.md b/gtcpip/README.md similarity index 100% rename from internal/gtcpip/README.md rename to gtcpip/README.md diff --git a/internal/gtcpip/checksum/checksum.go b/gtcpip/checksum/checksum.go similarity index 100% rename from internal/gtcpip/checksum/checksum.go rename to gtcpip/checksum/checksum.go diff --git a/internal/gtcpip/checksum/checksum_default.go b/gtcpip/checksum/checksum_default.go similarity index 100% rename from internal/gtcpip/checksum/checksum_default.go rename to gtcpip/checksum/checksum_default.go diff --git a/internal/gtcpip/checksum/checksum_ts.go b/gtcpip/checksum/checksum_ts.go similarity index 100% rename from internal/gtcpip/checksum/checksum_ts.go rename to gtcpip/checksum/checksum_ts.go diff --git a/internal/gtcpip/checksum/checksum_unsafe.go b/gtcpip/checksum/checksum_unsafe.go similarity index 100% rename from internal/gtcpip/checksum/checksum_unsafe.go rename to gtcpip/checksum/checksum_unsafe.go diff --git a/internal/gtcpip/errors.go b/gtcpip/errors.go similarity index 100% rename from internal/gtcpip/errors.go rename to gtcpip/errors.go diff --git a/internal/gtcpip/header/checksum.go b/gtcpip/header/checksum.go similarity index 97% rename from internal/gtcpip/header/checksum.go rename to gtcpip/header/checksum.go index 2c21e6d3..303502cc 100644 --- a/internal/gtcpip/header/checksum.go +++ b/gtcpip/header/checksum.go @@ -20,8 +20,8 @@ import ( "encoding/binary" "fmt" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" ) // PseudoHeaderChecksum calculates the pseudo-header checksum for the given diff --git a/internal/gtcpip/header/eth.go b/gtcpip/header/eth.go similarity index 99% rename from internal/gtcpip/header/eth.go rename to gtcpip/header/eth.go index 9d876ee6..613a72c6 100644 --- a/internal/gtcpip/header/eth.go +++ b/gtcpip/header/eth.go @@ -17,7 +17,7 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" ) const ( diff --git a/internal/gtcpip/header/icmpv4.go b/gtcpip/header/icmpv4.go similarity index 98% rename from internal/gtcpip/header/icmpv4.go rename to gtcpip/header/icmpv4.go index 580101c0..3b481041 100644 --- a/internal/gtcpip/header/icmpv4.go +++ b/gtcpip/header/icmpv4.go @@ -17,8 +17,8 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" ) // ICMPv4 represents an ICMPv4 header stored in a byte array. diff --git a/internal/gtcpip/header/icmpv6.go b/gtcpip/header/icmpv6.go similarity index 98% rename from internal/gtcpip/header/icmpv6.go rename to gtcpip/header/icmpv6.go index 520b4036..7eae97ab 100644 --- a/internal/gtcpip/header/icmpv6.go +++ b/gtcpip/header/icmpv6.go @@ -17,8 +17,8 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" ) // ICMPv6 represents an ICMPv6 header stored in a byte array. diff --git a/internal/gtcpip/header/interfaces.go b/gtcpip/header/interfaces.go similarity index 98% rename from internal/gtcpip/header/interfaces.go rename to gtcpip/header/interfaces.go index fc13100c..c0bb410c 100644 --- a/internal/gtcpip/header/interfaces.go +++ b/gtcpip/header/interfaces.go @@ -17,7 +17,7 @@ package header import ( "net/netip" - tcpip "github.com/sagernet/sing-tun/internal/gtcpip" + tcpip "github.com/sagernet/sing-tun/gtcpip" ) const ( diff --git a/internal/gtcpip/header/ipv4.go b/gtcpip/header/ipv4.go similarity index 99% rename from internal/gtcpip/header/ipv4.go rename to gtcpip/header/ipv4.go index ad06f38c..d5ffbf1d 100644 --- a/internal/gtcpip/header/ipv4.go +++ b/gtcpip/header/ipv4.go @@ -20,8 +20,8 @@ import ( "net/netip" "time" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" "github.com/sagernet/sing/common" ) diff --git a/internal/gtcpip/header/ipv6.go b/gtcpip/header/ipv6.go similarity index 99% rename from internal/gtcpip/header/ipv6.go rename to gtcpip/header/ipv6.go index 1a5a7a05..4de30737 100644 --- a/internal/gtcpip/header/ipv6.go +++ b/gtcpip/header/ipv6.go @@ -20,7 +20,7 @@ import ( "fmt" "net/netip" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" ) const ( diff --git a/internal/gtcpip/header/ipv6_extension_headers.go b/gtcpip/header/ipv6_extension_headers.go similarity index 99% rename from internal/gtcpip/header/ipv6_extension_headers.go rename to gtcpip/header/ipv6_extension_headers.go index 20064d8b..6c48b1bf 100644 --- a/internal/gtcpip/header/ipv6_extension_headers.go +++ b/gtcpip/header/ipv6_extension_headers.go @@ -20,7 +20,7 @@ import ( "fmt" "math" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" "github.com/sagernet/sing/common" ) diff --git a/internal/gtcpip/header/ipv6_fragment.go b/gtcpip/header/ipv6_fragment.go similarity index 99% rename from internal/gtcpip/header/ipv6_fragment.go rename to gtcpip/header/ipv6_fragment.go index 49aaca71..38f0b202 100644 --- a/internal/gtcpip/header/ipv6_fragment.go +++ b/gtcpip/header/ipv6_fragment.go @@ -17,7 +17,7 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" ) const ( diff --git a/internal/gtcpip/header/ndp_neighbor_advert.go b/gtcpip/header/ndp_neighbor_advert.go similarity index 98% rename from internal/gtcpip/header/ndp_neighbor_advert.go rename to gtcpip/header/ndp_neighbor_advert.go index 7a934cce..8f36765a 100644 --- a/internal/gtcpip/header/ndp_neighbor_advert.go +++ b/gtcpip/header/ndp_neighbor_advert.go @@ -14,7 +14,7 @@ package header -import "github.com/sagernet/sing-tun/internal/gtcpip" +import "github.com/sagernet/sing-tun/gtcpip" // NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will // only contain the body of an ICMPv6 packet. diff --git a/internal/gtcpip/header/ndp_neighbor_solicit.go b/gtcpip/header/ndp_neighbor_solicit.go similarity index 97% rename from internal/gtcpip/header/ndp_neighbor_solicit.go rename to gtcpip/header/ndp_neighbor_solicit.go index 61d61a8a..b4af20ce 100644 --- a/internal/gtcpip/header/ndp_neighbor_solicit.go +++ b/gtcpip/header/ndp_neighbor_solicit.go @@ -14,7 +14,7 @@ package header -import "github.com/sagernet/sing-tun/internal/gtcpip" +import "github.com/sagernet/sing-tun/gtcpip" // NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only // contain the body of an ICMPv6 packet. diff --git a/internal/gtcpip/header/ndp_options.go b/gtcpip/header/ndp_options.go similarity index 99% rename from internal/gtcpip/header/ndp_options.go rename to gtcpip/header/ndp_options.go index ba293398..365329a2 100644 --- a/internal/gtcpip/header/ndp_options.go +++ b/gtcpip/header/ndp_options.go @@ -23,7 +23,7 @@ import ( "math" "time" - "github.com/sagernet/sing-tun/internal/gtcpip" + "github.com/sagernet/sing-tun/gtcpip" "github.com/sagernet/sing/common" ) diff --git a/internal/gtcpip/header/ndp_router_advert.go b/gtcpip/header/ndp_router_advert.go similarity index 100% rename from internal/gtcpip/header/ndp_router_advert.go rename to gtcpip/header/ndp_router_advert.go diff --git a/internal/gtcpip/header/ndp_router_solicit.go b/gtcpip/header/ndp_router_solicit.go similarity index 100% rename from internal/gtcpip/header/ndp_router_solicit.go rename to gtcpip/header/ndp_router_solicit.go diff --git a/internal/gtcpip/header/ndpoptionidentifier_string.go b/gtcpip/header/ndpoptionidentifier_string.go similarity index 100% rename from internal/gtcpip/header/ndpoptionidentifier_string.go rename to gtcpip/header/ndpoptionidentifier_string.go diff --git a/internal/gtcpip/header/netip.go b/gtcpip/header/netip.go similarity index 100% rename from internal/gtcpip/header/netip.go rename to gtcpip/header/netip.go diff --git a/internal/gtcpip/header/tcp.go b/gtcpip/header/tcp.go similarity index 99% rename from internal/gtcpip/header/tcp.go rename to gtcpip/header/tcp.go index 1b58df86..824b08c8 100644 --- a/internal/gtcpip/header/tcp.go +++ b/gtcpip/header/tcp.go @@ -17,9 +17,9 @@ package header import ( "encoding/binary" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/seqnum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/seqnum" "github.com/google/btree" ) diff --git a/internal/gtcpip/header/udp.go b/gtcpip/header/udp.go similarity index 98% rename from internal/gtcpip/header/udp.go rename to gtcpip/header/udp.go index a995a172..ce7708e1 100644 --- a/internal/gtcpip/header/udp.go +++ b/gtcpip/header/udp.go @@ -18,8 +18,8 @@ import ( "encoding/binary" "math" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" ) const ( diff --git a/internal/gtcpip/seqnum/seqnum.go b/gtcpip/seqnum/seqnum.go similarity index 100% rename from internal/gtcpip/seqnum/seqnum.go rename to gtcpip/seqnum/seqnum.go diff --git a/internal/gtcpip/tcpip.go b/gtcpip/tcpip.go similarity index 100% rename from internal/gtcpip/tcpip.go rename to gtcpip/tcpip.go diff --git a/internal/checksum_test/sum_bench_test.go b/internal/checksum_test/sum_bench_test.go index 35ee021c..2d07fff6 100644 --- a/internal/checksum_test/sum_bench_test.go +++ b/internal/checksum_test/sum_bench_test.go @@ -4,7 +4,7 @@ import ( "crypto/rand" "testing" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/checksum" "github.com/sagernet/sing-tun/internal/tschecksum" ) diff --git a/nfqueue_linux.go b/nfqueue_linux.go index baaefb54..9eed52fc 100644 --- a/nfqueue_linux.go +++ b/nfqueue_linux.go @@ -7,7 +7,7 @@ import ( "errors" "sync/atomic" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" diff --git a/ping/destination.go b/ping/destination.go index 8648ecc8..36f35f42 100644 --- a/ping/destination.go +++ b/ping/destination.go @@ -10,7 +10,7 @@ import ( "time" "github.com/sagernet/sing-tun" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" diff --git a/ping/destination_rewriter.go b/ping/destination_rewriter.go index a61e1556..26bb3551 100644 --- a/ping/destination_rewriter.go +++ b/ping/destination_rewriter.go @@ -4,7 +4,7 @@ import ( "net/netip" "github.com/sagernet/sing-tun" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common/buf" ) diff --git a/ping/ping.go b/ping/ping.go index 248987c2..1b0c89f2 100644 --- a/ping/ping.go +++ b/ping/ping.go @@ -9,7 +9,7 @@ import ( "sync/atomic" "time" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" diff --git a/ping/ping_test.go b/ping/ping_test.go index 5a04be17..cf50b05f 100644 --- a/ping/ping_test.go +++ b/ping/ping_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/sagernet/gvisor/pkg/rand" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing-tun/ping" "github.com/sagernet/sing/common/buf" diff --git a/ping/socket_linux_unprivileged.go b/ping/socket_linux_unprivileged.go index 3742cc83..f709684a 100644 --- a/ping/socket_linux_unprivileged.go +++ b/ping/socket_linux_unprivileged.go @@ -8,7 +8,7 @@ import ( "sync" "time" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" diff --git a/ping/source_rewriter.go b/ping/source_rewriter.go index 480c6a78..545560de 100644 --- a/ping/source_rewriter.go +++ b/ping/source_rewriter.go @@ -6,7 +6,7 @@ import ( "sync" "github.com/sagernet/sing-tun" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common/logger" ) diff --git a/redirect_linux.go b/redirect_linux.go index e9c892c8..d575d192 100644 --- a/redirect_linux.go +++ b/redirect_linux.go @@ -44,6 +44,8 @@ type autoRedirect struct { nfqueueEnabled bool redirectRouteTableIndex int redirectInterfaces []control.Interface + dockerFirewallMonitor *nftables.Monitor + dockerFirewallDone chan struct{} } func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) { diff --git a/redirect_nftables.go b/redirect_nftables.go index 266bbe91..f17e1f36 100644 --- a/redirect_nftables.go +++ b/redirect_nftables.go @@ -283,11 +283,15 @@ func (r *autoRedirect) setupNFTables() error { if err != nil { return E.Cause(err, "configure openwrt firewall4") } - err = nft.Flush() if err != nil { return E.Cause(err, "flush nftables") } + r.startDockerFirewallMonitor() + err = r.configureDockerFirewall(false) + if err != nil && r.logger != nil { + r.logger.Warn("configure docker firewall: ", err) + } r.networkListener = r.networkMonitor.RegisterCallback(func() { err = r.nftablesUpdateLocalAddressSet() @@ -361,6 +365,7 @@ func (r *autoRedirect) cleanupNFTables() { if r.networkListener != nil { r.networkMonitor.UnregisterCallback(r.networkListener) } + r.stopDockerFirewallMonitor() nft, err := nftables.New() if err != nil { return @@ -372,6 +377,10 @@ func (r *autoRedirect) cleanupNFTables() { _ = r.configureOpenWRTFirewall4(nft, true) _ = nft.Flush() _ = nft.CloseLasting() + err = r.configureDockerFirewall(true) + if err != nil && r.logger != nil { + r.logger.Warn("cleanup docker firewall: ", err) + } } func (r *autoRedirect) nftablesCreatePreMatchChains(nft *nftables.Conn, table *nftables.Table) error { diff --git a/redirect_nftables_docker.go b/redirect_nftables_docker.go new file mode 100644 index 00000000..0c6fc896 --- /dev/null +++ b/redirect_nftables_docker.go @@ -0,0 +1,266 @@ +//go:build linux + +package tun + +import ( + "bytes" + "strings" + + "github.com/sagernet/nftables" + "github.com/sagernet/nftables/expr" + "github.com/sagernet/nftables/userdata" + E "github.com/sagernet/sing/common/exceptions" +) + +const ( + nftablesDockerFilterTable = "filter" + nftablesDockerUserChain = "DOCKER-USER" +) + +func (r *autoRedirect) startDockerFirewallMonitor() { + if r.dockerFirewallMonitor != nil { + return + } + doneCh := make(chan struct{}) + r.dockerFirewallDone = doneCh + monitor := nftables.NewMonitor( + nftables.WithMonitorAction(nftables.MonitorActionAny), + nftables.WithMonitorObject(nftables.MonitorObjectRuleset), + nftables.WithMonitorEventBuffer(16), + ) + nft, err := nftables.New() + if err != nil { + if r.logger != nil { + r.logger.Warn("create nftables monitor connection: ", err) + } + close(doneCh) + r.dockerFirewallDone = nil + return + } + events, err := nft.AddGenerationalMonitor(monitor) + _ = nft.CloseLasting() + if err != nil { + if r.logger != nil { + r.logger.Warn("start nftables monitor: ", err) + } + close(doneCh) + r.dockerFirewallDone = nil + return + } + r.dockerFirewallMonitor = monitor + go r.loopDockerFirewallMonitor(events, doneCh) +} + +func (r *autoRedirect) stopDockerFirewallMonitor() { + if r.dockerFirewallMonitor == nil { + return + } + _ = r.dockerFirewallMonitor.Close() + <-r.dockerFirewallDone + r.dockerFirewallMonitor = nil + r.dockerFirewallDone = nil +} + +func (r *autoRedirect) loopDockerFirewallMonitor(events <-chan *nftables.MonitorEvents, doneCh chan<- struct{}) { + defer close(doneCh) + for monitorEvents := range events { + if monitorEvents != nil && monitorEvents.GeneratedBy != nil && monitorEvents.GeneratedBy.Error != nil { + if r.logger != nil { + r.logger.Warn("nftables monitor closed: ", monitorEvents.GeneratedBy.Error) + } + return + } + if !nftablesDockerFirewallEventsRelevant(monitorEvents) { + continue + } + err := r.configureDockerFirewall(false) + if err != nil && r.logger != nil { + r.logger.Warn("update docker firewall: ", err) + } + } +} + +func (r *autoRedirect) configureDockerFirewall(cleanup bool) error { + nft, err := nftables.New() + if err != nil { + return E.Cause(err, "create nftables connection") + } + defer nft.CloseLasting() + + err = r.configureDockerFirewallWithConn(nft, cleanup) + if err != nil { + return err + } + return nft.Flush() +} + +func (r *autoRedirect) configureDockerFirewallWithConn(nft *nftables.Conn, cleanup bool) error { + var err error + if r.enableIPv4 { + err = E.Errors(err, r.configureDockerFirewallForFamily(nft, nftables.TableFamilyIPv4, cleanup)) + } + if r.enableIPv6 { + err = E.Errors(err, r.configureDockerFirewallForFamily(nft, nftables.TableFamilyIPv6, cleanup)) + } + return err +} + +func (r *autoRedirect) configureDockerFirewallForFamily(nft *nftables.Conn, family nftables.TableFamily, cleanup bool) error { + table, chain, loaded, err := nftablesLoadDockerUserChain(nft, family) + if err != nil || !loaded { + return err + } + err = r.configureDockerFirewallRules(nft, table, chain, cleanup) + return err +} + +func (r *autoRedirect) configureDockerFirewallRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, cleanup bool) error { + rules, err := nft.GetRules(table, chain) + if err != nil { + return E.Cause(err, "list docker user rules") + } + if cleanup { + return r.cleanupDockerFirewallRules(nft, rules) + } + return r.reconcileDockerFirewallRules(nft, table, chain, rules) +} + +func nftablesLoadDockerUserChain(nft *nftables.Conn, family nftables.TableFamily) (*nftables.Table, *nftables.Chain, bool, error) { + table, err := nft.ListTableOfFamily(nftablesDockerFilterTable, family) + if err != nil { + return nil, nil, false, nil + } + chain, err := nft.ListChain(table, nftablesDockerUserChain) + if err != nil { + return nil, nil, false, nil + } + return table, chain, true, nil +} + +func nftablesDockerFirewallEventsRelevant(events *nftables.MonitorEvents) bool { + if events == nil { + return false + } + for _, event := range events.Changes { + if nftablesDockerFirewallEventRelevant(event) { + return true + } + } + return false +} + +func nftablesDockerFirewallEventRelevant(event *nftables.MonitorEvent) bool { + if event == nil || event.Error != nil { + return false + } + switch data := event.Data.(type) { + case *nftables.Table: + return nftablesIsDockerFirewallTable(data) + case *nftables.Chain: + return data.Name == nftablesDockerUserChain && nftablesIsDockerFirewallTable(data.Table) + case *nftables.Rule: + return data.Chain != nil && data.Chain.Name == nftablesDockerUserChain && nftablesIsDockerFirewallTable(data.Table) + default: + return false + } +} + +func nftablesIsDockerFirewallTable(table *nftables.Table) bool { + return table != nil && + table.Name == nftablesDockerFilterTable && + (table.Family == nftables.TableFamilyIPv4 || table.Family == nftables.TableFamilyIPv6) +} + +func (r *autoRedirect) cleanupDockerFirewallRules(nft *nftables.Conn, rules []*nftables.Rule) error { + var deleteErr error + for _, rule := range rules { + if r.nftablesIsDockerCompatibilityRule(rule) { + deleteErr = E.Errors(deleteErr, nft.DelRule(rule)) + } + } + return deleteErr +} + +func (r *autoRedirect) reconcileDockerFirewallRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, rules []*nftables.Rule) error { + outputComment := r.nftablesDockerCompatibilityComment("output to tun") + inputComment := r.nftablesDockerCompatibilityComment("input from tun") + var hasOutputRule bool + var hasInputRule bool + var deleteErr error + for _, rule := range rules { + if nftablesDockerCompatibilityRuleMatches(rule, r.tunOptions.Name, expr.MetaKeyOIFNAME, outputComment) && !hasOutputRule { + hasOutputRule = true + } else if nftablesDockerCompatibilityRuleMatches(rule, r.tunOptions.Name, expr.MetaKeyIIFNAME, inputComment) && !hasInputRule { + hasInputRule = true + } else if r.nftablesIsDockerCompatibilityRule(rule) { + deleteErr = E.Errors(deleteErr, nft.DelRule(rule)) + } + } + if deleteErr != nil { + return deleteErr + } + if !hasOutputRule { + nft.InsertRule(nftablesDockerCompatibilityRule(table, chain, r.tunOptions.Name, expr.MetaKeyOIFNAME, outputComment)) + } + if !hasInputRule { + nft.InsertRule(nftablesDockerCompatibilityRule(table, chain, r.tunOptions.Name, expr.MetaKeyIIFNAME, inputComment)) + } + return nil +} + +func nftablesDockerCompatibilityRule(table *nftables.Table, chain *nftables.Chain, ifName string, ifNameKey expr.MetaKey, comment string) *nftables.Rule { + return &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: ifNameKey, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: nftablesIfname(ifName), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + UserData: userdata.AppendString(nil, userdata.TypeComment, comment), + } +} + +func nftablesDockerCompatibilityRuleMatches(rule *nftables.Rule, ifName string, ifNameKey expr.MetaKey, comment string) bool { + ruleComment, loaded := userdata.GetString(rule.UserData, userdata.TypeComment) + if !loaded || ruleComment != comment || len(rule.Exprs) != 4 { + return false + } + meta, loaded := rule.Exprs[0].(*expr.Meta) + if !loaded || meta.Key != ifNameKey || meta.Register != 1 { + return false + } + cmp, loaded := rule.Exprs[1].(*expr.Cmp) + if !loaded || cmp.Op != expr.CmpOpEq || cmp.Register != 1 || !bytes.Equal(cmp.Data, nftablesIfname(ifName)) { + return false + } + _, loaded = rule.Exprs[2].(*expr.Counter) + if !loaded { + return false + } + verdict, loaded := rule.Exprs[3].(*expr.Verdict) + return loaded && verdict.Kind == expr.VerdictAccept +} + +func (r *autoRedirect) nftablesIsDockerCompatibilityRule(rule *nftables.Rule) bool { + comment, loaded := userdata.GetString(rule.UserData, userdata.TypeComment) + return loaded && strings.HasPrefix(comment, r.nftablesDockerCompatibilityCommentPrefix()) +} + +func (r *autoRedirect) nftablesDockerCompatibilityComment(direction string) string { + return r.nftablesDockerCompatibilityCommentPrefix() + direction +} + +func (r *autoRedirect) nftablesDockerCompatibilityCommentPrefix() string { + return "!" + r.tableName + ": Docker compatibility " +} diff --git a/redirect_nftables_rules.go b/redirect_nftables_rules.go index dddd9c66..1ef5c19b 100644 --- a/redirect_nftables_rules.go +++ b/redirect_nftables_rules.go @@ -3,6 +3,7 @@ package tun import ( + "net" "net/netip" _ "unsafe" @@ -376,6 +377,149 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft }) } } + if len(r.tunOptions.IncludeMACAddress) > 0 { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFTYPE, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: binaryutil.NativeEndian.PutUint16(unix.ARPHRD_ETHER), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + if len(r.tunOptions.IncludeMACAddress) > 1 { + includeMACSet := &nftables.Set{ + Table: table, + Anonymous: true, + Constant: true, + KeyType: nftables.TypeEtherAddr, + } + err := nft.AddSet(includeMACSet, common.Map(r.tunOptions.IncludeMACAddress, func(it net.HardwareAddr) nftables.SetElement { + return nftables.SetElement{ + Key: []byte(it), + } + })) + if err != nil { + return err + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Lookup{ + SourceRegister: 1, + SetID: includeMACSet.ID, + SetName: includeMACSet.Name, + Invert: true, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } else { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte(r.tunOptions.IncludeMACAddress[0]), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } + } + if len(r.tunOptions.ExcludeMACAddress) > 0 { + if len(r.tunOptions.ExcludeMACAddress) > 1 { + excludeMACSet := &nftables.Set{ + Table: table, + Anonymous: true, + Constant: true, + KeyType: nftables.TypeEtherAddr, + } + err := nft.AddSet(excludeMACSet, common.Map(r.tunOptions.ExcludeMACAddress, func(it net.HardwareAddr) nftables.SetElement { + return nftables.SetElement{ + Key: []byte(it), + } + })) + if err != nil { + return err + } + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Lookup{ + SourceRegister: 1, + SetID: excludeMACSet.ID, + SetName: excludeMACSet.Name, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } else { + nft.AddRule(&nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Payload{ + OperationType: expr.PayloadLoad, + DestRegister: 1, + Base: expr.PayloadBaseLLHeader, + Offset: 6, + Len: 6, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(r.tunOptions.ExcludeMACAddress[0]), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictReturn, + }, + }, + }) + } + } } else { if len(r.tunOptions.IncludeUID) > 0 { if len(r.tunOptions.IncludeUID) > 1 || r.tunOptions.IncludeUID[0].Start != r.tunOptions.IncludeUID[0].End { @@ -531,7 +675,7 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet6RouteExcludeAddress.ID, inet6RouteExcludeAddress.Name, nftables.TableFamilyIPv6, false) } - if !r.tunOptions.EXP_DisableDNSHijack && ((chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeNAT) || + if r.tunOptions.DNSModeOrDefault() == DNSModeHijack && ((chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeNAT) || (r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeNAT)) { if r.enableIPv4 { err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv4, 5, "inet4_local_address_set") @@ -853,23 +997,19 @@ func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily( if err != nil { return E.Cause(err, "add dns protocol set") } - dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool { - return it.Is4() == (family == nftables.TableFamilyIPv4) - }) - if !dnsServer.IsValid() { - if family == nftables.TableFamilyIPv4 { - if HasNextAddress(r.tunOptions.Inet4Address[0], 1) { - dnsServer = r.tunOptions.Inet4Address[0].Addr().Next() - } - } else { - if HasNextAddress(r.tunOptions.Inet6Address[0], 1) { - dnsServer = r.tunOptions.Inet6Address[0].Addr().Next() - } - } + var dnsServers []netip.Addr + if family == nftables.TableFamilyIPv4 { + dnsServers, err = r.tunOptions.Inet4DNSAddress() + } else { + dnsServers, err = r.tunOptions.Inet6DNSAddress() + } + if err != nil { + return err } - if !dnsServer.IsValid() { + if len(dnsServers) == 0 { return nil } + dnsServer := dnsServers[0] exprs := []expr.Any{ &expr.Meta{ Key: expr.MetaKeyNFPROTO, diff --git a/stack_gvisor_lazy.go b/stack_gvisor_lazy.go index f5e2e6e6..dcbcafb9 100644 --- a/stack_gvisor_lazy.go +++ b/stack_gvisor_lazy.go @@ -19,7 +19,7 @@ import ( ) type gLazyConn struct { - tcpConn *gonet.TCPConn + tcpConn *gTCPConn parentCtx context.Context stack *stack.Stack request *tcp.ForwarderRequest @@ -31,9 +31,6 @@ type gLazyConn struct { } func (c *gLazyConn) HandshakeContext(ctx context.Context) error { - if c.handshakeDone { - return c.handshakeErr - } c.handshakeAccess.Lock() defer c.handshakeAccess.Unlock() if c.handshakeDone { @@ -66,15 +63,12 @@ func (c *gLazyConn) HandshakeContext(ctx context.Context) error { endpoint.SocketOptions().SetKeepAlive(true) endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIdleOption(15 * time.Second))) endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIntervalOption(15 * time.Second))) - tcpConn := gonet.NewTCPConn(&wq, endpoint) + tcpConn := newGTCPConn(&wq, endpoint, c.localAddr, c.remoteAddr) c.tcpConn = tcpConn return nil } func (c *gLazyConn) HandshakeFailure(err error) error { - if c.handshakeDone { - return os.ErrInvalid - } c.handshakeAccess.Lock() defer c.handshakeAccess.Unlock() if c.handshakeDone { @@ -90,6 +84,18 @@ func (c *gLazyConn) HandshakeSuccess() error { return c.HandshakeContext(context.Background()) } +func (c *gLazyConn) NeedHandshakeForRead() bool { + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() + return !c.handshakeDone +} + +func (c *gLazyConn) NeedHandshakeForWrite() bool { + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() + return !c.handshakeDone +} + func (c *gLazyConn) Read(b []byte) (n int, err error) { err = c.HandshakeContext(context.Background()) if err != nil { @@ -139,57 +145,38 @@ func (c *gLazyConn) SetWriteDeadline(t time.Time) error { } func (c *gLazyConn) Close() error { - if !c.handshakeDone { - c.handshakeAccess.Lock() - if !c.handshakeDone { - c.request.Complete(true) - c.handshakeErr = net.ErrClosed - c.handshakeDone = true - return nil - } else if c.handshakeErr != nil { - return nil - } - c.handshakeAccess.Unlock() - } else if c.handshakeErr != nil { + if c.closeBeforeHandshake() { return nil } return c.tcpConn.Close() } func (c *gLazyConn) CloseRead() error { - if !c.handshakeDone { - c.handshakeAccess.Lock() - if !c.handshakeDone { - c.request.Complete(true) - c.handshakeErr = net.ErrClosed - c.handshakeDone = true - return nil - } else if c.handshakeErr != nil { - return nil - } - c.handshakeAccess.Unlock() - } else if c.handshakeErr != nil { + if c.closeBeforeHandshake() { return nil } return c.tcpConn.CloseRead() } func (c *gLazyConn) CloseWrite() error { + if c.closeBeforeHandshake() { + return nil + } + return c.tcpConn.CloseWrite() +} + +func (c *gLazyConn) closeBeforeHandshake() bool { + c.handshakeAccess.Lock() + defer c.handshakeAccess.Unlock() if !c.handshakeDone { - c.handshakeAccess.Lock() - if !c.handshakeDone { + if c.request != nil { c.request.Complete(true) - c.handshakeErr = net.ErrClosed - c.handshakeDone = true - return nil - } else if c.handshakeErr != nil { - return nil } - c.handshakeAccess.Unlock() - } else if c.handshakeErr != nil { - return nil + c.handshakeErr = net.ErrClosed + c.handshakeDone = true + return true } - return c.tcpConn.CloseRead() + return c.handshakeErr != nil } func (c *gLazyConn) ReaderReplaceable() bool { diff --git a/stack_gvisor_tcp.go b/stack_gvisor_tcp.go index 0c63ee11..ba8af6df 100644 --- a/stack_gvisor_tcp.go +++ b/stack_gvisor_tcp.go @@ -11,7 +11,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/header" "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/checksum" "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" diff --git a/stack_gvisor_tcp_conn.go b/stack_gvisor_tcp_conn.go new file mode 100644 index 00000000..ad48d42f --- /dev/null +++ b/stack_gvisor_tcp_conn.go @@ -0,0 +1,325 @@ +//go:build with_gvisor + +package tun + +import ( + "bytes" + "errors" + "io" + "net" + "os" + "time" + + "github.com/sagernet/gvisor/pkg/sync" + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/waiter" + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" +) + +var ( + _ net.Conn = (*gTCPConn)(nil) + _ N.ReadWaiter = (*gTCPConn)(nil) +) + +type gTCPConn struct { + gTCPDeadline + + wq *waiter.Queue + ep tcpip.Endpoint + + localAddr net.Addr + remoteAddr net.Addr + + readMu sync.Mutex + readWaitOption N.ReadWaitOptions +} + +func newGTCPConn(wq *waiter.Queue, ep tcpip.Endpoint, localAddr net.Addr, remoteAddr net.Addr) *gTCPConn { + conn := &gTCPConn{ + wq: wq, + ep: ep, + localAddr: localAddr, + remoteAddr: remoteAddr, + } + conn.gTCPDeadline.init() + return conn +} + +func (c *gTCPConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOption = options + return false +} + +func (c *gTCPConn) WaitReadBuffer() (*buf.Buffer, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + deadline := c.readCancel() + for { + if err := c.waitReadable(deadline); err != nil { + return nil, err + } + buffer := c.readWaitOption.NewBuffer() + writer := tcpip.SliceWriter(buffer.FreeBytes()) + result, err := c.ep.Read(&writer, tcpip.ReadOptions{}) + if _, wouldBlock := err.(*tcpip.ErrWouldBlock); wouldBlock { + buffer.Release() + continue + } + if err != nil { + buffer.Release() + return nil, c.translateReadError(err) + } + if result.Count == 0 { + buffer.Release() + continue + } + buffer.Truncate(result.Count) + c.readWaitOption.PostReturn(buffer) + c.ep.ModerateRecvBuf(result.Count) + return buffer, nil + } +} + +func (c *gTCPConn) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + writer := tcpip.SliceWriter(b) + n, err := c.readTo(&writer, c.readCancel()) + if n != 0 { + c.ep.ModerateRecvBuf(n) + } + return n, err +} + +func (c *gTCPConn) readTo(writer io.Writer, deadline <-chan struct{}) (int, error) { + select { + case <-deadline: + return 0, c.newOpError("read", os.ErrDeadlineExceeded) + default: + } + + result, err := c.ep.Read(writer, tcpip.ReadOptions{}) + if _, wouldBlock := err.(*tcpip.ErrWouldBlock); wouldBlock { + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) + c.wq.EventRegister(&waitEntry) + defer c.wq.EventUnregister(&waitEntry) + for { + result, err = c.ep.Read(writer, tcpip.ReadOptions{}) + if _, wouldBlock = err.(*tcpip.ErrWouldBlock); !wouldBlock { + break + } + select { + case <-deadline: + return 0, c.newOpError("read", os.ErrDeadlineExceeded) + case <-notifyCh: + } + } + } + + if err != nil { + return 0, c.translateReadError(err) + } + return result.Count, nil +} + +func (c *gTCPConn) waitReadable(deadline <-chan struct{}) error { + select { + case <-deadline: + return c.newOpError("read", os.ErrDeadlineExceeded) + default: + } + if c.ep.Readiness(waiter.ReadableEvents)&waiter.ReadableEvents != 0 { + return nil + } + + waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents) + c.wq.EventRegister(&waitEntry) + defer c.wq.EventUnregister(&waitEntry) + for c.ep.Readiness(waiter.ReadableEvents)&waiter.ReadableEvents == 0 { + select { + case <-deadline: + return c.newOpError("read", os.ErrDeadlineExceeded) + case <-notifyCh: + } + } + return nil +} + +func (c *gTCPConn) translateReadError(err tcpip.Error) error { + if _, closed := err.(*tcpip.ErrClosedForReceive); closed { + return io.EOF + } + return c.newOpError("read", gonet.TranslateNetstackError(err)) +} + +func (c *gTCPConn) Write(b []byte) (int, error) { + deadline := c.writeCancel() + + select { + case <-deadline: + return 0, c.newOpError("write", os.ErrDeadlineExceeded) + default: + } + + var ( + reader bytes.Reader + nBytes int + entry waiter.Entry + ch <-chan struct{} + ) + for nBytes != len(b) { + reader.Reset(b[nBytes:]) + n, err := c.ep.Write(&reader, tcpip.WriteOptions{}) + nBytes += int(n) + switch err.(type) { + case nil: + case *tcpip.ErrWouldBlock: + if ch == nil { + entry, ch = waiter.NewChannelEntry(waiter.WritableEvents) + c.wq.EventRegister(&entry) + defer c.wq.EventUnregister(&entry) + } else { + select { + case <-deadline: + return nBytes, c.newOpError("write", os.ErrDeadlineExceeded) + case <-ch: + continue + } + } + default: + return nBytes, c.newOpError("write", gonet.TranslateNetstackError(err)) + } + } + return nBytes, nil +} + +func (c *gTCPConn) Close() error { + c.ep.Close() + return nil +} + +func (c *gTCPConn) CloseRead() error { + if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { + return c.newOpError("close", errors.New(err.String())) + } + return nil +} + +func (c *gTCPConn) CloseWrite() error { + if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { + return c.newOpError("close", errors.New(err.String())) + } + return nil +} + +func (c *gTCPConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *gTCPConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *gTCPConn) SetDeadline(t time.Time) error { + return c.gTCPDeadline.SetDeadline(t) +} + +func (c *gTCPConn) SetReadDeadline(t time.Time) error { + return c.gTCPDeadline.SetReadDeadline(t) +} + +func (c *gTCPConn) SetWriteDeadline(t time.Time) error { + return c.gTCPDeadline.SetWriteDeadline(t) +} + +func (c *gTCPConn) newOpError(op string, err error) *net.OpError { + return &net.OpError{ + Op: op, + Net: "tcp", + Source: c.localAddr, + Addr: c.remoteAddr, + Err: err, + } +} + +type gTCPDeadline struct { + mu sync.Mutex + + readTimer *time.Timer + readCancelCh chan struct{} + writeTimer *time.Timer + writeCancelCh chan struct{} +} + +func (d *gTCPDeadline) init() { + d.readCancelCh = make(chan struct{}) + d.writeCancelCh = make(chan struct{}) +} + +func (d *gTCPDeadline) readCancel() <-chan struct{} { + d.mu.Lock() + cancelCh := d.readCancelCh + d.mu.Unlock() + return cancelCh +} + +func (d *gTCPDeadline) writeCancel() <-chan struct{} { + d.mu.Lock() + cancelCh := d.writeCancelCh + d.mu.Unlock() + return cancelCh +} + +func (d *gTCPDeadline) SetDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.readCancelCh, &d.readTimer, t) + d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) + d.mu.Unlock() + return nil +} + +func (d *gTCPDeadline) SetReadDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.readCancelCh, &d.readTimer, t) + d.mu.Unlock() + return nil +} + +func (d *gTCPDeadline) SetWriteDeadline(t time.Time) error { + d.mu.Lock() + d.setDeadline(&d.writeCancelCh, &d.writeTimer, t) + d.mu.Unlock() + return nil +} + +func (d *gTCPDeadline) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) { + if *timer != nil && !(*timer).Stop() { + *cancelCh = make(chan struct{}) + } + + select { + case <-*cancelCh: + *cancelCh = make(chan struct{}) + default: + } + + if t.IsZero() { + *timer = nil + return + } + + timeout := time.Until(t) + if timeout <= 0 { + close(*cancelCh) + return + } + + ch := *cancelCh + *timer = time.AfterFunc(timeout, func() { + close(ch) + }) +} diff --git a/stack_mixed.go b/stack_mixed.go index 8836d6ba..33284053 100644 --- a/stack_mixed.go +++ b/stack_mixed.go @@ -12,7 +12,7 @@ import ( "github.com/sagernet/gvisor/pkg/tcpip/link/channel" "github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" ) diff --git a/stack_system.go b/stack_system.go index 030eee17..1c917e0f 100644 --- a/stack_system.go +++ b/stack_system.go @@ -8,8 +8,8 @@ import ( "syscall" "time" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" diff --git a/stack_system_packet.go b/stack_system_packet.go index 34fe51e4..a8f8076e 100644 --- a/stack_system_packet.go +++ b/stack_system_packet.go @@ -4,7 +4,7 @@ import ( "net/netip" "syscall" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" ) diff --git a/tun.go b/tun.go index 35cd0956..abfe67fa 100644 --- a/tun.go +++ b/tun.go @@ -9,8 +9,10 @@ import ( "strings" "time" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" @@ -68,6 +70,12 @@ const ( DefaultIPRoute2AutoRedirectFallbackRuleIndex = 32768 ) +const ( + DNSModeDisabled = "disabled" + DNSModeNative = "native" + DNSModeHijack = "hijack" +) + type Options struct { Name string Inet4Address []netip.Prefix @@ -78,7 +86,8 @@ type Options struct { InterfaceScope bool Inet4Gateway netip.Addr Inet6Gateway netip.Addr - DNSServers []netip.Addr + DNSMode string + DNSAddress []netip.Addr IPRoute2TableIndex int IPRoute2RuleIndex int IPRoute2AutoRedirectFallbackRuleIndex int @@ -102,6 +111,8 @@ type Options struct { IncludeAndroidUser []int IncludePackage []string ExcludePackage []string + IncludeMACAddress []net.HardwareAddr + ExcludeMACAddress []net.HardwareAddr InterfaceFinder control.InterfaceFinder InterfaceMonitor DefaultInterfaceMonitor FileDescriptor int @@ -122,6 +133,57 @@ type Options struct { EXP_SendMsgX bool } +func (o *Options) DNSModeOrDefault() string { + if o.DNSMode == "" { + return DNSModeHijack + } + return o.DNSMode +} + +func (o *Options) DNSServerAddress() ([]netip.Addr, error) { + inet4DNS, err := o.Inet4DNSAddress() + if err != nil { + return nil, err + } + inet6DNS, err := o.Inet6DNSAddress() + if err != nil { + return nil, err + } + return append(inet4DNS, inet6DNS...), nil +} + +func (o *Options) Inet4DNSAddress() ([]netip.Addr, error) { + if len(o.Inet4Address) == 0 { + return nil, nil + } + if len(o.DNSAddress) > 0 { + return common.Filter(o.DNSAddress, netip.Addr.Is4), nil + } + if HasNextAddress(o.Inet4Address[0], 1) { + return []netip.Addr{o.Inet4Address[0].Addr().Next()}, nil + } + if !(len(o.Inet6Address) > 0 && HasNextAddress(o.Inet6Address[0], 1)) { + return nil, E.New("no IPv4 server configured and no usable next address in ", o.Inet6Address[0], " for DNS") + } + return nil, nil +} + +func (o *Options) Inet6DNSAddress() ([]netip.Addr, error) { + if len(o.Inet6Address) == 0 { + return nil, nil + } + if len(o.DNSAddress) > 0 { + return common.Filter(o.DNSAddress, netip.Addr.Is6), nil + } + if HasNextAddress(o.Inet6Address[0], 1) { + return []netip.Addr{o.Inet6Address[0].Addr().Next()}, nil + } + if !(len(o.Inet4Address) > 0 && HasNextAddress(o.Inet4Address[0], 1)) { + return nil, E.New("no IPv6 server configured and no usable next address in ", o.Inet6Address[0], " for DNS") + } + return nil, nil +} + func (o *Options) Inet4GatewayAddr() netip.Addr { if o.Inet4Gateway.IsValid() { return o.Inet4Gateway diff --git a/tun_darwin.go b/tun_darwin.go index 8aa6923f..f4ca0edd 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -9,7 +9,7 @@ import ( "syscall" "unsafe" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing-tun/internal/rawfile_darwin" "github.com/sagernet/sing-tun/internal/stopfd_darwin" "github.com/sagernet/sing/common" diff --git a/tun_linux.go b/tun_linux.go index 20fdce23..dc1a02b7 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -14,8 +14,8 @@ import ( "unsafe" "github.com/sagernet/netlink" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/header" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" @@ -317,7 +317,12 @@ func (t *NativeTun) Start() error { return E.Cause(err, "set rules") } - t.setSearchDomainForSystemdResolved() + if t.options.DNSMode != DNSModeDisabled { + err = t.setSearchDomainForSystemdResolved() + if err != nil { + return E.Cause(err, "set search domain") + } + } if t.options.AutoRoute && runtime.GOOS == "android" { t.interfaceCallback = t.options.InterfaceMonitor.RegisterCallback(t.routeUpdate) @@ -332,7 +337,9 @@ func (t *NativeTun) Close() error { if t.options.EXP_ExternalConfiguration { return common.Close(common.PtrOrNil(t.tunFile)) } - t.unsetSearchDomainForSystemdResolved() + if t.options.DNSMode != DNSModeDisabled { + t.unsetSearchDomainForSystemdResolved() + } t.unsetAddresses() return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile))) } @@ -1073,37 +1080,24 @@ func (t *NativeTun) routeUpdate(_ *control.Interface, flags int) { } } -func (t *NativeTun) setSearchDomainForSystemdResolved() { - if t.options.EXP_DisableDNSHijack { - return - } +func (t *NativeTun) setSearchDomainForSystemdResolved() error { ctlPath, err := exec.LookPath("resolvectl") if err != nil { - return - } - dnsServer := t.options.DNSServers - if len(dnsServer) == 0 { - if len(t.options.Inet4Address) > 0 && HasNextAddress(t.options.Inet4Address[0], 1) { - dnsServer = append(dnsServer, t.options.Inet4Address[0].Addr().Next()) - } - if len(t.options.Inet6Address) > 0 && HasNextAddress(t.options.Inet6Address[0], 1) { - dnsServer = append(dnsServer, t.options.Inet6Address[0].Addr().Next()) - } + return nil } - if len(dnsServer) == 0 { - return + dnsAddress, err := t.options.DNSServerAddress() + if err != nil { + return err } go func() { _ = shell.Exec(ctlPath, "domain", t.options.Name, "~.").Run() _ = shell.Exec(ctlPath, "default-route", t.options.Name, "true").Run() - _ = shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsServer, netip.Addr.String)...)...).Run() + _ = shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsAddress, netip.Addr.String)...)...).Run() }() + return nil } func (t *NativeTun) unsetSearchDomainForSystemdResolved() { - if t.options.EXP_DisableDNSHijack { - return - } ctlPath, err := exec.LookPath("resolvectl") if err != nil { return diff --git a/tun_offload.go b/tun_offload.go index a0eee82f..83c833af 100644 --- a/tun_offload.go +++ b/tun_offload.go @@ -4,9 +4,9 @@ import ( "encoding/binary" "fmt" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/header" ) const ( diff --git a/tun_offload_linux.go b/tun_offload_linux.go index 77337607..a3085304 100644 --- a/tun_offload_linux.go +++ b/tun_offload_linux.go @@ -13,9 +13,9 @@ import ( "io" "unsafe" - "github.com/sagernet/sing-tun/internal/gtcpip" - "github.com/sagernet/sing-tun/internal/gtcpip/checksum" - "github.com/sagernet/sing-tun/internal/gtcpip/header" + "github.com/sagernet/sing-tun/gtcpip" + "github.com/sagernet/sing-tun/gtcpip/checksum" + "github.com/sagernet/sing-tun/gtcpip/header" "golang.org/x/sys/unix" ) diff --git a/tun_windows.go b/tun_windows.go index d00d51db..6dfce2f2 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -16,7 +16,6 @@ import ( "github.com/sagernet/sing-tun/internal/winipcfg" "github.com/sagernet/sing-tun/internal/winsys" "github.com/sagernet/sing-tun/internal/wintun" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/windnsapi" @@ -81,16 +80,14 @@ func (t *NativeTun) configure() error { if err != nil { return E.Cause(err, "set ipv4 address") } - if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack { - dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is4) - if len(dnsServers) == 0 && HasNextAddress(t.options.Inet4Address[0], 1) { - dnsServers = []netip.Addr{t.options.Inet4Address[0].Addr().Next()} + if t.options.AutoRoute && t.options.DNSModeOrDefault() != DNSModeDisabled { + dnsServers, err := t.options.Inet4DNSAddress() + if err != nil { + return err } - if len(dnsServers) > 0 { - err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), dnsServers, nil) - if err != nil { - return E.Cause(err, "set ipv4 dns") - } + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), dnsServers, nil) + if err != nil { + return E.Cause(err, "set ipv4 dns") } } else { err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET), nil, nil) @@ -104,16 +101,14 @@ func (t *NativeTun) configure() error { if err != nil { return E.Cause(err, "set ipv6 address") } - if t.options.AutoRoute && !t.options.EXP_DisableDNSHijack { - dnsServers := common.Filter(t.options.DNSServers, netip.Addr.Is6) - if len(dnsServers) == 0 && HasNextAddress(t.options.Inet6Address[0], 1) { - dnsServers = []netip.Addr{t.options.Inet6Address[0].Addr().Next()} + if t.options.AutoRoute && t.options.DNSModeOrDefault() != DNSModeDisabled { + dnsServers, err := t.options.Inet6DNSAddress() + if err != nil { + return err } - if len(dnsServers) > 0 { - err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), dnsServers, nil) - if err != nil { - return E.Cause(err, "set ipv6 dns") - } + err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), dnsServers, nil) + if err != nil { + return E.Cause(err, "set ipv6 dns") } } else { err = luid.SetDNS(winipcfg.AddressFamily(windows.AF_INET6), nil, nil) @@ -334,7 +329,7 @@ func (t *NativeTun) Start() error { } } - if !t.options.EXP_DisableDNSHijack { + if t.options.DNSModeOrDefault() == DNSModeHijack { blockDNSCondition := make([]winsys.FWPM_FILTER_CONDITION0, 1) blockDNSCondition[0].FieldKey = winsys.FWPM_CONDITION_IP_REMOTE_PORT blockDNSCondition[0].MatchType = winsys.FWP_MATCH_EQUAL