summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/checker/BUILD1
-rw-r--r--pkg/tcpip/checker/checker.go26
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go26
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go6
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go3
-rw-r--r--pkg/tcpip/stack/packet_buffer.go10
-rw-r--r--pkg/tcpip/stack/registration.go4
-rw-r--r--pkg/tcpip/tcpip.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go9
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go99
11 files changed, 160 insertions, 28 deletions
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
index ed434807f..c984470e6 100644
--- a/pkg/tcpip/checker/BUILD
+++ b/pkg/tcpip/checker/BUILD
@@ -12,5 +12,6 @@ go_library(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/seqnum",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index ee264b726..1e5f5abf2 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -21,6 +21,7 @@ import (
"reflect"
"testing"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -169,10 +170,9 @@ func ReceiveTClass(want uint32) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
t.Helper()
if !cm.HasTClass {
- t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want)
- }
- if got := cm.TClass; got != want {
- t.Fatalf("got cm.TClass = %d, want %d", got, want)
+ t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass)
+ } else if got := cm.TClass; got != want {
+ t.Errorf("got cm.TClass = %d, want %d", got, want)
}
}
}
@@ -182,10 +182,22 @@ func ReceiveTOS(want uint8) ControlMessagesChecker {
return func(t *testing.T, cm tcpip.ControlMessages) {
t.Helper()
if !cm.HasTOS {
- t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want)
+ t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS)
+ } else if got := cm.TOS; got != want {
+ t.Errorf("got cm.TOS = %d, want %d", got, want)
}
- if got := cm.TOS; got != want {
- t.Fatalf("got cm.TOS = %d, want %d", got, want)
+ }
+}
+
+// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in
+// ControlMessages.
+func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker {
+ return func(t *testing.T, cm tcpip.ControlMessages) {
+ t.Helper()
+ if !cm.HasIPPacketInfo {
+ t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo)
+ } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" {
+ t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff)
}
}
}
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index 467083239..fc1e34fc7 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -199,7 +199,7 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- // TODO(gvisor.dev/issue/3267/): Queue these packets as well once
+ // TODO(gvisor.dev/issue/3267): Queue these packets as well once
// WriteRawPacket takes PacketBuffer instead of VectorisedView.
return e.lower.WriteRawPacket(vv)
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 6c4f0ae3e..9ff27a363 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -173,9 +173,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
newPayload := pkt.Data.Clone(nil)
newPayload.CapLength(innerMTU)
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
- Header: pkt.Header,
- Data: newPayload,
- NetworkHeader: buffer.View(h),
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ NetworkProtocolNumber: header.IPv4ProtocolNumber,
}); err != nil {
return err
}
@@ -192,9 +193,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
newPayloadLength := outerMTU - pkt.Header.UsedLength()
newPayload.CapLength(newPayloadLength)
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
- Header: pkt.Header,
- Data: newPayload,
- NetworkHeader: buffer.View(h),
+ Header: pkt.Header,
+ Data: newPayload,
+ NetworkHeader: buffer.View(h),
+ NetworkProtocolNumber: header.IPv4ProtocolNumber,
}); err != nil {
return err
}
@@ -206,9 +208,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int,
startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU)
emptyVV := buffer.NewVectorisedView(0, []buffer.View{})
if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{
- Header: startOfHdr,
- Data: emptyVV,
- NetworkHeader: buffer.View(h),
+ Header: startOfHdr,
+ Data: emptyVV,
+ NetworkHeader: buffer.View(h),
+ NetworkProtocolNumber: header.IPv4ProtocolNumber,
}); err != nil {
return err
}
@@ -249,10 +252,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
- nicName := e.stack.FindNICNameFromID(e.NICID())
// iptables filtering. All packets that reach here are locally
// generated.
+ nicName := e.stack.FindNICNameFromID(e.NICID())
ipt := e.stack.IPTables()
if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok {
// iptables is telling us to drop the packet.
@@ -304,6 +308,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pkt := pkts.Front(); pkt != nil; {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
pkt = pkt.Next()
}
@@ -570,6 +575,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu
parseTransportHeader = false
}
+ pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber
pkt.NetworkHeader = hdr
pkt.Data.TrimFront(len(hdr))
pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr))
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index ded97ac64..63e2c36c2 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -150,6 +150,9 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI
if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want {
t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want)
}
+ if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want {
+ t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want)
+ }
if i < len(packets)-1 {
sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset)
} else {
@@ -285,7 +288,8 @@ func TestFragmentation(t *testing.T) {
source := &stack.PacketBuffer{
Header: hdr,
// Save the source payload because WritePacket will modify it.
- Data: payload.Clone(nil),
+ Data: payload.Clone(nil),
+ NetworkProtocolNumber: header.IPv4ProtocolNumber,
}
c := buildContext(t, nil, ft.mtu)
err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 4a0b53c45..d7d7fc611 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -117,6 +117,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params)
pkt.NetworkHeader = buffer.View(ip)
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
if r.Loop&stack.PacketLoop != 0 {
// The inbound path expects the network header to still be in
@@ -152,6 +153,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
for pb := pkts.Front(); pb != nil; pb = pb.Next() {
ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params)
pb.NetworkHeader = buffer.View(ip)
+ pb.NetworkProtocolNumber = header.IPv6ProtocolNumber
}
n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber)
@@ -586,6 +588,7 @@ traverseExtensions:
}
ipHdr = header.IPv6(hdr)
+ pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
pkt.NetworkHeader = hdr
pkt.Data.TrimFront(len(hdr))
pkt.Data.CapLength(int(ipHdr.PayloadLength()))
diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go
index 5d6865e35..9e871f968 100644
--- a/pkg/tcpip/stack/packet_buffer.go
+++ b/pkg/tcpip/stack/packet_buffer.go
@@ -62,6 +62,11 @@ type PacketBuffer struct {
NetworkHeader buffer.View
TransportHeader buffer.View
+ // NetworkProtocol is only valid when NetworkHeader is set.
+ // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol
+ // numbers in registration APIs that take a PacketBuffer.
+ NetworkProtocolNumber tcpip.NetworkProtocolNumber
+
// Hash is the transport layer hash of this packet. A value of zero
// indicates no valid hash has been set.
Hash uint32
@@ -72,9 +77,8 @@ type PacketBuffer struct {
// The following fields are only set by the qdisc layer when the packet
// is added to a queue.
- EgressRoute *Route
- GSOOptions *GSO
- NetworkProtocolNumber tcpip.NetworkProtocolNumber
+ EgressRoute *Route
+ GSOOptions *GSO
// NatDone indicates if the packet has been manipulated as per NAT
// iptables rule.
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 8604c4259..4570e8969 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -249,8 +249,8 @@ type NetworkEndpoint interface {
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
- // protocol. It takes ownership of pkt. pkt.TransportHeader must have already
- // been set.
+ // protocol. It takes ownership of pkt. pkt.TransportHeader must have
+ // already been set.
WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error
// WritePackets writes packets to the given destination address and
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 45f59b60f..091bc5281 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -968,7 +968,7 @@ type IPPacketInfo struct {
// LocalAddr is the local address.
LocalAddr Address
- // DestinationAddr is the destination address.
+ // DestinationAddr is the destination address found in the IP header.
DestinationAddr Address
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 444b5b01c..4a2b6c03a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1444,13 +1444,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
switch r.NetProto {
case header.IPv4ProtocolNumber:
packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
- packet.packetInfo.LocalAddr = r.LocalAddress
- packet.packetInfo.DestinationAddr = r.RemoteAddress
- packet.packetInfo.NIC = r.NICID()
case header.IPv6ProtocolNumber:
packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS()
}
+ // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast
+ // address. packetInfo.LocalAddr should hold a unicast address that can be
+ // used to respond to the incoming packet.
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.LocalAddress
+ packet.packetInfo.NIC = r.NICID()
packet.timestamp = e.stack.Clock().NowNanoseconds()
e.rcvMu.Unlock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 66e8911c8..1a32622ca 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -1309,6 +1309,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) {
}
}
+func TestReadIPPacketInfo(t *testing.T) {
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ flow testFlow
+ expectedLocalAddr tcpip.Address
+ expectedDestAddr tcpip.Address
+ }{
+ {
+ name: "IPv4 unicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: unicastV4,
+ expectedLocalAddr: stackAddr,
+ expectedDestAddr: stackAddr,
+ },
+ {
+ name: "IPv4 multicast",
+ proto: header.IPv4ProtocolNumber,
+ flow: multicastV4,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastAddr,
+ expectedDestAddr: multicastAddr,
+ },
+ {
+ name: "IPv4 broadcast",
+ proto: header.IPv4ProtocolNumber,
+ flow: broadcast,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: broadcastAddr,
+ expectedDestAddr: broadcastAddr,
+ },
+ {
+ name: "IPv6 unicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: unicastV6,
+ expectedLocalAddr: stackV6Addr,
+ expectedDestAddr: stackV6Addr,
+ },
+ {
+ name: "IPv6 multicast",
+ proto: header.IPv6ProtocolNumber,
+ flow: multicastV6,
+ // This should actually be a unicast address assigned to the interface.
+ //
+ // TODO(gvisor.dev/issue/3556): This check is validating incorrect
+ // behaviour. We still include the test so that once the bug is
+ // resolved, this test will start to fail and the individual tasked
+ // with fixing this bug knows to also fix this test :).
+ expectedLocalAddr: multicastV6Addr,
+ expectedDestAddr: multicastV6Addr,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createEndpoint(test.proto)
+
+ bindAddr := tcpip.FullAddress{Port: stackPort}
+ if err := c.ep.Bind(bindAddr); err != nil {
+ t.Fatalf("Bind(%+v): %s", bindAddr, err)
+ }
+
+ if test.flow.isMulticast() {
+ ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(%+v): %s:", ifoptSet, err)
+ }
+ }
+
+ if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil {
+ t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err)
+ }
+
+ testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{
+ NIC: 1,
+ LocalAddr: test.expectedLocalAddr,
+ DestinationAddr: test.expectedDestAddr,
+ }))
+
+ if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 {
+ t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got)
+ }
+ })
+ }
+}
+
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()