summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/network/arp/arp.go16
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go44
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go24
-rw-r--r--pkg/tcpip/stack/registration.go36
-rw-r--r--pkg/tcpip/stack/stack.go45
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go27
-rw-r--r--pkg/tcpip/transport/raw/protocol.go9
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go22
-rw-r--r--pkg/tcpip/transport/udp/protocol.go12
9 files changed, 72 insertions, 163 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index fd6395fc1..26cf1c528 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -16,9 +16,9 @@
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
//
-// To use it in the networking stack, pass arp.ProtocolName as one of the
-// network protocols when calling stack.New. Then add an "arp" address to
-// every NIC on the stack that should respond to ARP requests. That is:
+// To use it in the networking stack, pass arp.NewProtocol() as one of the
+// network protocols when calling stack.New. Then add an "arp" address to every
+// NIC on the stack that should respond to ARP requests. That is:
//
// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
// // handle err
@@ -33,9 +33,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ARP protocol name.
- ProtocolName = "arp"
-
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
@@ -200,8 +197,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an ARP network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b7a06f525..b7b07a6c1 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -14,9 +14,9 @@
// Package ipv4 contains the implementation of the ipv4 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv4.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv4.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv4
@@ -32,9 +32,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv4 protocol name.
- ProtocolName = "ipv4"
-
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -53,6 +50,7 @@ type endpoint struct {
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
fragmentation *fragmentation.Fragmentation
+ protocol *protocol
}
// NewEndpoint creates a new ipv4 endpoint.
@@ -64,6 +62,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
dispatcher: dispatcher,
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
}
return e, nil
@@ -204,7 +203,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1)
}
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -267,7 +266,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
if payload.Size() > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
}
ip.SetID(uint16(id))
}
@@ -325,14 +324,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {}
-type protocol struct{}
-
-// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+type protocol struct {
+ ids []uint32
+ hashIV uint32
}
// Number returns the ipv4 protocol number.
@@ -378,7 +372,7 @@ func calculateMTU(mtu uint32) uint32 {
// hashRoute calculates a hash value for the given route. It uses the source &
// destination address, the transport protocol number, and a random initial
// value (generated once on initialization) to generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
t := r.LocalAddress
a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = r.RemoteAddress
@@ -386,22 +380,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-var (
- ids []uint32
- hashIV uint32
-)
-
-func init() {
- ids = make([]uint32, buckets)
+// NewProtocol returns an IPv4 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
r := hash.RandN32(1 + buckets)
for i := range ids {
ids[i] = r[i]
}
- hashIV = r[buckets]
+ hashIV := r[buckets]
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+ return &protocol{ids: ids, hashIV: hashIV}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 331a8bdaa..7de6a4546 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -14,9 +14,9 @@
// Package ipv6 contains the implementation of the ipv6 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv6.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv6.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv6
@@ -28,9 +28,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv6 protocol name.
- ProtocolName = "ipv6"
-
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -160,14 +157,6 @@ func (*endpoint) Close() {}
type protocol struct{}
-// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
-}
-
// Number returns the ipv6 protocol number.
func (p *protocol) Number() tcpip.NetworkProtocolNumber {
return ProtocolNumber
@@ -221,8 +210,7 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an IPv6 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 07e4c770d..80101d4bb 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -366,14 +366,6 @@ type LinkAddressCache interface {
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// TransportProtocolFactory functions are used by the stack to instantiate
-// transport protocols.
-type TransportProtocolFactory func() TransportProtocol
-
-// NetworkProtocolFactory provides methods to be used by the stack to
-// instantiate network protocols.
-type NetworkProtocolFactory func() NetworkProtocol
-
// UnassociatedEndpointFactory produces endpoints for writing packets not
// associated with a particular transport protocol. Such endpoints can be used
// to write arbitrary packets that include the IP header.
@@ -381,34 +373,6 @@ type UnassociatedEndpointFactory interface {
NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
-var (
- transportProtocols = make(map[string]TransportProtocolFactory)
- networkProtocols = make(map[string]NetworkProtocolFactory)
-
- unassociatedFactory UnassociatedEndpointFactory
-)
-
-// RegisterTransportProtocolFactory registers a new transport protocol factory
-// with the stack so that it becomes available to users of the stack. This
-// function is intended to be called by init() functions of the protocols.
-func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
- transportProtocols[name] = p
-}
-
-// RegisterNetworkProtocolFactory registers a new network protocol factory with
-// the stack so that it becomes available to users of the stack. This function
-// is intended to be called by init() functions of the protocols.
-func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
- networkProtocols[name] = p
-}
-
-// RegisterUnassociatedFactory registers a factory to produce endpoints not
-// associated with any particular transport protocol. This function is intended
-// to be called by init() functions of the protocols.
-func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
- unassociatedFactory = f
-}
-
// GSOType is the type of GSO segments.
//
// +stateify savable
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index f7ba3cb0f..18d1704a5 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -17,11 +17,6 @@
//
// For consumers, the only function of interest is New(), everything else is
// provided by the tcpip/public package.
-//
-// For protocol implementers, RegisterTransportProtocolFactory() and
-// RegisterNetworkProtocolFactory() are used to register protocol factories with
-// the stack, which will then be used to instantiate protocol objects when
-// consumers interact with the stack.
package stack
import (
@@ -351,6 +346,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ // unassociatedFactory creates unassociated endpoints. If nil, raw
+ // endpoints are disabled. It is set during Stack creation and is
+ // immutable.
unassociatedFactory UnassociatedEndpointFactory
demux *transportDemuxer
@@ -359,10 +357,6 @@ type Stack struct {
linkAddrCache *linkAddrCache
- // raw indicates whether raw sockets may be created. It is set during
- // Stack creation and is immutable.
- raw bool
-
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
forwarding bool
@@ -398,6 +392,12 @@ type Stack struct {
// Options contains optional Stack configuration.
type Options struct {
+ // NetworkProtocols lists the network protocols to enable.
+ NetworkProtocols []NetworkProtocol
+
+ // TransportProtocols lists the transport protocols to enable.
+ TransportProtocols []TransportProtocol
+
// Clock is an optional clock source used for timestampping packets.
//
// If no Clock is specified, the clock source will be time.Now.
@@ -411,8 +411,9 @@ type Options struct {
// stack (false).
HandleLocal bool
- // Raw indicates whether raw sockets may be created.
- Raw bool
+ // UnassociatedFactory produces unassociated endpoints raw endpoints.
+ // Raw endpoints are enabled only if this is non-nil.
+ UnassociatedFactory UnassociatedEndpointFactory
}
// New allocates a new networking stack with only the requested networking and
@@ -422,7 +423,7 @@ type Options struct {
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
// are supported.
-func New(network []string, transport []string, opts Options) *Stack {
+func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
clock = &tcpip.StdClock{}
@@ -438,17 +439,11 @@ func New(network []string, transport []string, opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- raw: opts.Raw,
icmpRateLimiter: NewICMPRateLimiter(),
}
// Add specified network protocols.
- for _, name := range network {
- netProtoFactory, ok := networkProtocols[name]
- if !ok {
- continue
- }
- netProto := netProtoFactory()
+ for _, netProto := range opts.NetworkProtocols {
s.networkProtocols[netProto.Number()] = netProto
if r, ok := netProto.(LinkAddressResolver); ok {
s.linkAddrResolvers[r.LinkAddressProtocol()] = r
@@ -456,18 +451,14 @@ func New(network []string, transport []string, opts Options) *Stack {
}
// Add specified transport protocols.
- for _, name := range transport {
- transProtoFactory, ok := transportProtocols[name]
- if !ok {
- continue
- }
- transProto := transProtoFactory()
+ for _, transProto := range opts.TransportProtocols {
s.transportProtocols[transProto.Number()] = &transportProtocolState{
proto: transProto,
}
}
- s.unassociatedFactory = unassociatedFactory
+ // Add the factory for unassociated endpoints, if present.
+ s.unassociatedFactory = opts.UnassociatedFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -602,7 +593,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if !s.raw {
+ if s.unassociatedFactory == nil {
return nil, tcpip.ErrNotPermitted
}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 1eb790932..bfb16f7c3 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -14,10 +14,9 @@
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and
-// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
-// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
-// calling stack.New(). Then endpoints can be created by passing
+// must be added to the project, and activated on the stack by passing
+// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
+// protocols when calling stack.New(). Then endpoints can be created by passing
// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
// when calling Stack.NewEndpoint().
package icmp
@@ -34,15 +33,9 @@ import (
)
const (
- // ProtocolName4 is the string representation of the icmp protocol name.
- ProtocolName4 = "icmp4"
-
// ProtocolNumber4 is the ICMP protocol number.
ProtocolNumber4 = header.ICMPv4ProtocolNumber
- // ProtocolName6 is the string representation of the icmp protocol name.
- ProtocolName6 = "icmp6"
-
// ProtocolNumber6 is the IPv6-ICMP protocol number.
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
@@ -125,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
- })
+// NewProtocol4 returns an ICMPv4 transport protocol.
+func NewProtocol4() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+}
- stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
- })
+// NewProtocol6 returns an ICMPv6 transport protocol.
+func NewProtocol6() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index 783c21e6b..a2512d666 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -20,13 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-type factory struct{}
+// EndpointFactory implements stack.UnassociatedEndpointFactory.
+type EndpointFactory struct{}
// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
-
-func init() {
- stack.RegisterUnassociatedFactory(factory{})
-}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a13b2022..d5d8ab96a 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -14,7 +14,7 @@
// Package tcp contains the implementation of the TCP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// activated on the stack by passing tcp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing tcp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -34,9 +34,6 @@ import (
)
const (
- // ProtocolName is the string representation of the tcp protocol name.
- ProtocolName = "tcp"
-
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
@@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
- congestionControl: ccReno,
- availableCongestionControl: []string{ccReno, ccCubic},
- }
- })
+// NewProtocol returns a TCP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 068d9a272..f5cc932dd 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -14,7 +14,7 @@
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// activated on the stack by passing udp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing udp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -30,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the udp protocol name.
- ProtocolName = "udp"
-
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
)
@@ -182,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{}
- })
+// NewProtocol returns a UDP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{}
}