From 6f8fb7e0db2790ff1f5ba835780c03fe245e437f Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 27 Aug 2020 15:45:02 -0700 Subject: Improve type safety for socket options The existing implementation for {G,S}etSockOpt take arguments of an empty interface type which all types (implicitly) implement; any type may be passed to the functions. This change introduces marker interfaces for socket options that may be set or queried which socket option types implement to ensure that invalid types are caught at compile time. Different interfaces are used to allow the compiler to enforce read-only or set-only socket options. Fixes #3714. RELNOTES: n/a PiperOrigin-RevId: 328832161 --- pkg/tcpip/transport/tcp/endpoint.go | 58 +++++++++++----------- pkg/tcpip/transport/tcp/tcp_test.go | 95 +++++++++++++++++++++++++------------ 2 files changed, 93 insertions(+), 60 deletions(-) (limited to 'pkg/tcpip/transport/tcp') diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8a5e993b5..c5d9eba5d 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1736,10 +1736,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // SetSockOpt sets a socket option. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { - case tcpip.BindToDeviceOption: - id := tcpip.NICID(v) + case *tcpip.BindToDeviceOption: + id := tcpip.NICID(*v) if id != 0 && !e.stack.HasNIC(id) { return tcpip.ErrUnknownDevice } @@ -1747,27 +1747,27 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.bindToDevice = id e.UnlockUser() - case tcpip.KeepaliveIdleOption: + case *tcpip.KeepaliveIdleOption: e.keepalive.Lock() - e.keepalive.idle = time.Duration(v) + e.keepalive.idle = time.Duration(*v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case tcpip.KeepaliveIntervalOption: + case *tcpip.KeepaliveIntervalOption: e.keepalive.Lock() - e.keepalive.interval = time.Duration(v) + e.keepalive.interval = time.Duration(*v) e.keepalive.Unlock() e.notifyProtocolGoroutine(notifyKeepaliveChanged) - case tcpip.OutOfBandInlineOption: + case *tcpip.OutOfBandInlineOption: // We don't currently support disabling this option. - case tcpip.TCPUserTimeoutOption: + case *tcpip.TCPUserTimeoutOption: e.LockUser() - e.userTimeout = time.Duration(v) + e.userTimeout = time.Duration(*v) e.UnlockUser() - case tcpip.CongestionControlOption: + case *tcpip.CongestionControlOption: // Query the available cc algorithms in the stack and // validate that the specified algorithm is actually // supported in the stack. @@ -1777,10 +1777,10 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } availCC := strings.Split(string(avail), " ") for _, cc := range availCC { - if v == tcpip.CongestionControlOption(cc) { + if *v == tcpip.CongestionControlOption(cc) { e.LockUser() state := e.EndpointState() - e.cc = v + e.cc = *v switch state { case StateEstablished: if e.EndpointState() == state { @@ -1796,43 +1796,43 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // control algorithm is specified. return tcpip.ErrNoSuchFile - case tcpip.TCPLingerTimeoutOption: + case *tcpip.TCPLingerTimeoutOption: e.LockUser() switch { - case v < 0: + case *v < 0: // Same as effectively disabling TCPLinger timeout. - v = -1 - case v == 0: + *v = -1 + case *v == 0: // Same as the stack default. var stackLingerTimeout tcpip.TCPLingerTimeoutOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &stackLingerTimeout); err != nil { panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %v", ProtocolNumber, &stackLingerTimeout, err)) } - v = stackLingerTimeout - case v > tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout): + *v = stackLingerTimeout + case *v > tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout): // Cap it to Stack's default TCP_LINGER2 timeout. - v = tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) + *v = tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) default: } - e.tcpLingerTimeout = time.Duration(v) + e.tcpLingerTimeout = time.Duration(*v) e.UnlockUser() - case tcpip.TCPDeferAcceptOption: + case *tcpip.TCPDeferAcceptOption: e.LockUser() - if time.Duration(v) > MaxRTO { - v = tcpip.TCPDeferAcceptOption(MaxRTO) + if time.Duration(*v) > MaxRTO { + *v = tcpip.TCPDeferAcceptOption(MaxRTO) } - e.deferAccept = time.Duration(v) + e.deferAccept = time.Duration(*v) e.UnlockUser() - case tcpip.SocketDetachFilterOption: + case *tcpip.SocketDetachFilterOption: return nil - case tcpip.LingerOption: + case *tcpip.LingerOption: e.LockUser() - e.linger = v + e.linger = *v e.UnlockUser() default: @@ -1993,7 +1993,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch o := opt.(type) { case *tcpip.BindToDeviceOption: e.LockUser() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 3d3034d50..adb32e428 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -1349,7 +1349,9 @@ func TestConnectBindToDevice(t *testing.T) { c.Create(-1) bindToDevice := tcpip.BindToDeviceOption(test.device) - c.EP.SetSockOpt(bindToDevice) + if err := c.EP.SetSockOpt(&bindToDevice); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%d)): %s", bindToDevice, bindToDevice, err) + } // Start connection attempt. waitEntry, _ := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) @@ -4321,16 +4323,15 @@ func TestBindToDeviceOption(t *testing.T) { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { - t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr) + if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } bindToDevice := tcpip.BindToDeviceOption(88888) if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt got %s, want %v", err, nil) - } - if got, want := bindToDevice, testAction.getBindToDevice; got != want { - t.Errorf("bindToDevice got %d, want %d", got, want) + t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) + } else if bindToDevice != testAction.getBindToDevice { + t.Errorf("got bindToDevice = %d, want %d", bindToDevice, testAction.getBindToDevice) } }) } @@ -4806,20 +4807,20 @@ func TestEndpointSetCongestionControl(t *testing.T) { var oldCC tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&oldCC); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err) + t.Fatalf("c.EP.GetSockOpt(&%T) = %s", oldCC, err) } if connected { c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) } - if err := c.EP.SetSockOpt(tc.cc); err != tc.err { - t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err) + if err := c.EP.SetSockOpt(&tc.cc); err != tc.err { + t.Fatalf("got c.EP.SetSockOpt(&%#v) = %s, want %s", tc.cc, err, tc.err) } var cc tcpip.CongestionControlOption if err := c.EP.GetSockOpt(&cc); err != nil { - t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err) + t.Fatalf("c.EP.GetSockOpt(&%T): %s", cc, err) } got, want := cc, oldCC @@ -4831,7 +4832,7 @@ func TestEndpointSetCongestionControl(t *testing.T) { want = tc.cc } if got != want { - t.Fatalf("got congestion control: %v, want: %v", got, want) + t.Fatalf("got congestion control = %+v, want = %+v", got, want) } }) } @@ -4852,11 +4853,23 @@ func TestKeepalive(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + const keepAliveIdle = 100 * time.Millisecond const keepAliveInterval = 3 * time.Second - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + keepAliveIdleOpt := tcpip.KeepaliveIdleOption(keepAliveIdle) + if err := c.EP.SetSockOpt(&keepAliveIdleOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOpt, keepAliveIdle, err) + } + keepAliveIntervalOpt := tcpip.KeepaliveIntervalOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&keepAliveIntervalOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOpt, keepAliveInterval, err) + } c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) - c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5); err != nil { + t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5): %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { + t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) + } // 5 unacked keepalives are sent. ACK each one, and check that the // connection stays alive after 5. @@ -6216,15 +6229,17 @@ func TestTCPLingerTimeout(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { - t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + v := tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout) + if err := c.EP.SetSockOpt(&v); err != nil { + t.Fatalf("SetSockOpt(&%T(%s)) = %s", v, tc.tcpLingerTimeout, err) } - var v tcpip.TCPLingerTimeoutOption + + v = 0 if err := c.EP.GetSockOpt(&v); err != nil { - t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + t.Fatalf("GetSockOpt(&%T) = %s", v, err) } if got, want := time.Duration(v), tc.want; got != want { - t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + t.Fatalf("got linger timeout = %s, want = %s", got, want) } }) } @@ -6941,7 +6956,10 @@ func TestTCPUserTimeout(t *testing.T) { // expired. initRTO := 1 * time.Second userTimeout := initRTO / 2 - c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + v := tcpip.TCPUserTimeoutOption(userTimeout) + if err := c.EP.SetSockOpt(&v); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s): %s", v, userTimeout, err) + } // Send some data and wait before ACKing it. view := buffer.NewView(3) @@ -7015,18 +7033,31 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + const keepAliveIdle = 100 * time.Millisecond const keepAliveInterval = 3 * time.Second - c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) - c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) - c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) - c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + keepAliveIdleOption := tcpip.KeepaliveIdleOption(keepAliveIdle) + if err := c.EP.SetSockOpt(&keepAliveIdleOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIdleOption, keepAliveIdle, err) + } + keepAliveIntervalOption := tcpip.KeepaliveIntervalOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&keepAliveIntervalOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", keepAliveIntervalOption, keepAliveInterval, err) + } + if err := c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10); err != nil { + t.Fatalf("c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10): %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true); err != nil { + t.Fatalf("c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true): %s", err) + } // Set userTimeout to be the duration to be 1 keepalive // probes. Which means that after the first probe is sent // the second one should cause the connection to be // closed due to userTimeout being hit. - userTimeout := 1 * keepAliveInterval - c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + userTimeout := tcpip.TCPUserTimeoutOption(keepAliveInterval) + if err := c.EP.SetSockOpt(&userTimeout); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", userTimeout, keepAliveInterval, err) + } // Check that the connection is still alive. if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { @@ -7233,8 +7264,9 @@ func TestTCPDeferAccept(t *testing.T) { } const tcpDeferAccept = 1 * time.Second - if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + tcpDeferAcceptOption := tcpip.TCPDeferAcceptOption(tcpDeferAccept) + if err := c.EP.SetSockOpt(&tcpDeferAcceptOption); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)): %s", tcpDeferAcceptOption, tcpDeferAccept, err) } irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) @@ -7290,8 +7322,9 @@ func TestTCPDeferAcceptTimeout(t *testing.T) { } const tcpDeferAccept = 1 * time.Second - if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { - t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + tcpDeferAcceptOpt := tcpip.TCPDeferAcceptOption(tcpDeferAccept) + if err := c.EP.SetSockOpt(&tcpDeferAcceptOpt); err != nil { + t.Fatalf("c.EP.SetSockOpt(&%T(%s)) failed: %s", tcpDeferAcceptOpt, tcpDeferAccept, err) } irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) -- cgit v1.2.3