diff options
-rw-r--r-- | pkg/tcpip/stack/stack.go | 12 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 3 | ||||
-rw-r--r-- | test/syscalls/linux/BUILD | 9 | ||||
-rw-r--r-- | test/syscalls/linux/mkdir.cc | 2 | ||||
-rw-r--r-- | test/syscalls/linux/open_create.cc | 2 | ||||
-rw-r--r-- | test/util/BUILD | 6 | ||||
-rw-r--r-- | test/util/temp_umask.h (renamed from test/syscalls/linux/temp_umask.h) | 6 |
10 files changed, 58 insertions, 50 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index ebb6c5e3b..13354d884 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -551,11 +551,13 @@ type TransportEndpointInfo struct { RegisterNICID tcpip.NICID } -// AddrNetProto unwraps the specified address if it is a V4-mapped V6 address -// and returns the network protocol number to be used to communicate with the -// specified address. It returns an error if the passed address is incompatible -// with the receiver. -func (e *TransportEndpointInfo) AddrNetProto(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { +// AddrNetProtoLocked unwraps the specified address if it is a V4-mapped V6 +// address and returns the network protocol number to be used to communicate +// with the specified address. It returns an error if the passed address is +// incompatible with the receiver. +// +// Preconditon: the parent endpoint mu must be held while calling this method. +func (e *TransportEndpointInfo) AddrNetProtoLocked(addr tcpip.FullAddress, v6only bool) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { netProto := e.NetProto switch len(addr.Addr) { case header.IPv4AddressSize: diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 426da1ee6..2a396e9bc 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -291,15 +291,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } - toCopy := *to - to = &toCopy - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - // Find the enpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, to.Addr, netProto, false /* multicastLoop */) + // Find the endpoint. + r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) if err != nil { return 0, nil, err } @@ -480,13 +478,14 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err }) } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, false /* v6only */) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -517,7 +516,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -630,7 +629,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8b9154e69..40cc664c0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1874,13 +1874,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -1910,7 +1911,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc connectingAddr := addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } @@ -2276,7 +2277,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } e.BindAddr = addr.Addr - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 1c6a600b8..0af4514e1 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -443,19 +443,19 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrBroadcastDisabled } - netProto, err := e.checkV4Mapped(to) + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err } - r, _, err := e.connectRoute(nicID, *to, netProto) + r, _, err := e.connectRoute(nicID, dst, netProto) if err != nil { return 0, nil, err } defer r.Release() route = &r - dstPort = to.Port + dstPort = dst.Port } if route.IsResolutionRequired() { @@ -566,7 +566,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - netProto, err := e.checkV4Mapped(&fa) + fa, netProto, err := e.checkV4MappedLocked(fa) if err != nil { return err } @@ -927,13 +927,14 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u return nil } -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProto(*addr, e.v6only) +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) if err != nil { - return 0, err + return tcpip.FullAddress{}, 0, err } - *addr = unwrapped - return netProto, nil + return unwrapped, netProto, nil } // Disconnect implements tcpip.Endpoint.Disconnect. @@ -981,10 +982,6 @@ func (e *endpoint) Disconnect() *tcpip.Error { // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { - netProto, err := e.checkV4Mapped(&addr) - if err != nil { - return err - } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState @@ -1012,6 +1009,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + r, nicID, err := e.connectRoute(nicID, addr, netProto) if err != nil { return err @@ -1139,7 +1141,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr) + addr, netProto, err := e.checkV4MappedLocked(addr) if err != nil { return err } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 43fb047ed..466bd9381 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -69,6 +69,9 @@ func (e *endpoint) afterLoad() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + e.stack = s for _, m := range e.multicastMemberships { diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 70c120e42..dae2b1077 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -167,11 +167,6 @@ cc_library( ) cc_library( - name = "temp_umask", - hdrs = ["temp_umask.h"], -) - -cc_library( name = "unix_domain_socket_test_util", testonly = 1, srcs = ["unix_domain_socket_test_util.cc"], @@ -1140,11 +1135,11 @@ cc_binary( srcs = ["mkdir.cc"], linkstatic = 1, deps = [ - ":temp_umask", "//test/util:capability_util", "//test/util:fs_util", gtest, "//test/util:temp_path", + "//test/util:temp_umask", "//test/util:test_main", "//test/util:test_util", ], @@ -1299,12 +1294,12 @@ cc_binary( srcs = ["open_create.cc"], linkstatic = 1, deps = [ - ":temp_umask", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:fs_util", gtest, "//test/util:temp_path", + "//test/util:temp_umask", "//test/util:test_main", "//test/util:test_util", ], diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index cf138d328..def4c50a4 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -18,10 +18,10 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/temp_umask.h" #include "test/util/capability_util.h" #include "test/util/fs_util.h" #include "test/util/temp_path.h" +#include "test/util/temp_umask.h" #include "test/util/test_util.h" namespace gvisor { diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 902d0a0dc..51eacf3f2 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -19,11 +19,11 @@ #include <unistd.h> #include "gtest/gtest.h" -#include "test/syscalls/linux/temp_umask.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" #include "test/util/temp_path.h" +#include "test/util/temp_umask.h" #include "test/util/test_util.h" namespace gvisor { diff --git a/test/util/BUILD b/test/util/BUILD index 8b5a0f25c..2a17c33ee 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -350,3 +350,9 @@ cc_library( ":save_util", ], ) + +cc_library( + name = "temp_umask", + testonly = 1, + hdrs = ["temp_umask.h"], +) diff --git a/test/syscalls/linux/temp_umask.h b/test/util/temp_umask.h index 81a25440c..e7de84a54 100644 --- a/test/syscalls/linux/temp_umask.h +++ b/test/util/temp_umask.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ -#define GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ +#ifndef GVISOR_TEST_UTIL_TEMP_UMASK_H_ +#define GVISOR_TEST_UTIL_TEMP_UMASK_H_ #include <sys/stat.h> #include <sys/types.h> @@ -36,4 +36,4 @@ class TempUmask { } // namespace testing } // namespace gvisor -#endif // GVISOR_TEST_SYSCALLS_TEMP_UMASK_H_ +#endif // GVISOR_TEST_UTIL_TEMP_UMASK_H_ |