summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/buffer/BUILD5
-rw-r--r--pkg/tcpip/buffer/prependable.go18
-rw-r--r--pkg/tcpip/buffer/prependable_test.go50
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go4
-rw-r--r--pkg/tcpip/stack/nic.go7
-rw-r--r--pkg/tcpip/stack/stack.go75
-rw-r--r--pkg/tcpip/stack/stack_test.go5
-rw-r--r--pkg/tcpip/stack/transport_test.go2
8 files changed, 141 insertions, 25 deletions
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index d6c31bfa2..a7bf0c4dc 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -16,6 +16,9 @@ go_library(
go_test(
name = "buffer_test",
size = "small",
- srcs = ["view_test.go"],
+ srcs = [
+ "prependable_test.go",
+ "view_test.go",
+ ],
embed = [":buffer"],
)
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
index 48a2a2713..2f9a23d61 100644
--- a/pkg/tcpip/buffer/prependable.go
+++ b/pkg/tcpip/buffer/prependable.go
@@ -32,13 +32,19 @@ func NewPrependable(size int) Prependable {
return Prependable{buf: NewView(size), usedIdx: size}
}
-// NewPrependableFromView creates an entirely-used Prependable from a View.
+// NewPrependableFromView creates a Prependable from a View and allocates
+// additional space if needed.
//
-// NewPrependableFromView takes ownership of v. Note that since the entire
-// prependable is used, further attempts to call Prepend will note that size >
-// p.usedIdx and return nil.
-func NewPrependableFromView(v View) Prependable {
- return Prependable{buf: v, usedIdx: 0}
+// NewPrependableFromView takes ownership of v. Note that if the entire
+// prependable is used, further attempts to call Prepend will note that
+// size > p.usedIdx and return nil.
+func NewPrependableFromView(v View, extraCap int) Prependable {
+ if extraCap == 0 {
+ return Prependable{buf: v, usedIdx: 0}
+ }
+ buf := make([]byte, extraCap, extraCap + len(v))
+ buf = append(buf, v...)
+ return Prependable{buf: buf, usedIdx: extraCap}
}
// NewEmptyPrependableFromView creates a new prependable buffer from a View.
diff --git a/pkg/tcpip/buffer/prependable_test.go b/pkg/tcpip/buffer/prependable_test.go
new file mode 100644
index 000000000..43660c307
--- /dev/null
+++ b/pkg/tcpip/buffer/prependable_test.go
@@ -0,0 +1,50 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package buffer
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestNewPrependableFromView(t *testing.T) {
+ tests := []struct {
+ comment string
+ view View
+ extraSize int
+ want Prependable
+ }{
+ {
+ comment: "Reserve extra space",
+ view: View("abc"),
+ extraSize: 2,
+ want: Prependable{buf: View("\x00\x00abc"), usedIdx: 2},
+ },
+ {
+ comment: "Don't reserve extra space",
+ view: View("abc"),
+ extraSize: 0,
+ want: Prependable{buf: View("abc"), usedIdx: 0},
+ },
+ }
+
+ for _, testCase := range tests {
+ t.Run(testCase.comment, func(t *testing.T) {
+ prep := NewPrependableFromView(testCase.view, testCase.extraSize)
+ if !reflect.DeepEqual(prep, testCase.want) {
+ t.Errorf("NewPrependableFromView(%#v, %d) = %#v; want %#v", testCase.view, testCase.extraSize, prep, testCase.want)
+ }
+ } )
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 1339f8474..90f4406e5 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -307,7 +307,9 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
return nil
}
- hdr := buffer.NewPrependableFromView(payload.ToView())
+ // If we want to send the packet to a link-layer,
+ // we have to reserve space for an Ethernet header.
+ hdr := buffer.NewPrependableFromView(payload.ToView(), int(e.linkEP.MaxHeaderLength()))
r.Stats().IP.PacketsSent.Increment()
return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a867f8c00..ab6798aa6 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -780,7 +780,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
// packet and forward it to the NIC.
//
// TODO: Should we be forwarding the packet even if promiscuous?
- if n.stack.Forwarding() {
+ if n.stack.Forwarding(protocol) {
r, err := n.stack.FindRoute(0, "", dst, protocol, false /* multicastLoop */)
if err != nil {
n.stack.stats.IP.InvalidAddressesReceived.Increment()
@@ -805,9 +805,12 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
} else {
// n doesn't have a destination endpoint.
// Send the packet out of n.
- hdr := buffer.NewPrependableFromView(vv.First())
+ // If we want to send the packet to a link-layer,
+ // we have to reserve space for an Ethernet header.
+ hdr := buffer.NewPrependableFromView(vv.First(), int(n.linkEP.MaxHeaderLength()))
vv.RemoveFirst()
+ // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6.
// TODO(b/128629022): use route.WritePacket.
if err := n.linkEP.WritePacket(&r, nil /* gso */, hdr, vv, protocol); err != nil {
r.Stats().IP.OutgoingPacketErrors.Increment()
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 242d2150c..71e0618f4 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -21,7 +21,9 @@ package stack
import (
"encoding/binary"
+ "math"
"sync"
+ "sync/atomic"
"time"
"golang.org/x/time/rate"
@@ -48,6 +50,42 @@ const (
DefaultTOS = 0
)
+const (
+ // fakeNetNumber is used as a protocol number in tests.
+ //
+ // This constant should match fakeNetNumber in stack_test.go.
+ fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+)
+
+type forwardingFlag uint32
+
+// Packet forwarding flags. Forwarding settings for different network protocols
+// are stored as bit flags in an uint32 number.
+const (
+ forwardingIPv4 forwardingFlag = 1 << iota
+ forwardingIPv6
+
+ // forwardingFake is used to test package forwarding with a fake protocol.
+ forwardingFake
+)
+
+func getForwardingFlag(protocol tcpip.NetworkProtocolNumber) forwardingFlag {
+ var flag forwardingFlag
+ switch protocol {
+ case header.IPv4ProtocolNumber:
+ flag = forwardingIPv4
+ case header.IPv6ProtocolNumber:
+ flag = forwardingIPv6
+ case fakeNetNumber:
+ // This network protocol number is used in stack_test to test
+ // packet forwarding.
+ flag = forwardingFake
+ default:
+ // We only support forwarding for IPv4 and IPv6.
+ }
+ return flag
+}
+
type transportProtocolState struct {
proto TransportProtocol
defaultHandler func(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool
@@ -363,7 +401,10 @@ type Stack struct {
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
- forwarding bool
+
+ // forwarding contains the enable bits for packet forwarding for different
+ // network protocols.
+ forwarding uint32
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
@@ -630,20 +671,28 @@ func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
-// SetForwarding enables or disables the packet forwarding between NICs.
-func (s *Stack) SetForwarding(enable bool) {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.Lock()
- s.forwarding = enable
- s.mu.Unlock()
+// SetForwarding enables or disables packet forwarding between NICs.
+func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) {
+ flag := getForwardingFlag(protocol)
+ for {
+ forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding))
+ var newValue forwardingFlag
+ if enable {
+ newValue = forwarding | flag
+ } else {
+ newValue = forwarding & ^flag
+ }
+ if atomic.CompareAndSwapUint32(&s.forwarding, uint32(forwarding), uint32(newValue)) {
+ break
+ }
+ }
}
-// Forwarding returns if the packet forwarding between NICs is enabled.
-func (s *Stack) Forwarding() bool {
- // TODO(igudger, bgeffon): Expose via /proc/sys/net/ipv4/ip_forward.
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.forwarding
+// Forwarding returns if packet forwarding between NICs is enabled.
+func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
+ flag := getForwardingFlag(protocol)
+ forwarding := forwardingFlag(atomic.LoadUint32(&s.forwarding))
+ return forwarding & flag != 0
}
// SetRouteTable assigns the route table to be used by this stack. It
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 9dae853d0..ef3d1beb0 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -36,6 +36,9 @@ import (
)
const (
+ // fakeNetNumber is used as a protocol number in tests.
+ //
+ // This constant should match fakeNetNumber in stack.go.
fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
fakeNetHeaderLen = 12
fakeDefaultPrefixLen = 8
@@ -1825,7 +1828,7 @@ func TestNICForwarding(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 86c62be25..6d3daed24 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -528,7 +528,7 @@ func TestTransportForwarding(t *testing.T) {
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
})
- s.SetForwarding(true)
+ s.SetForwarding(fakeNetNumber, true)
// TODO(b/123449044): Change this to a channel NIC.
ep1 := loopback.New()