summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-08-27 15:45:02 -0700
committergVisor bot <gvisor-bot@google.com>2020-08-27 15:46:44 -0700
commit6f8fb7e0db2790ff1f5ba835780c03fe245e437f (patch)
treeac9514f801e358e5397fe983053705e1320ccf7b /pkg/tcpip/transport/udp
parent29d528399cf93ea2f27afe245738ae13ed4deace (diff)
Improve type safety for socket options
The existing implementation for {G,S}etSockOpt take arguments of an empty interface type which all types (implicitly) implement; any type may be passed to the functions. This change introduces marker interfaces for socket options that may be set or queried which socket option types implement to ensure that invalid types are caught at compile time. Different interfaces are used to allow the compiler to enforce read-only or set-only socket options. Fixes #3714. RELNOTES: n/a PiperOrigin-RevId: 328832161
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go16
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go45
2 files changed, 29 insertions, 32 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 1d5ebe3f2..c74bc4d94 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -683,9 +683,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
- case tcpip.MulticastInterfaceOption:
+ case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
defer e.mu.Unlock()
@@ -721,7 +721,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastNICID = nic
e.multicastAddr = addr
- case tcpip.AddMembershipOption:
+ case *tcpip.AddMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
return tcpip.ErrInvalidOptionValue
}
@@ -764,7 +764,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastMemberships = append(e.multicastMemberships, memToInsert)
- case tcpip.RemoveMembershipOption:
+ case *tcpip.RemoveMembershipOption:
if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
return tcpip.ErrInvalidOptionValue
}
@@ -808,8 +808,8 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
- case tcpip.BindToDeviceOption:
- id := tcpip.NICID(v)
+ case *tcpip.BindToDeviceOption:
+ id := tcpip.NICID(*v)
if id != 0 && !e.stack.HasNIC(id) {
return tcpip.ErrUnknownDevice
}
@@ -817,7 +817,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.bindToDevice = id
e.mu.Unlock()
- case tcpip.SocketDetachFilterOption:
+ case *tcpip.SocketDetachFilterOption:
return nil
}
return nil
@@ -960,7 +960,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
switch o := opt.(type) {
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index bd1c8ac31..0cbc045d8 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -539,7 +539,7 @@ func TestBindToDeviceOption(t *testing.T) {
opts := stack.NICOptions{Name: "my_device"}
if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil {
- t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err)
+ t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %s", opts, err)
}
// nicIDPtr is used instead of taking the address of NICID literals, which is
@@ -563,16 +563,15 @@ func TestBindToDeviceOption(t *testing.T) {
t.Run(testAction.name, func(t *testing.T) {
if testAction.setBindToDevice != nil {
bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
- if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
- t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr)
+ if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr {
+ t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr)
}
}
bindToDevice := tcpip.BindToDeviceOption(88888)
if err := ep.GetSockOpt(&bindToDevice); err != nil {
- t.Errorf("GetSockOpt got %v, want %v", err, nil)
- }
- if got, want := bindToDevice, testAction.getBindToDevice; got != want {
- t.Errorf("bindToDevice got %d, want %d", got, want)
+ t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err)
+ } else if bindToDevice != testAction.getBindToDevice {
+ t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice)
}
})
}
@@ -628,12 +627,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Check the peer address.
h := flow.header4Tuple(incoming)
if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr)
+ c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
- c.t.Fatalf("bad payload: got %x, want %x", v, payload)
+ c.t.Fatalf("got payload = %x, want = %x", v, payload)
}
// Run any checkers against the ControlMessages.
@@ -694,7 +693,7 @@ func TestBindReservedPort(t *testing.T) {
}
defer ep.Close()
if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
}
}
@@ -707,7 +706,7 @@ func TestBindReservedPort(t *testing.T) {
// We can't bind ipv4-any on the port reserved by the connected endpoint
// above, since the endpoint is dual-stack.
if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want {
- t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
+ t.Fatalf("got ep.Bind(...) = %s, want = %s", got, want)
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil {
@@ -830,7 +829,7 @@ func TestV4ReadSelfSource(t *testing.T) {
}
if _, _, err := c.ep.Read(nil); err != tt.wantErr {
- t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr)
+ t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr)
}
})
}
@@ -871,8 +870,8 @@ func TestReadOnBoundToMulticast(t *testing.T) {
// Join multicast group.
ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr}
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatal("SetSockOpt failed:", err)
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
}
// Check that we receive multicast packets but not unicast or broadcast
@@ -1403,8 +1402,8 @@ func TestReadIPPacketInfo(t *testing.T) {
if test.flow.isMulticast() {
ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()}
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt(%+v): %s:", ifoptSet, err)
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s:", ifoptSet, err)
}
}
@@ -1547,7 +1546,7 @@ func TestSetTOS(t *testing.T) {
}
// Test for expected default value.
if v != 0 {
- c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
+ c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0)
}
if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil {
@@ -1708,19 +1707,17 @@ func TestMulticastInterfaceOption(t *testing.T) {
}
}
- if err := c.ep.SetSockOpt(ifoptSet); err != nil {
- c.t.Fatalf("SetSockOpt failed: %s", err)
+ if err := c.ep.SetSockOpt(&ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt(&%#v): %s", ifoptSet, err)
}
// Verify multicast interface addr and NIC were set correctly.
// Note that NIC must be 1 since this is our outgoing interface.
- ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
var ifoptGot tcpip.MulticastInterfaceOption
if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
- c.t.Fatalf("GetSockOpt failed: %s", err)
- }
- if ifoptGot != ifoptWant {
- c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
+ c.t.Fatalf("GetSockOpt(&%T): %s", ifoptGot, err)
+ } else if ifoptWant := (tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}); ifoptGot != ifoptWant {
+ c.t.Errorf("got multicast interface option = %#v, want = %#v", ifoptGot, ifoptWant)
}
})
}