summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/udp
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2020-12-28 22:05:49 +0000
committergVisor bot <gvisor-bot@google.com>2020-12-28 22:05:49 +0000
commit5c21c7c3bd1552f4d5f87ef588fc213e2a2278ef (patch)
treeb62b3f2c71f46e145c15d7740262f7d59c91c87f /pkg/tcpip/transport/udp
parentb0f23fb7e0cf908622bc6b8c90e2819de6de6ccb (diff)
parent3ff7324dfa7c096a50b628189d5c3f2d4d5ec2f6 (diff)
Merge release-20201208.0-89-g3ff7324df (automated)
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go120
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_state_autogen.go39
3 files changed, 107 insertions, 54 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 763d1d654..9b9e4deb0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -109,7 +109,6 @@ type endpoint struct {
multicastAddr tcpip.Address
multicastNICID tcpip.NICID
portFlags ports.Flags
- bindToDevice tcpip.NICID
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
@@ -226,6 +225,13 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
+// UpdateLastError implements tcpip.SocketOptionsHandler.UpdateLastError.
+func (e *endpoint) UpdateLastError(err *tcpip.Error) {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+}
+
// Abort implements stack.TransportEndpoint.Abort.
func (e *endpoint) Abort() {
e.Close()
@@ -511,6 +517,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
+ so := e.SocketOptions()
+ if so.GetRecvError() {
+ so.QueueLocalErr(
+ tcpip.ErrMessageTooLong,
+ route.NetProto,
+ header.UDPMaximumPacketSize,
+ tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ Port: dstPort,
+ },
+ v,
+ )
+ }
return 0, nil, tcpip.ErrMessageTooLong
}
@@ -638,6 +658,10 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
+func (e *endpoint) HasNIC(id int32) bool {
+ return id == 0 || e.stack.HasNIC(tcpip.NICID(id))
+}
+
// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
switch v := opt.(type) {
@@ -754,15 +778,6 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
delete(e.multicastMemberships, memToRemove)
- case *tcpip.BindToDeviceOption:
- id := tcpip.NICID(*v)
- if id != 0 && !e.stack.HasNIC(id) {
- return tcpip.ErrUnknownDevice
- }
- e.mu.Lock()
- e.bindToDevice = id
- e.mu.Unlock()
-
case *tcpip.SocketDetachFilterOption:
return nil
}
@@ -838,11 +853,6 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
}
e.mu.Unlock()
- case *tcpip.BindToDeviceOption:
- e.mu.RLock()
- *o = tcpip.BindToDeviceOption(e.bindToDevice)
- e.mu.RUnlock()
-
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -996,7 +1006,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
if err != nil {
return err
}
- defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: e.ID.LocalAddress,
@@ -1024,6 +1033,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
id, btd, err := e.registerWithStack(nicID, netProtos, id)
if err != nil {
+ r.Release()
return err
}
@@ -1034,7 +1044,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.ID = id
e.boundBindToDevice = btd
- e.route = r.Clone()
+ e.route = r
e.dstPort = addr.Port
e.RegisterNICID = nicID
e.effectiveNetProtos = netProtos
@@ -1092,21 +1102,22 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *tcp
}
func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) {
+ bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
if err != nil {
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
id.LocalPort = port
}
e.boundPortFlags = e.portFlags
- err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice)
+ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{})
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
e.boundPortFlags = ports.Flags{}
}
- return id, e.bindToDevice, err
+ return id, bindToDevice, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
@@ -1259,6 +1270,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
+ // Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet.
@@ -1267,10 +1279,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return
}
- // TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
- // packets at "Parse" instead of when handling a packet.
- pkt.Data.CapLength(int(hdr.PayloadLength()))
-
if !verifyChecksum(hdr, pkt) {
// Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment()
@@ -1304,7 +1312,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
- Port: hdr.SourcePort(),
+ Port: header.UDP(hdr).SourcePort(),
},
destinationAddress: tcpip.FullAddress{
NIC: pkt.NICID,
@@ -1341,15 +1349,63 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
}
}
+func (e *endpoint) onICMPError(err *tcpip.Error, id stack.TransportEndpointID, errType byte, errCode byte, extra uint32, pkt *stack.PacketBuffer) {
+ // Update last error first.
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ // Update the error queue if IP_RECVERR is enabled.
+ if e.SocketOptions().GetRecvError() {
+ // Linux passes the payload without the UDP header.
+ var payload []byte
+ udp := header.UDP(pkt.Data.ToView())
+ if len(udp) >= header.UDPMinimumSize {
+ payload = udp.Payload()
+ }
+
+ e.SocketOptions().QueueErr(&tcpip.SockError{
+ Err: err,
+ ErrOrigin: header.ICMPOriginFromNetProto(pkt.NetworkProtocolNumber),
+ ErrType: errType,
+ ErrCode: errCode,
+ ErrInfo: extra,
+ Payload: payload,
+ Dst: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.RemoteAddress,
+ Port: id.RemotePort,
+ },
+ Offender: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ },
+ NetProto: pkt.NetworkProtocolNumber,
+ })
+ }
+
+ // Notify of the error.
+ e.waiterQueue.Notify(waiter.EventErr)
+}
+
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) {
if typ == stack.ControlPortUnreachable {
if e.EndpointState() == StateConnected {
- e.lastErrorMu.Lock()
- e.lastError = tcpip.ErrConnectionRefused
- e.lastErrorMu.Unlock()
-
- e.waiterQueue.Notify(waiter.EventErr)
+ var errType byte
+ var errCode byte
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ errType = byte(header.ICMPv4DstUnreachable)
+ errCode = byte(header.ICMPv4PortUnreachable)
+ case header.IPv6ProtocolNumber:
+ errType = byte(header.ICMPv6DstUnreachable)
+ errCode = byte(header.ICMPv6PortUnreachable)
+ default:
+ panic(fmt.Sprintf("unsupported net proto for infering ICMP type and code: %d", pkt.NetworkProtocolNumber))
+ }
+ e.onICMPError(tcpip.ErrConnectionRefused, id, errType, errCode, extra, pkt)
return
}
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index 14e4648cd..d7fc21f11 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -78,7 +78,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
route.ResolveWith(r.pkt.SourceLinkAddress())
ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue)
- if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, tcpip.NICID(ep.ops.GetBindToDevice())); err != nil {
ep.Close()
route.Release()
return nil, err
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
index ec0a8c902..2b7726097 100644
--- a/pkg/tcpip/transport/udp/udp_state_autogen.go
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -73,7 +73,6 @@ func (e *endpoint) StateFields() []string {
"multicastAddr",
"multicastNICID",
"portFlags",
- "bindToDevice",
"lastError",
"boundBindToDevice",
"boundPortFlags",
@@ -91,7 +90,7 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
var rcvBufSizeMaxValue int = e.saveRcvBufSizeMax()
stateSinkObject.SaveValue(6, rcvBufSizeMaxValue)
var lastErrorValue string = e.saveLastError()
- stateSinkObject.SaveValue(19, lastErrorValue)
+ stateSinkObject.SaveValue(18, lastErrorValue)
stateSinkObject.Save(0, &e.TransportEndpointInfo)
stateSinkObject.Save(1, &e.DefaultSocketOptionsHandler)
stateSinkObject.Save(2, &e.waiterQueue)
@@ -109,15 +108,14 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(15, &e.multicastAddr)
stateSinkObject.Save(16, &e.multicastNICID)
stateSinkObject.Save(17, &e.portFlags)
- stateSinkObject.Save(18, &e.bindToDevice)
- stateSinkObject.Save(20, &e.boundBindToDevice)
- stateSinkObject.Save(21, &e.boundPortFlags)
- stateSinkObject.Save(22, &e.sendTOS)
- stateSinkObject.Save(23, &e.shutdownFlags)
- stateSinkObject.Save(24, &e.multicastMemberships)
- stateSinkObject.Save(25, &e.effectiveNetProtos)
- stateSinkObject.Save(26, &e.owner)
- stateSinkObject.Save(27, &e.ops)
+ stateSinkObject.Save(19, &e.boundBindToDevice)
+ stateSinkObject.Save(20, &e.boundPortFlags)
+ stateSinkObject.Save(21, &e.sendTOS)
+ stateSinkObject.Save(22, &e.shutdownFlags)
+ stateSinkObject.Save(23, &e.multicastMemberships)
+ stateSinkObject.Save(24, &e.effectiveNetProtos)
+ stateSinkObject.Save(25, &e.owner)
+ stateSinkObject.Save(26, &e.ops)
}
func (e *endpoint) StateLoad(stateSourceObject state.Source) {
@@ -138,17 +136,16 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(15, &e.multicastAddr)
stateSourceObject.Load(16, &e.multicastNICID)
stateSourceObject.Load(17, &e.portFlags)
- stateSourceObject.Load(18, &e.bindToDevice)
- stateSourceObject.Load(20, &e.boundBindToDevice)
- stateSourceObject.Load(21, &e.boundPortFlags)
- stateSourceObject.Load(22, &e.sendTOS)
- stateSourceObject.Load(23, &e.shutdownFlags)
- stateSourceObject.Load(24, &e.multicastMemberships)
- stateSourceObject.Load(25, &e.effectiveNetProtos)
- stateSourceObject.Load(26, &e.owner)
- stateSourceObject.Load(27, &e.ops)
+ stateSourceObject.Load(19, &e.boundBindToDevice)
+ stateSourceObject.Load(20, &e.boundPortFlags)
+ stateSourceObject.Load(21, &e.sendTOS)
+ stateSourceObject.Load(22, &e.shutdownFlags)
+ stateSourceObject.Load(23, &e.multicastMemberships)
+ stateSourceObject.Load(24, &e.effectiveNetProtos)
+ stateSourceObject.Load(25, &e.owner)
+ stateSourceObject.Load(26, &e.ops)
stateSourceObject.LoadValue(6, new(int), func(y interface{}) { e.loadRcvBufSizeMax(y.(int)) })
- stateSourceObject.LoadValue(19, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
+ stateSourceObject.LoadValue(18, new(string), func(y interface{}) { e.loadLastError(y.(string)) })
stateSourceObject.AfterLoad(e.afterLoad)
}