summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack/nic.go
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-03-13 10:43:09 -0700
committergVisor bot <gvisor-bot@google.com>2020-03-13 10:44:23 -0700
commit28d26d2c4f231c447a10bcbcfb8223a804c9d8bc (patch)
tree949a227ddd4dc40282473b71076d10279a9bbd2b /pkg/tcpip/stack/nic.go
parent8f8f16efafd48da3c5e4db329a90bb76620b2324 (diff)
Honour the link's MaxHeaderLength when forwarding
LinkEndpoints may expect/assume that the a tcpip.PacketBuffer's Header has enough capacity for its own headers, as per documentation for LinkEndpoint.MaxHeaderLength. Test: stack_test.TestNICForwarding PiperOrigin-RevId: 300784192
Diffstat (limited to 'pkg/tcpip/stack/nic.go')
-rw-r--r--pkg/tcpip/stack/nic.go18
1 files changed, 17 insertions, 1 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 3cd5fec71..230ee0697 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -15,6 +15,7 @@
package stack
import (
+ "fmt"
"log"
"reflect"
"sort"
@@ -1259,9 +1260,24 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
func (n *NIC) forwardPacket(r *Route, protocol tcpip.NetworkProtocolNumber, pkt tcpip.PacketBuffer) {
// TODO(b/143425874) Decrease the TTL field in forwarded packets.
- pkt.Header = buffer.NewPrependableFromView(pkt.Data.First())
+
+ firstData := pkt.Data.First()
pkt.Data.RemoveFirst()
+ if linkHeaderLen := int(n.linkEP.MaxHeaderLength()); linkHeaderLen == 0 {
+ pkt.Header = buffer.NewPrependableFromView(firstData)
+ } else {
+ firstDataLen := len(firstData)
+
+ // pkt.Header should have enough capacity to hold n.linkEP's headers.
+ pkt.Header = buffer.NewPrependable(firstDataLen + linkHeaderLen)
+
+ // TODO(b/151227689): avoid copying the packet when forwarding
+ if n := copy(pkt.Header.Prepend(firstDataLen), firstData); n != firstDataLen {
+ panic(fmt.Sprintf("copied %d bytes, expected %d", n, firstDataLen))
+ }
+ }
+
if err := n.linkEP.WritePacket(r, nil /* gso */, protocol, pkt); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
return