diff options
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r-- | pkg/tcpip/transport/udp/protocol.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 9 |
2 files changed, 12 insertions, 10 deletions
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index e6fc23258..da5b1deb2 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -45,6 +45,7 @@ const ( ) type protocol struct { + stack *stack.Stack } // Number returns the udp protocol number. @@ -53,14 +54,14 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { } // NewEndpoint creates a new udp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil +func (p *protocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(p.stack, netProto, waiterQueue), nil } // NewRawEndpoint creates a new raw UDP endpoint. It implements // stack.TransportProtocol.NewRawEndpoint. -func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue) +func (p *protocol) NewRawEndpoint(netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(p.stack, netProto, header.UDPProtocolNumber, waiterQueue) } // MinimumPacketSize returns the minimum valid udp packet size. @@ -114,6 +115,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a UDP transport protocol. -func NewProtocol(*stack.Stack) stack.TransportProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.TransportProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 22a809efa..7aaedb708 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1486,6 +1486,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) LinkEndpoint() stack.LinkEndpoint { + return nil +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1509,10 +1513,7 @@ func TestTTL(t *testing.T) { } else { p = ipv6.NewProtocol(nil) } - ep := p.NewEndpoint(&testInterface{}, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, - })) + ep := p.NewEndpoint(&testInterface{}, nil, nil, nil) wantTTL = ep.DefaultTTL() ep.Close() } |