summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2020-08-28 11:47:58 -0700
committergVisor bot <gvisor-bot@google.com>2020-08-28 11:50:17 -0700
commitbdd5996a73b14d6f6600ab7aa00cdaed459cab16 (patch)
tree21f6e20395347d83f8deec04dba3e833eff3a3fa /pkg/tcpip/stack
parent8b9cb36d1c74f71da5bc70b73330291f1df298ad (diff)
Improve type safety for network protocol options
The existing implementation for NetworkProtocol.{Set}Option take arguments of an empty interface type which all types (implicitly) implement; any type may be passed to the functions. This change introduces marker interfaces for network protocol options that may be set or queried which network protocol option types implement to ensure that invalid types are caught at compile time. Different interfaces are used to allow the compiler to enforce read-only or set-only socket options. PiperOrigin-RevId: 328980359
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/forwarder_test.go10
-rw-r--r--pkg/tcpip/stack/nic_test.go4
-rw-r--r--pkg/tcpip/stack/registration.go4
-rw-r--r--pkg/tcpip/stack/stack.go4
-rw-r--r--pkg/tcpip/stack/stack_test.go68
5 files changed, 28 insertions, 62 deletions
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index 91165ebc7..54759091a 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -154,17 +154,17 @@ func (f *fwdTestNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCac
}
}
-func (f *fwdTestNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+func (*fwdTestNetworkProtocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func (f *fwdTestNetworkProtocol) Option(option interface{}) *tcpip.Error {
+func (*fwdTestNetworkProtocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func (f *fwdTestNetworkProtocol) Close() {}
+func (*fwdTestNetworkProtocol) Close() {}
-func (f *fwdTestNetworkProtocol) Wait() {}
+func (*fwdTestNetworkProtocol) Wait() {}
func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
if f.onLinkAddressResolved != nil {
@@ -182,7 +182,7 @@ func (f *fwdTestNetworkProtocol) ResolveStaticAddress(addr tcpip.Address) (tcpip
return "", false
}
-func (f *fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+func (*fwdTestNetworkProtocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index 1e065b5c1..dd6474297 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -201,12 +201,12 @@ func (p *testIPv6Protocol) NewEndpoint(nicID tcpip.NICID, _ LinkAddressCache, _
}
// SetOption implements NetworkProtocol.SetOption.
-func (*testIPv6Protocol) SetOption(interface{}) *tcpip.Error {
+func (*testIPv6Protocol) SetOption(tcpip.SettableNetworkProtocolOption) *tcpip.Error {
return nil
}
// Option implements NetworkProtocol.Option.
-func (*testIPv6Protocol) Option(interface{}) *tcpip.Error {
+func (*testIPv6Protocol) Option(tcpip.GettableNetworkProtocolOption) *tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 21ac38583..2d88fa1f7 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -303,12 +303,12 @@ type NetworkProtocol interface {
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
- SetOption(option interface{}) *tcpip.Error
+ SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
- Option(option interface{}) *tcpip.Error
+ Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error
// Close requests that any worker goroutines owned by the protocol
// stop.
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 7f5ed9e83..c86ee1c13 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -785,7 +785,7 @@ func (s *Stack) UniqueID() uint64 {
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
-func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
@@ -802,7 +802,7 @@ func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, op
// if err != nil {
// ...
// }
-func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 1deeccb89..60b54c244 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -158,23 +158,13 @@ func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack
func (*fakeNetworkEndpoint) Close() {}
-type fakeNetGoodOption bool
-
-type fakeNetBadOption bool
-
-type fakeNetInvalidValueOption int
-
-type fakeNetOptions struct {
- good bool
-}
-
// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
// number of packets sent and received via endpoints of this protocol. The index
// where packets are added is given by the packet's destination address MOD 10.
type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
- opts fakeNetOptions
+ defaultTTL uint8
}
func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
@@ -206,22 +196,20 @@ func (f *fakeNetworkProtocol) NewEndpoint(nicID tcpip.NICID, _ stack.LinkAddress
}
}
-func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+func (f *fakeNetworkProtocol) SetOption(option tcpip.SettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case fakeNetGoodOption:
- f.opts.good = bool(v)
+ case *tcpip.DefaultTTLOption:
+ f.defaultTTL = uint8(*v)
return nil
- case fakeNetInvalidValueOption:
- return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
}
-func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
+func (f *fakeNetworkProtocol) Option(option tcpip.GettableNetworkProtocolOption) *tcpip.Error {
switch v := option.(type) {
- case *fakeNetGoodOption:
- *v = fakeNetGoodOption(f.opts.good)
+ case *tcpip.DefaultTTLOption:
+ *v = tcpip.DefaultTTLOption(f.defaultTTL)
return nil
default:
return tcpip.ErrUnknownProtocolOption
@@ -1640,46 +1628,24 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
}
}
-func TestNetworkOptions(t *testing.T) {
+func TestNetworkOption(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
TransportProtocols: []stack.TransportProtocol{},
})
- // Try an unsupported network protocol.
- if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
- t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
+ opt := tcpip.DefaultTTLOption(5)
+ if err := s.SetNetworkProtocolOption(fakeNetNumber, &opt); err != nil {
+ t.Fatalf("s.SetNetworkProtocolOption(%d, &%T(%d)): %s", fakeNetNumber, opt, opt, err)
}
- testCases := []struct {
- option interface{}
- wantErr *tcpip.Error
- verifier func(t *testing.T, p stack.NetworkProtocol)
- }{
- {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
- t.Helper()
- fakeNet := p.(*fakeNetworkProtocol)
- if fakeNet.opts.good != true {
- t.Fatalf("fakeNet.opts.good = false, want = true")
- }
- var v fakeNetGoodOption
- if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
- }
- if v != true {
- t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
- }
- }},
- {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
- {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
+ var optGot tcpip.DefaultTTLOption
+ if err := s.NetworkProtocolOption(fakeNetNumber, &optGot); err != nil {
+ t.Fatalf("s.NetworkProtocolOption(%d, &%T): %s", fakeNetNumber, optGot, err)
}
- for _, tc := range testCases {
- if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
- t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
- }
- if tc.verifier != nil {
- tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
- }
+
+ if opt != optGot {
+ t.Errorf("got optGot = %d, want = %d", optGot, opt)
}
}