summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
authorKevin Krakauer <krakauer@google.com>2020-01-13 16:06:29 -0800
committerKevin Krakauer <krakauer@google.com>2020-01-13 16:06:29 -0800
commitd51eaa59c020cca9b7bc27cec0338ead089f3ed6 (patch)
tree3b41776af9426496567573ed17698562daf39006 /pkg/tcpip/transport/udp
parentd793677cd424fef10ac0b080871d181db0bcdec0 (diff)
parent1c3d3c70b93d483894dd49fb444171347f0ca250 (diff)
Merge branch 'iptables-write-input-drop' into iptables-write-filter-proto
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go92
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go48
3 files changed, 72 insertions, 69 deletions
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 97e4d5825..57ff123e3 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -30,6 +30,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/sleep",
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1ac4705af..a4ff29a7d 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -15,8 +15,7 @@
package udp
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
@@ -456,14 +455,9 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
- return nil
-}
-
-// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.NetProto != header.IPv6ProtocolNumber {
@@ -478,8 +472,20 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- e.v6only = v != 0
+ e.v6only = v
+ }
+
+ return nil
+}
+
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return nil
+}
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
case tcpip.TTLOption:
e.mu.Lock()
e.ttl = uint8(v)
@@ -624,19 +630,14 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.mu.Unlock()
case tcpip.BindToDeviceOption:
- e.mu.Lock()
- defer e.mu.Unlock()
- if v == "" {
- e.bindToDevice = 0
- return nil
- }
- for nicID, nic := range e.stack.NICInfo() {
- if nic.Name == string(v) {
- e.bindToDevice = nicID
- return nil
- }
+ id := tcpip.NICID(v)
+ if id != 0 && !e.stack.HasNIC(id) {
+ return tcpip.ErrUnknownDevice
}
- return tcpip.ErrUnknownDevice
+ e.mu.Lock()
+ e.bindToDevice = id
+ e.mu.Unlock()
+ return nil
case tcpip.BroadcastOption:
e.mu.Lock()
@@ -660,8 +661,27 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.NetProto != header.IPv6ProtocolNumber {
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ return v, nil
+ }
+
+ return false, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
v := 0
@@ -695,22 +715,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.V6OnlyOption:
- // We only recognize this option on v6 endpoints.
- if e.NetProto != header.IPv6ProtocolNumber {
- return tcpip.ErrUnknownProtocolOption
- }
-
- e.mu.Lock()
- v := e.v6only
- e.mu.Unlock()
-
- *o = 0
- if v {
- *o = 1
- }
- return nil
-
case *tcpip.TTLOption:
e.mu.Lock()
*o = tcpip.TTLOption(e.ttl)
@@ -757,12 +761,8 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case *tcpip.BindToDeviceOption:
e.mu.RLock()
- defer e.mu.RUnlock()
- if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
- *o = tcpip.BindToDeviceOption(nic.Name)
- return nil
- }
- *o = tcpip.BindToDeviceOption("")
+ *o = tcpip.BindToDeviceOption(e.bindToDevice)
+ e.mu.RUnlock()
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 7051a7a9c..d33507156 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -335,7 +335,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) {
c.createEndpoint(flow.sockProto())
if flow.isV6Only() {
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
} else if flow.isBroadcast() {
@@ -508,46 +508,42 @@ func TestBindToDeviceOption(t *testing.T) {
}
defer ep.Close()
- if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
- t.Errorf("CreateNamedNIC failed: %v", err)
+ opts := stack.NICOptions{Name: "my_device"}
+ if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
+ t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
}
- // Make an nameless NIC.
- if err := s.CreateNIC(54321, loopback.New()); err != nil {
- t.Errorf("CreateNIC failed: %v", err)
- }
-
- // strPtr is used instead of taking the address of string literals, which is
+ // nicIDPtr is used instead of taking the address of NICID literals, which is
// a compiler error.
- strPtr := func(s string) *string {
+ nicIDPtr := func(s tcpip.NICID) *tcpip.NICID {
return &s
}
testActions := []struct {
name string
- setBindToDevice *string
+ setBindToDevice *tcpip.NICID
setBindToDeviceError *tcpip.Error
getBindToDevice tcpip.BindToDeviceOption
}{
- {"GetDefaultValue", nil, nil, ""},
- {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
- {"BindToExistent", strPtr("my_device"), nil, "my_device"},
- {"UnbindToDevice", strPtr(""), nil, ""},
+ {"GetDefaultValue", nil, nil, 0},
+ {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0},
+ {"BindToExistent", nicIDPtr(321), nil, 321},
+ {"UnbindToDevice", nicIDPtr(0), nil, 0},
}
for _, testAction := range testActions {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
}
}
- bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
- if ep.GetSockOpt(&bindToDevice) != nil {
- t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ bindToDevice := tcpip.BindToDeviceOption(88888)
+ if err := ep.GetSockOpt(&bindToDevice); err != nil {
+ t.Errorf("GetSockOpt got %v, want %v", err, nil)
}
if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %q, want %q", got, want)
+ t.Errorf("bindToDevice got %d, want %d", got, want)
}
})
}
@@ -1232,7 +1228,10 @@ func TestTTL(t *testing.T) {
} else {
p = ipv6.NewProtocol()
}
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ }))
if err != nil {
t.Fatal(err)
}
@@ -1265,7 +1264,10 @@ func TestSetTTL(t *testing.T) {
} else {
p = ipv6.NewProtocol()
}
- ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil)
+ ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ }))
if err != nil {
t.Fatal(err)
}