summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/testing
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp/testing')
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD2
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go77
2 files changed, 59 insertions, 20 deletions
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
index 19b0d31c5..b33ec2087 100644
--- a/pkg/tcpip/transport/tcp/testing/context/BUILD
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -8,7 +8,7 @@ go_library(
srcs = ["context.go"],
importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context",
visibility = [
- "//:sandbox",
+ "//visibility:public",
],
deps = [
"//pkg/tcpip",
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index ef823e4ae..b0a376eba 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -231,14 +231,15 @@ func (c *Context) CheckNoPacket(errMsg string) {
// addresses. It will fail with an error if no packet is received for
// 2 seconds.
func (c *Context) GetPacket() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize {
c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize)
@@ -259,14 +260,15 @@ func (c *Context) GetPacket() []byte {
// and destination address. If no packet is available it will return
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+
+ hdr := p.Pkt.Header.View()
+ b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...)
checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
return b
@@ -302,11 +304,19 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
copy(icmp[header.ICMPv4PayloadOffset:], p2)
// Inject packet.
- c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
// BuildSegment builds a TCP segment based on the given Headers and payload.
func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
+ return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
+}
+
+// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
+// payload and source and destination IPv4 addresses.
+func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -319,8 +329,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(tcp.ProtocolNumber),
- SrcAddr: TestAddr,
- DstAddr: StackAddr,
+ SrcAddr: src,
+ DstAddr: dst,
})
ip.SetChecksum(^ip.CalculateChecksum())
@@ -337,7 +347,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
@@ -350,13 +360,26 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
- c.linkEP.Inject(ipv4.ProtocolNumber, s)
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: s,
+ })
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
- c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h))
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: c.BuildSegment(payload, h),
+ })
+}
+
+// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
+// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
+// provided source and destination IPv4 addresses.
+func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
+ c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
+ Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
+ })
}
// SendAck sends an ACK packet.
@@ -462,14 +485,15 @@ func (c *Context) CreateV6Endpoint(v6only bool) {
// GetV6Packet reads a single packet from the link layer endpoint of the context
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
+ c.t.Helper()
select {
case p := <-c.linkEP.C:
if p.Proto != ipv6.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
- b := make([]byte, len(p.Header)+len(p.Payload))
- copy(b, p.Header)
- copy(b[len(p.Header):], p.Payload)
+ b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size())
+ copy(b, p.Pkt.Header.View())
+ copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView())
checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
return b
@@ -484,6 +508,13 @@ func (c *Context) GetV6Packet() []byte {
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
// the context.
func (c *Context) SendV6Packet(payload []byte, h *Headers) {
+ c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
+}
+
+// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
+// endpoint of the context using the provided source and destination IPv6
+// addresses.
+func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@@ -494,8 +525,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
NextHeader: uint8(tcp.ProtocolNumber),
HopLimit: 65,
- SrcAddr: TestV6Addr,
- DstAddr: StackV6Addr,
+ SrcAddr: src,
+ DstAddr: dst,
})
// Initialize the TCP header.
@@ -511,14 +542,16 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
})
// Calculate the TCP pseudo-header checksum.
- xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t)))
+ xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
t.SetChecksum(^t.CalculateChecksum(xsum))
// Inject packet.
- c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+ c.linkEP.InjectInbound(ipv6.ProtocolNumber, tcpip.PacketBuffer{
+ Data: buf.ToVectorisedView(),
+ })
}
// CreateConnected creates a connected TCP endpoint.
@@ -1059,3 +1092,9 @@ func (c *Context) SetGSOEnabled(enable bool) {
func (c *Context) MSSWithoutOptions() uint16 {
return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize)
}
+
+// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no
+// options are in use for IPv6 packets.
+func (c *Context) MSSWithoutOptionsV6() uint16 {
+ return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize)
+}