diff options
Diffstat (limited to 'pkg/tcpip/transport/icmp/endpoint.go')
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 82 |
1 files changed, 49 insertions, 33 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 451d3880e..e17e737e4 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,7 +15,6 @@ package icmp import ( - "encoding/binary" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -85,6 +84,7 @@ type endpoint struct { // NIC. regNICID tcpip.NICID route stack.Route `state:"manual"` + ttl uint8 } func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -105,7 +105,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -205,7 +205,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpi // Write writes data to the endpoint's peer. This method does not block // if the data cannot be written. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) if opts.More { return 0, nil, tcpip.ErrInvalidOptionValue @@ -290,17 +290,17 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-cha } } - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } switch e.netProto { case header.IPv4ProtocolNumber: - err = send4(route, e.id.LocalPort, v) + err = send4(route, e.id.LocalPort, v, e.ttl) case header.IPv6ProtocolNumber: - err = send6(route, e.id.LocalPort, v) + err = send6(route, e.id.LocalPort, v, e.ttl) } if err != nil { @@ -315,8 +315,20 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } -// SetSockOpt sets a socket option. Currently not supported. +// SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(o) + e.mu.Unlock() + } + + return nil +} + +// SetSockOptInt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { return nil } @@ -332,6 +344,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { } e.rcvMu.Unlock() return v, nil + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + } return -1, tcpip.ErrUnknownProtocolOption } @@ -342,40 +366,33 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { case tcpip.ErrorOption: return nil - case *tcpip.SendBufferSizeOption: - e.mu.Lock() - *o = tcpip.SendBufferSizeOption(e.sndBufSize) - e.mu.Unlock() + case *tcpip.KeepaliveEnabledOption: + *o = 0 return nil - case *tcpip.ReceiveBufferSizeOption: + case *tcpip.TTLOption: e.rcvMu.Lock() - *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) + *o = tcpip.TTLOption(e.ttl) e.rcvMu.Unlock() return nil - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: return tcpip.ErrUnknownProtocolOption } } -func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } - // Set the ident to the user-specified port. Sequence number should - // already be set by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident) - hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(icmpv4, data) + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + icmpv4.SetIdent(ident) data = data[header.ICMPv4MinimumSize:] // Linux performs these basic checks. @@ -386,22 +403,21 @@ func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */) } -func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return tcpip.ErrInvalidEndpointState } - // Set the ident. Sequence number is provided by the user. - binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident) - - hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength())) + hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) - icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) copy(icmpv6, data) - data = data[header.ICMPv6EchoMinimumSize:] + // Set the ident. Sequence number is provided by the user. + icmpv6.SetIdent(ident) + data = data[header.ICMPv6MinimumSize:] if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return tcpip.ErrInvalidEndpointState @@ -410,7 +426,7 @@ func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { icmpv6.SetChecksum(0) icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { @@ -541,14 +557,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil |