summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/checker/checker.go18
-rw-r--r--pkg/tcpip/network/BUILD1
-rw-r--r--pkg/tcpip/network/ip_test.go405
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go40
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go44
-rw-r--r--pkg/tcpip/tcpip.go2
6 files changed, 483 insertions, 27 deletions
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index d4d785cca..6f81b0164 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -178,6 +178,24 @@ func PayloadLen(payloadLength int) NetworkChecker {
}
}
+// IPPayload creates a checker that checks the payload.
+func IPPayload(payload []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ t.Helper()
+
+ got := h[0].Payload()
+
+ // cmp.Diff does not consider nil slices equal to empty slices, but we do.
+ if len(got) == 0 && len(payload) == 0 {
+ return
+ }
+
+ if diff := cmp.Diff(payload, got); diff != "" {
+ t.Errorf("payload mismatch (-want +got):\n%s", diff)
+ }
+ }
+}
+
// IPv4Options returns a checker that checks the options in an IPv4 packet.
func IPv4Options(want []byte) NetworkChecker {
return func(t *testing.T, h []header.Network) {
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
index 59710352b..c118a2929 100644
--- a/pkg/tcpip/network/BUILD
+++ b/pkg/tcpip/network/BUILD
@@ -12,6 +12,7 @@ go_test(
"//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index d436873b6..c1848a19c 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,11 +15,13 @@
package ip_test
import (
+ "strings"
"testing"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -1020,3 +1022,406 @@ func truncatedPacket(view buffer.View, trunc, netHdrLen int) *stack.PacketBuffer
_, _ = pkt.NetworkHeader().Consume(netHdrLen)
return pkt
}
+
+func TestWriteHeaderIncludedPacket(t *testing.T) {
+ const (
+ nicID = 1
+ transportProto = 5
+
+ dataLen = 4
+ optionsLen = 4
+ )
+
+ dataBuf := [dataLen]byte{1, 2, 3, 4}
+ data := dataBuf[:]
+
+ ipv4OptionsBuf := [optionsLen]byte{0, 1, 0, 1}
+ ipv4Options := ipv4OptionsBuf[:]
+
+ ipv6FragmentExtHdrBuf := [header.IPv6FragmentExtHdrLength]byte{transportProto, 0, 62, 4, 1, 2, 3, 4}
+ ipv6FragmentExtHdr := ipv6FragmentExtHdrBuf[:]
+
+ var ipv6PayloadWithExtHdrBuf [dataLen + header.IPv6FragmentExtHdrLength]byte
+ ipv6PayloadWithExtHdr := ipv6PayloadWithExtHdrBuf[:]
+ if n := copy(ipv6PayloadWithExtHdr, ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ if n := copy(ipv6PayloadWithExtHdr[header.IPv6FragmentExtHdrLength:], data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+
+ tests := []struct {
+ name string
+ protoFactory stack.NetworkProtocolFactory
+ protoNum tcpip.NetworkProtocolNumber
+ nicAddr tcpip.Address
+ remoteAddr tcpip.Address
+ pktGen func(*testing.T, tcpip.Address) buffer.View
+ checker func(*testing.T, *stack.PacketBuffer, tcpip.Address)
+ expectedErr *tcpip.Error
+ }{
+ {
+ name: "IPv4",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(uint16(header.IPv4MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv4 with IHL too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv4MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize - 1,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 too small",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ {
+ name: "IPv4 minimum size",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv4MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv4MinimumSize)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(header.IPv4MinimumSize),
+ checker.IPFullLength(header.IPv4MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv4 with options",
+ protoFactory: ipv4.NewProtocol,
+ protoNum: ipv4.ProtocolNumber,
+ nicAddr: localIPv4Addr,
+ remoteAddr: remoteIPv4Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ipHdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ totalLen := ipHdrLen + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv4(hdr.Prepend(ipHdrLen))
+ ip.Encode(&header.IPv4Fields{
+ IHL: uint8(ipHdrLen),
+ Protocol: transportProto,
+ TTL: ipv4.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options))
+ }
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv4Any {
+ src = localIPv4Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ hdrLen := header.IPv4MinimumSize + len(ipv4Options)
+ if len(netHdr.View()) != hdrLen {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen)
+ }
+
+ checker.IPv4(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv4Addr),
+ checker.IPv4HeaderLength(hdrLen),
+ checker.IPFullLength(uint16(hdrLen+len(data))),
+ checker.IPv4Options(ipv4Options),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(data))),
+ checker.IPPayload(data),
+ )
+ },
+ },
+ {
+ name: "IPv6 with extension header",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(data)), data); n != len(data) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(data))
+ }
+ if n := copy(hdr.Prepend(len(ipv6FragmentExtHdr)), ipv6FragmentExtHdr); n != len(ipv6FragmentExtHdr) {
+ t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv6FragmentExtHdr))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: uint8(header.IPv6FragmentExtHdrIdentifier),
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return hdr.View()
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if want := header.IPv6MinimumSize + len(ipv6FragmentExtHdr); len(netHdr.View()) != want {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), want)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(uint16(header.IPv6MinimumSize+len(ipv6PayloadWithExtHdr))),
+ checker.IPPayload(ipv6PayloadWithExtHdr),
+ )
+ },
+ },
+ {
+ name: "IPv6 minimum size",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip)
+ },
+ checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) {
+ if src == header.IPv6Any {
+ src = localIPv6Addr
+ }
+
+ netHdr := pkt.NetworkHeader()
+
+ if len(netHdr.View()) != header.IPv6MinimumSize {
+ t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), header.IPv6MinimumSize)
+ }
+
+ checker.IPv6(t, stack.PayloadSince(netHdr),
+ checker.SrcAddr(src),
+ checker.DstAddr(remoteIPv6Addr),
+ checker.IPFullLength(header.IPv6MinimumSize),
+ checker.IPPayload(nil),
+ )
+ },
+ },
+ {
+ name: "IPv6 too small",
+ protoFactory: ipv6.NewProtocol,
+ protoNum: ipv6.ProtocolNumber,
+ nicAddr: localIPv6Addr,
+ remoteAddr: remoteIPv6Addr,
+ pktGen: func(t *testing.T, src tcpip.Address) buffer.View {
+ ip := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ NextHeader: transportProto,
+ HopLimit: ipv6.DefaultTTL,
+ SrcAddr: src,
+ DstAddr: header.IPv4Any,
+ })
+ return buffer.View(ip[:len(ip)-1])
+ },
+ expectedErr: tcpip.ErrMalformedHeader,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ subTests := []struct {
+ name string
+ srcAddr tcpip.Address
+ }{
+ {
+ name: "unspecified source",
+ srcAddr: tcpip.Address(strings.Repeat("\x00", len(test.nicAddr))),
+ },
+ {
+ name: "random source",
+ srcAddr: tcpip.Address(strings.Repeat("\xab", len(test.nicAddr))),
+ },
+ }
+
+ for _, subTest := range subTests {
+ t.Run(subTest.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{test.protoFactory},
+ })
+ e := channel.New(1, 1280, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, test.protoNum, test.nicAddr); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.protoNum, test.nicAddr, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{Destination: test.remoteAddr.WithPrefix().Subnet(), NIC: nicID}})
+
+ r, err := s.FindRoute(nicID, test.nicAddr, test.remoteAddr, test.protoNum, false /* multicastLoop */)
+ if err != nil {
+ t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", nicID, test.remoteAddr, test.nicAddr, test.protoNum, err)
+ }
+ defer r.Release()
+
+ if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(),
+ })); err != test.expectedErr {
+ t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr)
+ }
+
+ if test.expectedErr != nil {
+ return
+ }
+
+ pkt, ok := e.Read()
+ if !ok {
+ t.Fatal("expected a packet to be written")
+ }
+ test.checker(t, pkt.Pkt, subTest.srcAddr)
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index c5ac7b8b5..7e2e53523 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -237,7 +237,10 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt)
+}
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -347,30 +350,27 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacket writes a packet already containing a network
-// header through the given route.
+// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
// The packet already has an IP header, but there are a few required
// checks.
h, ok := pkt.Data.PullUp(header.IPv4MinimumSize)
if !ok {
- return tcpip.ErrInvalidOptionValue
+ return tcpip.ErrMalformedHeader
}
ip := header.IPv4(h)
- if !ip.IsValid(pkt.Data.Size()) {
- return tcpip.ErrInvalidOptionValue
- }
// Always set the total length.
- ip.SetTotalLength(uint16(pkt.Data.Size()))
+ pktSize := pkt.Data.Size()
+ ip.SetTotalLength(uint16(pktSize))
// Set the source address when zero.
- if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
+ if ip.SourceAddress() == header.IPv4Any {
ip.SetSourceAddress(r.LocalAddress)
}
- // Set the destination. If the packet already included a destination,
- // it will be part of the route.
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
ip.SetDestinationAddress(r.RemoteAddress)
// Set the packet ID when zero.
@@ -387,19 +387,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
- if r.Loop&stack.PacketLoop != 0 {
- e.HandlePacket(r, pkt.Clone())
- }
- if r.Loop&stack.PacketOut == 0 {
- return nil
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ if !parse.IPv4(pkt) || !header.IPv4(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
}
- if err := e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, pkt); err != nil {
- r.Stats().IP.OutgoingPacketErrors.Increment()
- return err
- }
- r.Stats().IP.PacketsSent.Increment()
- return nil
+ return e.writePacket(r, nil /* gso */, pkt)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 2bd8f4ece..632914dd6 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -426,7 +426,10 @@ func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, p
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error {
e.addIPHeader(r, pkt, params)
+ return e.writePacket(r, gso, pkt, params.Protocol)
+}
+func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.PacketBuffer, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
// iptables filtering. All packets that reach here are locally
// generated.
nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
@@ -468,7 +471,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
}
if e.packetMustBeFragmented(pkt, gso) {
- sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
+ sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error {
// TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each
// fragment one by one using WritePacket() (current strategy) or if we
// want to create a PacketBufferList from the fragments and feed it to
@@ -569,11 +572,40 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
return n + len(dropped), nil
}
-// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
-// supported by IPv6.
-func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
- // TODO(b/146666412): Support IPv6 header-included packets.
- return tcpip.ErrNotSupported
+// WriteHeaderIncludedPacker implements stack.NetworkEndpoint.
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error {
+ // The packet already has an IP header, but there are a few required checks.
+ h, ok := pkt.Data.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return tcpip.ErrMalformedHeader
+ }
+ ip := header.IPv6(h)
+
+ // Always set the payload length.
+ pktSize := pkt.Data.Size()
+ ip.SetPayloadLength(uint16(pktSize - header.IPv6MinimumSize))
+
+ // Set the source address when zero.
+ if ip.SourceAddress() == header.IPv6Any {
+ ip.SetSourceAddress(r.LocalAddress)
+ }
+
+ // Set the destination. If the packet already included a destination, it will
+ // be part of the route anyways.
+ ip.SetDestinationAddress(r.RemoteAddress)
+
+ // Populate the packet buffer's network header and don't allow an invalid
+ // packet to be sent.
+ //
+ // Note that parsing only makes sure that the packet is well formed as per the
+ // wire format. We also want to check if the header's fields are valid before
+ // sending the packet.
+ proto, _, _, _, ok := parse.IPv6(pkt)
+ if !ok || !header.IPv6(pkt.NetworkHeader().View()).IsValid(pktSize) {
+ return tcpip.ErrMalformedHeader
+ }
+
+ return e.writePacket(r, nil /* gso */, pkt, proto)
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index c42bb0991..d77848d61 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -111,6 +111,7 @@ var (
ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
+ ErrMalformedHeader = &Error{msg: "header is malformed"}
)
var messageToError map[string]*Error
@@ -159,6 +160,7 @@ func StringToError(s string) *Error {
ErrBroadcastDisabled,
ErrNotPermitted,
ErrAddressFamilyNotSupported,
+ ErrMalformedHeader,
}
messageToError = make(map[string]*Error)