diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/checker/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/checker/checker.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 99 |
5 files changed, 126 insertions, 11 deletions
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index ed434807f..c984470e6 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -12,5 +12,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index ee264b726..1e5f5abf2 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -169,10 +170,9 @@ func ReceiveTClass(want uint32) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTClass { - t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want) - } - if got := cm.TClass; got != want { - t.Fatalf("got cm.TClass = %d, want %d", got, want) + t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) + } else if got := cm.TClass; got != want { + t.Errorf("got cm.TClass = %d, want %d", got, want) } } } @@ -182,10 +182,22 @@ func ReceiveTOS(want uint8) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTOS { - t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) + } else if got := cm.TOS; got != want { + t.Errorf("got cm.TOS = %d, want %d", got, want) } - if got := cm.TOS; got != want { - t.Fatalf("got cm.TOS = %d, want %d", got, want) + } +} + +// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in +// ControlMessages. +func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasIPPacketInfo { + t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) + } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { + t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) } } } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 45f59b60f..091bc5281 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -968,7 +968,7 @@ type IPPacketInfo struct { // LocalAddr is the local address. LocalAddr Address - // DestinationAddr is the destination address. + // DestinationAddr is the destination address found in the IP header. DestinationAddr Address } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 444b5b01c..4a2b6c03a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1444,13 +1444,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk switch r.NetProto { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() - packet.packetInfo.LocalAddr = r.LocalAddress - packet.packetInfo.DestinationAddr = r.RemoteAddress - packet.packetInfo.NIC = r.NICID() case header.IPv6ProtocolNumber: packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } + // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast + // address. packetInfo.LocalAddr should hold a unicast address that can be + // used to respond to the incoming packet. + packet.packetInfo.LocalAddr = r.LocalAddress + packet.packetInfo.DestinationAddr = r.LocalAddress + packet.packetInfo.NIC = r.NICID() packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 66e8911c8..1a32622ca 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1309,6 +1309,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { } } +func TestReadIPPacketInfo(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedLocalAddr tcpip.Address + expectedDestAddr tcpip.Address + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedLocalAddr: stackAddr, + expectedDestAddr: stackAddr, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastAddr, + expectedDestAddr: multicastAddr, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: broadcastAddr, + expectedDestAddr: broadcastAddr, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedLocalAddr: stackV6Addr, + expectedDestAddr: stackV6Addr, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastV6Addr, + expectedDestAddr: multicastV6Addr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(%+v): %s:", ifoptSet, err) + } + } + + if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil { + t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err) + } + + testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: 1, + LocalAddr: test.expectedLocalAddr, + DestinationAddr: test.expectedDestAddr, + })) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() |