summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/link
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/link')
-rw-r--r--pkg/tcpip/link/ethernet/BUILD15
-rw-r--r--pkg/tcpip/link/ethernet/ethernet.go99
-rw-r--r--pkg/tcpip/link/pipe/BUILD15
-rw-r--r--pkg/tcpip/link/pipe/pipe.go115
-rw-r--r--pkg/tcpip/link/rawfile/BUILD13
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go2
-rw-r--r--pkg/tcpip/link/rawfile/errors.go8
-rw-r--r--pkg/tcpip/link/rawfile/errors_test.go53
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go36
-rw-r--r--pkg/tcpip/link/sniffer/BUILD1
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go60
-rw-r--r--pkg/tcpip/link/tun/BUILD15
-rw-r--r--pkg/tcpip/link/tun/device.go51
13 files changed, 411 insertions, 72 deletions
diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD
new file mode 100644
index 000000000..ec92ed623
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "ethernet",
+ srcs = ["ethernet.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/nested",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go
new file mode 100644
index 000000000..3eef7cd56
--- /dev/null
+++ b/pkg/tcpip/link/ethernet/ethernet.go
@@ -0,0 +1,99 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package ethernet provides an implementation of an ethernet link endpoint that
+// wraps an inner link endpoint.
+package ethernet
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/nested"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.NetworkDispatcher = (*Endpoint)(nil)
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns an ethernet link endpoint that wraps an inner link endpoint.
+func New(ep stack.LinkEndpoint) *Endpoint {
+ var e Endpoint
+ e.Endpoint.Init(ep, &e)
+ return &e
+}
+
+// Endpoint is an ethernet endpoint.
+//
+// It adds an ethernet header to packets before sending them out through its
+// inner link endpoint and consumes an ethernet header before sending the
+// packet to the stack.
+type Endpoint struct {
+ nested.Endpoint
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.
+func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize)
+ if !ok {
+ return
+ }
+
+ eth := header.Ethernet(hdr)
+ if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) {
+ e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt)
+ }
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities()
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt)
+ return e.Endpoint.WritePacket(r, gso, proto, pkt)
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ linkAddr := e.Endpoint.LinkAddress()
+
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt)
+ }
+
+ return e.Endpoint.WritePackets(r, gso, pkts, proto)
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return header.EthernetMinimumSize + e.Endpoint.MaxHeaderLength()
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareEther
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
+ eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
+ fields := header.EthernetFields{
+ SrcAddr: local,
+ DstAddr: remote,
+ Type: proto,
+ }
+ eth.Encode(&fields)
+}
diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD
new file mode 100644
index 000000000..9f31c1ffc
--- /dev/null
+++ b/pkg/tcpip/link/pipe/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "pipe",
+ srcs = ["pipe.go"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go
new file mode 100644
index 000000000..523b0d24b
--- /dev/null
+++ b/pkg/tcpip/link/pipe/pipe.go
@@ -0,0 +1,115 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package pipe provides the implementation of pipe-like data-link layer
+// endpoints. Such endpoints allow packets to be sent between two interfaces.
+package pipe
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+var _ stack.LinkEndpoint = (*Endpoint)(nil)
+
+// New returns both ends of a new pipe.
+func New(linkAddr1, linkAddr2 tcpip.LinkAddress) (*Endpoint, *Endpoint) {
+ ep1 := &Endpoint{
+ linkAddr: linkAddr1,
+ }
+ ep2 := &Endpoint{
+ linkAddr: linkAddr2,
+ }
+ ep1.linked = ep2
+ ep2.linked = ep1
+ return ep1, ep2
+}
+
+// Endpoint is one end of a pipe.
+type Endpoint struct {
+ dispatcher stack.NetworkDispatcher
+ linked *Endpoint
+ linkAddr tcpip.LinkAddress
+}
+
+// WritePacket implements stack.LinkEndpoint.
+func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error {
+ if !e.linked.IsAttached() {
+ return nil
+ }
+
+ // Note that the local address from the perspective of this endpoint is the
+ // remote address from the perspective of the other end of the pipe
+ // (e.linked). Similarly, the remote address from the perspective of this
+ // endpoint is the local address on the other end.
+ e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()),
+ }))
+
+ return nil
+}
+
+// WritePackets implements stack.LinkEndpoint.
+func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
+ panic("not implemented")
+}
+
+// WriteRawPacket implements stack.LinkEndpoint.
+func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error {
+ panic("not implemented")
+}
+
+// Attach implements stack.LinkEndpoint.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// IsAttached implements stack.LinkEndpoint.
+func (e *Endpoint) IsAttached() bool {
+ return e.dispatcher != nil
+}
+
+// Wait implements stack.LinkEndpoint.
+func (*Endpoint) Wait() {}
+
+// MTU implements stack.LinkEndpoint.
+func (*Endpoint) MTU() uint32 {
+ return header.IPv6MinimumMTU
+}
+
+// Capabilities implements stack.LinkEndpoint.
+func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.
+func (*Endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress implements stack.LinkEndpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+// ARPHardwareType implements stack.LinkEndpoint.
+func (*Endpoint) ARPHardwareType() header.ARPHardwareType {
+ return header.ARPHardwareNone
+}
+
+// AddHeader implements stack.LinkEndpoint.
+func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) {
+}
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
index 14b527bc2..6c410c5a6 100644
--- a/pkg/tcpip/link/rawfile/BUILD
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -18,3 +18,14 @@ go_library(
"@org_golang_x_sys//unix:go_default_library",
],
)
+
+go_test(
+ name = "rawfile_test",
+ srcs = [
+ "errors_test.go",
+ ],
+ library = "rawfile",
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index 99313ee25..5db4bf12b 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -14,7 +14,7 @@
// +build linux,amd64 linux,arm64
// +build go1.12
-// +build !go1.16
+// +build !go1.17
// Check go:linkname function signatures when updating Go version.
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index a0a873c84..604868fd8 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error
// *tcpip.Error.
//
// Valid, but unrecognized errnos will be translated to
-// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
+// tcpip.ErrInvalidEndpointState (EINVAL).
func TranslateErrno(e syscall.Errno) *tcpip.Error {
- if err := translations[e]; err != nil {
- return err
+ if e > 0 && e < syscall.Errno(len(translations)) {
+ if err := translations[e]; err != nil {
+ return err
+ }
}
return tcpip.ErrInvalidEndpointState
}
diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go
new file mode 100644
index 000000000..e4cdc66bd
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/errors_test.go
@@ -0,0 +1,53 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// +build linux
+
+package rawfile
+
+import (
+ "syscall"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+func TestTranslateErrno(t *testing.T) {
+ for _, test := range []struct {
+ errno syscall.Errno
+ translated *tcpip.Error
+ }{
+ {
+ errno: syscall.Errno(0),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(maxErrno),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.Errno(514),
+ translated: tcpip.ErrInvalidEndpointState,
+ },
+ {
+ errno: syscall.EEXIST,
+ translated: tcpip.ErrDuplicateAddress,
+ },
+ } {
+ got := TranslateErrno(test.errno)
+ if got != test.translated {
+ t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
index dc239a0d0..2777f1411 100644
--- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) {
const count = 1000000
var wg sync.WaitGroup
+ defer wg.Wait()
wg.Add(1)
go func() {
defer wg.Done()
@@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) {
}
}()
- wg.Add(1)
- go func() {
- defer wg.Done()
- runtime.Gosched()
- for i := 0; i < count; i++ {
- n := 1 + rr.Intn(80)
- rb := rx.Pull()
- for rb == nil {
- rb = rx.Pull()
- }
+ for i := 0; i < count; i++ {
+ n := 1 + rr.Intn(80)
+ rb := rx.Pull()
+ for rb == nil {
+ rb = rx.Pull()
+ }
- if n != len(rb) {
- t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
- }
+ if n != len(rb) {
+ t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
+ }
- for j := range rb {
- if v := byte(rr.Intn(256)); v != rb[j] {
- t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
- }
+ for j := range rb {
+ if v := byte(rr.Intn(256)); v != rb[j] {
+ t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
}
-
- rx.Flush()
}
- }()
- wg.Wait()
+ rx.Flush()
+ }
}
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
index 7cbc305e7..4aac12a8c 100644
--- a/pkg/tcpip/link/sniffer/BUILD
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -14,6 +14,7 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
+ "//pkg/tcpip/header/parse",
"//pkg/tcpip/link/nested",
"//pkg/tcpip/stack",
],
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 4fb127978..560477926 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/header/parse"
"gvisor.dev/gvisor/pkg/tcpip/link/nested"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -195,49 +196,52 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
var transProto uint8
src := tcpip.Address("unknown")
dst := tcpip.Address("unknown")
- id := 0
- size := uint16(0)
+ var size uint16
+ var id uint32
var fragmentOffset uint16
var moreFragments bool
- // Examine the packet using a new VV. Backing storage must not be written.
- vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views())
-
+ // Clone the packet buffer to not modify the original.
+ //
+ // We don't clone the original packet buffer so that the new packet buffer
+ // does not have any of its headers set.
+ pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())})
switch protocol {
case header.IPv4ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv4MinimumSize)
- if !ok {
+ if ok := parse.IPv4(pkt); !ok {
return
}
- ipv4 := header.IPv4(hdr)
+
+ ipv4 := header.IPv4(pkt.NetworkHeader().View())
fragmentOffset = ipv4.FragmentOffset()
moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
- vv.TrimFront(int(ipv4.HeaderLength()))
- id = int(ipv4.ID())
+ id = uint32(ipv4.ID())
case header.IPv6ProtocolNumber:
- hdr, ok := vv.PullUp(header.IPv6MinimumSize)
+ proto, fragID, fragOffset, fragMore, ok := parse.IPv6(pkt)
if !ok {
return
}
- ipv6 := header.IPv6(hdr)
+
+ ipv6 := header.IPv6(pkt.NetworkHeader().View())
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
- transProto = ipv6.NextHeader()
+ transProto = uint8(proto)
size = ipv6.PayloadLength()
- vv.TrimFront(header.IPv6MinimumSize)
+ id = fragID
+ moreFragments = fragMore
+ fragmentOffset = fragOffset
case header.ARPProtocolNumber:
- hdr, ok := vv.PullUp(header.ARPSize)
- if !ok {
+ if parse.ARP(pkt) {
return
}
- vv.TrimFront(header.ARPSize)
- arp := header.ARP(hdr)
+
+ arp := header.ARP(pkt.NetworkHeader().View())
log.Infof(
"%s arp %s (%s) -> %s (%s) valid:%t",
prefix,
@@ -259,7 +263,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv4MinimumSize)
+ hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
if !ok {
break
}
@@ -296,7 +300,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.ICMPv6ProtocolNumber:
transName = "icmp"
- hdr, ok := vv.PullUp(header.ICMPv6MinimumSize)
+ hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
if !ok {
break
}
@@ -331,11 +335,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.UDPProtocolNumber:
transName = "udp"
- hdr, ok := vv.PullUp(header.UDPMinimumSize)
- if !ok {
+ if ok := parse.UDP(pkt); !ok {
break
}
- udp := header.UDP(hdr)
+
+ udp := header.UDP(pkt.TransportHeader().View())
if fragmentOffset == 0 {
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
@@ -345,19 +349,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P
case header.TCPProtocolNumber:
transName = "tcp"
- hdr, ok := vv.PullUp(header.TCPMinimumSize)
- if !ok {
+ if ok := parse.TCP(pkt); !ok {
break
}
- tcp := header.TCP(hdr)
+
+ tcp := header.TCP(pkt.TransportHeader().View())
if fragmentOffset == 0 {
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {
details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
break
}
- if offset > vv.Size() && !moreFragments {
- details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size())
+ if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments {
+ details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size)
break
}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
index 6c137f693..86f14db76 100644
--- a/pkg/tcpip/link/tun/BUILD
+++ b/pkg/tcpip/link/tun/BUILD
@@ -1,19 +1,34 @@
load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "tun_endpoint_refs",
+ out = "tun_endpoint_refs.go",
+ package = "tun",
+ prefix = "tunEndpoint",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "tunEndpoint",
+ },
+)
+
go_library(
name = "tun",
srcs = [
"device.go",
"protocol.go",
+ "tun_endpoint_refs.go",
"tun_unsafe.go",
],
visibility = ["//visibility:public"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sync",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go
index 3b1510a33..cda6328a2 100644
--- a/pkg/tcpip/link/tun/device.go
+++ b/pkg/tcpip/link/tun/device.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -77,13 +76,29 @@ func (d *Device) Release(ctx context.Context) {
}
}
+// NICID returns the NIC ID of the device.
+//
+// Must only be called after the device has been attached to an endpoint.
+func (d *Device) NICID() tcpip.NICID {
+ d.mu.RLock()
+ defer d.mu.RUnlock()
+
+ if d.endpoint == nil {
+ panic("called NICID on a device that has not been attached")
+ }
+
+ return d.endpoint.nicID
+}
+
// SetIff services TUNSETIFF ioctl(2) request.
-func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
+//
+// Returns true if a new NIC was created; false if an existing one was attached.
+func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) {
d.mu.Lock()
defer d.mu.Unlock()
if d.endpoint != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
// Input validations.
@@ -91,7 +106,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
isTap := flags&linux.IFF_TAP != 0
supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI)
if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
prefix := "tun"
@@ -104,32 +119,32 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error {
linkCaps |= stack.CapabilityResolutionRequired
}
- endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps)
+ endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps)
if err != nil {
- return syserror.EINVAL
+ return false, syserror.EINVAL
}
d.endpoint = endpoint
d.notifyHandle = d.endpoint.AddNotify(d)
d.flags = flags
- return nil
+ return created, nil
}
-func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) {
+func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) {
for {
// 1. Try to attach to an existing NIC.
if name != "" {
- if nic, found := s.GetNICByName(name); found {
- endpoint, ok := nic.LinkEndpoint().(*tunEndpoint)
+ if linkEP := s.GetLinkEndpointByName(name); linkEP != nil {
+ endpoint, ok := linkEP.(*tunEndpoint)
if !ok {
// Not a NIC created by tun device.
- return nil, syserror.EOPNOTSUPP
+ return nil, false, syserror.EOPNOTSUPP
}
if !endpoint.TryIncRef() {
// Race detected: NIC got deleted in between.
continue
}
- return endpoint, nil
+ return endpoint, false, nil
}
}
@@ -142,6 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
name: name,
isTap: prefix == "tap",
}
+ endpoint.EnableLeakCheck()
endpoint.Endpoint.LinkEPCapabilities = linkCaps
if endpoint.name == "" {
endpoint.name = fmt.Sprintf("%s%d", prefix, id)
@@ -151,12 +167,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE
})
switch err {
case nil:
- return endpoint, nil
+ return endpoint, true, nil
case tcpip.ErrDuplicateNICID:
// Race detected: A NIC has been created in between.
continue
default:
- return nil, syserror.EINVAL
+ return nil, false, syserror.EINVAL
}
}
}
@@ -331,19 +347,18 @@ func (d *Device) WriteNotify() {
// It is ref-counted as multiple opening files can attach to the same NIC.
// The last owner is responsible for deleting the NIC.
type tunEndpoint struct {
+ tunEndpointRefs
*channel.Endpoint
- refs.AtomicRefCount
-
stack *stack.Stack
nicID tcpip.NICID
name string
isTap bool
}
-// DecRef decrements refcount of e, removes NIC if refcount goes to 0.
+// DecRef decrements refcount of e, removing NIC if it reaches 0.
func (e *tunEndpoint) DecRef(ctx context.Context) {
- e.DecRefWithDestructor(ctx, func(context.Context) {
+ e.tunEndpointRefs.DecRef(func() {
e.stack.RemoveNIC(e.nicID)
})
}