From 6f8fb7e0db2790ff1f5ba835780c03fe245e437f Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Thu, 27 Aug 2020 15:45:02 -0700 Subject: Improve type safety for socket options The existing implementation for {G,S}etSockOpt take arguments of an empty interface type which all types (implicitly) implement; any type may be passed to the functions. This change introduces marker interfaces for socket options that may be set or queried which socket option types implement to ensure that invalid types are caught at compile time. Different interfaces are used to allow the compiler to enforce read-only or set-only socket options. Fixes #3714. RELNOTES: n/a PiperOrigin-RevId: 328832161 --- pkg/sentry/socket/netstack/netstack.go | 55 +++++++++++++++++--------------- pkg/sentry/socket/unix/transport/unix.go | 14 ++++---- 2 files changed, 35 insertions(+), 34 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 8da77cc68..0bf21f7d8 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -236,7 +236,7 @@ type commonEndpoint interface { // SetSockOpt implements tcpip.Endpoint.SetSockOpt and // transport.Endpoint.SetSockOpt. - SetSockOpt(interface{}) *tcpip.Error + SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and // transport.Endpoint.SetSockOptBool. @@ -248,7 +248,7 @@ type commonEndpoint interface { // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. - GetSockOpt(interface{}) *tcpip.Error + GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and // transport.Endpoint.GetSockOpt. @@ -1778,8 +1778,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int t.Kernel().EmitUnimplementedEvent(t) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // setSockOptSocket implements SetSockOpt when level is SOL_SOCKET. @@ -1824,7 +1823,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } name := string(optVal[:n]) if name == "" { - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0))) + v := tcpip.BindToDeviceOption(0) + return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) } s := t.NetworkContext() if s == nil { @@ -1832,7 +1832,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } for nicID, nic := range s.Interfaces() { if nic.Name == name { - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID))) + v := tcpip.BindToDeviceOption(nicID) + return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) } } return syserr.ErrUnknownDevice @@ -1898,7 +1899,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v))) + opt := tcpip.OutOfBandInlineOption(v) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.SO_NO_CHECK: if len(optVal) < sizeOfInt32 { @@ -1921,21 +1923,20 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } return syserr.TranslateNetstackError( - ep.SetSockOpt(tcpip.LingerOption{ + ep.SetSockOpt(&tcpip.LingerOption{ Enabled: v.OnOff != 0, Timeout: time.Second * time.Duration(v.Linger)})) case linux.SO_DETACH_FILTER: // optval is ignored. var v tcpip.SocketDetachFilterOption - return syserr.TranslateNetstackError(ep.SetSockOpt(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) default: socket.SetSockOptEmitUnimplementedEvent(t, name) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. @@ -1982,7 +1983,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 1 || v > linux.MAX_TCP_KEEPIDLE { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIdleOption(time.Second * time.Duration(v)))) + opt := tcpip.KeepaliveIdleOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_KEEPINTVL: if len(optVal) < sizeOfInt32 { @@ -1993,7 +1995,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 1 || v > linux.MAX_TCP_KEEPINTVL { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + opt := tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_KEEPCNT: if len(optVal) < sizeOfInt32 { @@ -2015,11 +2018,12 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 0 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + opt := tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) - if err := ep.SetSockOpt(v); err != nil { + if err := ep.SetSockOpt(&v); err != nil { return syserr.TranslateNetstackError(err) } return nil @@ -2030,7 +2034,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := int32(usermem.ByteOrder.Uint32(optVal)) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + opt := tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_DEFER_ACCEPT: if len(optVal) < sizeOfInt32 { @@ -2040,7 +2045,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * if v < 0 { v = 0 } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + opt := tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)) + return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) case linux.TCP_SYNCNT: if len(optVal) < sizeOfInt32 { @@ -2065,8 +2071,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * emitUnimplementedEventTCP(t, name) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. @@ -2144,8 +2149,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name emitUnimplementedEventIPv6(t, name) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } var ( @@ -2223,7 +2227,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.AddMembershipOption{ + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.AddMembershipOption{ NIC: tcpip.NICID(req.InterfaceIndex), // TODO(igudger): Change AddMembership to use the standard // any address representation. @@ -2237,7 +2241,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.RemoveMembershipOption{ + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.RemoveMembershipOption{ NIC: tcpip.NICID(req.InterfaceIndex), // TODO(igudger): Change DropMembership to use the standard // any address representation. @@ -2251,7 +2255,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastInterfaceOption{ + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), })) @@ -2375,8 +2379,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in t.Kernel().EmitUnimplementedEvent(t) } - // Default to the old behavior; hand off to network stack. - return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{})) + return nil } // emitUnimplementedEventTCP emits unimplemented event if name is valid. This diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 4bf06d4dc..cc9d650fb 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -172,9 +172,8 @@ type Endpoint interface { // connected. GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) - // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option - // types. - SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOpt sets a socket option. + SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error // SetSockOptBool sets a socket option for simple cases when a value has // the int type. @@ -184,9 +183,8 @@ type Endpoint interface { // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error - // GetSockOpt gets a socket option. opt should be a pointer to one of the - // tcpip.*Option types. - GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOpt gets a socket option. + GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error // GetSockOptBool gets a socket option for simple cases when a return // value has the int type. @@ -841,7 +839,7 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess } // SetSockOpt sets a socket option. Currently not supported. -func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { +func (e *baseEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { return nil } @@ -943,7 +941,7 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { +func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { switch opt.(type) { case *tcpip.LingerOption: return nil -- cgit v1.2.3