summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/conntrack.go46
-rw-r--r--pkg/tcpip/stack/forwarder_test.go78
-rw-r--r--pkg/tcpip/stack/iptables_targets.go5
-rw-r--r--pkg/tcpip/stack/nic.go51
-rw-r--r--pkg/tcpip/stack/registration.go13
-rw-r--r--pkg/tcpip/stack/stack_test.go70
-rw-r--r--pkg/tcpip/stack/transport_test.go19
7 files changed, 154 insertions, 128 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index d4053be08..05bf62788 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -20,7 +20,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
@@ -147,44 +146,6 @@ type ConnTrackTable struct {
Seed uint32
}
-// parseHeaders sets headers in the packet.
-func parseHeaders(pkt *PacketBuffer) {
- newPkt := pkt.Clone()
-
- // Set network header.
- hdr, ok := newPkt.Data.PullUp(header.IPv4MinimumSize)
- if !ok {
- return
- }
- netHeader := header.IPv4(hdr)
- newPkt.NetworkHeader = hdr
- length := int(netHeader.HeaderLength())
-
- // TODO(gvisor.dev/issue/170): Need to support for other
- // protocols as well.
- // Set transport header.
- switch protocol := netHeader.TransportProtocol(); protocol {
- case header.UDPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.UDPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.UDP(h[length:]))
- }
- case header.TCPProtocolNumber:
- if newPkt.TransportHeader == nil {
- h, ok := newPkt.Data.PullUp(length + header.TCPMinimumSize)
- if !ok {
- return
- }
- newPkt.TransportHeader = buffer.View(header.TCP(h[length:]))
- }
- }
- pkt.NetworkHeader = newPkt.NetworkHeader
- pkt.TransportHeader = newPkt.TransportHeader
-}
-
// packetToTuple converts packet to a tuple in original direction.
func packetToTuple(pkt *PacketBuffer, hook Hook) (connTrackTuple, *tcpip.Error) {
var tuple connTrackTuple
@@ -257,13 +218,6 @@ func (ct *ConnTrackTable) getTupleHash(tuple connTrackTuple) uint32 {
// TODO(gvisor.dev/issue/170): Only TCP packets are supported. Need to support other
// transport protocols.
func (ct *ConnTrackTable) connTrackForPacket(pkt *PacketBuffer, hook Hook, createConn bool) (*connTrack, ctDirection) {
- if hook == Prerouting {
- // Headers will not be set in Prerouting.
- // TODO(gvisor.dev/issue/170): Change this after parsing headers
- // code is added.
- parseHeaders(pkt)
- }
-
var dir ctDirection
tuple, err := packetToTuple(pkt, hook)
if err != nil {
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 63537aaad..a6546cef0 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -33,6 +33,10 @@ const (
// except where another value is explicitly used. It is chosen to match
// the MTU of loopback interfaces on linux systems.
fwdTestNetDefaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
)
// fwdTestNetworkEndpoint is a network-layer protocol endpoint.
@@ -69,15 +73,8 @@ func (f *fwdTestNetworkEndpoint) ID() *NetworkEndpointID {
}
func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) {
- // Consume the network header.
- b, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
- if !ok {
- return
- }
- pkt.Data.TrimFront(fwdTestNetHeaderLen)
-
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -100,9 +97,9 @@ func (f *fwdTestNetworkEndpoint) WritePacket(r *Route, gso *GSO, params NetworkH
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := pkt.Header.Prepend(fwdTestNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
+ b[dstAddrOffset] = r.RemoteAddress[0]
+ b[srcAddrOffset] = f.id.LocalAddress[0]
+ b[protocolNumberOffset] = byte(params.Protocol)
return f.ep.WritePacket(r, gso, fwdTestNetNumber, pkt)
}
@@ -140,7 +137,17 @@ func (f *fwdTestNetworkProtocol) DefaultPrefixLen() int {
}
func (*fwdTestNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
+}
+
+func (*fwdTestNetworkProtocol) Parse(pkt *PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ netHeader, ok := pkt.Data.PullUp(fwdTestNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = netHeader
+ pkt.Data.TrimFront(fwdTestNetHeaderLen)
+ return tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), true, true
}
func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, ep LinkEndpoint, _ *Stack) (NetworkEndpoint, *tcpip.Error) {
@@ -361,7 +368,7 @@ func TestForwardingWithStaticResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
+ buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -398,7 +405,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
+ buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -429,7 +436,7 @@ func TestForwardingWithNoResolver(t *testing.T) {
// inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf := buffer.NewView(30)
- buf[0] = 3
+ buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -459,7 +466,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Inject an inbound packet to address 4 on NIC 1. This packet should
// not be forwarded.
buf := buffer.NewView(30)
- buf[0] = 4
+ buf[dstAddrOffset] = 4
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -467,7 +474,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Inject an inbound packet to address 3 on NIC 1, and see if it is
// forwarded to NIC 2.
buf = buffer.NewView(30)
- buf[0] = 3
+ buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -480,9 +487,8 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -509,7 +515,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Inject two inbound packets to address 3 on NIC 1.
for i := 0; i < 2; i++ {
buf := buffer.NewView(30)
- buf[0] = 3
+ buf[dstAddrOffset] = 3
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -524,9 +530,8 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] != 3 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want = 3", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
@@ -554,7 +559,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
for i := 0; i < maxPendingPacketsPerResolution+5; i++ {
// Inject inbound 'maxPendingPacketsPerResolution + 5' packets on NIC 1.
buf := buffer.NewView(30)
- buf[0] = 3
+ buf[dstAddrOffset] = 3
// Set the packet sequence number.
binary.BigEndian.PutUint16(buf[fwdTestNetHeaderLen:], uint16(i))
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
@@ -571,14 +576,18 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
t.Fatal("packet not forwarded")
}
- b := p.Pkt.Data.ToView()
- if b[0] != 3 {
- t.Fatalf("got b[0] = %d, want = 3", b[0])
+ if b := p.Pkt.Header.View(); b[dstAddrOffset] != 3 {
+ t.Fatalf("got b[dstAddrOffset] = %d, want = 3", b[dstAddrOffset])
+ }
+ seqNumBuf, ok := p.Pkt.Data.PullUp(2) // The sequence number is a uint16 (2 bytes).
+ if !ok {
+ t.Fatalf("p.Pkt.Data is too short to hold a sequence number: %d", p.Pkt.Data.Size())
}
- // The first 5 packets should not be forwarded so the the
- // sequemnce number should start with 5.
+
+ // The first 5 packets should not be forwarded so the sequence number should
+ // start with 5.
want := uint16(i + 5)
- if n := binary.BigEndian.Uint16(b[fwdTestNetHeaderLen:]); n != want {
+ if n := binary.BigEndian.Uint16(seqNumBuf); n != want {
t.Fatalf("got the packet #%d, want = #%d", n, want)
}
@@ -609,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Each packet has a different destination address (3 to
// maxPendingResolutions + 7).
buf := buffer.NewView(30)
- buf[0] = byte(3 + i)
+ buf[dstAddrOffset] = byte(3 + i)
ep1.InjectInbound(fwdTestNetNumber, &PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -626,9 +635,8 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// The first 5 packets (address 3 to 7) should not be forwarded
// because their address resolutions are interrupted.
- b := p.Pkt.Data.ToView()
- if b[0] < 8 {
- t.Fatalf("got b[0] = %d, want b[0] >= 8", b[0])
+ if p.Pkt.NetworkHeader[dstAddrOffset] < 8 {
+ t.Fatalf("got p.Pkt.NetworkHeader[dstAddrOffset] = %d, want p.Pkt.NetworkHeader[dstAddrOffset] >= 8", p.Pkt.NetworkHeader[dstAddrOffset])
}
// Test that the address resolution happened correctly.
diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go
index 36cc6275d..92e31643e 100644
--- a/pkg/tcpip/stack/iptables_targets.go
+++ b/pkg/tcpip/stack/iptables_targets.go
@@ -98,11 +98,6 @@ func (rt RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrackTable, hook Hook
return RuleAccept, 0
}
- // Set network header.
- if hook == Prerouting {
- parseHeaders(pkt)
- }
-
// Drop the packet if network and transport header are not set.
if pkt.NetworkHeader == nil || pkt.TransportHeader == nil {
return RuleDrop, 0
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 6664aea06..d756ae6f5 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1212,12 +1212,21 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
n.stack.stats.IP.PacketsReceived.Increment()
}
- netHeader, ok := pkt.Data.PullUp(netProto.MinimumPacketSize())
+ // Parse headers.
+ transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt)
if !ok {
+ // The packet is too small to contain a network header.
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- src, dst := netProto.ParseAddresses(netHeader)
+ if hasTransportHdr {
+ // Parse the transport header if present.
+ if state, ok := n.stack.transportProtocols[transProtoNum]; ok {
+ state.proto.Parse(pkt)
+ }
+ }
+
+ src, dst := netProto.ParseAddresses(pkt.NetworkHeader)
if n.stack.handleLocal && !n.isLoopback() && n.getRef(protocol, src) != nil {
// The source address is one of our own, so we never should have gotten a
@@ -1301,8 +1310,18 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
- if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen != 0 {
- pkt.Header = buffer.NewPrependable(linkHeaderLen)
+ // TODO(b/151227689): Avoid copying the packet when forwarding. We can do this
+ // by having lower layers explicity write each header instead of just
+ // pkt.Header.
+
+ // pkt may have set its NetworkHeader and TransportHeader. If we're
+ // forwarding, we'll have to copy them into pkt.Header.
+ pkt.Header = buffer.NewPrependable(int(n.linkEP.MaxHeaderLength()) + len(pkt.NetworkHeader) + len(pkt.TransportHeader))
+ if n := copy(pkt.Header.Prepend(len(pkt.TransportHeader)), pkt.TransportHeader); n != len(pkt.TransportHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.TransportHeader)))
+ }
+ if n := copy(pkt.Header.Prepend(len(pkt.NetworkHeader)), pkt.NetworkHeader); n != len(pkt.NetworkHeader) {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, len(pkt.NetworkHeader)))
}
// WritePacket takes ownership of pkt, calculate numBytes first.
@@ -1333,13 +1352,31 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// validly formed.
n.stack.demux.deliverRawPacket(r, protocol, pkt)
- transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
- if !ok {
+ // TransportHeader is nil only when pkt is an ICMP packet or was reassembled
+ // from fragments.
+ if pkt.TransportHeader == nil {
+ // TODO(gvisor.dev/issue/170): ICMP packets don't have their
+ // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a
+ // full explanation.
+ if protocol == header.ICMPv4ProtocolNumber || protocol == header.ICMPv6ProtocolNumber {
+ transHeader, ok := pkt.Data.PullUp(transProto.MinimumPacketSize())
+ if !ok {
+ n.stack.stats.MalformedRcvdPackets.Increment()
+ return
+ }
+ pkt.TransportHeader = transHeader
+ } else {
+ // This is either a bad packet or was re-assembled from fragments.
+ transProto.Parse(pkt)
+ }
+ }
+
+ if len(pkt.TransportHeader) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
- srcPort, dstPort, err := transProto.ParsePorts(transHeader)
+ srcPort, dstPort, err := transProto.ParsePorts(pkt.TransportHeader)
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 94f177841..5cbc946b6 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -168,6 +168,11 @@ type TransportProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
+
+ // Parse sets pkt.TransportHeader and trims pkt.Data appropriately. It does
+ // neither and returns false if pkt.Data is too small, i.e. pkt.Data.Size() <
+ // MinimumPacketSize()
+ Parse(pkt *PacketBuffer) (ok bool)
}
// TransportDispatcher contains the methods used by the network stack to deliver
@@ -313,6 +318,14 @@ type NetworkProtocol interface {
// Wait waits for any worker goroutines owned by the protocol to stop.
Wait()
+
+ // Parse sets pkt.NetworkHeader and trims pkt.Data appropriately. It
+ // returns:
+ // - The encapsulated protocol, if present.
+ // - Whether there is an encapsulated transport protocol payload (e.g. ARP
+ // does not encapsulate anything).
+ // - Whether pkt.Data was large enough to parse and set pkt.NetworkHeader.
+ Parse(pkt *PacketBuffer) (proto tcpip.TransportProtocolNumber, hasTransportHdr bool, ok bool)
}
// NetworkDispatcher contains the methods used by the network stack to deliver
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index f6ddc3ced..ffef9bc2c 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -52,6 +52,10 @@ const (
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
+
+ dstAddrOffset = 0
+ srcAddrOffset = 1
+ protocolNumberOffset = 2
)
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
@@ -94,26 +98,24 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
- // Consume the network header.
- b, ok := pkt.Data.PullUp(fakeNetHeaderLen)
- if !ok {
- return
- }
- pkt.Data.TrimFront(fakeNetHeaderLen)
-
// Handle control packets.
- if b[2] == uint8(fakeControlProtocol) {
+ if pkt.NetworkHeader[protocolNumberOffset] == uint8(fakeControlProtocol) {
nb, ok := pkt.Data.PullUp(fakeNetHeaderLen)
if !ok {
return
}
pkt.Data.TrimFront(fakeNetHeaderLen)
- f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, pkt)
+ f.dispatcher.DeliverTransportControlPacket(
+ tcpip.Address(nb[srcAddrOffset:srcAddrOffset+1]),
+ tcpip.Address(nb[dstAddrOffset:dstAddrOffset+1]),
+ fakeNetNumber,
+ tcpip.TransportProtocolNumber(nb[protocolNumberOffset]),
+ stack.ControlPortUnreachable, 0, pkt)
return
}
// Dispatch the packet to the transport protocol.
- f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), pkt)
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader[protocolNumberOffset]), pkt)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
@@ -138,18 +140,13 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params
// Add the protocol's header to the packet and send it to the link
// endpoint.
- b := pkt.Header.Prepend(fakeNetHeaderLen)
- b[0] = r.RemoteAddress[0]
- b[1] = f.id.LocalAddress[0]
- b[2] = byte(params.Protocol)
+ pkt.NetworkHeader = pkt.Header.Prepend(fakeNetHeaderLen)
+ pkt.NetworkHeader[dstAddrOffset] = r.RemoteAddress[0]
+ pkt.NetworkHeader[srcAddrOffset] = f.id.LocalAddress[0]
+ pkt.NetworkHeader[protocolNumberOffset] = byte(params.Protocol)
if r.Loop&stack.PacketLoop != 0 {
- views := make([]buffer.View, 1, 1+len(pkt.Data.Views()))
- views[0] = pkt.Header.View()
- views = append(views, pkt.Data.Views()...)
- f.HandlePacket(r, &stack.PacketBuffer{
- Data: buffer.NewVectorisedView(len(views[0])+pkt.Data.Size(), views),
- })
+ f.HandlePacket(r, pkt)
}
if r.Loop&stack.PacketOut == 0 {
return nil
@@ -205,7 +202,7 @@ func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int {
}
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
- return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+ return tcpip.Address(v[srcAddrOffset : srcAddrOffset+1]), tcpip.Address(v[dstAddrOffset : dstAddrOffset+1])
}
func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint, _ *stack.Stack) (stack.NetworkEndpoint, *tcpip.Error) {
@@ -247,6 +244,17 @@ func (*fakeNetworkProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeNetworkProtocol) Wait() {}
+// Parse implements TransportProtocol.Parse.
+func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProtocolNumber, bool, bool) {
+ hdr, ok := pkt.Data.PullUp(fakeNetHeaderLen)
+ if !ok {
+ return 0, false, false
+ }
+ pkt.NetworkHeader = hdr
+ pkt.Data.TrimFront(fakeNetHeaderLen)
+ return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true
+}
+
func fakeNetFactory() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
}
@@ -292,7 +300,7 @@ func TestNetworkReceive(t *testing.T) {
buf := buffer.NewView(30)
// Make sure packet with wrong address is not delivered.
- buf[0] = 3
+ buf[dstAddrOffset] = 3
ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -304,7 +312,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to first endpoint.
- buf[0] = 1
+ buf[dstAddrOffset] = 1
ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -316,7 +324,7 @@ func TestNetworkReceive(t *testing.T) {
}
// Make sure packet is delivered to second endpoint.
- buf[0] = 2
+ buf[dstAddrOffset] = 2
ep.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
@@ -982,7 +990,7 @@ func TestAddressRemoval(t *testing.T) {
buf := buffer.NewView(30)
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1032,7 +1040,7 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
}
// Send and receive packets, and verify they are received.
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testRecv(t, fakeNet, localAddrByte, ep, buf)
testSend(t, r, ep, nil)
testSendTo(t, s, remoteAddr, ep, nil)
@@ -1114,7 +1122,7 @@ func TestEndpointExpiration(t *testing.T) {
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
if promiscuous {
if err := s.SetPromiscuousMode(nicID, true); err != nil {
@@ -1277,7 +1285,7 @@ func TestPromiscuousMode(t *testing.T) {
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
testFailingRecv(t, fakeNet, localAddrByte, ep, buf)
// Set promiscuous mode, then check that packet is delivered.
@@ -1658,7 +1666,7 @@ func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -1766,7 +1774,7 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
buf := buffer.NewView(30)
const localAddrByte byte = 0x01
- buf[0] = localAddrByte
+ buf[dstAddrOffset] = localAddrByte
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatal("NewSubnet failed:", err)
@@ -2344,7 +2352,7 @@ func TestNICForwarding(t *testing.T) {
// Send a packet to dstAddr.
buf := buffer.NewView(30)
- buf[0] = dstAddr[0]
+ buf[dstAddrOffset] = dstAddr[0]
ep1.InjectInbound(fakeNetNumber, &stack.PacketBuffer{
Data: buf.ToVectorisedView(),
})
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index cb350ead3..ad61c09d6 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -83,7 +83,8 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return 0, nil, tcpip.ErrNoRoute
}
- hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
+ hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()) + fakeTransHeaderLen)
+ hdr.Prepend(fakeTransHeaderLen)
v, err := p.FullPayload()
if err != nil {
return 0, nil, err
@@ -324,6 +325,17 @@ func (*fakeTransportProtocol) Close() {}
// Wait implements TransportProtocol.Wait.
func (*fakeTransportProtocol) Wait() {}
+// Parse implements TransportProtocol.Parse.
+func (*fakeTransportProtocol) Parse(pkt *stack.PacketBuffer) bool {
+ hdr, ok := pkt.Data.PullUp(fakeTransHeaderLen)
+ if !ok {
+ return false
+ }
+ pkt.TransportHeader = hdr
+ pkt.Data.TrimFront(fakeTransHeaderLen)
+ return true
+}
+
func fakeTransFactory() stack.TransportProtocol {
return &fakeTransportProtocol{}
}
@@ -642,11 +654,10 @@ func TestTransportForwarding(t *testing.T) {
t.Fatal("Response packet not forwarded")
}
- hdrs := p.Pkt.Data.ToView()
- if dst := hdrs[0]; dst != 3 {
+ if dst := p.Pkt.NetworkHeader[0]; dst != 3 {
t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
}
- if src := hdrs[1]; src != 1 {
+ if src := p.Pkt.NetworkHeader[1]; src != 1 {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}