diff options
author | Josh Bleecher Snyder <josharian@gmail.com> | 2021-03-31 13:55:18 -0700 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-04-02 11:07:08 -0600 |
commit | 10533c3e73cdb6f4c4f19e01464782b69ace739e (patch) | |
tree | c19f5ce9c6785b22e72afec19d2a73a0d818e0c6 /conn/bind_linux.go | |
parent | 8ed83e0427a693db6d909897dc73bf7ce6e22b21 (diff) |
all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which
to receive packets (an IPv4 source and an IPv6 source),
allow the conn.Bind to specify a set of sources.
Beneficial consequences:
* If there's no IPv6 support on a system,
conn.Bind.Open can choose not to return a receive function for it,
which is simpler than tracking that state in the bind.
This simplification removes existing data races from both
conn.StdNetBind and bindtest.ChannelBind.
* If there are more than two sources on a system,
the conn.Bind no longer needs to add a separate muxing layer.
Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
Diffstat (limited to 'conn/bind_linux.go')
-rw-r--r-- | conn/bind_linux.go | 109 |
1 files changed, 54 insertions, 55 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go index 70ea609..9eec384 100644 --- a/conn/bind_linux.go +++ b/conn/bind_linux.go @@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 { // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. type LinuxSocketBind struct { - sock4 int - sock6 int - lastMark uint32 - closing sync.RWMutex + // mu guards sock4 and sock6 and the associated fds. + // As long as someone holds mu (read or write), the associated fds are valid. + mu sync.RWMutex + sock4 int + sock6 int } func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } @@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { return nil, errors.New("invalid IP address") } -func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) { +func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) { + bind.mu.Lock() + defer bind.mu.Unlock() + var err error var newPort uint16 var tries int if bind.sock4 != -1 || bind.sock6 != -1 { - return 0, ErrBindAlreadyOpen + return nil, 0, ErrBindAlreadyOpen } originalPort := port again: port = originalPort + var sock4, sock6 int // Attempt ipv6 bind, update port if successful. - bind.sock6, newPort, err = create6(port) + sock6, newPort, err = create6(port) if err != nil { - if err != syscall.EAFNOSUPPORT { - return 0, err + if !errors.Is(err, syscall.EAFNOSUPPORT) { + return nil, 0, err } } else { port = newPort } // Attempt ipv4 bind, update port if successful. - bind.sock4, newPort, err = create4(port) + sock4, newPort, err = create4(port) if err != nil { - if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 { - unix.Close(bind.sock6) + if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + unix.Close(sock6) tries++ goto again } - if err != syscall.EAFNOSUPPORT { - unix.Close(bind.sock6) - return 0, err + if !errors.Is(err, syscall.EAFNOSUPPORT) { + unix.Close(sock6) + return nil, 0, err } } else { port = newPort } - if bind.sock4 == -1 && bind.sock6 == -1 { - return 0, syscall.EAFNOSUPPORT + var fns []ReceiveFunc + if sock4 != -1 { + fns = append(fns, makeReceiveIPv4(sock4)) + bind.sock4 = sock4 + } + if sock6 != -1 { + fns = append(fns, makeReceiveIPv6(sock6)) + bind.sock6 = sock6 + } + if len(fns) == 0 { + return nil, 0, syscall.EAFNOSUPPORT } - return port, nil + return fns, port, nil } func (bind *LinuxSocketBind) SetMark(value uint32) error { - bind.closing.RLock() - defer bind.closing.RUnlock() + bind.mu.RLock() + defer bind.mu.RUnlock() if bind.sock6 != -1 { err := unix.SetsockoptInt( @@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error { } } - bind.lastMark = value return nil } func (bind *LinuxSocketBind) Close() error { - var err1, err2 error - bind.closing.RLock() + // Take a readlock to shut down the sockets... + bind.mu.RLock() if bind.sock6 != -1 { unix.Shutdown(bind.sock6, unix.SHUT_RDWR) } if bind.sock4 != -1 { unix.Shutdown(bind.sock4, unix.SHUT_RDWR) } - bind.closing.RUnlock() - bind.closing.Lock() + bind.mu.RUnlock() + // ...and a write lock to close the fd. + // This ensures that no one else is using the fd. + bind.mu.Lock() + defer bind.mu.Unlock() + var err1, err2 error if bind.sock6 != -1 { err1 = unix.Close(bind.sock6) bind.sock6 = -1 @@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error { err2 = unix.Close(bind.sock4) bind.sock4 = -1 } - bind.closing.Unlock() if err1 != nil { return err1 @@ -208,46 +224,29 @@ func (bind *LinuxSocketBind) Close() error { return err2 } -func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - bind.closing.RLock() - defer bind.closing.RUnlock() - - var end LinuxSocketEndpoint - if bind.sock6 == -1 { - return 0, nil, net.ErrClosed +func makeReceiveIPv6(sock int) ReceiveFunc { + return func(buff []byte) (int, Endpoint, error) { + var end LinuxSocketEndpoint + n, err := receive6(sock, buff, &end) + return n, &end, err } - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err } -func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - bind.closing.RLock() - defer bind.closing.RUnlock() - - var end LinuxSocketEndpoint - if bind.sock4 == -1 { - return 0, nil, net.ErrClosed +func makeReceiveIPv4(sock int) ReceiveFunc { + return func(buff []byte) (int, Endpoint, error) { + var end LinuxSocketEndpoint + n, err := receive4(sock, buff, &end) + return n, &end, err } - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err } func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { - bind.closing.RLock() - defer bind.closing.RUnlock() - nend, ok := end.(*LinuxSocketEndpoint) if !ok { return ErrWrongEndpointType } + bind.mu.RLock() + defer bind.mu.RUnlock() if !nend.isV6 { if bind.sock4 == -1 { return net.ErrClosed |