summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/link
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/link')
-rw-r--r--pkg/tcpip/link/loopback/loopback.go10
-rw-r--r--pkg/tcpip/link/rawfile/BUILD9
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_test.go46
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go6
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go2
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go65
6 files changed, 114 insertions, 24 deletions
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index 1e2255bfa..073c84ef9 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -98,13 +98,13 @@ func (e *endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList
// WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket.
func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
- // Reject the packet if it's shorter than an ethernet header.
- if vv.Size() < header.EthernetMinimumSize {
+ // There should be an ethernet header at the beginning of vv.
+ hdr, ok := vv.PullUp(header.EthernetMinimumSize)
+ if !ok {
+ // Reject the packet if it's shorter than an ethernet header.
return tcpip.ErrBadAddress
}
-
- // There should be an ethernet header at the beginning of vv.
- linkHeader := header.Ethernet(vv.First()[:header.EthernetMinimumSize])
+ linkHeader := header.Ethernet(hdr)
vv.TrimFront(len(linkHeader))
e.dispatcher.DeliverNetworkPacket(e, "" /* remote */, "" /* local */, linkHeader.Type(), stack.PacketBuffer{
Data: vv,
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 14b527bc2..9cc08d0e2 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -18,3 +18,10 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "rawfile_test",
+ size = "small",
+ srcs = ["rawfile_test.go"],
+ library = ":rawfile",
+)
diff --git a/pkg/tcpip/link/rawfile/rawfile_test.go b/pkg/tcpip/link/rawfile/rawfile_test.go
new file mode 100644
index 000000000..8f14ba761
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_test.go
@@ -0,0 +1,46 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "syscall"
+ "testing"
+)
+
+func TestNonBlockingWrite3ZeroLength(t *testing.T) {
+ fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0)
+ if err != nil {
+ t.Fatalf("failed to open /dev/null: %v", err)
+ }
+ defer syscall.Close(fd)
+
+ if err := NonBlockingWrite3(fd, []byte{}, []byte{0}, nil); err != nil {
+ t.Fatalf("failed to write: %v", err)
+ }
+}
+
+func TestNonBlockingWrite3Nil(t *testing.T) {
+ fd, err := syscall.Open("/dev/null", syscall.O_WRONLY, 0)
+ if err != nil {
+ t.Fatalf("failed to open /dev/null: %v", err)
+ }
+ defer syscall.Close(fd)
+
+ if err := NonBlockingWrite3(fd, nil, []byte{0}, nil); err != nil {
+ t.Fatalf("failed to write: %v", err)
+ }
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 44e25d475..92efd0bf8 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -76,9 +76,13 @@ func NonBlockingWrite3(fd int, b1, b2, b3 []byte) *tcpip.Error {
// We have two buffers. Build the iovec that represents them and issue
// a writev syscall.
+ var base *byte
+ if len(b1) > 0 {
+ base = &b1[0]
+ }
iovec := [3]syscall.Iovec{
{
- Base: &b1[0],
+ Base: base,
Len: uint64(len(b1)),
},
{
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
index 27ea3f531..33f640b85 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem_test.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -674,7 +674,7 @@ func TestSimpleReceive(t *testing.T) {
// Wait for packet to be received, then check it.
c.waitForPackets(1, time.After(5*time.Second), "Timeout waiting for packet")
c.mu.Lock()
- rcvd := []byte(c.packets[0].vv.First())
+ rcvd := []byte(c.packets[0].vv.ToView())
c.packets = c.packets[:0]
c.mu.Unlock()
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index be2537a82..0799c8f4d 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -171,11 +171,7 @@ func (e *endpoint) GSOMaxSize() uint32 {
func (e *endpoint) dumpPacket(prefix string, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
writer := e.writer
if writer == nil && atomic.LoadUint32(&LogPackets) == 1 {
- first := pkt.Header.View()
- if len(first) == 0 {
- first = pkt.Data.First()
- }
- logPacket(prefix, protocol, first, gso)
+ logPacket(prefix, protocol, pkt, gso)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
totalLength := pkt.Header.UsedLength() + pkt.Data.Size()
@@ -238,7 +234,7 @@ func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error {
// Wait implements stack.LinkEndpoint.Wait.
func (e *endpoint) Wait() { e.lower.Wait() }
-func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
+func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
@@ -247,28 +243,49 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
size := uint16(0)
var fragmentOffset uint16
var moreFragments bool
+
+ // Create a clone of pkt, including any headers if present. Avoid allocating
+ // backing memory for the clone.
+ views := [8]buffer.View{}
+ vv := buffer.NewVectorisedView(0, views[:0])
+ vv.AppendView(pkt.Header.View())
+ vv.Append(pkt.Data)
+
switch protocol {
case header.IPv4ProtocolNumber:
- ipv4 := header.IPv4(b)
+ hdr, ok := vv.PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return
+ }
+ ipv4 := header.IPv4(hdr)
fragmentOffset = ipv4.FragmentOffset()
moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
- b = b[ipv4.HeaderLength():]
+ vv.TrimFront(int(ipv4.HeaderLength()))
id = int(ipv4.ID())
case header.IPv6ProtocolNumber:
- ipv6 := header.IPv6(b)
+ hdr, ok := vv.PullUp(header.IPv6MinimumSize)
+ if !ok {
+ return
+ }
+ ipv6 := header.IPv6(hdr)
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
transProto = ipv6.NextHeader()
size = ipv6.PayloadLength()
- b = b[header.IPv6MinimumSize:]
+ vv.TrimFront(header.IPv6MinimumSize)
case header.ARPProtocolNumber:
- arp := header.ARP(b)
+ hdr, ok := vv.PullUp(header.ARPSize)
+ if !ok {
+ return
+ }
+ vv.TrimFront(header.ARPSize)
+ arp := header.ARP(hdr)
log.Infof(
"%s arp %v (%v) -> %v (%v) valid:%v",
prefix,
@@ -284,7 +301,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
// We aren't guaranteed to have a transport header - it's possible for
// writes via raw endpoints to contain only network headers.
- if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && len(b) < minSize {
+ if minSize, ok := transportProtocolMinSizes[tcpip.TransportProtocolNumber(transProto)]; ok && vv.Size() < minSize {
log.Infof("%s %v -> %v transport protocol: %d, but no transport header found (possible raw packet)", prefix, src, dst, transProto)
return
}
@@ -297,7 +314,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
- icmp := header.ICMPv4(b)
+ hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
+ if !ok {
+ break
+ }
+ icmp := header.ICMPv4(hdr)
icmpType := "unknown"
if fragmentOffset == 0 {
switch icmp.Type() {
@@ -330,7 +351,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.ICMPv6ProtocolNumber:
transName = "icmp"
- icmp := header.ICMPv6(b)
+ hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
+ if !ok {
+ break
+ }
+ icmp := header.ICMPv6(hdr)
icmpType := "unknown"
switch icmp.Type() {
case header.ICMPv6DstUnreachable:
@@ -361,7 +386,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.UDPProtocolNumber:
transName = "udp"
- udp := header.UDP(b)
+ hdr, ok := vv.PullUp(header.UDPMinimumSize)
+ if !ok {
+ break
+ }
+ udp := header.UDP(hdr)
if fragmentOffset == 0 && len(udp) >= header.UDPMinimumSize {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
@@ -371,7 +400,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.Vie
case header.TCPProtocolNumber:
transName = "tcp"
- tcp := header.TCP(b)
+ hdr, ok := vv.PullUp(header.TCPMinimumSize)
+ if !ok {
+ break
+ }
+ tcp := header.TCP(hdr)
if fragmentOffset == 0 && len(tcp) >= header.TCPMinimumSize {
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {