diff options
-rw-r--r-- | pkg/tcpip/network/BUILD | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ip_test.go | 93 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 3 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 3 |
4 files changed, 99 insertions, 2 deletions
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD index 7b1ff44f4..c0179104a 100644 --- a/pkg/tcpip/network/BUILD +++ b/pkg/tcpip/network/BUILD @@ -23,8 +23,10 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/testutil", "//pkg/tcpip/transport/icmp", + "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index 771b9173a..2179302d3 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "bytes" "fmt" "strings" "testing" @@ -32,8 +33,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/testutil" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) const nicID = 1 @@ -2032,3 +2035,93 @@ func TestJoinLeaveAllRoutersGroup(t *testing.T) { }) } } + +func TestSetNICIDBeforeDeliveringToRawEndpoint(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.AddressWithPrefix + payloadOffset int + }{ + { + name: "IPv4", + proto: header.IPv4ProtocolNumber, + addr: localIPv4AddrWithPrefix, + payloadOffset: header.IPv4MinimumSize, + }, + { + name: "IPv6", + proto: header.IPv6ProtocolNumber, + addr: localIPv6AddrWithPrefix, + payloadOffset: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, + }, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + RawFactory: raw.EndpointFactory{}, + }) + if err := s.CreateNIC(nicID, loopback.New()); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddressWithPrefix(nicID, test.proto, test.addr); err != nil { + t.Fatalf("AddAddressWithPrefix(%d, %d, %s): %s", nicID, test.proto, test.addr, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: test.addr.Subnet(), + NIC: nicID, + }, + }) + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + ep, err := s.NewRawEndpoint(udp.ProtocolNumber, test.proto, &wq, true /* associated */) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.proto, err) + } + defer ep.Close() + + writeOpts := tcpip.WriteOptions{ + To: &tcpip.FullAddress{ + Addr: test.addr.Address, + }, + } + data := []byte{1, 2, 3, 4} + var r bytes.Reader + r.Reset(data) + if n, err := ep.Write(&r, writeOpts); err != nil { + t.Fatalf("ep.Write(_, _): %s", err) + } else if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, _) = (%d, nil), want = (%d, nil)", n, want) + } + + // Wait for the endpoint to become readable. + <-ch + + var w bytes.Buffer + rr, err := ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if err != nil { + t.Fatalf("ep.Read(...): %s", err) + } + if diff := cmp.Diff(data, w.Bytes()[test.payloadOffset:]); diff != "" { + t.Errorf("payload mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tcpip.FullAddress{Addr: test.addr.Address, NIC: nicID}, rr.RemoteAddr); diff != "" { + t.Errorf("remote addr mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 44c85bdb8..e2472c851 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -856,6 +856,8 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum } func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) { + pkt.NICID = e.nic.ID() + // Raw socket packets are delivered based solely on the transport protocol // number. We only require that the packet be valid IPv4, and that they not // be fragmented. @@ -863,7 +865,6 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) } - pkt.NICID = e.nic.ID() stats := e.stats stats.ip.ValidPacketsReceived.Increment() diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index b1aec5312..d4bd61748 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -1127,11 +1127,12 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum } func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) { + pkt.NICID = e.nic.ID() + // Raw socket packets are delivered based solely on the transport protocol // number. We only require that the packet be valid IPv6. e.dispatcher.DeliverRawPacket(h.TransportProtocol(), pkt) - pkt.NICID = e.nic.ID() stats := e.stats.ip stats.ValidPacketsReceived.Increment() |