summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/buffer/view.go2
-rw-r--r--pkg/tcpip/header/icmpv4.go14
-rw-r--r--pkg/tcpip/header/ipv4.go8
-rw-r--r--pkg/tcpip/header/ipv6.go4
-rw-r--r--pkg/tcpip/header/tcp.go17
-rw-r--r--pkg/tcpip/iptables/BUILD18
-rw-r--r--pkg/tcpip/iptables/iptables.go81
-rw-r--r--pkg/tcpip/iptables/targets.go35
-rw-r--r--pkg/tcpip/iptables/types.go183
-rw-r--r--pkg/tcpip/link/rawfile/errors.go2
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go2
-rw-r--r--pkg/tcpip/network/arp/arp.go4
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go4
-rw-r--r--pkg/tcpip/network/ip_test.go10
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go24
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go49
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go7
-rw-r--r--pkg/tcpip/stack/registration.go20
-rw-r--r--pkg/tcpip/stack/route.go12
-rw-r--r--pkg/tcpip/stack/stack.go10
-rw-r--r--pkg/tcpip/stack/stack_test.go4
-rw-r--r--pkg/tcpip/stack/transport_test.go5
-rw-r--r--pkg/tcpip/tcpip.go111
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go46
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go7
-rw-r--r--pkg/tcpip/transport/raw/BUILD1
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go135
-rw-r--r--pkg/tcpip/transport/raw/protocol.go32
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go49
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go1
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go6
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go75
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go2
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go339
34 files changed, 1024 insertions, 295 deletions
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 1a9d40778..150310c11 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -50,7 +50,7 @@ func (v View) ToVectorisedView() VectorisedView {
return NewVectorisedView(len(v), []View{v})
}
-// VectorisedView is a vectorised version of View using non contigous memory.
+// VectorisedView is a vectorised version of View using non contiguous memory.
// It supports all the convenience methods supported by View.
//
// +stateify savable
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
index c081de61f..c52c0d851 100644
--- a/pkg/tcpip/header/icmpv4.go
+++ b/pkg/tcpip/header/icmpv4.go
@@ -24,15 +24,11 @@ import (
type ICMPv4 []byte
const (
- // ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
- ICMPv4MinimumSize = 4
-
- // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet.
- ICMPv4EchoMinimumSize = 6
+ // ICMPv4PayloadOffset defines the start of ICMP payload.
+ ICMPv4PayloadOffset = 4
- // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP
- // destination unreachable packet.
- ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4
+ // ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv4MinimumSize = 8
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
@@ -104,5 +100,5 @@ func (ICMPv4) SetDestinationPort(uint16) {
// Payload implements Transport.Payload.
func (b ICMPv4) Payload() []byte {
- return b[ICMPv4MinimumSize:]
+ return b[ICMPv4PayloadOffset:]
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index 7da4c4845..94a3af289 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -85,6 +85,10 @@ const (
// units, the header cannot exceed 15*4 = 60 bytes.
IPv4MaximumHeaderSize = 60
+ // MinIPFragmentPayloadSize is the minimum number of payload bytes that
+ // the first fragment must carry when an IPv4 packet is fragmented.
+ MinIPFragmentPayloadSize = 8
+
// IPv4AddressSize is the size, in bytes, of an IPv4 address.
IPv4AddressSize = 4
@@ -268,6 +272,10 @@ func (b IPv4) IsValid(pktSize int) bool {
return false
}
+ if IPVersion(b) != IPv4Version {
+ return false
+ }
+
return true
}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index 7163eaa36..95fe8bfc3 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -184,6 +184,10 @@ func (b IPv6) IsValid(pktSize int) bool {
return false
}
+ if IPVersion(b) != IPv6Version {
+ return false
+ }
+
return true
}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 1141443bb..82cfe785c 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -176,6 +176,21 @@ const (
// TCPProtocolNumber is TCP's transport protocol number.
TCPProtocolNumber tcpip.TransportProtocolNumber = 6
+
+ // TCPMinimumMSS is the minimum acceptable value for MSS. This is the
+ // same as the value TCP_MIN_MSS defined net/tcp.h.
+ TCPMinimumMSS = IPv4MaximumHeaderSize + TCPHeaderMaximumSize + MinIPFragmentPayloadSize - IPv4MinimumSize - TCPMinimumSize
+
+ // TCPMaximumMSS is the maximum acceptable value for MSS.
+ TCPMaximumMSS = 0xffff
+
+ // TCPDefaultMSS is the MSS value that should be used if an MSS option
+ // is not received from the peer. It's also the value returned by
+ // TCP_MAXSEG option for a socket in an unconnected state.
+ //
+ // Per RFC 1122, page 85: "If an MSS option is not received at
+ // connection setup, TCP MUST assume a default send MSS of 536."
+ TCPDefaultMSS = 536
)
// SourcePort returns the "source port" field of the tcp header.
@@ -306,7 +321,7 @@ func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
synOpts := TCPSynOptions{
// Per RFC 1122, page 85: "If an MSS option is not received at
// connection setup, TCP MUST assume a default send MSS of 536."
- MSS: 536,
+ MSS: TCPDefaultMSS,
// If no window scale option is specified, WS in options is
// returned as -1; this is because the absence of the option
// indicates that the we cannot use window scaling on the
diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD
new file mode 100644
index 000000000..fc9abbb55
--- /dev/null
+++ b/pkg/tcpip/iptables/BUILD
@@ -0,0 +1,18 @@
+package(licenses = ["notice"])
+
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+go_library(
+ name = "iptables",
+ srcs = [
+ "iptables.go",
+ "targets.go",
+ "types.go",
+ ],
+ importpath = "gvisor.dev/gvisor/pkg/tcpip/iptables",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ ],
+)
diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go
new file mode 100644
index 000000000..f1e1d1fad
--- /dev/null
+++ b/pkg/tcpip/iptables/iptables.go
@@ -0,0 +1,81 @@
+// Copyright 2019 The gVisor authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package iptables supports packet filtering and manipulation via the iptables
+// tool.
+package iptables
+
+const (
+ tablenameNat = "nat"
+ tablenameMangle = "mangle"
+)
+
+// Chain names as defined by net/ipv4/netfilter/ip_tables.c.
+const (
+ chainNamePrerouting = "PREROUTING"
+ chainNameInput = "INPUT"
+ chainNameForward = "FORWARD"
+ chainNameOutput = "OUTPUT"
+ chainNamePostrouting = "POSTROUTING"
+)
+
+// DefaultTables returns a default set of tables. Each chain is set to accept
+// all packets.
+func DefaultTables() *IPTables {
+ return &IPTables{
+ Tables: map[string]Table{
+ tablenameNat: Table{
+ BuiltinChains: map[Hook]Chain{
+ Prerouting: unconditionalAcceptChain(chainNamePrerouting),
+ Input: unconditionalAcceptChain(chainNameInput),
+ Output: unconditionalAcceptChain(chainNameOutput),
+ Postrouting: unconditionalAcceptChain(chainNamePostrouting),
+ },
+ DefaultTargets: map[Hook]Target{
+ Prerouting: UnconditionalAcceptTarget{},
+ Input: UnconditionalAcceptTarget{},
+ Output: UnconditionalAcceptTarget{},
+ Postrouting: UnconditionalAcceptTarget{},
+ },
+ UserChains: map[string]Chain{},
+ },
+ tablenameMangle: Table{
+ BuiltinChains: map[Hook]Chain{
+ Prerouting: unconditionalAcceptChain(chainNamePrerouting),
+ Output: unconditionalAcceptChain(chainNameOutput),
+ },
+ DefaultTargets: map[Hook]Target{
+ Prerouting: UnconditionalAcceptTarget{},
+ Output: UnconditionalAcceptTarget{},
+ },
+ UserChains: map[string]Chain{},
+ },
+ },
+ Priorities: map[Hook][]string{
+ Prerouting: []string{tablenameMangle, tablenameNat},
+ Output: []string{tablenameMangle, tablenameNat},
+ },
+ }
+}
+
+func unconditionalAcceptChain(name string) Chain {
+ return Chain{
+ Name: name,
+ Rules: []Rule{
+ Rule{
+ Target: UnconditionalAcceptTarget{},
+ },
+ },
+ }
+}
diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go
new file mode 100644
index 000000000..19a7f77e3
--- /dev/null
+++ b/pkg/tcpip/iptables/targets.go
@@ -0,0 +1,35 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file contains various Targets.
+
+package iptables
+
+import "gvisor.dev/gvisor/pkg/tcpip/buffer"
+
+// UnconditionalAcceptTarget accepts all packets.
+type UnconditionalAcceptTarget struct{}
+
+// Action implements Target.Action.
+func (UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
+ return Accept, ""
+}
+
+// UnconditionalDropTarget denies all packets.
+type UnconditionalDropTarget struct{}
+
+// Action implements Target.Action.
+func (UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) {
+ return Drop, ""
+}
diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go
new file mode 100644
index 000000000..600bd9a10
--- /dev/null
+++ b/pkg/tcpip/iptables/types.go
@@ -0,0 +1,183 @@
+// Copyright 2019 The gVisor authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package iptables
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+)
+
+// A Hook specifies one of the hooks built into the network stack.
+//
+// Userspace app Userspace app
+// ^ |
+// | v
+// [Input] [Output]
+// ^ |
+// | v
+// | routing
+// | |
+// | v
+// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]----->
+type Hook uint
+
+// These values correspond to values in include/uapi/linux/netfilter.h.
+const (
+ // Prerouting happens before a packet is routed to applications or to
+ // be forwarded.
+ Prerouting Hook = iota
+
+ // Input happens before a packet reaches an application.
+ Input
+
+ // Forward happens once it's decided that a packet should be forwarded
+ // to another host.
+ Forward
+
+ // Output happens after a packet is written by an application to be
+ // sent out.
+ Output
+
+ // Postrouting happens just before a packet goes out on the wire.
+ Postrouting
+
+ // The total number of hooks.
+ NumHooks
+)
+
+// A Verdict is returned by a rule's target to indicate how traversal of rules
+// should (or should not) continue.
+type Verdict int
+
+const (
+ // Accept indicates the packet should continue traversing netstack as
+ // normal.
+ Accept Verdict = iota
+
+ // Drop inicates the packet should be dropped, stopping traversing
+ // netstack.
+ Drop
+
+ // Stolen indicates the packet was co-opted by the target and should
+ // stop traversing netstack.
+ Stolen
+
+ // Queue indicates the packet should be queued for userspace processing.
+ Queue
+
+ // Repeat indicates the packet should re-traverse the chains for the
+ // current hook.
+ Repeat
+
+ // None indicates no verdict was reached.
+ None
+
+ // Jump indicates a jump to another chain.
+ Jump
+
+ // Continue indicates that traversal should continue at the next rule.
+ Continue
+
+ // Return indicates that traversal should return to the calling chain.
+ Return
+)
+
+// IPTables holds all the tables for a netstack.
+type IPTables struct {
+ // Tables maps table names to tables. User tables have arbitrary names.
+ Tables map[string]Table
+
+ // Priorities maps each hook to a list of table names. The order of the
+ // list is the order in which each table should be visited for that
+ // hook.
+ Priorities map[Hook][]string
+}
+
+// A Table defines a set of chains and hooks into the network stack. The
+// currently supported tables are:
+// * nat
+// * mangle
+type Table struct {
+ // BuiltinChains holds the un-deletable chains built into netstack. If
+ // a hook isn't present in the map, this table doesn't utilize that
+ // hook.
+ BuiltinChains map[Hook]Chain
+
+ // DefaultTargets holds a target for each hook that will be executed if
+ // chain traversal doesn't yield a verdict.
+ DefaultTargets map[Hook]Target
+
+ // UserChains holds user-defined chains for the keyed by name. Users
+ // can give their chains arbitrary names.
+ UserChains map[string]Chain
+
+ // Chains maps names to chains for both builtin and user-defined chains.
+ // Its entries point to Chains already either in BuiltinChains or
+ // UserChains, and its purpose is to make looking up tables by name
+ // fast.
+ Chains map[string]*Chain
+}
+
+// ValidHooks returns a bitmap of the builtin hooks for the given table.
+func (table *Table) ValidHooks() (uint32, *tcpip.Error) {
+ hooks := uint32(0)
+ for hook, _ := range table.BuiltinChains {
+ hooks |= 1 << hook
+ }
+ return hooks, nil
+}
+
+// A Chain defines a list of rules for packet processing. When a packet
+// traverses a chain, it is checked against each rule until either a rule
+// returns a verdict or the chain ends.
+//
+// By convention, builtin chains end with a rule that matches everything and
+// returns either Accept or Drop. User-defined chains end with Return. These
+// aren't strictly necessary here, but the iptables tool writes tables this way.
+type Chain struct {
+ // Name is the chain name.
+ Name string
+
+ // Rules is the list of rules to traverse.
+ Rules []Rule
+}
+
+// A Rule is a packet processing rule. It consists of two pieces. First it
+// contains zero or more matchers, each of which is a specification of which
+// packets this rule applies to. If there are no matchers in the rule, it
+// applies to any packet.
+type Rule struct {
+ // Matchers is the list of matchers for this rule.
+ Matchers []Matcher
+
+ // Target is the action to invoke if all the matchers match the packet.
+ Target Target
+}
+
+// A Matcher is the interface for matching packets.
+type Matcher interface {
+ // Match returns whether the packet matches and whether the packet
+ // should be "hotdropped", i.e. dropped immediately. This is usually
+ // used for suspicious packets.
+ Match(hook Hook, packet buffer.VectorisedView, interfaceName string) (matches bool, hotdrop bool)
+}
+
+// A Target is the interface for taking an action for a packet.
+type Target interface {
+ // Action takes an action on the packet and returns a verdict on how
+ // traversal should (or should not) continue. If the return value is
+ // Jump, it also returns the name of the chain to jump to.
+ Action(packet buffer.VectorisedView) (Verdict, string)
+}
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
index 80e91bb34..a0a873c84 100644
--- a/pkg/tcpip/link/rawfile/errors.go
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -30,7 +30,7 @@ var translations [maxErrno]*tcpip.Error
// TranslateErrno translate an errno from the syscall package into a
// *tcpip.Error.
//
-// Valid, but unreconigized errnos will be translated to
+// Valid, but unrecognized errnos will be translated to
// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
func TranslateErrno(e syscall.Errno) *tcpip.Error {
if err := translations[e]; err != nil {
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 08847f95f..e3fbb15c2 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -110,7 +110,7 @@ type PollEvent struct {
// BlockingRead reads from a file descriptor that is set up as non-blocking. If
// no data is available, it will block in a poll() syscall until the file
-// descirptor becomes readable.
+// descriptor becomes readable.
func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index ca3d6c0bf..cb35635fc 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -83,6 +83,10 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buf
return tcpip.ErrNotSupported
}
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
v := vv.First()
h := header.ARP(v)
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
index 6822059d6..1628a82be 100644
--- a/pkg/tcpip/network/fragmentation/fragmentation.go
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -60,7 +60,7 @@ type Fragmentation struct {
// lowMemoryLimit specifies the limit on which we will reach by dropping
// fragments after reaching highMemoryLimit.
//
-// reassemblingTimeout specifes the maximum time allowed to reassemble a packet.
+// reassemblingTimeout specifies the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
@@ -80,7 +80,7 @@ func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout t
}
}
-// Process processes an incoming fragment beloning to an ID
+// Process processes an incoming fragment belonging to an ID
// and returns a complete packet when all the packets belonging to that ID have been received.
func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
f.mu.Lock()
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index db65ee7cc..8ff428445 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -282,10 +282,10 @@ func TestIPv4ReceiveControl(t *testing.T) {
{"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10},
{"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8},
{"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8},
- {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8},
{"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
- {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8},
+ {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8},
}
r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb")
if err != nil {
@@ -301,7 +301,7 @@ func TestIPv4ReceiveControl(t *testing.T) {
}
defer ep.Close()
- const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4
+ const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize
view := buffer.NewView(dataOffset + 8)
// Create the outer IPv4 header.
@@ -319,10 +319,10 @@ func TestIPv4ReceiveControl(t *testing.T) {
icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
icmp.SetType(header.ICMPv4DstUnreachable)
icmp.SetCode(c.code)
- copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+ copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef})
// Create the inner IPv4 header.
- ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:])
+ ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:])
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: 100,
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
index bc7f1c42a..fbef6947d 100644
--- a/pkg/tcpip/network/ipv4/icmp.go
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -68,10 +68,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
switch h.Type() {
case header.ICMPv4Echo:
received.Echo.Increment()
- if len(v) < header.ICMPv4EchoMinimumSize {
- received.Invalid.Increment()
- return
- }
// Only send a reply if the checksum is valid.
wantChecksum := h.Checksum()
@@ -93,9 +89,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
vv := vv.Clone(nil)
- vv.TrimFront(header.ICMPv4EchoMinimumSize)
- hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4EchoMinimumSize)
- pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
+ vv.TrimFront(header.ICMPv4MinimumSize)
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize)
+ pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
copy(pkt, h)
pkt.SetType(header.ICMPv4EchoReply)
pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0)))
@@ -108,25 +104,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V
case header.ICMPv4EchoReply:
received.EchoReply.Increment()
- if len(v) < header.ICMPv4EchoMinimumSize {
- received.Invalid.Increment()
- return
- }
+
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv)
case header.ICMPv4DstUnreachable:
received.DstUnreachable.Increment()
- if len(v) < header.ICMPv4DstUnreachableMinimumSize {
- received.Invalid.Increment()
- return
- }
- vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize)
+
+ vv.TrimFront(header.ICMPv4MinimumSize)
switch h.Code() {
case header.ICMPv4PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, vv)
case header.ICMPv4FragmentationNeeded:
- mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:]))
+ mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:]))
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 1e3a7425a..e44a73d96 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -232,6 +232,55 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+// WriteHeaderIncludedPacket writes a packet already containing a network
+// header through the given route.
+func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+ // The packet already has an IP header, but there are a few required
+ // checks.
+ ip := header.IPv4(payload.First())
+ if !ip.IsValid(payload.Size()) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ // Always set the total length.
+ ip.SetTotalLength(uint16(payload.Size()))
+
+ // Set the source address when zero.
+ if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
+ ip.SetSourceAddress(r.LocalAddress)
+ }
+
+ // Set the destination. If the packet already included a destination,
+ // it will be part of the route.
+ ip.SetDestinationAddress(r.RemoteAddress)
+
+ // Set the packet ID when zero.
+ if ip.ID() == 0 {
+ id := uint32(0)
+ if payload.Size() > header.IPv4MaximumHeaderSize+8 {
+ // Packets of 68 bytes or less are required by RFC 791 to not be
+ // fragmented, so we only assign ids to larger packets.
+ id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1)
+ }
+ ip.SetID(uint16(id))
+ }
+
+ // Always set the checksum.
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ if loop&stack.PacketLoop != 0 {
+ e.HandlePacket(r, payload)
+ }
+ if loop&stack.PacketOut == 0 {
+ return nil
+ }
+
+ hdr := buffer.NewPrependableFromView(payload.ToView())
+ r.Stats().IP.PacketsSent.Increment()
+ return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
+}
+
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 27367d6c5..e3e8739fd 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -120,6 +120,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
}
+// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
+// supported by IPv6.
+func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+ // TODO(b/119580726): Support IPv6 header-included packets.
+ return tcpip.ErrNotSupported
+}
+
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 0ecaa0833..462265281 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -174,6 +174,10 @@ type NetworkEndpoint interface {
// protocol.
WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
+ // WriteHeaderIncludedPacket writes a packet that includes a network
+ // header to the given destination address.
+ WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error
+
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
@@ -357,10 +361,19 @@ type TransportProtocolFactory func() TransportProtocol
// instantiate network protocols.
type NetworkProtocolFactory func() NetworkProtocol
+// UnassociatedEndpointFactory produces endpoints for writing packets not
+// associated with a particular transport protocol. Such endpoints can be used
+// to write arbitrary packets that include the IP header.
+type UnassociatedEndpointFactory interface {
+ NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+}
+
var (
transportProtocols = make(map[string]TransportProtocolFactory)
networkProtocols = make(map[string]NetworkProtocolFactory)
+ unassociatedFactory UnassociatedEndpointFactory
+
linkEPMu sync.RWMutex
nextLinkEndpointID tcpip.LinkEndpointID = 1
linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
@@ -380,6 +393,13 @@ func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
networkProtocols[name] = p
}
+// RegisterUnassociatedFactory registers a factory to produce endpoints not
+// associated with any particular transport protocol. This function is intended
+// to be called by init() functions of the protocols.
+func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
+ unassociatedFactory = f
+}
+
// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
// ID that can be used to refer to it.
func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 36d7b6ac7..391ab4344 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -163,6 +163,18 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
return err
}
+// WriteHeaderIncludedPacket writes a packet already containing a network
+// header through the given route.
+func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
+ if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
+ r.Stats().IP.OutgoingPacketErrors.Increment()
+ return err
+ }
+ r.ref.nic.stats.Tx.Packets.Increment()
+ r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size()))
+ return nil
+}
+
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
return r.ref.ep.DefaultTTL()
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 2d7f56ca9..3e8fb2a6c 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -340,6 +340,8 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ unassociatedFactory UnassociatedEndpointFactory
+
demux *transportDemuxer
stats tcpip.Stats
@@ -442,6 +444,8 @@ func New(network []string, transport []string, opts Options) *Stack {
}
}
+ s.unassociatedFactory = unassociatedFactory
+
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -574,11 +578,15 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// NewRawEndpoint creates a new raw transport layer endpoint of the given
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
-func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
if !s.raw {
return nil, tcpip.ErrNotPermitted
}
+ if !associated {
+ return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue)
+ }
+
t, ok := s.transportProtocols[transport]
if !ok {
return nil, tcpip.ErrUnknownProtocol
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 69884af03..959071dbe 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -137,6 +137,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber)
}
+func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
func (*fakeNetworkEndpoint) Close() {}
type fakeNetGoodOption bool
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 788ffcc8c..b418db046 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -90,6 +90,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ return -1, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch opt.(type) {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 4aafb51ab..c5d79da5e 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -66,43 +66,44 @@ func (e *Error) IgnoreStats() bool {
// Errors that can be returned by the network stack.
var (
- ErrUnknownProtocol = &Error{msg: "unknown protocol"}
- ErrUnknownNICID = &Error{msg: "unknown nic id"}
- ErrUnknownDevice = &Error{msg: "unknown device"}
- ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
- ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
- ErrDuplicateAddress = &Error{msg: "duplicate address"}
- ErrNoRoute = &Error{msg: "no route"}
- ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
- ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
- ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
- ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
- ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
- ErrNoPortAvailable = &Error{msg: "no ports are available"}
- ErrPortInUse = &Error{msg: "port is in use"}
- ErrBadLocalAddress = &Error{msg: "bad local address"}
- ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
- ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
- ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
- ErrConnectionRefused = &Error{msg: "connection was refused"}
- ErrTimeout = &Error{msg: "operation timed out"}
- ErrAborted = &Error{msg: "operation aborted"}
- ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
- ErrDestinationRequired = &Error{msg: "destination address is required"}
- ErrNotSupported = &Error{msg: "operation not supported"}
- ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
- ErrNotConnected = &Error{msg: "endpoint not connected"}
- ErrConnectionReset = &Error{msg: "connection reset by peer"}
- ErrConnectionAborted = &Error{msg: "connection aborted"}
- ErrNoSuchFile = &Error{msg: "no such file"}
- ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
- ErrNoLinkAddress = &Error{msg: "no remote link address"}
- ErrBadAddress = &Error{msg: "bad address"}
- ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
- ErrMessageTooLong = &Error{msg: "message too long"}
- ErrNoBufferSpace = &Error{msg: "no buffer space available"}
- ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
- ErrNotPermitted = &Error{msg: "operation not permitted"}
+ ErrUnknownProtocol = &Error{msg: "unknown protocol"}
+ ErrUnknownNICID = &Error{msg: "unknown nic id"}
+ ErrUnknownDevice = &Error{msg: "unknown device"}
+ ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
+ ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
+ ErrDuplicateAddress = &Error{msg: "duplicate address"}
+ ErrNoRoute = &Error{msg: "no route"}
+ ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
+ ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
+ ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
+ ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
+ ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
+ ErrNoPortAvailable = &Error{msg: "no ports are available"}
+ ErrPortInUse = &Error{msg: "port is in use"}
+ ErrBadLocalAddress = &Error{msg: "bad local address"}
+ ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
+ ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
+ ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
+ ErrConnectionRefused = &Error{msg: "connection was refused"}
+ ErrTimeout = &Error{msg: "operation timed out"}
+ ErrAborted = &Error{msg: "operation aborted"}
+ ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
+ ErrDestinationRequired = &Error{msg: "destination address is required"}
+ ErrNotSupported = &Error{msg: "operation not supported"}
+ ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
+ ErrNotConnected = &Error{msg: "endpoint not connected"}
+ ErrConnectionReset = &Error{msg: "connection reset by peer"}
+ ErrConnectionAborted = &Error{msg: "connection aborted"}
+ ErrNoSuchFile = &Error{msg: "no such file"}
+ ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
+ ErrNoLinkAddress = &Error{msg: "no remote link address"}
+ ErrBadAddress = &Error{msg: "bad address"}
+ ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
+ ErrMessageTooLong = &Error{msg: "message too long"}
+ ErrNoBufferSpace = &Error{msg: "no buffer space available"}
+ ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"}
+ ErrNotPermitted = &Error{msg: "operation not permitted"}
+ ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
)
// Errors related to Subnet
@@ -287,6 +288,12 @@ type ControlMessages struct {
// Timestamp is the time (in ns) that the last packed used to create
// the read data was received.
Timestamp int64
+
+ // HasInq indicates whether Inq is valid/set.
+ HasInq bool
+
+ // Inq is the number of bytes ready to be received.
+ Inq int32
}
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
@@ -339,6 +346,10 @@ type Endpoint interface {
// get the actual result. The first call to Connect after the socket has
// connected returns nil. Calling connect again results in ErrAlreadyConnected.
// Anything else -- the attempt to connect failed.
+ //
+ // If address.Addr is empty, this means that Enpoint has to be
+ // disconnected if this is supported, otherwise
+ // ErrAddressFamilyNotSupported must be returned.
Connect(address FullAddress) *Error
// Shutdown closes the read and/or write end of the endpoint connection
@@ -378,6 +389,10 @@ type Endpoint interface {
// *Option types.
GetSockOpt(opt interface{}) *Error
+ // GetSockOptInt gets a socket option for simple cases where a return
+ // value has the int type.
+ GetSockOptInt(SockOpt) (int, *Error)
+
// State returns a socket's lifecycle state. The returned value is
// protocol-specific and is primarily used for diagnostics.
State() uint32
@@ -403,6 +418,18 @@ type WriteOptions struct {
EndOfRecord bool
}
+// SockOpt represents socket options which values have the int type.
+type SockOpt int
+
+const (
+ // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
+ // unread bytes in the input buffer should be returned.
+ ReceiveQueueSizeOption SockOpt = iota
+
+ // TODO(b/137664753): convert all int socket options to be handled via
+ // GetSockOptInt.
+)
+
// ErrorOption is used in GetSockOpt to specify that the last error reported by
// the endpoint should be cleared and returned.
type ErrorOption struct{}
@@ -419,10 +446,6 @@ type ReceiveBufferSizeOption int
// unread bytes in the output buffer should be returned.
type SendQueueSizeOption int
-// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
-// unread bytes in the input buffer should be returned.
-type ReceiveQueueSizeOption int
-
// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int
@@ -491,6 +514,10 @@ type AvailableCongestionControlOption string
// buffer moderation.
type ModerateReceiveBufferOption bool
+// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current
+// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option.
+type MaxSegOption int
+
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
@@ -534,7 +561,7 @@ type BroadcastOption int
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
-// masked target address matches the destination adddress in the row.
+// masked target address matches the destination address in the row.
type Route struct {
// Destination is the address that must be matched against the masked
// target address to check if this row is viable.
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 33cefd937..ba6671c26 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -291,7 +291,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
switch e.netProto {
case header.IPv4ProtocolNumber:
- err = e.send4(route, v)
+ err = send4(route, e.id.LocalPort, v)
case header.IPv6ProtocolNumber:
err = send6(route, e.id.LocalPort, v)
@@ -314,6 +314,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ v = p.data.Size()
+ }
+ e.rcvMu.Unlock()
+ return v, nil
+ }
+ return -1, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
@@ -332,17 +348,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.rcvMu.Unlock()
return nil
- case *tcpip.ReceiveQueueSizeOption:
- e.rcvMu.Lock()
- if e.rcvList.Empty() {
- *o = 0
- } else {
- p := e.rcvList.Front()
- *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
- }
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -352,20 +357,20 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
-func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error {
- if len(data) < header.ICMPv4EchoMinimumSize {
+func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+ if len(data) < header.ICMPv4MinimumSize {
return tcpip.ErrInvalidEndpointState
}
// Set the ident to the user-specified port. Sequence number should
// already be set by the user.
- binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort)
+ binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident)
- hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
+ hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
- icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
copy(icmpv4, data)
- data = data[header.ICMPv4EchoMinimumSize:]
+ data = data[header.ICMPv4MinimumSize:]
// Linux performs these basic checks.
if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
@@ -422,6 +427,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
+ if addr.Addr == "" {
+ // AF_UNSPEC isn't supported.
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
nicid := addr.NIC
localPort := uint16(0)
switch e.state {
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index c89538131..7fdba5d56 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -90,19 +90,18 @@ func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProt
func (p *protocol) MinimumPacketSize() int {
switch p.number {
case ProtocolNumber4:
- return header.ICMPv4EchoMinimumSize
+ return header.ICMPv4MinimumSize
case ProtocolNumber6:
return header.ICMPv6EchoMinimumSize
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
-// ParsePorts returns the source and destination ports stored in the given icmp
-// packet.
+// ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil.
func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
switch p.number {
case ProtocolNumber4:
- return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil
+ return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil
case ProtocolNumber6:
return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
}
diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD
index 34a14bf7f..bc4b255b4 100644
--- a/pkg/tcpip/transport/raw/BUILD
+++ b/pkg/tcpip/transport/raw/BUILD
@@ -21,6 +21,7 @@ go_library(
"endpoint.go",
"endpoint_state.go",
"packet_list.go",
+ "protocol.go",
],
importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 03f495e48..b633cd9d8 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -16,7 +16,7 @@
// sockets allow applications to:
//
// * manually write and inspect transport layer headers and payloads
-// * receive all traffic of a given transport protcol (e.g. ICMP or UDP)
+// * receive all traffic of a given transport protocol (e.g. ICMP or UDP)
// * optionally write and inspect network layer and link layer headers for
// packets
//
@@ -67,6 +67,7 @@ type endpoint struct {
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
+ associated bool
// The following fields are used to manage the receive queue and are
// protected by rcvMu.
@@ -97,8 +98,12 @@ type endpoint struct {
}
// NewEndpoint returns a raw endpoint for the given protocols.
-// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET.
+// TODO(b/129292371): IP_HDRINCL and AF_PACKET.
func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
if netProto != header.IPv4ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
@@ -110,6 +115,16 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
+ associated: associated,
+ }
+
+ // Unassociated endpoints are write-only and users call Write() with IP
+ // headers included. Because they're write-only, We don't need to
+ // register with the stack.
+ if !associated {
+ ep.rcvBufSizeMax = 0
+ ep.waiterQueue = nil
+ return ep, nil
}
if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
@@ -124,7 +139,7 @@ func (ep *endpoint) Close() {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.closed {
+ if ep.closed || !ep.associated {
return
}
@@ -142,8 +157,11 @@ func (ep *endpoint) Close() {
if ep.connected {
ep.route.Release()
+ ep.connected = false
}
+ ep.closed = true
+
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -152,6 +170,10 @@ func (ep *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ if !ep.associated {
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
+ }
+
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -192,6 +214,33 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
return 0, nil, tcpip.ErrInvalidEndpointState
}
+ payloadBytes, err := payload.Get(payload.Size())
+ if err != nil {
+ ep.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ // If this is an unassociated socket and callee provided a nonzero
+ // destination address, route using that address.
+ if !ep.associated {
+ ip := header.IPv4(payloadBytes)
+ if !ip.IsValid(payload.Size()) {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+ dstAddr := ip.DestinationAddress()
+ // Update dstAddr with the address in the IP header, unless
+ // opts.To is set (e.g. if sendto specifies a specific
+ // address).
+ if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil {
+ opts.To = &tcpip.FullAddress{
+ NIC: 0, // NIC is unset.
+ Addr: dstAddr, // The address from the payload.
+ Port: 0, // There are no ports here.
+ }
+ }
+ }
+
// Did the user caller provide a destination? If not, use the connected
// destination.
if opts.To == nil {
@@ -216,12 +265,12 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
return 0, nil, tcpip.ErrInvalidEndpointState
}
- n, ch, err := ep.finishWrite(payload, savedRoute)
+ n, ch, err := ep.finishWrite(payloadBytes, savedRoute)
ep.mu.Unlock()
return n, ch, err
}
- n, ch, err := ep.finishWrite(payload, &ep.route)
+ n, ch, err := ep.finishWrite(payloadBytes, &ep.route)
ep.mu.RUnlock()
return n, ch, err
}
@@ -248,7 +297,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
return 0, nil, err
}
- n, ch, err := ep.finishWrite(payload, &route)
+ n, ch, err := ep.finishWrite(payloadBytes, &route)
route.Release()
ep.mu.RUnlock()
return n, ch, err
@@ -256,7 +305,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
// finishWrite writes the payload to a route. It resolves the route if
// necessary. It's really just a helper to make defer unnecessary in Write.
-func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
+func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
// We may need to resolve the route (match a link layer address to the
// network address). If that requires blocking (e.g. to use ARP),
// return a channel on which the caller can wait.
@@ -269,13 +318,14 @@ func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uint
}
}
- payloadBytes, err := payload.Get(payload.Size())
- if err != nil {
- return 0, nil, err
- }
-
switch ep.netProto {
case header.IPv4ProtocolNumber:
+ if !ep.associated {
+ if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil {
+ return 0, nil, err
+ }
+ break
+ }
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil {
return 0, nil, err
@@ -298,6 +348,11 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
+ if addr.Addr == "" {
+ // AF_UNSPEC isn't supported.
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
if ep.closed {
return tcpip.ErrInvalidEndpointState
}
@@ -330,15 +385,17 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
}
defer route.Release()
- // Re-register the endpoint with the appropriate NIC.
- if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
- return err
+ if ep.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+ ep.registeredNIC = nic
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
- // Save the route and NIC we've connected via.
+ // Save the route we've connected via.
ep.route = route.Clone()
- ep.registeredNIC = nic
ep.connected = true
return nil
@@ -381,14 +438,16 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
- // Re-register the endpoint with the appropriate NIC.
- if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
- return err
+ if ep.associated {
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+ ep.registeredNIC = addr.NIC
+ ep.boundNIC = addr.NIC
}
- ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
- ep.registeredNIC = addr.NIC
- ep.boundNIC = addr.NIC
ep.boundAddr = addr.Addr
ep.bound = true
@@ -428,6 +487,23 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() {
+ p := ep.rcvList.Front()
+ v = p.data.Size()
+ }
+ ep.rcvMu.Unlock()
+ return v, nil
+ }
+
+ return -1, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
@@ -446,17 +522,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
ep.rcvMu.Unlock()
return nil
- case *tcpip.ReceiveQueueSizeOption:
- ep.rcvMu.Lock()
- if ep.rcvList.Empty() {
- *o = 0
- } else {
- p := ep.rcvList.Front()
- *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
- }
- ep.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
new file mode 100644
index 000000000..783c21e6b
--- /dev/null
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -0,0 +1,32 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raw
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+type factory struct{}
+
+// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
+func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
+}
+
+func init() {
+ stack.RegisterUnassociatedFactory(factory{})
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ee60ebf58..89154391b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -117,6 +117,7 @@ const (
notifyDrain
notifyReset
notifyKeepaliveChanged
+ notifyMSSChanged
)
// SACKInfo holds TCP SACK related information for a given endpoint.
@@ -218,8 +219,6 @@ type endpoint struct {
mu sync.RWMutex `state:"nosave"`
id stack.TransportEndpointID
- // state endpointState `state:".(endpointState)"`
- // pState ProtocolState
state EndpointState `state:".(EndpointState)"`
isPortReserved bool `state:"manual"`
@@ -313,6 +312,10 @@ type endpoint struct {
// in SYN-RCVD state.
synRcvdCount int
+ // userMSS if non-zero is the MSS value explicitly set by the user
+ // for this endpoint using the TCP_MAXSEG setsockopt.
+ userMSS int
+
// The following fields are used to manage the send buffer. When
// segments are ready to be sent, they are added to sndQueue and the
// protocol goroutine is signaled via sndWaker.
@@ -917,6 +920,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.mu.Lock()
+ e.userMSS = int(userMSS)
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+ return nil
+
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1086,6 +1100,15 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
return e.rcvBufUsed, nil
}
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ return e.readyReceiveSize()
+ }
+ return -1, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
@@ -1096,6 +1119,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.lastErrorMu.Unlock()
return err
+ case *tcpip.MaxSegOption:
+ // This is just stubbed out. Linux never returns the user_mss
+ // value as it either returns the defaultMSS or returns the
+ // actual current MSS. Netstack just returns the defaultMSS
+ // always for now.
+ *o = header.TCPDefaultMSS
+ return nil
+
case *tcpip.SendBufferSizeOption:
e.sndBufMu.Lock()
*o = tcpip.SendBufferSizeOption(e.sndBufSize)
@@ -1108,15 +1139,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
e.rcvListMu.Unlock()
return nil
- case *tcpip.ReceiveQueueSizeOption:
- v, err := e.readyReceiveSize()
- if err != nil {
- return err
- }
-
- *o = tcpip.ReceiveQueueSizeOption(v)
- return nil
-
case *tcpip.DelayOption:
*o = 0
if v := atomic.LoadUint32(&e.delay); v != 0 {
@@ -1271,6 +1293,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol
// Connect connects the endpoint to its peer.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Addr == "" && addr.Port == 0 {
+ // AF_UNSPEC isn't supported.
+ return tcpip.ErrAddressFamilyNotSupported
+ }
+
return e.connect(addr, true, true)
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index ec61a3886..b93959034 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -342,6 +342,7 @@ func loadError(s string) *tcpip.Error {
tcpip.ErrNoBufferSpace,
tcpip.ErrBroadcastDisabled,
tcpip.ErrNotPermitted,
+ tcpip.ErrAddressFamilyNotSupported,
}
messageToError = make(map[string]*tcpip.Error)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 630dd7925..bcc0f3e28 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -271,7 +271,7 @@ func (c *Context) GetPacketNonBlocking() []byte {
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
- buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2))
+ buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2))
if len(buf) > maxTotalSize {
buf = buf[:maxTotalSize]
}
@@ -291,8 +291,8 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
icmp.SetType(typ)
icmp.SetCode(code)
- copy(icmp[header.ICMPv4MinimumSize:], p1)
- copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2)
+ copy(icmp[header.ICMPv4PayloadOffset:], p1)
+ copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2)
// Inject packet.
c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 99fdfb795..91f89a781 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -189,7 +189,6 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
p := e.rcvList.Front()
e.rcvList.Remove(p)
e.rcvBufSize -= p.data.Size()
-
e.rcvMu.Unlock()
if addr != nil {
@@ -253,7 +252,7 @@ func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stac
if nicid == 0 {
nicid = e.multicastNICID
}
- if localAddr == "" {
+ if localAddr == "" && nicid == 0 {
localAddr = e.multicastAddr
}
}
@@ -539,6 +538,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ v = p.data.Size()
+ }
+ e.rcvMu.Unlock()
+ return v, nil
+ }
+ return -1, tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
@@ -573,17 +588,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
- case *tcpip.ReceiveQueueSizeOption:
- e.rcvMu.Lock()
- if e.rcvList.Empty() {
- *o = 0
- } else {
- p := e.rcvList.Front()
- *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
- }
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.MulticastTTLOption:
e.mu.Lock()
*o = tcpip.MulticastTTLOption(e.multicastTTL)
@@ -671,6 +675,9 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.netProto
+ if len(addr.Addr) == 0 {
+ return netProto, nil
+ }
if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint.
if e.v6only {
@@ -698,8 +705,44 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
+func (e *endpoint) disconnect() *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if e.state != stateConnected {
+ return nil
+ }
+ id := stack.TransportEndpointID{}
+ // Exclude ephemerally bound endpoints.
+ if e.bindNICID != 0 || e.id.LocalAddress == "" {
+ var err *tcpip.Error
+ id = stack.TransportEndpointID{
+ LocalPort: e.id.LocalPort,
+ LocalAddress: e.id.LocalAddress,
+ }
+ id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ if err != nil {
+ return err
+ }
+ e.state = stateBound
+ } else {
+ e.state = stateInitial
+ }
+
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.id = id
+ e.route.Release()
+ e.route = stack.Route{}
+ e.dstPort = 0
+
+ return nil
+}
+
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Addr == "" {
+ return e.disconnect()
+ }
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
@@ -734,12 +777,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
defer r.Release()
id := stack.TransportEndpointID{
- LocalAddress: r.LocalAddress,
+ LocalAddress: e.id.LocalAddress,
LocalPort: localPort,
RemotePort: addr.Port,
RemoteAddress: r.RemoteAddress,
}
+ if e.state == stateInitial {
+ id.LocalAddress = r.LocalAddress
+ }
+
// Even if we're connected, this endpoint can still be used to send
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
index 701bdd72b..18e786397 100644
--- a/pkg/tcpip/transport/udp/endpoint_state.go
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -92,8 +92,6 @@ func (e *endpoint) afterLoad() {
if err != nil {
panic(*err)
}
-
- e.id.LocalAddress = e.route.LocalAddress
} else if len(e.id.LocalAddress) != 0 { // stateBound
if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
panic(tcpip.ErrBadLocalAddress)
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 75129a2ff..958d5712e 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -847,9 +847,7 @@ func TestWriteIncrementsPacketsSent(t *testing.T) {
}
}
-func TestTTL(t *testing.T) {
- payload := tcpip.SlicePayload(buffer.View(newPayload()))
-
+func setSockOptVariants(t *testing.T, optFunc func(*testing.T, string, tcpip.NetworkProtocolNumber, string)) {
for _, name := range []string{"v4", "v6", "dual"} {
t.Run(name, func(t *testing.T) {
var networkProtocolNumber tcpip.NetworkProtocolNumber
@@ -874,134 +872,219 @@ func TestTTL(t *testing.T) {
for _, variant := range variants {
t.Run(variant, func(t *testing.T) {
- for _, typ := range []string{"unicast", "multicast"} {
- t.Run(typ, func(t *testing.T) {
- var addr tcpip.Address
- var port uint16
- switch typ {
- case "unicast":
- port = testPort
- switch variant {
- case "v4":
- addr = testAddr
- case "mapped":
- addr = testV4MappedAddr
- case "v6":
- addr = testV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- port = multicastPort
- switch variant {
- case "v4":
- addr = multicastAddr
- case "mapped":
- addr = multicastV4MappedAddr
- case "v6":
- addr = multicastV6Addr
- default:
- t.Fatal("unknown test variant")
- }
- default:
- t.Fatal("unknown test variant")
- }
-
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- var err *tcpip.Error
- c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- switch name {
- case "v4":
- case "v6":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- case "dual":
- if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
- default:
- t.Fatal("unknown test variant")
- }
+ optFunc(t, name, networkProtocolNumber, variant)
+ })
+ }
+ })
+ }
+}
- const multicastTTL = 42
- if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
- c.t.Fatalf("SetSockOpt failed: %v", err)
- }
+func TestTTL(t *testing.T) {
+ payload := tcpip.SlicePayload(buffer.View(newPayload()))
- n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
- if err != nil {
- c.t.Fatalf("Write failed: %v", err)
- }
- if n != uintptr(len(payload)) {
- c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
- }
+ setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
+ for _, typ := range []string{"unicast", "multicast"} {
+ t.Run(typ, func(t *testing.T) {
+ var addr tcpip.Address
+ var port uint16
+ switch typ {
+ case "unicast":
+ port = testPort
+ switch variant {
+ case "v4":
+ addr = testAddr
+ case "mapped":
+ addr = testV4MappedAddr
+ case "v6":
+ addr = testV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ port = multicastPort
+ switch variant {
+ case "v4":
+ addr = multicastAddr
+ case "mapped":
+ addr = multicastV4MappedAddr
+ case "v6":
+ addr = multicastV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ switch name {
+ case "v4":
+ case "v6":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ case "dual":
+ if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ const multicastTTL = 42
+ if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
+ }
+
+ checkerFn := checker.IPv4
+ switch variant {
+ case "v4", "mapped":
+ case "v6":
+ checkerFn = checker.IPv6
+ default:
+ t.Fatal("unknown test variant")
+ }
+ var wantTTL uint8
+ var multicast bool
+ switch typ {
+ case "unicast":
+ multicast = false
+ switch variant {
+ case "v4", "mapped":
+ ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ case "v6":
+ ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ wantTTL = ep.DefaultTTL()
+ ep.Close()
+ default:
+ t.Fatal("unknown test variant")
+ }
+ case "multicast":
+ wantTTL = multicastTTL
+ multicast = true
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var networkProtocolNumber tcpip.NetworkProtocolNumber
+ switch variant {
+ case "v4", "mapped":
+ networkProtocolNumber = ipv4.ProtocolNumber
+ case "v6":
+ networkProtocolNumber = ipv6.ProtocolNumber
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ b := c.getPacket(networkProtocolNumber, multicast)
+ checkerFn(c.t, b,
+ checker.TTL(wantTTL),
+ checker.UDP(
+ checker.DstPort(port),
+ ),
+ )
+ })
+ }
+ })
+}
- checkerFn := checker.IPv4
- switch variant {
- case "v4", "mapped":
- case "v6":
- checkerFn = checker.IPv6
- default:
- t.Fatal("unknown test variant")
- }
- var wantTTL uint8
- var multicast bool
- switch typ {
- case "unicast":
- multicast = false
- switch variant {
- case "v4", "mapped":
- ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- case "v6":
- ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- wantTTL = ep.DefaultTTL()
- ep.Close()
- default:
- t.Fatal("unknown test variant")
- }
- case "multicast":
- wantTTL = multicastTTL
- multicast = true
- default:
- t.Fatal("unknown test variant")
+func TestMulticastInterfaceOption(t *testing.T) {
+ setSockOptVariants(t, func(t *testing.T, name string, networkProtocolNumber tcpip.NetworkProtocolNumber, variant string) {
+ for _, bindTyp := range []string{"bound", "unbound"} {
+ t.Run(bindTyp, func(t *testing.T) {
+ for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} {
+ t.Run(optTyp, func(t *testing.T) {
+ var mcastAddr, localIfAddr tcpip.Address
+ switch variant {
+ case "v4":
+ mcastAddr = multicastAddr
+ localIfAddr = stackAddr
+ case "mapped":
+ mcastAddr = multicastV4MappedAddr
+ localIfAddr = stackAddr
+ case "v6":
+ mcastAddr = multicastV6Addr
+ localIfAddr = stackV6Addr
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ var ifoptSet tcpip.MulticastInterfaceOption
+ switch optTyp {
+ case "use local-addr":
+ ifoptSet.InterfaceAddr = localIfAddr
+ case "use NICID":
+ ifoptSet.NIC = 1
+ case "use local-addr and NIC":
+ ifoptSet.InterfaceAddr = localIfAddr
+ ifoptSet.NIC = 1
+ default:
+ t.Fatal("unknown test variant")
+ }
+
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if bindTyp == "bound" {
+ // Bind the socket by connecting to the multicast address.
+ // This may have an influence on how the multicast interface
+ // is set.
+ addr := tcpip.FullAddress{
+ Addr: mcastAddr,
+ Port: multicastPort,
}
-
- var networkProtocolNumber tcpip.NetworkProtocolNumber
- switch variant {
- case "v4", "mapped":
- networkProtocolNumber = ipv4.ProtocolNumber
- case "v6":
- networkProtocolNumber = ipv6.ProtocolNumber
- default:
- t.Fatal("unknown test variant")
+ if err := c.ep.Connect(addr); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
}
-
- b := c.getPacket(networkProtocolNumber, multicast)
- checkerFn(c.t, b,
- checker.TTL(wantTTL),
- checker.UDP(
- checker.DstPort(port),
- ),
- )
- })
- }
- })
- }
- })
- }
+ }
+
+ if err := c.ep.SetSockOpt(ifoptSet); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ // Verify multicast interface addr and NIC were set correctly.
+ // Note that NIC must be 1 since this is our outgoing interface.
+ ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr}
+ var ifoptGot tcpip.MulticastInterfaceOption
+ if err := c.ep.GetSockOpt(&ifoptGot); err != nil {
+ c.t.Fatalf("GetSockOpt failed: %v", err)
+ }
+ if ifoptGot != ifoptWant {
+ c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant)
+ }
+ })
+ }
+ })
+ }
+ })
}