diff options
Diffstat (limited to 'pkg/tcpip/stack/transport_test.go')
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 62 |
1 files changed, 42 insertions, 20 deletions
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 86c62be25..748ce4ea5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -43,6 +43,7 @@ type fakeTransportEndpoint struct { proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + uniqueID uint64 // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint @@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto} +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Close() { @@ -82,7 +83,10 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions if err != nil { return 0, nil, err } - if err := f.route.WritePacket(nil /* gso */, hdr, buffer.View(v).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}); err != nil { + if err := f.route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: fakeTransNumber, TTL: 123, TOS: stack.DefaultTOS}, tcpip.PacketBuffer{ + Header: hdr, + Data: buffer.View(v).ToVectorisedView(), + }); err != nil { return 0, nil, err } @@ -144,6 +148,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } +func (f *fakeTransportEndpoint) UniqueID() uint64 { + return f.uniqueID +} + func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -192,7 +200,7 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ tcpip.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ if f.acceptQueue != nil { @@ -209,7 +217,7 @@ func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportE } } -func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, tcpip.PacketBuffer) { // Increment the number of received control packets. f.proto.controlCount++ } @@ -218,15 +226,15 @@ func (f *fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) { -} +func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} func (f *fakeTransportEndpoint) IPTables() (iptables.IPTables, error) { return iptables.IPTables{}, nil } -func (f *fakeTransportEndpoint) Resume(*stack.Stack) { -} +func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} + +func (f *fakeTransportEndpoint) Wait() {} type fakeTransportGoodOption bool @@ -251,7 +259,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil + return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { @@ -266,7 +274,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, tcpip.PacketBuffer) bool { return true } @@ -337,7 +345,9 @@ func TestTransportReceive(t *testing.T) { // Make sure packet with wrong protocol is not delivered. buf[0] = 1 buf[2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -346,7 +356,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 3 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 0 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) } @@ -355,7 +367,9 @@ func TestTransportReceive(t *testing.T) { buf[0] = 1 buf[1] = 2 buf[2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.packetCount != 1 { t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) } @@ -408,7 +422,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 0 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = 0 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -417,7 +433,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 3 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 0 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) } @@ -426,7 +444,9 @@ func TestTransportControlReceive(t *testing.T) { buf[fakeNetHeaderLen+0] = 2 buf[fakeNetHeaderLen+1] = 1 buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + linkEP.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if fakeTrans.controlCount != 1 { t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) } @@ -579,7 +599,9 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - ep2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.InjectInbound(fakeNetNumber, tcpip.PacketBuffer{ + Data: req.ToVectorisedView(), + }) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -598,10 +620,10 @@ func TestTransportForwarding(t *testing.T) { t.Fatal("Response packet not forwarded") } - if dst := p.Header[0]; dst != 3 { + if dst := p.Pkt.Header.View()[0]; dst != 3 { t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) } - if src := p.Header[1]; src != 1 { + if src := p.Pkt.Header.View()[1]; src != 1 { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } |