diff options
author | Josh Bleecher Snyder <josharian@gmail.com> | 2021-03-31 13:55:18 -0700 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-04-02 11:07:08 -0600 |
commit | 10533c3e73cdb6f4c4f19e01464782b69ace739e (patch) | |
tree | c19f5ce9c6785b22e72afec19d2a73a0d818e0c6 /conn/bind_std.go | |
parent | 8ed83e0427a693db6d909897dc73bf7ce6e22b21 (diff) |
all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which
to receive packets (an IPv4 source and an IPv6 source),
allow the conn.Bind to specify a set of sources.
Beneficial consequences:
* If there's no IPv6 support on a system,
conn.Bind.Open can choose not to return a receive function for it,
which is simpler than tracking that state in the bind.
This simplification removes existing data races from both
conn.StdNetBind and bindtest.ChannelBind.
* If there are more than two sources on a system,
the conn.Bind no longer needs to add a separate muxing layer.
Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
Diffstat (limited to 'conn/bind_std.go')
-rw-r--r-- | conn/bind_std.go | 69 |
1 files changed, 38 insertions, 31 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go index f8b8a1b..5261779 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -8,6 +8,7 @@ package conn import ( "errors" "net" + "sync" "syscall" ) @@ -16,6 +17,7 @@ import ( // It uses the Go's net package to implement networking. // See LinuxSocketBind for a proper implementation on the Linux platform. type StdNetBind struct { + mu sync.Mutex // protects following fields ipv4 *net.UDPConn ipv6 *net.UDPConn blackhole4 bool @@ -81,12 +83,15 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) { return conn, uaddr.Port, nil } -func (bind *StdNetBind) Open(uport uint16) (uint16, error) { +func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + bind.mu.Lock() + defer bind.mu.Unlock() + var err error var tries int if bind.ipv4 != nil || bind.ipv6 != nil { - return 0, ErrBindAlreadyOpen + return nil, 0, ErrBindAlreadyOpen } // Attempt to open ipv4 and ipv6 listeners on the same port. @@ -97,7 +102,7 @@ again: ipv4, port, err = listenNet("udp4", port) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - return 0, err + return nil, 0, err } // Listen on the same port as we're using for ipv4. @@ -109,17 +114,27 @@ again: } if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { ipv4.Close() - return 0, err + return nil, 0, err } - if ipv4 == nil && ipv6 == nil { - return 0, syscall.EAFNOSUPPORT + var fns []ReceiveFunc + if ipv4 != nil { + fns = append(fns, makeReceiveFunc(ipv4, true)) + bind.ipv4 = ipv4 } - bind.ipv4 = ipv4 - bind.ipv6 = ipv6 - return uint16(port), nil + if ipv6 != nil { + fns = append(fns, makeReceiveFunc(ipv6, false)) + bind.ipv6 = ipv6 + } + if len(fns) == 0 { + return nil, 0, syscall.EAFNOSUPPORT + } + return fns, uint16(port), nil } func (bind *StdNetBind) Close() error { + bind.mu.Lock() + defer bind.mu.Unlock() + var err1, err2 error if bind.ipv4 != nil { err1 = bind.ipv4.Close() @@ -137,23 +152,14 @@ func (bind *StdNetBind) Close() error { return err2 } -func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - if bind.ipv4 == nil { - return 0, nil, syscall.EAFNOSUPPORT +func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc { + return func(buff []byte) (int, Endpoint, error) { + n, endpoint, err := conn.ReadFromUDP(buff) + if isIPv4 && endpoint != nil { + endpoint.IP = endpoint.IP.To4() + } + return n, (*StdNetEndpoint)(endpoint), err } - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - if endpoint != nil { - endpoint.IP = endpoint.IP.To4() - } - return n, (*StdNetEndpoint)(endpoint), err -} - -func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - if bind.ipv6 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*StdNetEndpoint)(endpoint), err } func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { @@ -162,15 +168,16 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { if !ok { return ErrWrongEndpointType } - var conn *net.UDPConn - var blackhole bool - if nend.IP.To4() != nil { - blackhole = bind.blackhole4 - conn = bind.ipv4 - } else { + + bind.mu.Lock() + blackhole := bind.blackhole4 + conn := bind.ipv4 + if nend.IP.To4() == nil { blackhole = bind.blackhole6 conn = bind.ipv6 } + bind.mu.Unlock() + if blackhole { return nil } |