summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go8
-rw-r--r--pkg/tcpip/transport/tcp/segment.go29
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go4
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/udp/protocol.go9
5 files changed, 35 insertions, 21 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index feef8dca0..b1d820372 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -747,15 +747,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
// Only accept echo replies.
switch e.NetProto {
case header.IPv4ProtocolNumber:
- h := header.ICMPv4(pkt.Data.First())
- if h.Type() != header.ICMPv4EchoReply {
+ h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize)
+ if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
}
case header.IPv6ProtocolNumber:
- h := header.ICMPv6(pkt.Data.First())
- if h.Type() != header.ICMPv6EchoReply {
+ h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize)
+ if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply {
e.stack.Stats().DroppedPackets.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
return
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 40461fd31..7712ce652 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -144,7 +144,11 @@ func (s *segment) logicalLen() seqnum.Size {
// TCP checksum and stores the checksum and result of checksum verification in
// the csum and csumValid fields of the segment.
func (s *segment) parse() bool {
- h := header.TCP(s.data.First())
+ h, ok := s.data.PullUp(header.TCPMinimumSize)
+ if !ok {
+ return false
+ }
+ hdr := header.TCP(h)
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
@@ -156,12 +160,16 @@ func (s *segment) parse() bool {
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
- offset := int(h.DataOffset())
- if offset < header.TCPMinimumSize || offset > len(h) {
+ offset := int(hdr.DataOffset())
+ if offset < header.TCPMinimumSize {
+ return false
+ }
+ hdrWithOpts, ok := s.data.PullUp(offset)
+ if !ok {
return false
}
- s.options = []byte(h[header.TCPMinimumSize:offset])
+ s.options = []byte(hdrWithOpts[header.TCPMinimumSize:])
s.parsedOptions = header.ParseTCPOptions(s.options)
// Query the link capabilities to decide if checksum validation is
@@ -173,18 +181,19 @@ func (s *segment) parse() bool {
s.data.TrimFront(offset)
}
if verifyChecksum {
- s.csum = h.Checksum()
+ hdr = header.TCP(hdrWithOpts)
+ s.csum = hdr.Checksum()
xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
- xsum = h.CalculateChecksum(xsum)
+ xsum = hdr.CalculateChecksum(xsum)
s.data.TrimFront(offset)
xsum = header.ChecksumVV(s.data, xsum)
s.csumValid = xsum == 0xffff
}
- s.sequenceNumber = seqnum.Value(h.SequenceNumber())
- s.ackNumber = seqnum.Value(h.AckNumber())
- s.flags = h.Flags()
- s.window = seqnum.Size(h.WindowSize())
+ s.sequenceNumber = seqnum.Value(hdr.SequenceNumber())
+ s.ackNumber = seqnum.Value(hdr.AckNumber())
+ s.flags = hdr.Flags()
+ s.window = seqnum.Size(hdr.WindowSize())
return true
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 7e574859b..33e2b9a09 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -3556,7 +3556,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- tcpbuf := vv.First()[header.IPv4MinimumSize:]
+ tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4
c.SendSegment(vv)
@@ -3583,7 +3583,7 @@ func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
- tcpbuf := vv.First()[header.IPv4MinimumSize:]
+ tcpbuf := vv.ToView()[header.IPv4MinimumSize:]
// Overwrite a byte in the payload which should cause checksum
// verification to fail.
tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index edb54f0be..756ab913a 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -1250,8 +1250,8 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) {
// Get the header then trim it from the view.
- hdr := header.UDP(pkt.Data.First())
- if int(hdr.Length()) > pkt.Data.Size() {
+ hdr, ok := pkt.Data.PullUp(header.UDPMinimumSize)
+ if !ok || int(header.UDP(hdr).Length()) > pkt.Data.Size() {
// Malformed packet.
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
e.stats.ReceiveErrors.MalformedPacketsReceived.Increment()
@@ -1286,7 +1286,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
- Port: hdr.SourcePort(),
+ Port: header.UDP(hdr).SourcePort(),
},
}
packet.data = pkt.Data
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 6e31a9bac..52af6de22 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -68,8 +68,13 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
// that don't match any existing endpoint.
func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt stack.PacketBuffer) bool {
// Get the header then trim it from the view.
- hdr := header.UDP(pkt.Data.First())
- if int(hdr.Length()) > pkt.Data.Size() {
+ h, ok := pkt.Data.PullUp(header.UDPMinimumSize)
+ if !ok {
+ // Malformed packet.
+ r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
+ return true
+ }
+ if int(header.UDP(h).Length()) > pkt.Data.Size() {
// Malformed packet.
r.Stack().Stats().UDP.MalformedPacketsReceived.Increment()
return true