From d58eb9ce828fd7c831f30e922e01f1d2b84e462c Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Fri, 31 May 2019 16:14:04 -0700 Subject: Add basic iptables structures to netstack. Change-Id: Ib589906175a59dae315405a28f2d7f525ff8877f --- pkg/tcpip/iptables/BUILD | 20 +++++ pkg/tcpip/iptables/iptables.go | 97 +++++++++++++++++++++ pkg/tcpip/iptables/targets.go | 35 ++++++++ pkg/tcpip/iptables/types.go | 190 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 342 insertions(+) create mode 100644 pkg/tcpip/iptables/BUILD create mode 100644 pkg/tcpip/iptables/iptables.go create mode 100644 pkg/tcpip/iptables/targets.go create mode 100644 pkg/tcpip/iptables/types.go (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/iptables/BUILD b/pkg/tcpip/iptables/BUILD new file mode 100644 index 000000000..173f148da --- /dev/null +++ b/pkg/tcpip/iptables/BUILD @@ -0,0 +1,20 @@ +package(licenses = ["notice"]) + +load("//tools/go_stateify:defs.bzl", "go_library", "go_test") + +go_library( + name = "iptables", + srcs = [ + "iptables.go", + "targets.go", + "types.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/iptables", + visibility = ["//visibility:public"], + deps = [ + "//pkg/state", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + ], +) diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go new file mode 100644 index 000000000..ee1ed4666 --- /dev/null +++ b/pkg/tcpip/iptables/iptables.go @@ -0,0 +1,97 @@ +// Copyright 2019 The gVisor authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package iptables supports packet filtering and manipulation via the iptables +// tool. +package iptables + +const ( + tablenameNat = "nat" + tablenameMangle = "mangle" +) + +// Chain names as defined by net/ipv4/netfilter/ip_tables.c. +const ( + chainNamePrerouting = "PREROUTING" + chainNameInput = "INPUT" + chainNameForward = "FORWARD" + chainNameOutput = "OUTPUT" + chainNamePostrouting = "POSTROUTING" +) + +// DefaultTables returns a default set of tables. Each chain is set to accept +// all packets. +func DefaultTables() *IPTables { + tables := IPTables{ + Tables: map[string]*Table{ + tablenameNat: &Table{ + BuiltinChains: map[Hook]*Chain{ + Prerouting: unconditionalAcceptChain(chainNamePrerouting), + Input: unconditionalAcceptChain(chainNameInput), + Output: unconditionalAcceptChain(chainNameOutput), + Postrouting: unconditionalAcceptChain(chainNamePostrouting), + }, + DefaultTargets: map[Hook]Target{ + Prerouting: UnconditionalAcceptTarget{}, + Input: UnconditionalAcceptTarget{}, + Output: UnconditionalAcceptTarget{}, + Postrouting: UnconditionalAcceptTarget{}, + }, + UserChains: map[string]*Chain{}, + }, + tablenameMangle: &Table{ + BuiltinChains: map[Hook]*Chain{ + Prerouting: unconditionalAcceptChain(chainNamePrerouting), + Output: unconditionalAcceptChain(chainNameOutput), + }, + DefaultTargets: map[Hook]Target{ + Prerouting: UnconditionalAcceptTarget{}, + Output: UnconditionalAcceptTarget{}, + }, + UserChains: map[string]*Chain{}, + }, + }, + Priorities: map[Hook][]string{ + Prerouting: []string{tablenameMangle, tablenameNat}, + Output: []string{tablenameMangle, tablenameNat}, + }, + } + + // Initialize each table's Chains field. + tables.Tables[tablenameNat].Chains = map[string]*Chain{ + chainNamePrerouting: tables.Tables[tablenameNat].BuiltinChains[Prerouting], + chainNameInput: tables.Tables[tablenameNat].BuiltinChains[Input], + chainNameOutput: tables.Tables[tablenameNat].BuiltinChains[Output], + chainNamePostrouting: tables.Tables[tablenameNat].BuiltinChains[Postrouting], + } + tables.Tables[tablenameMangle].Chains = map[string]*Chain{ + chainNamePrerouting: tables.Tables[tablenameMangle].BuiltinChains[Prerouting], + chainNameInput: tables.Tables[tablenameMangle].BuiltinChains[Input], + chainNameOutput: tables.Tables[tablenameMangle].BuiltinChains[Output], + chainNamePostrouting: tables.Tables[tablenameMangle].BuiltinChains[Postrouting], + } + + return &tables +} + +func unconditionalAcceptChain(name string) *Chain { + return &Chain{ + Name: name, + Rules: []*Rule{ + &Rule{ + Target: UnconditionalAcceptTarget{}, + }, + }, + } +} diff --git a/pkg/tcpip/iptables/targets.go b/pkg/tcpip/iptables/targets.go new file mode 100644 index 000000000..028978e3a --- /dev/null +++ b/pkg/tcpip/iptables/targets.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + +// This file contains various Targets. + +// UnconditionalAcceptTarget accepts all packets. +type UnconditionalAcceptTarget struct{} + +// Action implements Target.Action. +func (_ UnconditionalAcceptTarget) Action(packet buffer.VectorisedView) (Verdict, string) { + return Accept, "" +} + +// UnconditionalDropTarget denies all packets. +type UnconditionalDropTarget struct{} + +// Action implements Target.Action. +func (_ UnconditionalDropTarget) Action(packet buffer.VectorisedView) (Verdict, string) { + return Drop, "" +} diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go new file mode 100644 index 000000000..65bfc7b1d --- /dev/null +++ b/pkg/tcpip/iptables/types.go @@ -0,0 +1,190 @@ +// Copyright 2019 The gVisor authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package iptables + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +// Hook specifies one of the hooks built into the network stack. +// +// Userspace app Userspace app +// ^ | +// | v +// [Input] [Output] +// ^ | +// | v +// | routing +// | | +// | v +// ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> +type Hook uint + +// These values correspond to values in include/uapi/linux/netfilter.h. +const ( + // Prerouting happens before a packet is routed to applications or to + // be forwarded. + Prerouting Hook = iota + + // Input happens before a packet reaches an application. + Input + + // Forward happens once it's decided that a packet should be forwarded + // to another host. + Forward + + // Output happens after a packet is written by an application to be + // sent out. + Output + + // Postrouting happens just before a packet goes out on the wire. + Postrouting + + // The total number of hooks. + NumHooks +) + +// Verdict is returned by a rule's target to indicate how traversal of rules +// should (or should not) continue. +type Verdict int + +const ( + // Accept indicates the packet should continue traversing netstack as + // normal. + Accept Verdict = iota + + // Drop inicates the packet should be dropped, stopping traversing + // netstack. + Drop + + // Stolen indicates the packet was co-opted by the target and should + // stop traversing netstack. + Stolen + + // Queue indicates the packet should be queued for userspace processing. + Queue + + // Repeat indicates the packet should re-traverse the chains for the + // current hook. + Repeat + + // None indicates no verdict was reached. + None + + // Jump indicates a jump to another chain. + Jump + + // Continue indicates that traversal should continue at the next rule. + Continue + + // Return indicates that traversal should return to the calling chain. + Return +) + +// IPTables holds all the tables for a netstack. +type IPTables struct { + // mu protects the entire struct. + mu sync.RWMutex + + // Tables maps table names to tables. User tables have arbitrary names. + Tables map[string]*Table + + // Priorities maps each hook to a list of table names. The order of the + // list is the order in which each table should be visited for that + // hook. + Priorities map[Hook][]string +} + +// Table defines a set of chains and hooks into the network stack. The +// currently supported tables are: +// * nat +// * mangle +type Table struct { + // BuiltinChains holds the un-deletable chains built into netstack. If + // a hook isn't present in the map, this table doesn't utilize that + // hook. + BuiltinChains map[Hook]*Chain + + // DefaultTargets holds a target for each hook that will be executed if + // chain traversal doesn't yield a verdict. + DefaultTargets map[Hook]Target + + // UserChains holds user-defined chains for the keyed by name. Users + // can give their chains arbitrary names. + UserChains map[string]*Chain + + // Chains maps names to chains for both builtin and user-defined chains. + // Its entries point to Chains already either in BuiltinChains and + // UserChains, and its purpose is to make looking up tables by name + // fast. + Chains map[string]*Chain +} + +// ValidHooks returns a bitmap of the builtin hooks for the given table. +// +// Precondition: IPTables.mu must be locked for reading. +func (table *Table) ValidHooks() (uint32, *tcpip.Error) { + hooks := uint32(0) + for hook, _ := range table.BuiltinChains { + hooks |= 1 << hook + } + return hooks, nil +} + +// Chain defines a list of rules for packet processing. When a packet traverses +// a chain, it is checked against each rule until either a rule returns a +// verdict or the chain ends. +// +// By convention, builtin chains end with a rule that matches everything and +// returns either Accept or Drop. User-defined chains end with Return. These +// aren't strictly necessary here, but the iptables tool writes tables this way. +type Chain struct { + // Name is the chain name. + Name string + + // Rules is the list of rules to traverse. + Rules []*Rule +} + +// Rule is a packet processing rule. It consists of two pieces. First it +// contains zero or more matchers, each of which is a specification of which +// packets this rule applies to. If there are no matchers in the rule, it +// applies to any packet. +type Rule struct { + // Matchers is the list of matchers for this rule. + Matchers []Matcher + + // Target is the action to invoke if all the matchers match the packet. + Target Target +} + +// Matcher is the interface for matching packets. +type Matcher interface { + // Match returns whether the packet matches and whether the packet + // should be "hotdropped", i.e. dropped immediately. This is usually + // used for suspicious packets. + Match(hook Hook, packet buffer.VectorisedView, interfaceName string) (matches bool, hotdrop bool) +} + +// Target is the interface for taking an action for a packet. +type Target interface { + // Action takes an action on the packet and returns a verdict on how + // traversal should (or should not) continue. If the return value is + // Jump, it also returns the name of the chain to jump to. + Action(packet buffer.VectorisedView) (Verdict, string) +} -- cgit v1.2.3 From 8afbd974da2483d8f81e3abde5c9d689719263cb Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Fri, 7 Jun 2019 12:54:53 -0700 Subject: Address Ian's comments. Change-Id: I7445033b1970cbba3f2ed0682fe520dce02d8fad --- pkg/tcpip/iptables/iptables.go | 36 +++++++++++------------------------- pkg/tcpip/iptables/types.go | 12 ++++++------ 2 files changed, 17 insertions(+), 31 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index ee1ed4666..bd54ef5a6 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -34,9 +34,9 @@ const ( // all packets. func DefaultTables() *IPTables { tables := IPTables{ - Tables: map[string]*Table{ - tablenameNat: &Table{ - BuiltinChains: map[Hook]*Chain{ + Tables: map[string]Table{ + tablenameNat: Table{ + BuiltinChains: map[Hook]Chain{ Prerouting: unconditionalAcceptChain(chainNamePrerouting), Input: unconditionalAcceptChain(chainNameInput), Output: unconditionalAcceptChain(chainNameOutput), @@ -48,10 +48,10 @@ func DefaultTables() *IPTables { Output: UnconditionalAcceptTarget{}, Postrouting: UnconditionalAcceptTarget{}, }, - UserChains: map[string]*Chain{}, + UserChains: map[string]Chain{}, }, - tablenameMangle: &Table{ - BuiltinChains: map[Hook]*Chain{ + tablenameMangle: Table{ + BuiltinChains: map[Hook]Chain{ Prerouting: unconditionalAcceptChain(chainNamePrerouting), Output: unconditionalAcceptChain(chainNameOutput), }, @@ -59,7 +59,7 @@ func DefaultTables() *IPTables { Prerouting: UnconditionalAcceptTarget{}, Output: UnconditionalAcceptTarget{}, }, - UserChains: map[string]*Chain{}, + UserChains: map[string]Chain{}, }, }, Priorities: map[Hook][]string{ @@ -68,28 +68,14 @@ func DefaultTables() *IPTables { }, } - // Initialize each table's Chains field. - tables.Tables[tablenameNat].Chains = map[string]*Chain{ - chainNamePrerouting: tables.Tables[tablenameNat].BuiltinChains[Prerouting], - chainNameInput: tables.Tables[tablenameNat].BuiltinChains[Input], - chainNameOutput: tables.Tables[tablenameNat].BuiltinChains[Output], - chainNamePostrouting: tables.Tables[tablenameNat].BuiltinChains[Postrouting], - } - tables.Tables[tablenameMangle].Chains = map[string]*Chain{ - chainNamePrerouting: tables.Tables[tablenameMangle].BuiltinChains[Prerouting], - chainNameInput: tables.Tables[tablenameMangle].BuiltinChains[Input], - chainNameOutput: tables.Tables[tablenameMangle].BuiltinChains[Output], - chainNamePostrouting: tables.Tables[tablenameMangle].BuiltinChains[Postrouting], - } - return &tables } -func unconditionalAcceptChain(name string) *Chain { - return &Chain{ +func unconditionalAcceptChain(name string) Chain { + return Chain{ Name: name, - Rules: []*Rule{ - &Rule{ + Rules: []Rule{ + Rule{ Target: UnconditionalAcceptTarget{}, }, }, diff --git a/pkg/tcpip/iptables/types.go b/pkg/tcpip/iptables/types.go index 65bfc7b1d..cdfb6ba28 100644 --- a/pkg/tcpip/iptables/types.go +++ b/pkg/tcpip/iptables/types.go @@ -98,11 +98,11 @@ const ( // IPTables holds all the tables for a netstack. type IPTables struct { - // mu protects the entire struct. - mu sync.RWMutex + // Mu protects the entire struct. + Mu sync.RWMutex // Tables maps table names to tables. User tables have arbitrary names. - Tables map[string]*Table + Tables map[string]Table // Priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that @@ -118,7 +118,7 @@ type Table struct { // BuiltinChains holds the un-deletable chains built into netstack. If // a hook isn't present in the map, this table doesn't utilize that // hook. - BuiltinChains map[Hook]*Chain + BuiltinChains map[Hook]Chain // DefaultTargets holds a target for each hook that will be executed if // chain traversal doesn't yield a verdict. @@ -126,7 +126,7 @@ type Table struct { // UserChains holds user-defined chains for the keyed by name. Users // can give their chains arbitrary names. - UserChains map[string]*Chain + UserChains map[string]Chain // Chains maps names to chains for both builtin and user-defined chains. // Its entries point to Chains already either in BuiltinChains and @@ -158,7 +158,7 @@ type Chain struct { Name string // Rules is the list of rules to traverse. - Rules []*Rule + Rules []Rule } // Rule is a packet processing rule. It consists of two pieces. First it -- cgit v1.2.3 From 06a83df533244dc2b3b8adfc1bf0608d3753c1d9 Mon Sep 17 00:00:00 2001 From: Kevin Krakauer Date: Mon, 10 Jun 2019 12:43:54 -0700 Subject: Address more comments. Change-Id: I83ae1079f3dcba6b018f59ab7898decab5c211d2 --- pkg/tcpip/iptables/iptables.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/tcpip/iptables/iptables.go b/pkg/tcpip/iptables/iptables.go index bd54ef5a6..f1e1d1fad 100644 --- a/pkg/tcpip/iptables/iptables.go +++ b/pkg/tcpip/iptables/iptables.go @@ -33,7 +33,7 @@ const ( // DefaultTables returns a default set of tables. Each chain is set to accept // all packets. func DefaultTables() *IPTables { - tables := IPTables{ + return &IPTables{ Tables: map[string]Table{ tablenameNat: Table{ BuiltinChains: map[Hook]Chain{ @@ -67,8 +67,6 @@ func DefaultTables() *IPTables { Output: []string{tablenameMangle, tablenameNat}, }, } - - return &tables } func unconditionalAcceptChain(name string) Chain { -- cgit v1.2.3 From 116cac053e2e4e167caa9707439065af7c7b82b3 Mon Sep 17 00:00:00 2001 From: Andrei Vagin Date: Wed, 3 Jul 2019 13:57:24 -0700 Subject: netstack/udp: connect with the AF_UNSPEC address family means disconnect PiperOrigin-RevId: 256433283 --- pkg/sentry/socket/epsocket/epsocket.go | 17 ++-- pkg/sentry/socket/unix/unix.go | 2 +- pkg/sentry/strace/socket.go | 2 +- pkg/syserr/netstack.go | 75 ++++++++------- pkg/tcpip/tcpip.go | 79 ++++++++------- pkg/tcpip/transport/icmp/endpoint.go | 5 + pkg/tcpip/transport/raw/endpoint.go | 5 + pkg/tcpip/transport/tcp/endpoint.go | 5 + pkg/tcpip/transport/tcp/endpoint_state.go | 1 + pkg/tcpip/transport/udp/endpoint.go | 42 +++++++- pkg/tcpip/transport/udp/endpoint_state.go | 2 - test/syscalls/linux/udp_socket.cc | 155 +++++++++++++++++++++++++++++- 12 files changed, 299 insertions(+), 91 deletions(-) (limited to 'pkg/tcpip') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index b2b2d98a1..9d1bcfd41 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -285,14 +285,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address { // GetAddress reads an sockaddr struct from the given address and converts it // to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6 // addresses. -func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { +func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { return tcpip.FullAddress{}, syserr.ErrInvalidArgument } family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) { + if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -317,7 +317,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET: var a linux.SockAddrInet if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) @@ -330,7 +330,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { case linux.AF_INET6: var a linux.SockAddrInet6 if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, syserr.ErrBadAddress + return tcpip.FullAddress{}, syserr.ErrInvalidArgument } binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) @@ -343,6 +343,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) { } return out, nil + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, nil + default: return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported } @@ -465,7 +468,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, false /* strict */) if err != nil { return err } @@ -498,7 +501,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, err := GetAddress(s.family, sockaddr) + addr, err := GetAddress(s.family, sockaddr, true /* strict */) if err != nil { return err } @@ -1922,7 +1925,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, err := GetAddress(s.family, to) + addrBuf, err := GetAddress(s.family, to, true /* strict */) if err != nil { return 0, err } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index b30871a90..637168714 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -110,7 +110,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr) + addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { return "", err } diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f9cf2eb21..386b40af7 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -332,7 +332,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, err := epsocket.GetAddress(int(family), b) + fa, err := epsocket.GetAddress(int(family), b, true /* strict */) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 965808a27..8ff922c69 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -49,43 +49,44 @@ var ( ) var netstackErrorTranslations = map[*tcpip.Error]*Error{ - tcpip.ErrUnknownProtocol: ErrUnknownProtocol, - tcpip.ErrUnknownNICID: ErrUnknownNICID, - tcpip.ErrUnknownDevice: ErrUnknownDevice, - tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption, - tcpip.ErrDuplicateNICID: ErrDuplicateNICID, - tcpip.ErrDuplicateAddress: ErrDuplicateAddress, - tcpip.ErrNoRoute: ErrNoRoute, - tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint, - tcpip.ErrAlreadyBound: ErrAlreadyBound, - tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState, - tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting, - tcpip.ErrAlreadyConnected: ErrAlreadyConnected, - tcpip.ErrNoPortAvailable: ErrNoPortAvailable, - tcpip.ErrPortInUse: ErrPortInUse, - tcpip.ErrBadLocalAddress: ErrBadLocalAddress, - tcpip.ErrClosedForSend: ErrClosedForSend, - tcpip.ErrClosedForReceive: ErrClosedForReceive, - tcpip.ErrWouldBlock: ErrWouldBlock, - tcpip.ErrConnectionRefused: ErrConnectionRefused, - tcpip.ErrTimeout: ErrTimeout, - tcpip.ErrAborted: ErrAborted, - tcpip.ErrConnectStarted: ErrConnectStarted, - tcpip.ErrDestinationRequired: ErrDestinationRequired, - tcpip.ErrNotSupported: ErrNotSupported, - tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported, - tcpip.ErrNotConnected: ErrNotConnected, - tcpip.ErrConnectionReset: ErrConnectionReset, - tcpip.ErrConnectionAborted: ErrConnectionAborted, - tcpip.ErrNoSuchFile: ErrNoSuchFile, - tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue, - tcpip.ErrNoLinkAddress: ErrHostDown, - tcpip.ErrBadAddress: ErrBadAddress, - tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable, - tcpip.ErrMessageTooLong: ErrMessageTooLong, - tcpip.ErrNoBufferSpace: ErrNoBufferSpace, - tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled, - tcpip.ErrNotPermitted: ErrNotPermittedNet, + tcpip.ErrUnknownProtocol: ErrUnknownProtocol, + tcpip.ErrUnknownNICID: ErrUnknownNICID, + tcpip.ErrUnknownDevice: ErrUnknownDevice, + tcpip.ErrUnknownProtocolOption: ErrUnknownProtocolOption, + tcpip.ErrDuplicateNICID: ErrDuplicateNICID, + tcpip.ErrDuplicateAddress: ErrDuplicateAddress, + tcpip.ErrNoRoute: ErrNoRoute, + tcpip.ErrBadLinkEndpoint: ErrBadLinkEndpoint, + tcpip.ErrAlreadyBound: ErrAlreadyBound, + tcpip.ErrInvalidEndpointState: ErrInvalidEndpointState, + tcpip.ErrAlreadyConnecting: ErrAlreadyConnecting, + tcpip.ErrAlreadyConnected: ErrAlreadyConnected, + tcpip.ErrNoPortAvailable: ErrNoPortAvailable, + tcpip.ErrPortInUse: ErrPortInUse, + tcpip.ErrBadLocalAddress: ErrBadLocalAddress, + tcpip.ErrClosedForSend: ErrClosedForSend, + tcpip.ErrClosedForReceive: ErrClosedForReceive, + tcpip.ErrWouldBlock: ErrWouldBlock, + tcpip.ErrConnectionRefused: ErrConnectionRefused, + tcpip.ErrTimeout: ErrTimeout, + tcpip.ErrAborted: ErrAborted, + tcpip.ErrConnectStarted: ErrConnectStarted, + tcpip.ErrDestinationRequired: ErrDestinationRequired, + tcpip.ErrNotSupported: ErrNotSupported, + tcpip.ErrQueueSizeNotSupported: ErrQueueSizeNotSupported, + tcpip.ErrNotConnected: ErrNotConnected, + tcpip.ErrConnectionReset: ErrConnectionReset, + tcpip.ErrConnectionAborted: ErrConnectionAborted, + tcpip.ErrNoSuchFile: ErrNoSuchFile, + tcpip.ErrInvalidOptionValue: ErrInvalidOptionValue, + tcpip.ErrNoLinkAddress: ErrHostDown, + tcpip.ErrBadAddress: ErrBadAddress, + tcpip.ErrNetworkUnreachable: ErrNetworkUnreachable, + tcpip.ErrMessageTooLong: ErrMessageTooLong, + tcpip.ErrNoBufferSpace: ErrNoBufferSpace, + tcpip.ErrBroadcastDisabled: ErrBroadcastDisabled, + tcpip.ErrNotPermitted: ErrNotPermittedNet, + tcpip.ErrAddressFamilyNotSupported: ErrAddressFamilyNotSupported, } // TranslateNetstackError converts an error from the tcpip package to a sentry diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 966b3acfd..c61f96fb0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -66,43 +66,44 @@ func (e *Error) IgnoreStats() bool { // Errors that can be returned by the network stack. var ( - ErrUnknownProtocol = &Error{msg: "unknown protocol"} - ErrUnknownNICID = &Error{msg: "unknown nic id"} - ErrUnknownDevice = &Error{msg: "unknown device"} - ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} - ErrDuplicateNICID = &Error{msg: "duplicate nic id"} - ErrDuplicateAddress = &Error{msg: "duplicate address"} - ErrNoRoute = &Error{msg: "no route"} - ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"} - ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true} - ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"} - ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true} - ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true} - ErrNoPortAvailable = &Error{msg: "no ports are available"} - ErrPortInUse = &Error{msg: "port is in use"} - ErrBadLocalAddress = &Error{msg: "bad local address"} - ErrClosedForSend = &Error{msg: "endpoint is closed for send"} - ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"} - ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true} - ErrConnectionRefused = &Error{msg: "connection was refused"} - ErrTimeout = &Error{msg: "operation timed out"} - ErrAborted = &Error{msg: "operation aborted"} - ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true} - ErrDestinationRequired = &Error{msg: "destination address is required"} - ErrNotSupported = &Error{msg: "operation not supported"} - ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"} - ErrNotConnected = &Error{msg: "endpoint not connected"} - ErrConnectionReset = &Error{msg: "connection reset by peer"} - ErrConnectionAborted = &Error{msg: "connection aborted"} - ErrNoSuchFile = &Error{msg: "no such file"} - ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} - ErrNoLinkAddress = &Error{msg: "no remote link address"} - ErrBadAddress = &Error{msg: "bad address"} - ErrNetworkUnreachable = &Error{msg: "network is unreachable"} - ErrMessageTooLong = &Error{msg: "message too long"} - ErrNoBufferSpace = &Error{msg: "no buffer space available"} - ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} - ErrNotPermitted = &Error{msg: "operation not permitted"} + ErrUnknownProtocol = &Error{msg: "unknown protocol"} + ErrUnknownNICID = &Error{msg: "unknown nic id"} + ErrUnknownDevice = &Error{msg: "unknown device"} + ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} + ErrDuplicateNICID = &Error{msg: "duplicate nic id"} + ErrDuplicateAddress = &Error{msg: "duplicate address"} + ErrNoRoute = &Error{msg: "no route"} + ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"} + ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true} + ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"} + ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true} + ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true} + ErrNoPortAvailable = &Error{msg: "no ports are available"} + ErrPortInUse = &Error{msg: "port is in use"} + ErrBadLocalAddress = &Error{msg: "bad local address"} + ErrClosedForSend = &Error{msg: "endpoint is closed for send"} + ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"} + ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true} + ErrConnectionRefused = &Error{msg: "connection was refused"} + ErrTimeout = &Error{msg: "operation timed out"} + ErrAborted = &Error{msg: "operation aborted"} + ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true} + ErrDestinationRequired = &Error{msg: "destination address is required"} + ErrNotSupported = &Error{msg: "operation not supported"} + ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"} + ErrNotConnected = &Error{msg: "endpoint not connected"} + ErrConnectionReset = &Error{msg: "connection reset by peer"} + ErrConnectionAborted = &Error{msg: "connection aborted"} + ErrNoSuchFile = &Error{msg: "no such file"} + ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} + ErrNoLinkAddress = &Error{msg: "no remote link address"} + ErrBadAddress = &Error{msg: "bad address"} + ErrNetworkUnreachable = &Error{msg: "network is unreachable"} + ErrMessageTooLong = &Error{msg: "message too long"} + ErrNoBufferSpace = &Error{msg: "no buffer space available"} + ErrBroadcastDisabled = &Error{msg: "broadcast socket option disabled"} + ErrNotPermitted = &Error{msg: "operation not permitted"} + ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ) // Errors related to Subnet @@ -339,6 +340,10 @@ type Endpoint interface { // get the actual result. The first call to Connect after the socket has // connected returns nil. Calling connect again results in ErrAlreadyConnected. // Anything else -- the attempt to connect failed. + // + // If address.Addr is empty, this means that Enpoint has to be + // disconnected if this is supported, otherwise + // ErrAddressFamilyNotSupported must be returned. Connect(address FullAddress) *Error // Shutdown closes the read and/or write end of the endpoint connection diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 33cefd937..ab9e80747 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -422,6 +422,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() + if addr.Addr == "" { + // AF_UNSPEC isn't supported. + return tcpip.ErrAddressFamilyNotSupported + } + nicid := addr.NIC localPort := uint16(0) switch e.state { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 73599dd9d..42aded77f 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -298,6 +298,11 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() + if addr.Addr == "" { + // AF_UNSPEC isn't supported. + return tcpip.ErrAddressFamilyNotSupported + } + if ep.closed { return tcpip.ErrInvalidEndpointState } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ee60ebf58..cb40fea94 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1271,6 +1271,11 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol // Connect connects the endpoint to its peer. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + if addr.Addr == "" && addr.Port == 0 { + // AF_UNSPEC isn't supported. + return tcpip.ErrAddressFamilyNotSupported + } + return e.connect(addr, true, true) } diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index ec61a3886..b93959034 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -342,6 +342,7 @@ func loadError(s string) *tcpip.Error { tcpip.ErrNoBufferSpace, tcpip.ErrBroadcastDisabled, tcpip.ErrNotPermitted, + tcpip.ErrAddressFamilyNotSupported, } messageToError = make(map[string]*tcpip.Error) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 99fdfb795..cb0ea83a6 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -698,8 +698,44 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } +func (e *endpoint) disconnect() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.state != stateConnected { + return nil + } + id := stack.TransportEndpointID{} + // Exclude ephemerally bound endpoints. + if e.bindNICID != 0 || e.id.LocalAddress == "" { + var err *tcpip.Error + id = stack.TransportEndpointID{ + LocalPort: e.id.LocalPort, + LocalAddress: e.id.LocalAddress, + } + id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + if err != nil { + return err + } + e.state = stateBound + } else { + e.state = stateInitial + } + + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.id = id + e.route.Release() + e.route = stack.Route{} + e.dstPort = 0 + + return nil +} + // Connect connects the endpoint to its peer. Specifying a NIC is optional. func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + if addr.Addr == "" { + return e.disconnect() + } if addr.Port == 0 { // We don't support connecting to port zero. return tcpip.ErrInvalidEndpointState @@ -734,12 +770,16 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { defer r.Release() id := stack.TransportEndpointID{ - LocalAddress: r.LocalAddress, + LocalAddress: e.id.LocalAddress, LocalPort: localPort, RemotePort: addr.Port, RemoteAddress: r.RemoteAddress, } + if e.state == stateInitial { + id.LocalAddress = r.LocalAddress + } + // Even if we're connected, this endpoint can still be used to send // packets on a different network protocol, so we register both even if // v6only is set to false and this is an ipv6 endpoint. diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 701bdd72b..18e786397 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -92,8 +92,6 @@ func (e *endpoint) afterLoad() { if err != nil { panic(*err) } - - e.id.LocalAddress = e.route.LocalAddress } else if len(e.id.LocalAddress) != 0 { // stateBound if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { panic(tcpip.ErrBadLocalAddress) diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 31db8a2ad..1bb0307c4 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -304,12 +305,50 @@ TEST_P(UdpSocketTest, ReceiveAfterConnect) { SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. - char received[512]; + char received[sizeof(buf)]; EXPECT_THAT(recv(s_, received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); } +TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { + // Connect s_ to loopback:TestPort, and bind t_ to loopback:TestPort. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(bind(t_, addr_[0], addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(t_, addr_[1], addrlen_), SyscallSucceeds()); + + // Get the address s_ was bound to during connect. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + + for (int i = 0; i < 2; i++) { + // Send from t_ to s_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + EXPECT_THAT(getsockname(s_, reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + ASSERT_THAT(sendto(t_, buf, sizeof(buf), 0, + reinterpret_cast(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(s_, received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Disconnect s_. + struct sockaddr addr = {}; + addr.sa_family = AF_UNSPEC; + ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), SyscallSucceeds()); + // Connect s_ loopback:TestPort. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + } +} + TEST_P(UdpSocketTest, Connect) { ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); @@ -335,6 +374,112 @@ TEST_P(UdpSocketTest, Connect) { EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); } +TEST_P(UdpSocketTest, DisconnectAfterBind) { + ASSERT_THAT(bind(s_, addr_[1], addrlen_), SyscallSucceeds()); + // Connect the socket. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT( + connect(s_, reinterpret_cast(&addr), sizeof(addr.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, addr_[1], addrlen_), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { + struct sockaddr_storage baddr = {}; + socklen_t addrlen; + auto port = *Port(reinterpret_cast(addr_[1])); + if (addr_[0]->sa_family == AF_INET) { + auto addr_in = reinterpret_cast(&baddr); + addr_in->sin_family = AF_INET; + addr_in->sin_port = port; + inet_pton(AF_INET, "0.0.0.0", + reinterpret_cast(&addr_in->sin_addr.s_addr)); + } else { + auto addr_in = reinterpret_cast(&baddr); + addr_in->sin6_family = AF_INET6; + addr_in->sin6_port = port; + inet_pton(AF_INET6, + "::", reinterpret_cast(&addr_in->sin6_addr.s6_addr)); + addr_in->sin6_scope_id = 0; + } + ASSERT_THAT(bind(s_, reinterpret_cast(&baddr), addrlen_), + SyscallSucceeds()); + // Connect the socket. + ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT( + connect(s_, reinterpret_cast(&addr), sizeof(addr.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, &baddr, addrlen), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(s_, reinterpret_cast(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, Disconnect) { + for (int i = 0; i < 2; i++) { + // Try to connect again. + EXPECT_THAT(connect(s_, addr_[2], addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT(getpeername(s_, reinterpret_cast(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, addr_[2], addrlen_), 0); + + // Try to disconnect. + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT( + connect(s_, reinterpret_cast(&addr), sizeof(addr.ss_family)), + SyscallSucceeds()); + + peerlen = sizeof(peer); + EXPECT_THAT(getpeername(s_, reinterpret_cast(&peer), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT(getsockname(s_, reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&addr), 0); + } +} + +TEST_P(UdpSocketTest, ConnectBadAddress) { + struct sockaddr addr = {}; + addr.sa_family = addr_[0]->sa_family; + ASSERT_THAT(connect(s_, &addr, sizeof(addr.sa_family)), + SyscallFailsWithErrno(EINVAL)); +} + TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { ASSERT_THAT(connect(s_, addr_[0], addrlen_), SyscallSucceeds()); @@ -397,7 +542,7 @@ TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. - char received[512]; + char received[sizeof(buf)]; EXPECT_THAT(recv(s_, received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); @@ -419,7 +564,7 @@ TEST_P(UdpSocketTest, SendAndReceiveConnected) { SyscallSucceedsWithValue(sizeof(buf))); // Receive the data. - char received[512]; + char received[sizeof(buf)]; EXPECT_THAT(recv(s_, received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); @@ -462,7 +607,7 @@ TEST_P(UdpSocketTest, ReceiveBeforeConnect) { ASSERT_THAT(connect(s_, addr_[1], addrlen_), SyscallSucceeds()); // Receive the data. It works because it was sent before the connect. - char received[512]; + char received[sizeof(buf)]; EXPECT_THAT(recv(s_, received, sizeof(received), 0), SyscallSucceedsWithValue(sizeof(received))); EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); @@ -491,7 +636,7 @@ TEST_P(UdpSocketTest, ReceiveFrom) { SyscallSucceedsWithValue(sizeof(buf))); // Receive the data and sender address. - char received[512]; + char received[sizeof(buf)]; struct sockaddr_storage addr; socklen_t addrlen = sizeof(addr); EXPECT_THAT(recvfrom(s_, received, sizeof(received), 0, -- cgit v1.2.3