summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/netstack.go29
-rw-r--r--pkg/tcpip/stack/stack.go8
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go89
-rw-r--r--pkg/tcpip/tcpip.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go27
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go42
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go27
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go31
8 files changed, 127 insertions, 128 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9e0d69046..764f11a6b 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -985,13 +985,23 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- if len(v) == 0 {
+ if v == 0 {
return []byte{}, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
}
- return append([]byte(v), 0), nil
+ s := t.NetworkContext()
+ if s == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ nic, ok := s.Interfaces()[int32(v)]
+ if !ok {
+ // The NICID no longer indicates a valid interface, probably because that
+ // interface was removed.
+ return nil, syserr.ErrUnknownDevice
+ }
+ return append([]byte(nic.Name), 0), nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1438,7 +1448,20 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
if n == -1 {
n = len(optVal)
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+ name := string(optVal[:n])
+ if name == "" {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return syserr.ErrNoDevice
+ }
+ for nicID, nic := range s.Interfaces() {
+ if nic.Name == name {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ }
+ }
+ return syserr.ErrUnknownDevice
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index e2a2edb2c..41bf9fd9b 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -901,6 +901,14 @@ type NICInfo struct {
Context NICContext
}
+// HasNIC returns true if the NICID is defined in the stack.
+func (s *Stack) HasNIC(id tcpip.NICID) bool {
+ s.mu.RLock()
+ _, ok := s.nics[id]
+ s.mu.RUnlock()
+ return ok
+}
+
// NICInfo returns a map of NICIDs to their associated information.
func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
s.mu.RLock()
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index df5ced887..5e9237de9 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -41,7 +41,7 @@ const (
type testContext struct {
t *testing.T
- linkEPs map[string]*channel.Endpoint
+ linkEps map[tcpip.NICID]*channel.Endpoint
s *stack.Stack
ep tcpip.Endpoint
@@ -66,27 +66,24 @@ func (c *testContext) createV6Endpoint(v6only bool) {
}
}
-// newDualTestContextMultiNic creates the testing context and also linkEpNames
-// named NICs.
-func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+// newDualTestContextMultiNIC creates the testing context and also linkEpIDs NICs.
+func newDualTestContextMultiNIC(t *testing.T, mtu uint32, linkEpIDs []tcpip.NICID) *testContext {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- linkEPs := make(map[string]*channel.Endpoint)
- for i, linkEpName := range linkEpNames {
- channelEP := channel.New(256, mtu, "")
- nicID := tcpip.NICID(i + 1)
- opts := stack.NICOptions{Name: linkEpName}
- if err := s.CreateNICWithOptions(nicID, channelEP, opts); err != nil {
- t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
+ linkEps := make(map[tcpip.NICID]*channel.Endpoint)
+ for _, linkEpID := range linkEpIDs {
+ channelEp := channel.New(256, mtu, "")
+ if err := s.CreateNIC(linkEpID, channelEp); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
}
- linkEPs[linkEpName] = channelEP
+ linkEps[linkEpID] = channelEp
- if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress IPv4 failed: %v", err)
}
- if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ if err := s.AddAddress(linkEpID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
t.Fatalf("AddAddress IPv6 failed: %v", err)
}
}
@@ -105,7 +102,7 @@ func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string)
return &testContext{
t: t,
s: s,
- linkEPs: linkEPs,
+ linkEps: linkEps,
}
}
@@ -122,7 +119,7 @@ func newPayload() []byte {
return b
}
-func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpID tcpip.NICID) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -153,7 +150,7 @@ func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string
u.SetChecksum(^u.CalculateChecksum(xsum))
// Inject packet.
- c.linkEPs[linkEpName].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ c.linkEps[linkEpID].InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
Data: buf.ToVectorisedView(),
})
}
@@ -183,7 +180,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
func TestDistribution(t *testing.T) {
type endpointSockopts struct {
reuse int
- bindToDevice string
+ bindToDevice tcpip.NICID
}
for _, test := range []struct {
name string
@@ -191,71 +188,71 @@ func TestDistribution(t *testing.T) {
endpoints []endpointSockopts
// wantedDistribution is the wanted ratio of packets received on each
// endpoint for each NIC on which packets are injected.
- wantedDistributions map[string][]float64
+ wantedDistributions map[tcpip.NICID][]float64
}{
{
"BindPortReuse",
// 5 endpoints that all have reuse set.
[]endpointSockopts{
- {1, ""},
- {1, ""},
- {1, ""},
- {1, ""},
- {1, ""},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed evenly.
- "dev0": {0.2, 0.2, 0.2, 0.2, 0.2},
+ 1: {0.2, 0.2, 0.2, 0.2, 0.2},
},
},
{
"BindToDevice",
// 3 endpoints with various bindings.
[]endpointSockopts{
- {0, "dev0"},
- {0, "dev1"},
- {0, "dev2"},
+ {0, 1},
+ {0, 2},
+ {0, 3},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 go only to the endpoint bound to dev0.
- "dev0": {1, 0, 0},
+ 1: {1, 0, 0},
// Injected packets on dev1 go only to the endpoint bound to dev1.
- "dev1": {0, 1, 0},
+ 2: {0, 1, 0},
// Injected packets on dev2 go only to the endpoint bound to dev2.
- "dev2": {0, 0, 1},
+ 3: {0, 0, 1},
},
},
{
"ReuseAndBindToDevice",
// 6 endpoints with various bindings.
[]endpointSockopts{
- {1, "dev0"},
- {1, "dev0"},
- {1, "dev1"},
- {1, "dev1"},
- {1, "dev1"},
- {1, ""},
+ {1, 1},
+ {1, 1},
+ {1, 2},
+ {1, 2},
+ {1, 2},
+ {1, 0},
},
- map[string][]float64{
+ map[tcpip.NICID][]float64{
// Injected packets on dev0 get distributed among endpoints bound to
// dev0.
- "dev0": {0.5, 0.5, 0, 0, 0, 0},
+ 1: {0.5, 0.5, 0, 0, 0, 0},
// Injected packets on dev1 get distributed among endpoints bound to
// dev1 or unbound.
- "dev1": {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ 2: {0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
// Injected packets on dev999 go only to the unbound.
- "dev999": {0, 0, 0, 0, 0, 1},
+ 1000: {0, 0, 0, 0, 0, 1},
},
},
} {
t.Run(test.name, func(t *testing.T) {
for device, wantedDistribution := range test.wantedDistributions {
- t.Run(device, func(t *testing.T) {
- var devices []string
+ t.Run(string(device), func(t *testing.T) {
+ var devices []tcpip.NICID
for d := range test.wantedDistributions {
devices = append(devices, d)
}
- c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ c := newDualTestContextMultiNIC(t, defaultMTU, devices)
defer c.cleanup()
c.createV6Endpoint(false)
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 1eca76c30..72b5ce179 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -552,7 +552,7 @@ type ReusePortOption int
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
-type BindToDeviceOption string
+type BindToDeviceOption NICID
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 2ac1b6877..920b24975 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1279,19 +1279,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.QuickAckOption:
if v == 0 {
@@ -1550,12 +1545,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = ""
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.QuickAckOption:
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 15745ebd4..1aa0733d0 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -1083,12 +1083,12 @@ func TestTrafficClassV6(t *testing.T) {
func TestConnectBindToDevice(t *testing.T) {
for _, test := range []struct {
name string
- device string
+ device tcpip.NICID
want tcp.EndpointState
}{
- {"RightDevice", "nic1", tcp.StateEstablished},
- {"WrongDevice", "nic2", tcp.StateSynSent},
- {"AnyDevice", "", tcp.StateEstablished},
+ {"RightDevice", 1, tcp.StateEstablished},
+ {"WrongDevice", 2, tcp.StateSynSent},
+ {"AnyDevice", 0, tcp.StateEstablished},
} {
t.Run(test.name, func(t *testing.T) {
c := context.New(t, defaultMTU)
@@ -3794,47 +3794,41 @@ func TestBindToDeviceOption(t *testing.T) {
}
defer ep.Close()
- opts := stack.NICOptions{Name: "my_device"}
- if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
- t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
- }
-
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ if err := s.CreateNIC(321, loopback.New()); err != nil {
t.Errorf("CreateNIC failed: %v", err)
}
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1a5ee6317..864dc8733 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -631,19 +631,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.BroadcastOption:
e.mu.Lock()
@@ -767,12 +762,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = tcpip.BindToDeviceOption("")
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 149fff999..0a82bc4fa 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -513,42 +513,37 @@ func TestBindToDeviceOption(t *testing.T) {
t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
- }
-
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}