summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go45
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go204
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go34
3 files changed, 269 insertions, 14 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index cb805522b..67c21be42 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -389,8 +389,8 @@ var loopbackSubnet = func() tcpip.Subnet {
}()
// deliverPacket attempts to find one or more matching transport endpoints, and
-// then, if matches are found, delivers the packet to them. Returns true if it
-// found one or more endpoints, false otherwise.
+// then, if matches are found, delivers the packet to them. Returns true if
+// the packet no longer needs to be handled.
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
@@ -400,13 +400,38 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RLock()
// Determine which transport endpoint or endpoints to deliver this packet to.
- // If the packet is a broadcast or multicast, then find all matching
- // transport endpoints.
+ // If the packet is a UDP broadcast or multicast, then find all matching
+ // transport endpoints. If the packet is a TCP packet with a non-unicast
+ // source or destination address, then do nothing further and instruct
+ // the caller to do the same.
var destEps []*endpointsByNic
- if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
- destEps = d.findAllEndpointsLocked(eps, id)
- } else if ep := d.findEndpointLocked(eps, id); ep != nil {
- destEps = append(destEps, ep)
+ switch protocol {
+ case header.UDPProtocolNumber:
+ if isMulticastOrBroadcast(id.LocalAddress) {
+ destEps = d.findAllEndpointsLocked(eps, id)
+ break
+ }
+
+ if ep := d.findEndpointLocked(eps, id); ep != nil {
+ destEps = append(destEps, ep)
+ }
+
+ case header.TCPProtocolNumber:
+ if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
+ // TCP can only be used to communicate between a single
+ // source and a single destination; the addresses must
+ // be unicast.
+ eps.mu.RUnlock()
+ r.Stats().TCP.InvalidSegmentsReceived.Increment()
+ return true
+ }
+
+ fallthrough
+
+ default:
+ if ep := d.findEndpointLocked(eps, id); ep != nil {
+ destEps = append(destEps, ep)
+ }
}
eps.mu.RUnlock()
@@ -587,3 +612,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
func isMulticastOrBroadcast(addr tcpip.Address) bool {
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
}
+
+func isUnicast(addr tcpip.Address) bool {
+ return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 84579ce52..b443fe9dc 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -4242,6 +4242,210 @@ func TestListenBacklogFull(t *testing.T) {
}
}
+// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
+// non unicast IPv4 address are not accepted.
+func TestListenNoAcceptNonUnicastV4(t *testing.T) {
+ multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
+ otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv4Any,
+ context.StackAddr,
+ },
+ {
+ "SourceBroadcast",
+ header.IPv4Broadcast,
+ context.StackAddr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackAddr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackAddr,
+ },
+ {
+ "DestUnspecified",
+ context.TestAddr,
+ header.IPv4Any,
+ },
+ {
+ "DestBroadcast",
+ context.TestAddr,
+ header.IPv4Broadcast,
+ },
+ {
+ "DestOurMulticast",
+ context.TestAddr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestAddr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+
+ if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendPacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestAddr, context.StackAddr)
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
+ }
+}
+
+// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
+// non unicast IPv6 address are not accepted.
+func TestListenNoAcceptNonUnicastV6(t *testing.T) {
+ multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
+ otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
+
+ tests := []struct {
+ name string
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ }{
+ {
+ "SourceUnspecified",
+ header.IPv6Any,
+ context.StackV6Addr,
+ },
+ {
+ "SourceAllNodes",
+ header.IPv6AllNodesMulticastAddress,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOurMulticast",
+ multicastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "SourceOtherMulticast",
+ otherMulticastAddr,
+ context.StackV6Addr,
+ },
+ {
+ "DestUnspecified",
+ context.TestV6Addr,
+ header.IPv6Any,
+ },
+ {
+ "DestAllNodes",
+ context.TestV6Addr,
+ header.IPv6AllNodesMulticastAddress,
+ },
+ {
+ "DestOurMulticast",
+ context.TestV6Addr,
+ multicastAddr,
+ },
+ {
+ "DestOtherMulticast",
+ context.TestV6Addr,
+ otherMulticastAddr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ t.Parallel()
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
+ t.Fatalf("JoinGroup failed: %s", err)
+ }
+
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ if err := c.EP.Listen(1); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ irs := seqnum.Value(789)
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, test.srcAddr, test.dstAddr)
+ c.CheckNoPacket("Should not have received a response")
+
+ // Handle normal packet.
+ c.SendV6PacketWithAddrs(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ }, context.TestV6Addr, context.StackV6Addr)
+ checker.IPv6(t, c.GetV6Packet(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1)))
+ })
+ }
+}
+
func TestListenSynRcvdQueueFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 4854e719d..0a733fa94 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -309,6 +309,12 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
// BuildSegment builds a TCP segment based on the given Headers and payload.
func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
+ return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
+}
+
+// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
+// payload and source and destination IPv4 addresses.
+func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -321,8 +327,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(tcp.ProtocolNumber),
- SrcAddr: TestAddr,
- DstAddr: StackAddr,
+ SrcAddr: src,
+ DstAddr: dst,
})
ip.SetChecksum(^ip.CalculateChecksum())
@@ -339,7 +345,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -365,6 +371,15 @@ func (c *Context) SendPacket(payload []byte, h *Headers) {
})
}
+// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
+// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
+// provided source and destination IPv4 addresses.
+func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
+ })
+}
+
// SendAck sends an ACK packet.
func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
c.SendAckWithSACK(seq, bytesReceived, nil)
@@ -490,6 +505,13 @@ func (c *Context) GetV6Packet() []byte {
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
// the context.
func (c *Context) SendV6Packet(payload []byte, h *Headers) {
+ c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
+}
+
+// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
+// endpoint of the context using the provided source and destination IPv6
+// addresses.
+func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -500,8 +522,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
NextHeader: uint8(tcp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: TestV6Addr,
- DstAddr: StackV6Addr,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
@@ -517,7 +539,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)