diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 42 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 104 |
7 files changed, 153 insertions, 27 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 62d1acad4..678f4e016 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + } return nil } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index a8f8454dd..57b7f5c19 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -278,7 +278,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 766c7648e..c2e9fd29f 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -63,6 +63,7 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool + hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -108,6 +109,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, + hdrIncluded: !associated, } // Override with stack defaults. @@ -182,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !e.associated { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue - } - e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -263,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !e.associated { + if e.hdrIncluded { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -353,7 +351,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - if !e.associated { + if e.hdrIncluded { if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ Data: buffer.View(payloadBytes).ToVectorisedView(), }); err != nil { @@ -508,11 +506,24 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } return tcpip.ErrUnknownProtocolOption } @@ -577,6 +588,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.KeepaliveEnabledOption: return false, nil + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -616,8 +633,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() - // Drop the packet if our buffer is currently full. - if e.rcvClosed { + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index caac6ef57..83dc10ed0 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1792,6 +1792,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.deferAccept = time.Duration(v) e.UnlockUser() + case tcpip.SocketDetachFilterOption: + return nil + default: return nil } diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 12bc1b5b5..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0584ec8dc..a14643ae8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -816,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() + + case tcpip.SocketDetachFilterOption: + return nil } return nil } @@ -1377,6 +1380,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } + // Never receive from a multicast address. + if header.IsV4MulticastAddress(id.RemoteAddress) || + header.IsV6MulticastAddress(id.RemoteAddress) { + e.stack.Stats().UDP.InvalidSourceAddress.Increment() + e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + // Verify checksum unless RX checksum offload is enabled. // On IPv4, UDP checksum is optional, and a zero value means // the transmitter omitted the checksum generation (RFC768). @@ -1395,10 +1407,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk } } - e.rcvMu.Lock() e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 91ba031fa..90781cf49 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -83,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -117,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -872,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + +// TestReadFromMulticaststats checks that a discarded packet +// that that was sent with multicast SOURCE address increments +// the correct counters and that a regular packet does not. +func TestReadFromMulticastStats(t *testing.T) { + t.Helper() + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + c.injectPacket(flow, payload) + + var want uint64 = 0 + if flow.isReverseMulticast() { + want = 1 + } + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { + t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { + t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { |