diff options
author | Nayana Bidari <nybidari@google.com> | 2020-12-22 14:41:11 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-12-22 14:44:02 -0800 |
commit | 7c8ba72b026db3b79f12e679ab69078a25c143e8 (patch) | |
tree | 71c5c14dd973fc55b218c635f37b11a72a1de190 /pkg/tcpip/transport/udp | |
parent | 202e9fa3695e015ba8875c094f70d75bce18b95e (diff) |
Move SO_BINDTODEVICE to socketops.
PiperOrigin-RevId: 348696094
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 12 |
3 files changed, 16 insertions, 28 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 24d0c2cb9..9b9e4deb0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -109,7 +109,6 @@ type endpoint struct { multicastAddr tcpip.Address multicastNICID tcpip.NICID portFlags ports.Flags - bindToDevice tcpip.NICID lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -659,6 +658,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } +func (e *endpoint) HasNIC(id int32) bool { + return id == 0 || e.stack.HasNIC(tcpip.NICID(id)) +} + // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { switch v := opt.(type) { @@ -775,15 +778,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { delete(e.multicastMemberships, memToRemove) - case *tcpip.BindToDeviceOption: - id := tcpip.NICID(*v) - if id != 0 && !e.stack.HasNIC(id) { - return tcpip.ErrUnknownDevice - } - e.mu.Lock() - e.bindToDevice = id - e.mu.Unlock() - case *tcpip.SocketDetachFilterOption: return nil } @@ -859,11 +853,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { } e.mu.Unlock() - case *tcpip.BindToDeviceOption: - e.mu.RLock() - *o = tcpip.BindToDeviceOption(e.bindToDevice) - e.mu.RUnlock() - default: return tcpip.ErrUnknownProtocolOption } @@ -1113,21 +1102,22 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp } func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { + bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) if err != nil { - return id, e.bindToDevice, err + return id, bindToDevice, err } id.LocalPort = port } e.boundPortFlags = e.portFlags - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } - return id, e.bindToDevice, err + return id, bindToDevice, err } func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 14e4648cd..d7fc21f11 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, route.ResolveWith(r.pkt.SourceLinkAddress()) ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) - if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil { ep.Close() route.Release() return nil, err diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 6f89b6271..8429f34b4 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -554,7 +554,7 @@ func TestBindToDeviceOption(t *testing.T) { name string setBindToDevice *tcpip.NICID setBindToDeviceError *tcpip.Error - getBindToDevice tcpip.BindToDeviceOption + getBindToDevice int32 }{ {"GetDefaultValue", nil, nil, 0}, {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, @@ -564,15 +564,13 @@ func TestBindToDeviceOption(t *testing.T) { for _, testAction := range testActions { t.Run(testAction.name, func(t *testing.T) { if testAction.setBindToDevice != nil { - bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) - if gotErr, wantErr := ep.SetSockOpt(&bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + bindToDevice := int32(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SocketOptions().SetBindToDevice(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { t.Errorf("got SetSockOpt(&%T(%d)) = %s, want = %s", bindToDevice, bindToDevice, gotErr, wantErr) } } - bindToDevice := tcpip.BindToDeviceOption(88888) - if err := ep.GetSockOpt(&bindToDevice); err != nil { - t.Errorf("GetSockOpt(&%T): %s", bindToDevice, err) - } else if bindToDevice != testAction.getBindToDevice { + bindToDevice := ep.SocketOptions().GetBindToDevice() + if bindToDevice != testAction.getBindToDevice { t.Errorf("got bindToDevice = %d, want = %d", bindToDevice, testAction.getBindToDevice) } }) |