diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2020-09-28 16:22:09 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-09-28 16:24:04 -0700 |
commit | a5acc0616c9552c7252e3f133f9ad4648422cc5f (patch) | |
tree | c62a2f05c8f8aee20671c39862ffe6721df35332 /pkg/tcpip/transport | |
parent | a0e0ba690f3f4946890010e12084db7f081d5bc7 (diff) |
Support creating protocol instances with Stack ref
Network or transport protocols may want to reach the stack. Support this
by letting the stack create the protocol instances so it can pass a
reference to itself at protocol creation time.
Note, protocols do not yet use the stack in this CL but later CLs will
make use of the stack from protocols.
PiperOrigin-RevId: 334260210
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/icmp/protocol.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 20 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/protocol.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 25 |
6 files changed, 31 insertions, 47 deletions
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 941c3c08d..7484f4ad9 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -13,12 +13,7 @@ // limitations under the License. // Package icmp contains the implementation of the ICMP and IPv6-ICMP transport -// protocols for use in ping. To use it in the networking stack, this package -// must be added to the project, and activated on the stack by passing -// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport -// protocols when calling stack.New(). Then endpoints can be created by passing -// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number -// when calling Stack.NewEndpoint(). +// protocols for use in ping. package icmp import ( @@ -135,11 +130,11 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol4 returns an ICMPv4 transport protocol. -func NewProtocol4() stack.TransportProtocol { +func NewProtocol4(*stack.Stack) stack.TransportProtocol { return &protocol{ProtocolNumber4} } // NewProtocol6 returns an ICMPv6 transport protocol. -func NewProtocol6() stack.TransportProtocol { +func NewProtocol6(*stack.Stack) stack.TransportProtocol { return &protocol{ProtocolNumber6} } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 371067048..6a3c2c32b 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package tcp contains the implementation of the TCP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing tcp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing tcp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package tcp contains the implementation of the TCP transport protocol. package tcp import ( @@ -510,7 +505,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a TCP transport protocol. -func NewProtocol() stack.TransportProtocol { +func NewProtocol(*stack.Stack) stack.TransportProtocol { p := protocol{ sendBufferSize: tcpip.TCPSendBufferSizeRangeOption{ Min: MinBufferSize, diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 8326736dc..dd810f594 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -4300,8 +4300,8 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { func TestDefaultBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4363,8 +4363,8 @@ func TestDefaultBufferSizes(t *testing.T) { func TestMinMaxBufferSizes(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) // Check the default values. @@ -4420,8 +4420,8 @@ func TestMinMaxBufferSizes(t *testing.T) { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}}) ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -4470,11 +4470,11 @@ func TestBindToDeviceOption(t *testing.T) { func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ - ipv4.NewProtocol(), - ipv6.NewProtocol(), + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + ipv6.NewProtocol, }, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) id := loopback.New() diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index ebbae6e2f..faf51ef95 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -155,8 +155,8 @@ type Context struct { // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, }) const sendBufferSize = 1 << 20 // 1 MiB diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index a1d0f49d9..e6fc23258 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -12,12 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package udp contains the implementation of the UDP transport protocol. To use -// it in the networking stack, this package must be added to the project, and -// activated on the stack by passing udp.NewProtocol() as one of the -// transport protocols when calling stack.New(). Then endpoints can be created -// by passing udp.ProtocolNumber as the transport protocol number when calling -// Stack.NewEndpoint(). +// Package udp contains the implementation of the UDP transport protocol. package udp import ( @@ -119,6 +114,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { } // NewProtocol returns a UDP transport protocol. -func NewProtocol() stack.TransportProtocol { +func NewProtocol(*stack.Stack) stack.TransportProtocol { return &protocol{} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 6813d57aa..0556ef879 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -294,8 +294,8 @@ type testContext struct { func newDualTestContext(t *testing.T, mtu uint32) *testContext { t.Helper() return newDualTestContextWithOptions(t, mtu, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) } @@ -532,8 +532,8 @@ func newMinPayload(minSize int) []byte { func TestBindToDeviceOption(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}}) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { @@ -807,8 +807,8 @@ func TestV4ReadSelfSource(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, HandleLocal: tt.handleLocal, }) defer c.cleanup() @@ -1485,13 +1485,13 @@ func TestTTL(t *testing.T) { } else { var p stack.NetworkProtocol if flow.isV4() { - p = ipv4.NewProtocol() + p = ipv4.NewProtocol(nil) } else { - p = ipv6.NewProtocol() + p = ipv6.NewProtocol(nil) } ep := p.NewEndpoint(0, nil, nil, nil, nil, stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, })) wantTTL = ep.DefaultTTL() ep.Close() @@ -2345,9 +2345,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, - - TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) e := channel.New(0, defaultMTU, "") if err := s.CreateNIC(nicID1, e); err != nil { |