From 4076153be6840c50ade746087b221a12d7bd2b3b Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 17 Sep 2021 15:29:25 -0700 Subject: Fix lock ordering violation This fixes a lock ordering violations introduced in https://github.com/google/gvisor/commit/ae3bd32011889fe59bb89946532dd7ee14973696 and https://github.com/google/gvisor/commit/477d7e5e10378e2f80f21ac9f536d12c4b94d7ce when connecting/binding sockets races with handling of packets/errors as the connect/bind path takes the transport/internal/network.Endpoint.mu lock before taking stack.endpointsByNIC.mu but the locks are taken in the reverse order when handling packets/errors. The fix is to revert the change to use a lock instead of atomics in https://github.com/google/gvisor/commit/477d7e5e10378e2f80f21ac9f536d12c4b94d7ce and introduce a new lock protecting only the endpoint info in transport/internal/network.Endpoint. ``` goroutine 60 [semacquire]: sync.runtime_Semacquire(0x62c957) go/gc/src/runtime/sema.go:56 +0x25 gvisor/pkg/sync/sync.(*CrossGoroutineRWMutex).RLock(0xc0006c4870) gvisor/pkg/sync/rwmutex_unsafe.go:76 +0x57 gvisor/pkg/sync/sync.(*RWMutex).RLock(...) gvisor/pkg/sync/rwmutex_unsafe.go:254 gvisor/pkg/tcpip/transport/internal/network/network.(*Endpoint).State(0xc0006c4858) gvisor/pkg/tcpip/transport/internal/network/endpoint.go:123 +0x3c gvisor/pkg/tcpip/transport/udp/udp.(*endpoint).HandleError(0xc0006c4840, {0x1c3a418, 0x2847498}, 0xc0006bdeea) gvisor/pkg/tcpip/transport/udp/endpoint.go:983 +0x5c gvisor/pkg/tcpip/stack/stack.(*endpointsByNIC).handleError(0xc00003dd70, 0xc0000f08c0, {0x75e1, {0xc0005da110, 0x10}, 0xdeea, {0xc0005da120, 0x10}}, {0x1c3a418, 0x2847498}, ...) gvisor/pkg/tcpip/stack/transport_demuxer.go:203 +0x254 gvisor/pkg/tcpip/stack/stack.(*transportDemuxer).deliverError(0xc00047c588, 0xc000688ca8, 0x86dd, 0x11, {0x1c3a418, 0x2847498}, 0xdf2345, {0x75e1, {0xc0005da110, 0x10}, ...}) gvisor/pkg/tcpip/stack/transport_demuxer.go:631 +0x205 gvisor/pkg/tcpip/stack/stack.(*nic).DeliverTransportError(0xc0000f08c0, {0xc0005da110, 0x10}, {0xc0005da120, 0x10}, 0x62c985, 0x0, {0x1c3a418, 0x2847498}, 0xc000299000) gvisor/pkg/tcpip/stack/nic.go:922 +0x253 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).handleControl(0xc00045d000, {0x1c3a418, 0x2847498}, 0xc000299000) gvisor/pkg/tcpip/network/ipv6/icmp.go:209 +0x3ac gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).handleICMP(0xc00045d000, 0xc000299000, 0x0, 0x10) gvisor/pkg/tcpip/network/ipv6/icmp.go:353 +0x96c gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).processExtensionHeaders(0xc00045d000, {0xc0005b7f0e, 0x28, 0x30}, 0xc000299000, 0x0) gvisor/pkg/tcpip/network/ipv6/ipv6.go:1554 +0x849 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).handleValidatedPacket(0xc00045d000, {0xc0005b7f0e, 0x28, 0x2b206370203a3033}, 0xc000299000, {0x18baf5d, 0x2}) gvisor/pkg/tcpip/network/ipv6/ipv6.go:1191 +0x396 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).HandlePacket(0xc00045d000, 0xc000031310) gvisor/pkg/tcpip/network/ipv6/ipv6.go:1107 +0x538 gvisor/pkg/tcpip/stack/stack.(*nic).DeliverNetworkPacket(0xc0000f08c0, {0x0, 0xc000688c38}, {0xc0005da09a, 0x6}, 0x86dd, 0xc000299000) gvisor/pkg/tcpip/stack/nic.go:779 +0x3fd gvisor/pkg/tcpip/link/nested/nested.(*Endpoint).DeliverNetworkPacket(0xc0003d1f10, {0xc0005da08a, 0x6}, {0xc0005da09a, 0x6}, 0x62c985, 0x962610) gvisor/pkg/tcpip/link/nested/nested.go:59 +0xd1 gvisor/pkg/tcpip/link/sniffer/sniffer.(*endpoint).DeliverNetworkPacket(0xc0003d1f10, {0xc0005da08a, 0x6}, {0xc0005da09a, 0x6}, 0x610f56, 0x6) gvisor/pkg/tcpip/link/sniffer/sniffer.go:140 +0x87 gvisor/pkg/tcpip/link/nested/nested.(*Endpoint).DeliverNetworkPacket(0xc0005200f0, {0xc0005da08a, 0x6}, {0xc0005da09a, 0x6}, 0x397800, 0x200) gvisor/pkg/tcpip/link/nested/nested.go:59 +0xd1 gvisor/pkg/tcpip/link/ethernet/ethernet.(*Endpoint).DeliverNetworkPacket(0xc0005200f0, {0xc0005032c0, 0x4}, {0x4, 0x26e}, 0x60d600, 0x6) gvisor/pkg/tcpip/link/ethernet/ethernet.go:63 +0x1ad gvisor/pkg/tcpip/link/loopback/loopback.(*endpoint).WriteRawPacket(0xc00019a540, 0xc000298f00) gvisor/pkg/tcpip/link/loopback/loopback.go:107 +0x191 gvisor/pkg/tcpip/link/loopback/loopback.(*endpoint).WritePacket(0x62c985, {{{0xc0005da060, 0x10}, {0xc0005da070, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/link/loopback/loopback.go:80 +0x37 gvisor/pkg/tcpip/link/nested/nested.(*Endpoint).WritePacket(...) gvisor/pkg/tcpip/link/nested/nested.go:107 gvisor/pkg/tcpip/link/ethernet/ethernet.(*Endpoint).WritePacket(0xc0005200f0, {{{0xc0005da060, 0x10}, {0xc0005da070, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/link/ethernet/ethernet.go:78 +0x142 gvisor/pkg/tcpip/link/nested/nested.(*Endpoint).WritePacket(...) gvisor/pkg/tcpip/link/nested/nested.go:107 gvisor/pkg/tcpip/link/sniffer/sniffer.(*endpoint).WritePacket(0xc0003d1f10, {{{0xc0005da060, 0x10}, {0xc0005da070, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/link/sniffer/sniffer.go:169 +0x108 gvisor/pkg/tcpip/stack/stack.(*nic).writePacket(0xc0000f08c0, {{{0xc0005da060, 0x10}, {0xc0005da070, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/stack/nic.go:380 +0x264 gvisor/pkg/tcpip/stack/stack.(*nic).writePacketBuffer(0xc0006c3540, {{{0xc0005da060, 0x10}, {0xc0005da070, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/stack/nic.go:324 +0xec gvisor/pkg/tcpip/stack/stack.(*nic).enqueuePacketBuffer(0xc0000f08c0, 0x62c985, 0xfc2c55, {0x1bfdac0, 0xc000298f00}) gvisor/pkg/tcpip/stack/nic.go:339 +0x234 gvisor/pkg/tcpip/stack/stack.(*nic).WritePacket(0xc000298f00, 0xffd8, 0x41a000, 0x4) gvisor/pkg/tcpip/stack/nic.go:317 +0x50 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).writePacket(0xc00045d000, 0xc0006c3540, 0xc000298f00, 0x3, 0x0) gvisor/pkg/tcpip/network/ipv6/ipv6.go:823 +0x427 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).WritePacket(0xc00045d000, 0xc0006c3540, {0x86dd, 0x0, 0x0}, 0xc000298f00) gvisor/pkg/tcpip/network/ipv6/ipv6.go:774 +0x2db gvisor/pkg/tcpip/stack/stack.(*Route).WritePacket(0xc0006c3540, {0x37a9f0, 0xc0, 0x0}, 0x86dd) gvisor/pkg/tcpip/stack/route.go:462 +0xe4 gvisor/pkg/tcpip/network/ipv6/ipv6.(*protocol).returnError(0xc000298400, {0x1c253e8, 0x2847498}, 0xc000298e00) gvisor/pkg/tcpip/network/ipv6/icmp.go:1277 +0x15f8 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).processExtensionHeaders(0xc00045d000, {0xc0005b7ece, 0x28, 0x30}, 0xc000298e00, 0x0) gvisor/pkg/tcpip/network/ipv6/ipv6.go:1565 +0x12e5 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).handleValidatedPacket(0xc00045d000, {0xc0005b7ece, 0x28, 0x0}, 0xc000298e00, {0x18baf5d, 0x2}) gvisor/pkg/tcpip/network/ipv6/ipv6.go:1191 +0x396 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).HandlePacket(0xc00045d000, 0xc0003df610) gvisor/pkg/tcpip/network/ipv6/ipv6.go:1107 +0x538 gvisor/pkg/tcpip/stack/stack.(*nic).DeliverNetworkPacket(0xc0000f08c0, {0x0, 0xc000688838}, {0xc000663fea, 0x6}, 0x86dd, 0xc000298e00) gvisor/pkg/tcpip/stack/nic.go:779 +0x3fd gvisor/pkg/tcpip/link/nested/nested.(*Endpoint).DeliverNetworkPacket(0xc0003d1f10, {0xc000663fda, 0x6}, {0xc000663fea, 0x6}, 0x62c985, 0x962610) gvisor/pkg/tcpip/link/nested/nested.go:59 +0xd1 gvisor/pkg/tcpip/link/sniffer/sniffer.(*endpoint).DeliverNetworkPacket(0xc0003d1f10, {0xc000663fda, 0x6}, {0xc000663fea, 0x6}, 0x610f56, 0x6) gvisor/pkg/tcpip/link/sniffer/sniffer.go:140 +0x87 gvisor/pkg/tcpip/link/nested/nested.(*Endpoint).DeliverNetworkPacket(0xc0005200f0, {0xc000663fda, 0x6}, {0xc000663fea, 0x6}, 0x397800, 0x200) gvisor/pkg/tcpip/link/nested/nested.go:59 +0xd1 gvisor/pkg/tcpip/link/ethernet/ethernet.(*Endpoint).DeliverNetworkPacket(0xc0005200f0, {0xc00003dec0, 0x2}, {0x2, 0x23e}, 0x60d600, 0x6) gvisor/pkg/tcpip/link/ethernet/ethernet.go:63 +0x1ad gvisor/pkg/tcpip/link/loopback/loopback.(*endpoint).WriteRawPacket(0xc00019a540, 0xc000298d00) gvisor/pkg/tcpip/link/loopback/loopback.go:107 +0x191 gvisor/pkg/tcpip/link/loopback/loopback.(*endpoint).WritePacket(0x62c985, {{{0xc000663fa0, 0x10}, {0xc000378f40, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/link/loopback/loopback.go:80 +0x37 gvisor/pkg/tcpip/link/nested/nested.(*Endpoint).WritePacket(...) gvisor/pkg/tcpip/link/nested/nested.go:107 gvisor/pkg/tcpip/link/ethernet/ethernet.(*Endpoint).WritePacket(0xc0005200f0, {{{0xc000663fa0, 0x10}, {0xc000378f40, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/link/ethernet/ethernet.go:78 +0x142 gvisor/pkg/tcpip/link/nested/nested.(*Endpoint).WritePacket(...) gvisor/pkg/tcpip/link/nested/nested.go:107 gvisor/pkg/tcpip/link/sniffer/sniffer.(*endpoint).WritePacket(0xc0003d1f10, {{{0xc000663fa0, 0x10}, {0xc000378f40, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/link/sniffer/sniffer.go:169 +0x108 gvisor/pkg/tcpip/stack/stack.(*nic).writePacket(0xc0000f08c0, {{{0xc000663fa0, 0x10}, {0xc000378f40, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/stack/nic.go:380 +0x264 gvisor/pkg/tcpip/stack/stack.(*nic).writePacketBuffer(0xc0006c2fa0, {{{0xc000663fa0, 0x10}, {0xc000378f40, 0x10}, {0x1bf6590, 0x6}, {0x0, 0x0}, 0x86dd, ...}, ...}, ...) gvisor/pkg/tcpip/stack/nic.go:324 +0xec gvisor/pkg/tcpip/stack/stack.(*nic).enqueuePacketBuffer(0xc0000f08c0, 0x62c985, 0xfc2c55, {0x1bfdac0, 0xc000298d00}) gvisor/pkg/tcpip/stack/nic.go:339 +0x234 gvisor/pkg/tcpip/stack/stack.(*nic).WritePacket(0xc000298d00, 0xffd8, 0x41a000, 0x4) gvisor/pkg/tcpip/stack/nic.go:317 +0x50 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).writePacket(0xc00045d000, 0xc0006c2fa0, 0xc000298d00, 0x3, 0x0) gvisor/pkg/tcpip/network/ipv6/ipv6.go:823 +0x427 gvisor/pkg/tcpip/network/ipv6/ipv6.(*endpoint).WritePacket(0xc00045d000, 0xc0006c2fa0, {0x86dd, 0x0, 0x0}, 0xc000298d00) gvisor/pkg/tcpip/network/ipv6/ipv6.go:774 +0x2db gvisor/pkg/tcpip/stack/stack.(*Route).WritePacket(0xc0006c2fa0, {0x2080000, 0xea, 0xde}, 0x6) gvisor/pkg/tcpip/stack/route.go:462 +0xe4 gvisor/pkg/tcpip/transport/internal/network/network.(*WriteContext).WritePacket(0xc0003e05e0, 0xc000298d00, 0x0) gvisor/pkg/tcpip/transport/internal/network/endpoint.go:212 +0x154 gvisor/pkg/tcpip/transport/udp/udp.(*endpoint).write(0xc0006c4840, {0x1c23ad0, 0xc0006cfd60}, {0xc0002ecf00, 0xf0, 0xdb, 0x3}) gvisor/pkg/tcpip/transport/udp/endpoint.go:457 +0x74c gvisor/pkg/tcpip/transport/udp/udp.(*endpoint).Write(0xc0006c4840, {0x1c23ad0, 0xc0006cfd60}, {0xc0002ecf00, 0x85, 0xc9, 0x62}) gvisor/pkg/tcpip/transport/udp/endpoint.go:323 +0x74 goroutine 133 [semacquire]: sync.runtime_Semacquire(0xc00003dd70) go/gc/src/runtime/sema.go:56 +0x25 gvisor/pkg/sync/sync.(*CrossGoroutineRWMutex).Lock(0xc00003dd70) gvisor/pkg/sync/rwmutex_unsafe.go:151 +0x79 gvisor/pkg/sync/sync.(*RWMutex).Lock(...) gvisor/pkg/sync/rwmutex_unsafe.go:286 gvisor/pkg/tcpip/stack/stack.(*endpointsByNIC).unregisterEndpoint(0xc00003dd70, 0x37a300, {0x1c3a558, 0xc0006c4840}, {0x0, 0x0, 0x0}) gvisor/pkg/tcpip/stack/transport_demuxer.go:246 +0x72 gvisor/pkg/tcpip/stack/stack.(*transportEndpoints).unregisterEndpoint(0xc0004b3f40, {0x75e1, {0x0, 0x0}, 0x0, {0x0, 0x0}}, {0x1c3a558, 0xc0006c4840}, {0x0, ...}, ...) gvisor/pkg/tcpip/stack/transport_demuxer.go:52 +0x193 gvisor/pkg/tcpip/stack/stack.(*transportDemuxer).unregisterEndpoint(0xc00047c588, {0xc000663fc8, 0x2, 0x0}, 0x11, {0x75e1, {0x0, 0x0}, 0x0, {0x0, ...}}, ...) gvisor/pkg/tcpip/stack/transport_demuxer.go:527 +0x1d4 gvisor/pkg/tcpip/stack/stack.(*Stack).UnregisterTransportEndpoint(...) gvisor/pkg/tcpip/stack/stack.go:1417 gvisor/pkg/tcpip/transport/udp/udp.(*endpoint).Connect.func1(0x86dd, {0x75e1, {0x0, 0x0}, 0x0, {0x0, 0x0}}, {0x75e1, {0x0, 0x0}, ...}) gvisor/pkg/tcpip/transport/udp/endpoint.go:619 +0x433 gvisor/pkg/tcpip/transport/internal/network/network.(*Endpoint).ConnectAndThen(0xc0006c4858, {0x0, {0xc000144270, 0xa0000eade88c0}, 0xabc5}, 0xc000353518) gvisor/pkg/tcpip/transport/internal/network/endpoint.go:408 +0x3cc gvisor/pkg/tcpip/transport/udp/udp.(*endpoint).Connect(0xc0006c4840, {0x37b9e0, {0xc000144270, 0xc000328a80}, 0xc1a0}) gvisor/pkg/tcpip/transport/udp/endpoint.go:593 +0x149 ``` PiperOrigin-RevId: 397412256 --- pkg/tcpip/transport/internal/network/endpoint.go | 168 ++++++++++++++------- .../transport/internal/network/endpoint_state.go | 16 +- 2 files changed, 120 insertions(+), 64 deletions(-) diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 3cb821475..e3094f59f 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -18,6 +18,7 @@ package network import ( "fmt" + "sync/atomic" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -39,11 +40,7 @@ type Endpoint struct { mu sync.RWMutex `state:"nosave"` // +checklocks:mu - state transport.DatagramEndpointState - // +checklocks:mu wasBound bool - // +checklocks:mu - info stack.TransportEndpointInfo // owner is the owner of transmitted packets. // // +checklocks:mu @@ -72,6 +69,34 @@ type Endpoint struct { ipv4TOS uint8 // +checklocks:mu ipv6TClass uint8 + + // Lock ordering: mu > infoMu. + infoMu sync.RWMutex `state:"nosave"` + // info has a dedicated mutex so that we can avoid lock ordering violations + // when reading the endpoint's info. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling Info() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setInfo. + // + // +checklocks:infoMu + info stack.TransportEndpointInfo + + // state holds a transport.DatagramBasedEndpointState. + // + // state must be accessed with atomics so that we can avoid lock ordering + // violations when reading the state. If we used mu, we need to guarantee + // that any lock taken while mu is held is not held when calling State() + // which is not true as of writing (we hold mu while registering transport + // endpoints (taking the transport demuxer lock but we also hold the demuxer + // lock when delivering packets/errors to endpoints). + // + // Writes must be performed through setEndpointState. + // + // +checkatomics + state uint32 } // +stateify savable @@ -101,7 +126,6 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr netProto: netProto, transProto: transProto, - state: transport.DatagramEndpointStateInitial, info: stack.TransportEndpointInfo{ NetProto: netProto, TransProto: transProto, @@ -111,6 +135,10 @@ func (e *Endpoint) Init(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, tr multicastTTL: 1, multicastMemberships: make(map[multicastMembership]struct{}), } + + e.mu.Lock() + defer e.mu.Unlock() + e.setEndpointState(transport.DatagramEndpointStateInitial) } // NetProto returns the network protocol the endpoint was initialized with. @@ -118,11 +146,19 @@ func (e *Endpoint) NetProto() tcpip.NetworkProtocolNumber { return e.netProto } +// setEndpointState sets the state of the endpoint. +// +// e.mu must be held to synchronize changes to state with the rest of the +// endpoint. +// +// +checklocks:e.mu +func (e *Endpoint) setEndpointState(state transport.DatagramEndpointState) { + atomic.StoreUint32(&e.state, uint32(state)) +} + // State returns the state of the endpoint. func (e *Endpoint) State() transport.DatagramEndpointState { - e.mu.RLock() - defer e.mu.RUnlock() - return e.state + return transport.DatagramEndpointState(atomic.LoadUint32(&e.state)) } // Close cleans the endpoint's resources and leaves the endpoint in a closed @@ -131,7 +167,7 @@ func (e *Endpoint) Close() { e.mu.Lock() defer e.mu.Unlock() - if e.state == transport.DatagramEndpointStateClosed { + if e.State() == transport.DatagramEndpointStateClosed { return } @@ -145,7 +181,7 @@ func (e *Endpoint) Close() { e.connectedRoute = nil } - e.state = transport.DatagramEndpointStateClosed + e.setEndpointState(transport.DatagramEndpointStateClosed) } // SetOwner sets the owner of transmitted packets. @@ -226,7 +262,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext return WriteContext{}, &tcpip.ErrInvalidOptionValue{} } - if e.state == transport.DatagramEndpointStateClosed { + if e.State() == transport.DatagramEndpointStateClosed { return WriteContext{}, &tcpip.ErrInvalidEndpointState{} } @@ -238,7 +274,7 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext if opts.To == nil { // If the user doesn't specify a destination, they should have // connected to another address. - if e.state != transport.DatagramEndpointStateConnected { + if e.State() != transport.DatagramEndpointStateConnected { return WriteContext{}, &tcpip.ErrDestinationRequired{} } @@ -250,18 +286,19 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext if nicID == 0 { nicID = tcpip.NICID(e.ops.GetBindToDevice()) } - if e.info.BindNICID != 0 { - if nicID != 0 && nicID != e.info.BindNICID { + info := e.Info() + if info.BindNICID != 0 { + if nicID != 0 && nicID != info.BindNICID { return WriteContext{}, &tcpip.ErrNoRoute{} } - nicID = e.info.BindNICID + nicID = info.BindNICID } if nicID == 0 { - nicID = e.info.RegisterNICID + nicID = info.RegisterNICID } - dst, netProto, err := e.checkV4MappedRLocked(*opts.To) + dst, netProto, err := e.checkV4Mapped(*opts.To) if err != nil { return WriteContext{}, err } @@ -301,20 +338,22 @@ func (e *Endpoint) Disconnect() { e.mu.Lock() defer e.mu.Unlock() - if e.state != transport.DatagramEndpointStateConnected { + if e.State() != transport.DatagramEndpointStateConnected { return } + info := e.Info() // Exclude ephemerally bound endpoints. if e.wasBound { - e.info.ID = stack.TransportEndpointID{ - LocalAddress: e.info.BindAddr, + info.ID = stack.TransportEndpointID{ + LocalAddress: info.BindAddr, } - e.state = transport.DatagramEndpointStateBound + e.setEndpointState(transport.DatagramEndpointStateBound) } else { - e.info.ID = stack.TransportEndpointID{} - e.state = transport.DatagramEndpointStateInitial + info.ID = stack.TransportEndpointID{} + e.setEndpointState(transport.DatagramEndpointStateInitial) } + e.setInfo(info) e.connectedRoute.Release() e.connectedRoute = nil @@ -327,7 +366,7 @@ func (e *Endpoint) Disconnect() { // TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. // +checklocks:e.mu func (e *Endpoint) connectRouteRLocked(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (*stack.Route, tcpip.NICID, tcpip.Error) { - localAddr := e.info.ID.LocalAddress + localAddr := e.Info().ID.LocalAddress if e.isBroadcastOrMulticast(nicID, netProto, localAddr) { // A packet can only originate from a unicast address (i.e., an interface). localAddr = "" @@ -370,24 +409,25 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.mu.Lock() defer e.mu.Unlock() + info := e.Info() nicID := addr.NIC - switch e.state { + switch e.State() { case transport.DatagramEndpointStateInitial: case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: - if e.info.BindNICID == 0 { + if info.BindNICID == 0 { break } - if nicID != 0 && nicID != e.info.BindNICID { + if nicID != 0 && nicID != info.BindNICID { return &tcpip.ErrInvalidEndpointState{} } - nicID = e.info.BindNICID + nicID = info.BindNICID default: return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedRLocked(addr) + addr, netProto, err := e.checkV4Mapped(addr) if err != nil { return err } @@ -398,14 +438,14 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. } id := stack.TransportEndpointID{ - LocalAddress: e.info.ID.LocalAddress, + LocalAddress: info.ID.LocalAddress, RemoteAddress: r.RemoteAddress(), } - if e.state == transport.DatagramEndpointStateInitial { + if e.State() == transport.DatagramEndpointStateInitial { id.LocalAddress = r.LocalAddress() } - if err := f(r.NetProto(), e.info.ID, id); err != nil { + if err := f(r.NetProto(), info.ID, id); err != nil { return err } @@ -414,10 +454,11 @@ func (e *Endpoint) ConnectAndThen(addr tcpip.FullAddress, f func(netProto tcpip. e.connectedRoute.Release() } e.connectedRoute = r - e.info.ID = id - e.info.RegisterNICID = nicID + info.ID = id + info.RegisterNICID = nicID + e.setInfo(info) e.effectiveNetProto = netProto - e.state = transport.DatagramEndpointStateConnected + e.setEndpointState(transport.DatagramEndpointStateConnected) return nil } @@ -426,7 +467,7 @@ func (e *Endpoint) Shutdown() tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - switch state := e.state; state { + switch state := e.State(); state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: @@ -439,11 +480,9 @@ func (e *Endpoint) Shutdown() tcpip.Error { // checkV4MappedRLocked determines the effective network protocol and converts // addr to its canonical form. -// -// TODO(https://gvisor.dev/issue/6590): Annotate read lock requirement. -// +checklocks:e.mu -func (e *Endpoint) checkV4MappedRLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) +func (e *Endpoint) checkV4Mapped(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { + info := e.Info() + unwrapped, netProto, err := info.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -474,11 +513,11 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != transport.DatagramEndpointStateInitial { + if e.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedRLocked(addr) + addr, netProto, err := e.checkV4Mapped(addr) if err != nil { return err } @@ -497,14 +536,16 @@ func (e *Endpoint) BindAndThen(addr tcpip.FullAddress, f func(tcpip.NetworkProto e.wasBound = true - e.info.ID = stack.TransportEndpointID{ + info := e.Info() + info.ID = stack.TransportEndpointID{ LocalAddress: addr.Addr, } - e.info.BindNICID = addr.NIC - e.info.RegisterNICID = nicID - e.info.BindAddr = addr.Addr + info.BindNICID = addr.NIC + info.RegisterNICID = nicID + info.BindAddr = addr.Addr + e.setInfo(info) e.effectiveNetProto = netProto - e.state = transport.DatagramEndpointStateBound + e.setEndpointState(transport.DatagramEndpointStateBound) return nil } @@ -520,13 +561,14 @@ func (e *Endpoint) GetLocalAddress() tcpip.FullAddress { e.mu.RLock() defer e.mu.RUnlock() - addr := e.info.BindAddr - if e.state == transport.DatagramEndpointStateConnected { + info := e.Info() + addr := info.BindAddr + if e.State() == transport.DatagramEndpointStateConnected { addr = e.connectedRoute.LocalAddress() } return tcpip.FullAddress{ - NIC: e.info.RegisterNICID, + NIC: info.RegisterNICID, Addr: addr, } } @@ -536,13 +578,13 @@ func (e *Endpoint) GetRemoteAddress() (tcpip.FullAddress, bool) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != transport.DatagramEndpointStateConnected { + if e.State() != transport.DatagramEndpointStateConnected { return tcpip.FullAddress{}, false } return tcpip.FullAddress{ Addr: e.connectedRoute.RemoteAddress(), - NIC: e.info.RegisterNICID, + NIC: e.Info().RegisterNICID, }, true } @@ -624,7 +666,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { defer e.mu.Unlock() fa := tcpip.FullAddress{Addr: v.InterfaceAddr} - fa, netProto, err := e.checkV4MappedRLocked(fa) + fa, netProto, err := e.checkV4Mapped(fa) if err != nil { return err } @@ -648,7 +690,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } } - if e.info.BindNICID != 0 && e.info.BindNICID != nic { + if info := e.Info(); info.BindNICID != 0 && info.BindNICID != nic { return &tcpip.ErrInvalidEndpointState{} } @@ -751,7 +793,19 @@ func (e *Endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { // Info returns a copy of the endpoint info. func (e *Endpoint) Info() stack.TransportEndpointInfo { - e.mu.RLock() - defer e.mu.RUnlock() + e.infoMu.RLock() + defer e.infoMu.RUnlock() return e.info } + +// setInfo sets the endpoint's info. +// +// e.mu must be held to synchronize changes to info with the rest of the +// endpoint. +// +// +checklocks:e.mu +func (e *Endpoint) setInfo(info stack.TransportEndpointInfo) { + e.infoMu.Lock() + defer e.infoMu.Unlock() + e.info = info +} diff --git a/pkg/tcpip/transport/internal/network/endpoint_state.go b/pkg/tcpip/transport/internal/network/endpoint_state.go index 173197512..68bd1fbf6 100644 --- a/pkg/tcpip/transport/internal/network/endpoint_state.go +++ b/pkg/tcpip/transport/internal/network/endpoint_state.go @@ -35,22 +35,24 @@ func (e *Endpoint) Resume(s *stack.Stack) { } } - switch e.state { + info := e.Info() + + switch state := e.State(); state { case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: case transport.DatagramEndpointStateBound: - if len(e.info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) { - if e.stack.CheckLocalAddress(e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress) == 0 { - panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", e.info.RegisterNICID, e.effectiveNetProto, e.info.ID.LocalAddress)) + if len(info.ID.LocalAddress) != 0 && !e.isBroadcastOrMulticast(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) { + if e.stack.CheckLocalAddress(info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress) == 0 { + panic(fmt.Sprintf("got e.stack.CheckLocalAddress(%d, %d, %s) = 0, want != 0", info.RegisterNICID, e.effectiveNetProto, info.ID.LocalAddress)) } } case transport.DatagramEndpointStateConnected: var err tcpip.Error multicastLoop := e.ops.GetMulticastLoop() - e.connectedRoute, err = e.stack.FindRoute(e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) + e.connectedRoute, err = e.stack.FindRoute(info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop) if err != nil { - panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", e.info.RegisterNICID, e.info.ID.LocalAddress, e.info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) + panic(fmt.Sprintf("e.stack.FindRoute(%d, %s, %s, %d, %t): %s", info.RegisterNICID, info.ID.LocalAddress, info.ID.RemoteAddress, e.effectiveNetProto, multicastLoop, err)) } default: - panic(fmt.Sprintf("unhandled state = %s", e.state)) + panic(fmt.Sprintf("unhandled state = %s", state)) } } -- cgit v1.2.3 From 7dacdbef528f7b556f23c1b02a360363dc556e31 Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Fri, 17 Sep 2021 15:31:19 -0700 Subject: Allow rebinding packet socket protocol ...to change the network protocol a packet socket may receive packets from. This CL is a portion of an originally larger CL that was split with https://github.com/google/gvisor/commit/a8ad692fd36cbaf7f5a6b9af39d601053dbee338 being the dependent CL. That CL (accidentally) included the change in the endpoint's `afterLoad` method to take the required lock when accessing the endpoint's netProto field. That change should have been in this CL. The CL that made the change mentioned in the commit message is cl/396946187. PiperOrigin-RevId: 397412582 --- pkg/sentry/socket/netstack/netstack.go | 5 +- pkg/tcpip/transport/packet/endpoint.go | 23 ++--- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/packet_socket.cc | 171 ++++++++++++++++++++++++++++++++- 4 files changed, 180 insertions(+), 20 deletions(-) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index f79bda922..aa081e90d 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -672,13 +672,10 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) - if a.Protocol != uint16(s.protocol) { - return syserr.ErrInvalidArgument - } - addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + Port: socket.Ntohs(a.Protocol), } } else { if s.minSockAddrLen() > len(sockaddr) { diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 2c9786175..1f30e5adb 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -59,13 +59,11 @@ type packet struct { // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` - netProto tcpip.NetworkProtocolNumber waiterQueue *waiter.Queue cooked bool ops tcpip.SocketOptions @@ -84,6 +82,8 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` // +checklocks:mu + netProto tcpip.NetworkProtocolNumber + // +checklocks:mu closed bool // +checklocks:mu bound bool @@ -98,10 +98,7 @@ type endpoint struct { // NewEndpoint returns a new packet endpoint. func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - }, + stack: s, cooked: cooked, netProto: netProto, waiterQueue: waiterQueue, @@ -214,13 +211,13 @@ func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tc ep.mu.Lock() closed := ep.closed nicID := ep.boundNIC + proto := ep.netProto ep.mu.Unlock() if closed { return 0, &tcpip.ErrClosedForSend{} } var remote tcpip.LinkAddress - proto := ep.netProto if to := opts.To; to != nil { remote = tcpip.LinkAddress(to.Addr) @@ -296,7 +293,8 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound && ep.boundNIC == addr.NIC { + netProto := tcpip.NetworkProtocolNumber(addr.Port) + if ep.bound && ep.boundNIC == addr.NIC && ep.netProto == netProto { // If the NIC being bound is the same then just return success. return nil } @@ -306,12 +304,13 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { ep.bound = false // Bind endpoint to receive packets from specific interface. - if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, netProto, ep); err != nil { return err } ep.bound = true ep.boundNIC = addr.NIC + ep.netProto = netProto return nil } @@ -473,10 +472,8 @@ func (*endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (ep *endpoint) Info() tcpip.EndpointInfo { ep.mu.RLock() - // Make a copy of the endpoint info. - ret := ep.TransportEndpointInfo - ep.mu.RUnlock() - return &ret + defer ep.mu.RUnlock() + return &stack.TransportEndpointInfo{NetProto: ep.netProto} } // Stats returns a pointer to the endpoint stats. diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 85fa58970..5efb3e620 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1487,6 +1487,7 @@ cc_binary( srcs = ["packet_socket.cc"], linkstatic = 1, deps = [ + ":ip_socket_test_util", "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:socket_util", diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc index c8d1e1d4a..43828a52e 100644 --- a/test/syscalls/linux/packet_socket.cc +++ b/test/syscalls/linux/packet_socket.cc @@ -14,10 +14,15 @@ #include #include +#include +#include +#include +#include #include #include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" #include "test/util/socket_util.h" @@ -27,10 +32,13 @@ namespace testing { namespace { +using ::testing::AnyOf; using ::testing::Combine; +using ::testing::Eq; using ::testing::Values; -class PacketSocketTest : public ::testing::TestWithParam> { +class PacketSocketCreationTest + : public ::testing::TestWithParam> { protected: void SetUp() override { if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) { @@ -42,18 +50,175 @@ class PacketSocketTest : public ::testing::TestWithParam> { } }; -TEST_P(PacketSocketTest, Create) { +TEST_P(PacketSocketCreationTest, Create) { const auto [type, protocol] = GetParam(); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, type, htons(protocol))); EXPECT_GE(fd.get(), 0); } -INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest, +INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketCreationTest, Combine(Values(SOCK_DGRAM, SOCK_RAW), Values(0, 1, 255, ETH_P_IP, ETH_P_IPV6, std::numeric_limits::max()))); +class PacketSocketTest : public ::testing::TestWithParam { + protected: + void SetUp() override { + if (!ASSERT_NO_ERRNO_AND_VALUE(HavePacketSocketCapability())) { + ASSERT_THAT(socket(AF_PACKET, GetParam(), 0), + SyscallFailsWithErrno(EPERM)); + GTEST_SKIP() << "Missing packet socket capability"; + } + + socket_ = ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, GetParam(), 0)); + } + + FileDescriptor socket_; +}; + +TEST_P(PacketSocketTest, RebindProtocol) { + const bool kEthHdrIncluded = GetParam() == SOCK_RAW; + + sockaddr_in udp_bind_addr = { + .sin_family = AF_INET, + .sin_addr = {.s_addr = htonl(INADDR_LOOPBACK)}, + }; + + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0)); + { + // Bind the socket so that we have something to send packets to. + // + // If we didn't do this, the UDP packets we send will be responded to with + // ICMP Destination Port Unreachable errors. + ASSERT_THAT( + bind(udp_sock.get(), reinterpret_cast(&udp_bind_addr), + sizeof(udp_bind_addr)), + SyscallSucceeds()); + socklen_t addrlen = sizeof(udp_bind_addr); + ASSERT_THAT( + getsockname(udp_sock.get(), reinterpret_cast(&udp_bind_addr), + &addrlen), + SyscallSucceeds()); + ASSERT_THAT(addrlen, sizeof(udp_bind_addr)); + } + + const int loopback_index = ASSERT_NO_ERRNO_AND_VALUE(GetLoopbackIndex()); + + auto send_udp_message = [&](const uint64_t v) { + ASSERT_THAT( + sendto(udp_sock.get(), reinterpret_cast(&v), sizeof(v), + 0 /* flags */, reinterpret_cast(&udp_bind_addr), + sizeof(udp_bind_addr)), + SyscallSucceeds()); + }; + + auto bind_to_network_protocol = [&](uint16_t protocol) { + const sockaddr_ll packet_bind_addr = { + .sll_family = AF_PACKET, + .sll_protocol = htons(protocol), + .sll_ifindex = loopback_index, + }; + + ASSERT_THAT(bind(socket_.get(), + reinterpret_cast(&packet_bind_addr), + sizeof(packet_bind_addr)), + SyscallSucceeds()); + }; + + auto test_recv = [&, this](const uint64_t v) { + constexpr int kInfiniteTimeout = -1; + pollfd pfd = { + .fd = socket_.get(), + .events = POLLIN, + }; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, kInfiniteTimeout), + SyscallSucceedsWithValue(1)); + + struct { + ethhdr eth; + iphdr ip; + udphdr udp; + uint64_t payload; + char unused; + } ABSL_ATTRIBUTE_PACKED read_pkt; + sockaddr_ll src; + socklen_t src_len = sizeof(src); + + char* buf = reinterpret_cast(&read_pkt); + size_t buflen = sizeof(read_pkt); + size_t expected_read_len = sizeof(read_pkt) - sizeof(read_pkt.unused); + if (!kEthHdrIncluded) { + buf += sizeof(read_pkt.eth); + buflen -= sizeof(read_pkt.eth); + expected_read_len -= sizeof(read_pkt.eth); + } + + ASSERT_THAT(recvfrom(socket_.get(), buf, buflen, 0, + reinterpret_cast(&src), &src_len), + SyscallSucceedsWithValue(expected_read_len)); + // sockaddr_ll ends with an 8 byte physical address field, but ethernet + // addresses only use 6 bytes. Linux used to return sizeof(sockaddr_ll)-2 + // here, but returns sizeof(sockaddr_ll) since + // https://github.com/torvalds/linux/commit/b2cf86e1563e33a14a1c69b3e508d15dc12f804c. + ASSERT_THAT(src_len, ::testing::AnyOf( + ::testing::Eq(sizeof(src)), + ::testing::Eq(sizeof(src) - sizeof(src.sll_addr) + + ETH_ALEN))); + EXPECT_EQ(src.sll_family, AF_PACKET); + EXPECT_EQ(src.sll_ifindex, loopback_index); + EXPECT_EQ(src.sll_halen, ETH_ALEN); + EXPECT_EQ(ntohs(src.sll_protocol), ETH_P_IP); + // This came from the loopback device, so the address is all 0s. + constexpr uint8_t allZeroesMAC[ETH_ALEN] = {}; + EXPECT_EQ(memcmp(src.sll_addr, allZeroesMAC, sizeof(allZeroesMAC)), 0); + if (kEthHdrIncluded) { + EXPECT_EQ(memcmp(read_pkt.eth.h_dest, allZeroesMAC, sizeof(allZeroesMAC)), + 0); + EXPECT_EQ( + memcmp(read_pkt.eth.h_source, allZeroesMAC, sizeof(allZeroesMAC)), 0); + EXPECT_EQ(ntohs(read_pkt.eth.h_proto), ETH_P_IP); + } + // IHL hold the size of the header in 4 byte units. + EXPECT_EQ(read_pkt.ip.ihl, sizeof(iphdr) / 4); + EXPECT_EQ(read_pkt.ip.version, IPVERSION); + const uint16_t ip_pkt_size = + sizeof(read_pkt) - sizeof(read_pkt.eth) - sizeof(read_pkt.unused); + EXPECT_EQ(ntohs(read_pkt.ip.tot_len), ip_pkt_size); + EXPECT_EQ(read_pkt.ip.protocol, IPPROTO_UDP); + EXPECT_EQ(ntohl(read_pkt.ip.daddr), INADDR_LOOPBACK); + EXPECT_EQ(ntohl(read_pkt.ip.saddr), INADDR_LOOPBACK); + EXPECT_EQ(read_pkt.udp.source, udp_bind_addr.sin_port); + EXPECT_EQ(read_pkt.udp.dest, udp_bind_addr.sin_port); + EXPECT_EQ(ntohs(read_pkt.udp.len), ip_pkt_size - sizeof(read_pkt.ip)); + EXPECT_EQ(read_pkt.payload, v); + }; + + // The packet socket is not bound to IPv4 so we should not receive the sent + // message. + uint64_t counter = 0; + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + + // Bind to IPv4 and expect to receive the UDP packet we send after binding. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + ASSERT_NO_FATAL_FAILURE(test_recv(counter)); + + // Bind the packet socket to a random protocol. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(255)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + + // Bind back to IPv4 and expect to the UDP packet we send after binding + // back to IPv4. + ASSERT_NO_FATAL_FAILURE(bind_to_network_protocol(ETH_P_IP)); + ASSERT_NO_FATAL_FAILURE(send_udp_message(++counter)); + ASSERT_NO_FATAL_FAILURE(test_recv(counter)); +} + +INSTANTIATE_TEST_SUITE_P(AllPacketSocketTests, PacketSocketTest, + Values(SOCK_DGRAM, SOCK_RAW)); + } // namespace } // namespace testing -- cgit v1.2.3