summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/socketops.go5
-rw-r--r--pkg/tcpip/transport/icmp/BUILD21
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go20
-rw-r--r--pkg/tcpip/transport/icmp/icmp_test.go235
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go39
5 files changed, 294 insertions, 26 deletions
diff --git a/pkg/tcpip/socketops.go b/pkg/tcpip/socketops.go
index b7c2de652..0ea85f9ed 100644
--- a/pkg/tcpip/socketops.go
+++ b/pkg/tcpip/socketops.go
@@ -601,9 +601,10 @@ func (so *SocketOptions) GetBindToDevice() int32 {
return atomic.LoadInt32(&so.bindToDevice)
}
-// SetBindToDevice sets value for SO_BINDTODEVICE option.
+// SetBindToDevice sets value for SO_BINDTODEVICE option. If bindToDevice is
+// zero, the socket device binding is removed.
func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
- if !so.handler.HasNIC(bindToDevice) {
+ if bindToDevice != 0 && !so.handler.HasNIC(bindToDevice) {
return &ErrUnknownDevice{}
}
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD
index 7e5c79776..bbc0e3ecc 100644
--- a/pkg/tcpip/transport/icmp/BUILD
+++ b/pkg/tcpip/transport/icmp/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
@@ -38,3 +38,22 @@ go_library(
"//pkg/waiter",
],
)
+
+go_test(
+ name = "icmp_x_test",
+ size = "small",
+ srcs = ["icmp_test.go"],
+ deps = [
+ ":icmp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/testutil",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 39f526023..fb77febcf 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -27,8 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// TODO(https://gvisor.dev/issues/5623): Unit test this package.
-
// +stateify savable
type icmpPacket struct {
icmpPacketEntry
@@ -134,7 +132,8 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
+ e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
}
// Close the receive list and drain it.
@@ -305,6 +304,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
+ if nicID == 0 {
+ nicID = tcpip.NICID(e.ops.GetBindToDevice())
+ }
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
return 0, &tcpip.ErrNoRoute{}
@@ -349,6 +351,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return int64(len(v)), nil
}
+var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
+
+// HasNIC implements tcpip.SocketOptionsHandler.
+func (e *endpoint) HasNIC(id int32) bool {
+ return e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt sets a socket option.
func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
return nil
@@ -608,17 +617,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
}
func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
+ err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
switch err.(type) {
case nil:
return true, nil
diff --git a/pkg/tcpip/transport/icmp/icmp_test.go b/pkg/tcpip/transport/icmp/icmp_test.go
new file mode 100644
index 000000000..cc950cbde
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_test.go
@@ -0,0 +1,235 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp_test
+
+import (
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/testutil"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package.
+// See the issue for remaining areas of work.
+
+var (
+ localV4Addr1 = testutil.MustParse4("10.0.0.1")
+ localV4Addr2 = testutil.MustParse4("10.0.0.2")
+ remoteV4Addr = testutil.MustParse4("10.0.0.3")
+)
+
+func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint {
+ t.Helper()
+
+ ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */)
+ t.Cleanup(ep.Close)
+
+ wep := stack.LinkEndpoint(ep)
+ if testing.Verbose() {
+ wep = sniffer.New(ep)
+ }
+
+ opts := stack.NICOptions{Name: name}
+ if err := s.CreateNICWithOptions(id, wep, opts); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
+ }
+
+ if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
+ }
+
+ s.AddRoute(tcpip.Route{
+ Destination: header.IPv4EmptySubnet,
+ NIC: id,
+ })
+
+ return ep
+}
+
+func writePayload(buf []byte) {
+ for i := range buf {
+ buf[i] = byte(i)
+ }
+}
+
+func newICMPv4EchoRequest(payloadSize uint32) buffer.View {
+ buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize))
+ writePayload(buf[header.ICMPv4MinimumSize:])
+
+ icmp := header.ICMPv4(buf)
+ icmp.SetType(header.ICMPv4Echo)
+ // No need to set the checksum; it is reset by the socket before the packet
+ // is sent.
+
+ return buf
+}
+
+// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket
+// when SO_BINDTODEVICE is set to the non-default NIC for that subnet.
+//
+// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to
+// the version of IP.
+func TestWriteUnboundWithBindToDevice(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
+ HandleLocal: true,
+ })
+
+ // Add two NICs, both with default routes on the same subnet. The first NIC
+ // added will be the default NIC for that subnet.
+ defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1)
+ alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2)
+
+ socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err)
+ }
+ defer socket.Close()
+
+ echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize
+
+ // Send a packet without SO_BINDTODEVICE. This verifies that the first NIC
+ // to be added is the default NIC to send packets when not explicitly bound.
+ {
+ buf := newICMPv4EchoRequest(echoPayloadSize)
+ r := buf.Reader()
+ n, err := socket.Write(&r, tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: remoteV4Addr},
+ })
+ if err != nil {
+ t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err)
+ }
+ if n != int64(len(buf)) {
+ t.Fatalf("got n = %d, want n = %d", n, len(buf))
+ }
+
+ // Verify the packet was sent out the default NIC.
+ p, ok := defaultEP.Read()
+ if !ok {
+ t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(t, b, []checker.NetworkChecker{
+ checker.SrcAddr(localV4Addr1),
+ checker.DstAddr(remoteV4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
+ ),
+ }...)
+
+ // Verify the packet was not sent out the alternate NIC.
+ if p, ok := alternateEP.Read(); ok {
+ t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
+ }
+ }
+
+ // Send a packet with SO_BINDTODEVICE. This exercises reliance on
+ // SO_BINDTODEVICE to route the packet to the alternate NIC.
+ {
+ // Use SO_BINDTODEVICE to send over the alternate NIC by default.
+ socket.SocketOptions().SetBindToDevice(2)
+
+ buf := newICMPv4EchoRequest(echoPayloadSize)
+ r := buf.Reader()
+ n, err := socket.Write(&r, tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: remoteV4Addr},
+ })
+ if err != nil {
+ t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
+ }
+ if n != int64(len(buf)) {
+ t.Fatalf("got n = %d, want n = %d", n, len(buf))
+ }
+
+ // Verify the packet was not sent out the default NIC.
+ if p, ok := defaultEP.Read(); ok {
+ t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p)
+ }
+
+ // Verify the packet was sent out the alternate NIC.
+ p, ok := alternateEP.Read()
+ if !ok {
+ t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(t, b, []checker.NetworkChecker{
+ checker.SrcAddr(localV4Addr2),
+ checker.DstAddr(remoteV4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
+ ),
+ }...)
+ }
+
+ // Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing
+ // the device binding will fallback to using the default NIC to send
+ // packets.
+ {
+ socket.SocketOptions().SetBindToDevice(0)
+
+ buf := newICMPv4EchoRequest(echoPayloadSize)
+ r := buf.Reader()
+ n, err := socket.Write(&r, tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: remoteV4Addr},
+ })
+ if err != nil {
+ t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
+ }
+ if n != int64(len(buf)) {
+ t.Fatalf("got n = %d, want n = %d", n, len(buf))
+ }
+
+ // Verify the packet was sent out the default NIC.
+ p, ok := defaultEP.Read()
+ if !ok {
+ t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
+ }
+
+ vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
+ b := vv.ToView()
+
+ checker.IPv4(t, b, []checker.NetworkChecker{
+ checker.SrcAddr(localV4Addr1),
+ checker.DstAddr(remoteV4Addr),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo),
+ checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
+ ),
+ }...)
+
+ // Verify the packet was not sent out the alternate NIC.
+ if p, ok := alternateEP.Read(); ok {
+ t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index b964e446b..def9d7186 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -54,7 +54,7 @@ const (
StateClosed
)
-// String implements fmt.Stringer.String.
+// String implements fmt.Stringer.
func (s EndpointState) String() string {
switch s {
case StateInitial:
@@ -214,7 +214,7 @@ func (e *endpoint) EndpointState() EndpointState {
return EndpointState(atomic.LoadUint32(&e.state))
}
-// UniqueID implements stack.TransportEndpoint.UniqueID.
+// UniqueID implements stack.TransportEndpoint.
func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
@@ -228,14 +228,14 @@ func (e *endpoint) LastError() tcpip.Error {
return err
}
-// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+// UpdateLastError implements tcpip.SocketOptionsHandler.
func (e *endpoint) UpdateLastError(err tcpip.Error) {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
}
-// Abort implements stack.TransportEndpoint.Abort.
+// Abort implements stack.TransportEndpoint.
func (e *endpoint) Abort() {
e.Close()
}
@@ -291,10 +291,10 @@ func (e *endpoint) Close() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
+// ModerateRecvBuf implements tcpip.Endpoint.
func (*endpoint) ModerateRecvBuf(int) {}
-// Read implements tcpip.Endpoint.Read.
+// Read implements tcpip.Endpoint.
func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, tcpip.Error) {
if err := e.LastError(); err != nil {
return tcpip.ReadResult{}, err
@@ -583,21 +583,21 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return int64(len(v)), nil
}
-// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
+// OnReuseAddressSet implements tcpip.SocketOptionsHandler.
func (e *endpoint) OnReuseAddressSet(v bool) {
e.mu.Lock()
e.portFlags.MostRecent = v
e.mu.Unlock()
}
-// OnReusePortSet implements tcpip.SocketOptionsHandler.OnReusePortSet.
+// OnReusePortSet implements tcpip.SocketOptionsHandler.
func (e *endpoint) OnReusePortSet(v bool) {
e.mu.Lock()
e.portFlags.LoadBalanced = v
e.mu.Unlock()
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+// SetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
switch opt {
case tcpip.MTUDiscoverOption:
@@ -631,11 +631,14 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
return nil
}
+var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)
+
+// HasNIC implements tcpip.SocketOptionsHandler.
func (e *endpoint) HasNIC(id int32) bool {
- return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+ return e.stack.HasNIC(tcpip.NICID(id))
}
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+// SetSockOpt implements tcpip.Endpoint.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
switch v := opt.(type) {
case *tcpip.MulticastInterfaceOption:
@@ -751,7 +754,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
return nil
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+// GetSockOptInt implements tcpip.Endpoint.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
switch opt {
case tcpip.IPv4TOSOption:
@@ -797,7 +800,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
}
}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+// GetSockOpt implements tcpip.Endpoint.
func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error {
switch o := opt.(type) {
case *tcpip.MulticastInterfaceOption:
@@ -874,7 +877,7 @@ func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddres
return unwrapped, netProto, nil
}
-// Disconnect implements tcpip.Endpoint.Disconnect.
+// Disconnect implements tcpip.Endpoint.
func (e *endpoint) Disconnect() tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -1388,7 +1391,7 @@ func (e *endpoint) HandleError(transErr stack.TransportError, pkt *stack.PacketB
}
}
-// State implements tcpip.Endpoint.State.
+// State implements tcpip.Endpoint.
func (e *endpoint) State() uint32 {
return uint32(e.EndpointState())
}
@@ -1407,19 +1410,19 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
return &e.stats
}
-// Wait implements tcpip.Endpoint.Wait.
+// Wait implements tcpip.Endpoint.
func (*endpoint) Wait() {}
func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, addr tcpip.Address) bool {
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || e.stack.IsSubnetBroadcast(nicID, netProto, addr)
}
-// SetOwner implements tcpip.Endpoint.SetOwner.
+// SetOwner implements tcpip.Endpoint.
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// SocketOptions implements tcpip.Endpoint.SocketOptions.
+// SocketOptions implements tcpip.Endpoint.
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
return &e.ops
}