diff options
Diffstat (limited to 'pkg/tcpip/link')
-rw-r--r-- | pkg/tcpip/link/ethernet/BUILD | 15 | ||||
-rw-r--r-- | pkg/tcpip/link/ethernet/ethernet.go | 99 | ||||
-rw-r--r-- | pkg/tcpip/link/pipe/BUILD | 15 | ||||
-rw-r--r-- | pkg/tcpip/link/pipe/pipe.go | 115 | ||||
-rw-r--r-- | pkg/tcpip/link/rawfile/BUILD | 13 | ||||
-rw-r--r-- | pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/link/rawfile/errors.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/link/rawfile/errors_test.go | 53 | ||||
-rw-r--r-- | pkg/tcpip/link/sharedmem/pipe/pipe_test.go | 36 | ||||
-rw-r--r-- | pkg/tcpip/link/sniffer/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/link/sniffer/sniffer.go | 60 | ||||
-rw-r--r-- | pkg/tcpip/link/tun/BUILD | 15 | ||||
-rw-r--r-- | pkg/tcpip/link/tun/device.go | 51 |
13 files changed, 411 insertions, 72 deletions
diff --git a/pkg/tcpip/link/ethernet/BUILD b/pkg/tcpip/link/ethernet/BUILD new file mode 100644 index 000000000..ec92ed623 --- /dev/null +++ b/pkg/tcpip/link/ethernet/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "ethernet", + srcs = ["ethernet.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/header", + "//pkg/tcpip/link/nested", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/ethernet/ethernet.go b/pkg/tcpip/link/ethernet/ethernet.go new file mode 100644 index 000000000..3eef7cd56 --- /dev/null +++ b/pkg/tcpip/link/ethernet/ethernet.go @@ -0,0 +1,99 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ethernet provides an implementation of an ethernet link endpoint that +// wraps an inner link endpoint. +package ethernet + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/nested" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.NetworkDispatcher = (*Endpoint)(nil) +var _ stack.LinkEndpoint = (*Endpoint)(nil) + +// New returns an ethernet link endpoint that wraps an inner link endpoint. +func New(ep stack.LinkEndpoint) *Endpoint { + var e Endpoint + e.Endpoint.Init(ep, &e) + return &e +} + +// Endpoint is an ethernet endpoint. +// +// It adds an ethernet header to packets before sending them out through its +// inner link endpoint and consumes an ethernet header before sending the +// packet to the stack. +type Endpoint struct { + nested.Endpoint +} + +// DeliverNetworkPacket implements stack.NetworkDispatcher. +func (e *Endpoint) DeliverNetworkPacket(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + hdr, ok := pkt.LinkHeader().Consume(header.EthernetMinimumSize) + if !ok { + return + } + + eth := header.Ethernet(hdr) + if dst := eth.DestinationAddress(); dst == e.Endpoint.LinkAddress() || dst == header.EthernetBroadcastAddress || header.IsMulticastEthernetAddress(dst) { + e.Endpoint.DeliverNetworkPacket(eth.SourceAddress() /* remote */, dst /* local */, eth.Type() /* protocol */, pkt) + } +} + +// Capabilities implements stack.LinkEndpoint. +func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityResolutionRequired | e.Endpoint.Capabilities() +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, gso *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + e.AddHeader(e.Endpoint.LinkAddress(), r.RemoteLinkAddress, proto, pkt) + return e.Endpoint.WritePacket(r, gso, proto, pkt) +} + +// WritePackets implements stack.LinkEndpoint. +func (e *Endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.PacketBufferList, proto tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + linkAddr := e.Endpoint.LinkAddress() + + for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { + e.AddHeader(linkAddr, r.RemoteLinkAddress, proto, pkt) + } + + return e.Endpoint.WritePackets(r, gso, pkts, proto) +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (e *Endpoint) MaxHeaderLength() uint16 { + return header.EthernetMinimumSize + e.Endpoint.MaxHeaderLength() +} + +// ARPHardwareType implements stack.LinkEndpoint. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareEther +} + +// AddHeader implements stack.LinkEndpoint. +func (*Endpoint) AddHeader(local, remote tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize)) + fields := header.EthernetFields{ + SrcAddr: local, + DstAddr: remote, + Type: proto, + } + eth.Encode(&fields) +} diff --git a/pkg/tcpip/link/pipe/BUILD b/pkg/tcpip/link/pipe/BUILD new file mode 100644 index 000000000..9f31c1ffc --- /dev/null +++ b/pkg/tcpip/link/pipe/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "pipe", + srcs = ["pipe.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + ], +) diff --git a/pkg/tcpip/link/pipe/pipe.go b/pkg/tcpip/link/pipe/pipe.go new file mode 100644 index 000000000..523b0d24b --- /dev/null +++ b/pkg/tcpip/link/pipe/pipe.go @@ -0,0 +1,115 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pipe provides the implementation of pipe-like data-link layer +// endpoints. Such endpoints allow packets to be sent between two interfaces. +package pipe + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +var _ stack.LinkEndpoint = (*Endpoint)(nil) + +// New returns both ends of a new pipe. +func New(linkAddr1, linkAddr2 tcpip.LinkAddress) (*Endpoint, *Endpoint) { + ep1 := &Endpoint{ + linkAddr: linkAddr1, + } + ep2 := &Endpoint{ + linkAddr: linkAddr2, + } + ep1.linked = ep2 + ep2.linked = ep1 + return ep1, ep2 +} + +// Endpoint is one end of a pipe. +type Endpoint struct { + dispatcher stack.NetworkDispatcher + linked *Endpoint + linkAddr tcpip.LinkAddress +} + +// WritePacket implements stack.LinkEndpoint. +func (e *Endpoint) WritePacket(r *stack.Route, _ *stack.GSO, proto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + if !e.linked.IsAttached() { + return nil + } + + // Note that the local address from the perspective of this endpoint is the + // remote address from the perspective of the other end of the pipe + // (e.linked). Similarly, the remote address from the perspective of this + // endpoint is the local address on the other end. + e.linked.dispatcher.DeliverNetworkPacket(r.LocalLinkAddress /* remote */, r.RemoteLinkAddress /* local */, proto, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), + })) + + return nil +} + +// WritePackets implements stack.LinkEndpoint. +func (*Endpoint) WritePackets(*stack.Route, *stack.GSO, stack.PacketBufferList, tcpip.NetworkProtocolNumber) (int, *tcpip.Error) { + panic("not implemented") +} + +// WriteRawPacket implements stack.LinkEndpoint. +func (*Endpoint) WriteRawPacket(buffer.VectorisedView) *tcpip.Error { + panic("not implemented") +} + +// Attach implements stack.LinkEndpoint. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint. +func (e *Endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// Wait implements stack.LinkEndpoint. +func (*Endpoint) Wait() {} + +// MTU implements stack.LinkEndpoint. +func (*Endpoint) MTU() uint32 { + return header.IPv6MinimumMTU +} + +// Capabilities implements stack.LinkEndpoint. +func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +// MaxHeaderLength implements stack.LinkEndpoint. +func (*Endpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress implements stack.LinkEndpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +// ARPHardwareType implements stack.LinkEndpoint. +func (*Endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +// AddHeader implements stack.LinkEndpoint. +func (*Endpoint) AddHeader(_, _ tcpip.LinkAddress, _ tcpip.NetworkProtocolNumber, _ *stack.PacketBuffer) { +} diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD index 14b527bc2..6c410c5a6 100644 --- a/pkg/tcpip/link/rawfile/BUILD +++ b/pkg/tcpip/link/rawfile/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -18,3 +18,14 @@ go_library( "@org_golang_x_sys//unix:go_default_library", ], ) + +go_test( + name = "rawfile_test", + srcs = [ + "errors_test.go", + ], + library = "rawfile", + deps = [ + "//pkg/tcpip", + ], +) diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index 99313ee25..5db4bf12b 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -14,7 +14,7 @@ // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.16 +// +build !go1.17 // Check go:linkname function signatures when updating Go version. diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go index a0a873c84..604868fd8 100644 --- a/pkg/tcpip/link/rawfile/errors.go +++ b/pkg/tcpip/link/rawfile/errors.go @@ -31,10 +31,12 @@ var translations [maxErrno]*tcpip.Error // *tcpip.Error. // // Valid, but unrecognized errnos will be translated to -// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos. +// tcpip.ErrInvalidEndpointState (EINVAL). func TranslateErrno(e syscall.Errno) *tcpip.Error { - if err := translations[e]; err != nil { - return err + if e > 0 && e < syscall.Errno(len(translations)) { + if err := translations[e]; err != nil { + return err + } } return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/link/rawfile/errors_test.go b/pkg/tcpip/link/rawfile/errors_test.go new file mode 100644 index 000000000..e4cdc66bd --- /dev/null +++ b/pkg/tcpip/link/rawfile/errors_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "syscall" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" +) + +func TestTranslateErrno(t *testing.T) { + for _, test := range []struct { + errno syscall.Errno + translated *tcpip.Error + }{ + { + errno: syscall.Errno(0), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(maxErrno), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.Errno(514), + translated: tcpip.ErrInvalidEndpointState, + }, + { + errno: syscall.EEXIST, + translated: tcpip.ErrDuplicateAddress, + }, + } { + got := TranslateErrno(test.errno) + if got != test.translated { + t.Errorf("TranslateErrno(%q) = %q, want %q", test.errno, got, test.translated) + } + } +} diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index dc239a0d0..2777f1411 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) { const count = 1000000 var wg sync.WaitGroup + defer wg.Wait() wg.Add(1) go func() { defer wg.Done() @@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) { } }() - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } + for i := 0; i < count; i++ { + n := 1 + rr.Intn(80) + rb := rx.Pull() + for rb == nil { + rb = rx.Pull() + } - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } + if n != len(rb) { + t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) + } - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } + for j := range rb { + if v := byte(rr.Intn(256)); v != rb[j] { + t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) } - - rx.Flush() } - }() - wg.Wait() + rx.Flush() + } } diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD index 7cbc305e7..4aac12a8c 100644 --- a/pkg/tcpip/link/sniffer/BUILD +++ b/pkg/tcpip/link/sniffer/BUILD @@ -14,6 +14,7 @@ go_library( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", + "//pkg/tcpip/header/parse", "//pkg/tcpip/link/nested", "//pkg/tcpip/stack", ], diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 4fb127978..560477926 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/header/parse" "gvisor.dev/gvisor/pkg/tcpip/link/nested" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -195,49 +196,52 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P var transProto uint8 src := tcpip.Address("unknown") dst := tcpip.Address("unknown") - id := 0 - size := uint16(0) + var size uint16 + var id uint32 var fragmentOffset uint16 var moreFragments bool - // Examine the packet using a new VV. Backing storage must not be written. - vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) - + // Clone the packet buffer to not modify the original. + // + // We don't clone the original packet buffer so that the new packet buffer + // does not have any of its headers set. + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())}) switch protocol { case header.IPv4ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv4MinimumSize) - if !ok { + if ok := parse.IPv4(pkt); !ok { return } - ipv4 := header.IPv4(hdr) + + ipv4 := header.IPv4(pkt.NetworkHeader().View()) fragmentOffset = ipv4.FragmentOffset() moreFragments = ipv4.Flags()&header.IPv4FlagMoreFragments == header.IPv4FlagMoreFragments src = ipv4.SourceAddress() dst = ipv4.DestinationAddress() transProto = ipv4.Protocol() size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) - vv.TrimFront(int(ipv4.HeaderLength())) - id = int(ipv4.ID()) + id = uint32(ipv4.ID()) case header.IPv6ProtocolNumber: - hdr, ok := vv.PullUp(header.IPv6MinimumSize) + proto, fragID, fragOffset, fragMore, ok := parse.IPv6(pkt) if !ok { return } - ipv6 := header.IPv6(hdr) + + ipv6 := header.IPv6(pkt.NetworkHeader().View()) src = ipv6.SourceAddress() dst = ipv6.DestinationAddress() - transProto = ipv6.NextHeader() + transProto = uint8(proto) size = ipv6.PayloadLength() - vv.TrimFront(header.IPv6MinimumSize) + id = fragID + moreFragments = fragMore + fragmentOffset = fragOffset case header.ARPProtocolNumber: - hdr, ok := vv.PullUp(header.ARPSize) - if !ok { + if parse.ARP(pkt) { return } - vv.TrimFront(header.ARPSize) - arp := header.ARP(hdr) + + arp := header.ARP(pkt.NetworkHeader().View()) log.Infof( "%s arp %s (%s) -> %s (%s) valid:%t", prefix, @@ -259,7 +263,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P switch tcpip.TransportProtocolNumber(transProto) { case header.ICMPv4ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv4MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) if !ok { break } @@ -296,7 +300,7 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.ICMPv6ProtocolNumber: transName = "icmp" - hdr, ok := vv.PullUp(header.ICMPv6MinimumSize) + hdr, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) if !ok { break } @@ -331,11 +335,11 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.UDPProtocolNumber: transName = "udp" - hdr, ok := vv.PullUp(header.UDPMinimumSize) - if !ok { + if ok := parse.UDP(pkt); !ok { break } - udp := header.UDP(hdr) + + udp := header.UDP(pkt.TransportHeader().View()) if fragmentOffset == 0 { srcPort = udp.SourcePort() dstPort = udp.DestinationPort() @@ -345,19 +349,19 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P case header.TCPProtocolNumber: transName = "tcp" - hdr, ok := vv.PullUp(header.TCPMinimumSize) - if !ok { + if ok := parse.TCP(pkt); !ok { break } - tcp := header.TCP(hdr) + + tcp := header.TCP(pkt.TransportHeader().View()) if fragmentOffset == 0 { offset := int(tcp.DataOffset()) if offset < header.TCPMinimumSize { details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) break } - if offset > vv.Size() && !moreFragments { - details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, vv.Size()) + if size := pkt.Data.Size() + len(tcp); offset > size && !moreFragments { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than tcp packet length %d", offset, size) break } diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 6c137f693..86f14db76 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -1,19 +1,34 @@ load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "tun_endpoint_refs", + out = "tun_endpoint_refs.go", + package = "tun", + prefix = "tunEndpoint", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "tunEndpoint", + }, +) + go_library( name = "tun", srcs = [ "device.go", "protocol.go", + "tun_endpoint_refs.go", "tun_unsafe.go", ], visibility = ["//visibility:public"], deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index 3b1510a33..cda6328a2 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -19,7 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -77,13 +76,29 @@ func (d *Device) Release(ctx context.Context) { } } +// NICID returns the NIC ID of the device. +// +// Must only be called after the device has been attached to an endpoint. +func (d *Device) NICID() tcpip.NICID { + d.mu.RLock() + defer d.mu.RUnlock() + + if d.endpoint == nil { + panic("called NICID on a device that has not been attached") + } + + return d.endpoint.nicID +} + // SetIff services TUNSETIFF ioctl(2) request. -func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { +// +// Returns true if a new NIC was created; false if an existing one was attached. +func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) (bool, error) { d.mu.Lock() defer d.mu.Unlock() if d.endpoint != nil { - return syserror.EINVAL + return false, syserror.EINVAL } // Input validations. @@ -91,7 +106,7 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { isTap := flags&linux.IFF_TAP != 0 supportedFlags := uint16(linux.IFF_TUN | linux.IFF_TAP | linux.IFF_NO_PI) if isTap && isTun || !isTap && !isTun || flags&^supportedFlags != 0 { - return syserror.EINVAL + return false, syserror.EINVAL } prefix := "tun" @@ -104,32 +119,32 @@ func (d *Device) SetIff(s *stack.Stack, name string, flags uint16) error { linkCaps |= stack.CapabilityResolutionRequired } - endpoint, err := attachOrCreateNIC(s, name, prefix, linkCaps) + endpoint, created, err := attachOrCreateNIC(s, name, prefix, linkCaps) if err != nil { - return syserror.EINVAL + return false, syserror.EINVAL } d.endpoint = endpoint d.notifyHandle = d.endpoint.AddNotify(d) d.flags = flags - return nil + return created, nil } -func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, error) { +func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkEndpointCapabilities) (*tunEndpoint, bool, error) { for { // 1. Try to attach to an existing NIC. if name != "" { - if nic, found := s.GetNICByName(name); found { - endpoint, ok := nic.LinkEndpoint().(*tunEndpoint) + if linkEP := s.GetLinkEndpointByName(name); linkEP != nil { + endpoint, ok := linkEP.(*tunEndpoint) if !ok { // Not a NIC created by tun device. - return nil, syserror.EOPNOTSUPP + return nil, false, syserror.EOPNOTSUPP } if !endpoint.TryIncRef() { // Race detected: NIC got deleted in between. continue } - return endpoint, nil + return endpoint, false, nil } } @@ -142,6 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE name: name, isTap: prefix == "tap", } + endpoint.EnableLeakCheck() endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { endpoint.name = fmt.Sprintf("%s%d", prefix, id) @@ -151,12 +167,12 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE }) switch err { case nil: - return endpoint, nil + return endpoint, true, nil case tcpip.ErrDuplicateNICID: // Race detected: A NIC has been created in between. continue default: - return nil, syserror.EINVAL + return nil, false, syserror.EINVAL } } } @@ -331,19 +347,18 @@ func (d *Device) WriteNotify() { // It is ref-counted as multiple opening files can attach to the same NIC. // The last owner is responsible for deleting the NIC. type tunEndpoint struct { + tunEndpointRefs *channel.Endpoint - refs.AtomicRefCount - stack *stack.Stack nicID tcpip.NICID name string isTap bool } -// DecRef decrements refcount of e, removes NIC if refcount goes to 0. +// DecRef decrements refcount of e, removing NIC if it reaches 0. func (e *tunEndpoint) DecRef(ctx context.Context) { - e.DecRefWithDestructor(ctx, func(context.Context) { + e.tunEndpointRefs.DecRef(func() { e.stack.RemoveNIC(e.nicID) }) } |