diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/checker/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/checker/checker.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/link/qdisc/fifo/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4_test.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/stack/packet_buffer.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 99 |
11 files changed, 160 insertions, 28 deletions
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index ed434807f..c984470e6 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -12,5 +12,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index ee264b726..1e5f5abf2 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -169,10 +170,9 @@ func ReceiveTClass(want uint32) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTClass { - t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want) - } - if got := cm.TClass; got != want { - t.Fatalf("got cm.TClass = %d, want %d", got, want) + t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) + } else if got := cm.TClass; got != want { + t.Errorf("got cm.TClass = %d, want %d", got, want) } } } @@ -182,10 +182,22 @@ func ReceiveTOS(want uint8) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTOS { - t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) + } else if got := cm.TOS; got != want { + t.Errorf("got cm.TOS = %d, want %d", got, want) } - if got := cm.TOS; got != want { - t.Fatalf("got cm.TOS = %d, want %d", got, want) + } +} + +// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in +// ControlMessages. +func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasIPPacketInfo { + t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) + } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { + t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) } } } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 467083239..fc1e34fc7 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -199,7 +199,7 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // TODO(gvisor.dev/issue/3267/): Queue these packets as well once + // TODO(gvisor.dev/issue/3267): Queue these packets as well once // WriteRawPacket takes PacketBuffer instead of VectorisedView. return e.lower.WriteRawPacket(vv) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6c4f0ae3e..9ff27a363 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -173,9 +173,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + NetworkProtocolNumber: header.IPv4ProtocolNumber, }); err != nil { return err } @@ -192,9 +193,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + NetworkProtocolNumber: header.IPv4ProtocolNumber, }); err != nil { return err } @@ -206,9 +208,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: startOfHdr, - Data: emptyVV, - NetworkHeader: buffer.View(h), + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + NetworkProtocolNumber: header.IPv4ProtocolNumber, }); err != nil { return err } @@ -249,10 +252,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) ipt := e.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. @@ -304,6 +308,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pkt := pkts.Front(); pkt != nil; { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber pkt = pkt.Next() } @@ -570,6 +575,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu parseTransportHeader = false } + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber pkt.NetworkHeader = hdr pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ded97ac64..63e2c36c2 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -150,6 +150,9 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } + if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { + t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + } if i < len(packets)-1 { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) } else { @@ -285,7 +288,8 @@ func TestFragmentation(t *testing.T) { source := &stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), + Data: payload.Clone(nil), + NetworkProtocolNumber: header.IPv4ProtocolNumber, } c := buildContext(t, nil, ft.mtu) err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4a0b53c45..d7d7fc611 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -117,6 +117,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in @@ -152,6 +153,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pb := pkts.Front(); pb != nil; pb = pb.Next() { ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) pb.NetworkHeader = buffer.View(ip) + pb.NetworkProtocolNumber = header.IPv6ProtocolNumber } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -586,6 +588,7 @@ traverseExtensions: } ipHdr = header.IPv6(hdr) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber pkt.NetworkHeader = hdr pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.PayloadLength())) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 5d6865e35..9e871f968 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -62,6 +62,11 @@ type PacketBuffer struct { NetworkHeader buffer.View TransportHeader buffer.View + // NetworkProtocol is only valid when NetworkHeader is set. + // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol + // numbers in registration APIs that take a PacketBuffer. + NetworkProtocolNumber tcpip.NetworkProtocolNumber + // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. Hash uint32 @@ -72,9 +77,8 @@ type PacketBuffer struct { // The following fields are only set by the qdisc layer when the packet // is added to a queue. - EgressRoute *Route - GSOOptions *GSO - NetworkProtocolNumber tcpip.NetworkProtocolNumber + EgressRoute *Route + GSOOptions *GSO // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 8604c4259..4570e8969 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -249,8 +249,8 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It takes ownership of pkt. pkt.TransportHeader must have already - // been set. + // protocol. It takes ownership of pkt. pkt.TransportHeader must have + // already been set. WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 45f59b60f..091bc5281 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -968,7 +968,7 @@ type IPPacketInfo struct { // LocalAddr is the local address. LocalAddr Address - // DestinationAddr is the destination address. + // DestinationAddr is the destination address found in the IP header. DestinationAddr Address } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 444b5b01c..4a2b6c03a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1444,13 +1444,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk switch r.NetProto { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() - packet.packetInfo.LocalAddr = r.LocalAddress - packet.packetInfo.DestinationAddr = r.RemoteAddress - packet.packetInfo.NIC = r.NICID() case header.IPv6ProtocolNumber: packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } + // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast + // address. packetInfo.LocalAddr should hold a unicast address that can be + // used to respond to the incoming packet. + packet.packetInfo.LocalAddr = r.LocalAddress + packet.packetInfo.DestinationAddr = r.LocalAddress + packet.packetInfo.NIC = r.NICID() packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 66e8911c8..1a32622ca 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1309,6 +1309,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { } } +func TestReadIPPacketInfo(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedLocalAddr tcpip.Address + expectedDestAddr tcpip.Address + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedLocalAddr: stackAddr, + expectedDestAddr: stackAddr, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastAddr, + expectedDestAddr: multicastAddr, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: broadcastAddr, + expectedDestAddr: broadcastAddr, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedLocalAddr: stackV6Addr, + expectedDestAddr: stackV6Addr, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastV6Addr, + expectedDestAddr: multicastV6Addr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(%+v): %s:", ifoptSet, err) + } + } + + if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil { + t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err) + } + + testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: 1, + LocalAddr: test.expectedLocalAddr, + DestinationAddr: test.expectedDestAddr, + })) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() |