summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go45
-rw-r--r--pkg/sentry/socket/epsocket/stack.go4
-rw-r--r--pkg/sentry/socket/hostinet/stack.go8
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go254
-rw-r--r--pkg/sentry/socket/unix/unix.go2
5 files changed, 285 insertions, 28 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 4d8a5ac22..635042263 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -291,18 +291,22 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// GetAddress reads an sockaddr struct from the given address and converts it
-// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
-// addresses.
-func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and
+// AF_INET6 addresses.
+//
+// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
+//
+// AddressAndFamily returns an address, its family.
+func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
}
family := usermem.ByteOrder.Uint16(addr)
if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
}
// Get the rest of the fields based on the address family.
@@ -310,7 +314,7 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
case linux.AF_UNIX:
path := addr[2:]
if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
// Drop the terminating NUL (if one exists) and everything after
// it for filesystem (non-abstract) addresses.
@@ -321,12 +325,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
}
return tcpip.FullAddress{
Addr: tcpip.Address(path),
- }, nil
+ }, family, nil
case linux.AF_INET:
var a linux.SockAddrInet
if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
@@ -334,12 +338,12 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
Addr: bytesToIPAddress(a.Addr[:]),
Port: ntohs(a.Port),
}
- return out, nil
+ return out, family, nil
case linux.AF_INET6:
var a linux.SockAddrInet6
if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
@@ -350,13 +354,13 @@ func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syse
if isLinkLocal(out.Addr) {
out.NIC = tcpip.NICID(a.Scope_id)
}
- return out, nil
+ return out, family, nil
case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, nil
+ return tcpip.FullAddress{}, family, nil
default:
- return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
}
}
@@ -482,11 +486,18 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, false /* strict */)
+ addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */)
if err != nil {
return err
}
+ if family == linux.AF_UNSPEC {
+ err := s.Endpoint.Disconnect()
+ if err == tcpip.ErrNotSupported {
+ return syserr.ErrAddressFamilyNotSupported
+ }
+ return syserr.TranslateNetstackError(err)
+ }
// Always return right away in the non-blocking case.
if !blocking {
return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
@@ -515,7 +526,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr, true /* strict */)
+ addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */)
if err != nil {
return err
}
@@ -2023,7 +2034,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, err := GetAddress(s.family, to, true /* strict */)
+ addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */)
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go
index 8f1572bf4..1b11f4b2d 100644
--- a/pkg/sentry/socket/epsocket/stack.go
+++ b/pkg/sentry/socket/epsocket/stack.go
@@ -198,8 +198,8 @@ func (s *Stack) IPTables() (iptables.IPTables, error) {
// FillDefaultIPTables sets the stack's iptables to the default tables, which
// allow and do not modify all traffic.
-func (s *Stack) FillDefaultIPTables() error {
- return netfilter.FillDefaultIPTables(s.Stack)
+func (s *Stack) FillDefaultIPTables() {
+ netfilter.FillDefaultIPTables(s.Stack)
}
// Resume implements inet.Stack.Resume.
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 1902fe155..3a4fdec47 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -203,8 +203,14 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error)
inetRoute.DstAddr = attr.Value
case syscall.RTA_SRC:
inetRoute.SrcAddr = attr.Value
- case syscall.RTA_OIF:
+ case syscall.RTA_GATEWAY:
inetRoute.GatewayAddr = attr.Value
+ case syscall.RTA_OIF:
+ expected := int(binary.Size(inetRoute.OutputInterface))
+ if len(attr.Value) != expected {
+ return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected)
+ }
+ binary.Unmarshal(attr.Value, usermem.ByteOrder, &inetRoute.OutputInterface)
}
}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index efdb42903..9f87c32f1 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -17,7 +17,10 @@
package netfilter
import (
+ "fmt"
+
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
@@ -26,21 +29,258 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
+// errorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const errorTargetName = "ERROR"
+
+// metadata is opaque to netstack. It holds data that we need to translate
+// between Linux's and netstack's iptables representations.
+type metadata struct {
+ HookEntry [linux.NF_INET_NUMHOOKS]uint32
+ Underflow [linux.NF_INET_NUMHOOKS]uint32
+ NumEntries uint32
+ Size uint32
+}
+
// GetInfo returns information about iptables.
func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
- // TODO(b/129292233): Implement.
- return linux.IPTGetinfo{}, syserr.ErrInvalidArgument
+ // Read in the struct and table name.
+ var info linux.IPTGetinfo
+ if _, err := t.CopyIn(outPtr, &info); err != nil {
+ return linux.IPTGetinfo{}, syserr.FromError(err)
+ }
+
+ // Find the appropriate table.
+ table, err := findTable(ep, info.TableName())
+ if err != nil {
+ return linux.IPTGetinfo{}, err
+ }
+
+ // Get the hooks that apply to this table.
+ info.ValidHooks = table.ValidHooks()
+
+ // Grab the metadata struct, which is used to store information (e.g.
+ // the number of entries) that applies to the user's encoding of
+ // iptables, but not netstack's.
+ metadata := table.Metadata().(metadata)
+
+ // Set values from metadata.
+ info.HookEntry = metadata.HookEntry
+ info.Underflow = metadata.Underflow
+ info.NumEntries = metadata.NumEntries
+ info.Size = metadata.Size
+
+ return info, nil
}
// GetEntries returns netstack's iptables rules encoded for the iptables tool.
func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
- // TODO(b/129292233): Implement.
- return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
+ // Read in the struct and table name.
+ var userEntries linux.IPTGetEntries
+ if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ return linux.KernelIPTGetEntries{}, syserr.FromError(err)
+ }
+
+ // Find the appropriate table.
+ table, err := findTable(ep, userEntries.TableName())
+ if err != nil {
+ return linux.KernelIPTGetEntries{}, err
+ }
+
+ // Convert netstack's iptables rules to something that the iptables
+ // tool can understand.
+ entries, _, err := convertNetstackToBinary(userEntries.TableName(), table)
+ if err != nil {
+ return linux.KernelIPTGetEntries{}, err
+ }
+ if binary.Size(entries) > uintptr(outLen) {
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
+ }
+
+ return entries, nil
+}
+
+func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) {
+ ipt, err := ep.IPTables()
+ if err != nil {
+ return iptables.Table{}, syserr.FromError(err)
+ }
+ table, ok := ipt.Tables[tableName]
+ if !ok {
+ return iptables.Table{}, syserr.ErrInvalidArgument
+ }
+ return table, nil
}
// FillDefaultIPTables sets stack's IPTables to the default tables and
// populates them with metadata.
-func FillDefaultIPTables(stack *stack.Stack) error {
- stack.SetIPTables(iptables.DefaultTables())
- return nil
+func FillDefaultIPTables(stack *stack.Stack) {
+ ipt := iptables.DefaultTables()
+
+ // In order to fill in the metadata, we have to translate ipt from its
+ // netstack format to Linux's giant-binary-blob format.
+ for name, table := range ipt.Tables {
+ _, metadata, err := convertNetstackToBinary(name, table)
+ if err != nil {
+ panic(fmt.Errorf("Unable to set default IP tables: %v", err))
+ }
+ table.SetMetadata(metadata)
+ ipt.Tables[name] = table
+ }
+
+ stack.SetIPTables(ipt)
+}
+
+// convertNetstackToBinary converts the iptables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a bit, reading some offsets,
+// jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) {
+ // Return values.
+ var entries linux.KernelIPTGetEntries
+ var meta metadata
+
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(name) {
+ return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ }
+ copy(entries.Name[:], name)
+
+ // Deal with the built in chains first (INPUT, OUTPUT, etc.). Each of
+ // these chains ends with an unconditional policy entry.
+ for hook := iptables.Prerouting; hook < iptables.NumHooks; hook++ {
+ chain, ok := table.BuiltinChains[hook]
+ if !ok {
+ // This table doesn't support this hook.
+ continue
+ }
+
+ // Sanity check.
+ if len(chain.Rules) < 1 {
+ return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument
+ }
+
+ for ruleIdx, rule := range chain.Rules {
+ // If this is the first rule of a builtin chain, set
+ // the metadata hook entry point.
+ if ruleIdx == 0 {
+ meta.HookEntry[hook] = entries.Size
+ }
+
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+ entry.TargetOffset += uint16(len(serialized))
+ }
+
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+
+ // The underflow rule is the last rule in the chain,
+ // and is an unconditional rule (i.e. it matches any
+ // packet). This is enforced when saving iptables.
+ if ruleIdx == len(chain.Rules)-1 {
+ meta.Underflow[hook] = entries.Size
+ }
+
+ entries.Size += uint32(entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ meta.NumEntries++
+ }
+
+ }
+
+ // TODO(gvisor.dev/issue/170): Deal with the user chains here. Each of
+ // these starts with an error node holding the chain's name and ends
+ // with an unconditional return.
+
+ // Lastly, each table ends with an unconditional error target rule as
+ // its final entry.
+ errorEntry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+ var errorTarget linux.XTErrorTarget
+ errorTarget.Target.TargetSize = linux.SizeOfXTErrorTarget
+ copy(errorTarget.ErrorName[:], errorTargetName)
+ copy(errorTarget.Target.Name[:], errorTargetName)
+
+ // Serialize and add it to the list of entries.
+ errorTargetBuf := make([]byte, 0, linux.SizeOfXTErrorTarget)
+ serializedErrorTarget := binary.Marshal(errorTargetBuf, usermem.ByteOrder, errorTarget)
+ errorEntry.Elems = append(errorEntry.Elems, serializedErrorTarget...)
+ errorEntry.NextOffset += uint16(len(serializedErrorTarget))
+
+ entries.Size += uint32(errorEntry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, errorEntry)
+ meta.NumEntries++
+ meta.Size = entries.Size
+
+ return entries, meta, nil
+}
+
+func marshalMatcher(matcher iptables.Matcher) []byte {
+ switch matcher.(type) {
+ default:
+ // TODO(gvisor.dev/issue/170): We don't support any matchers yet, so
+ // any call to marshalMatcher will panic.
+ panic(fmt.Errorf("unknown matcher of type %T", matcher))
+ }
+}
+
+func marshalTarget(target iptables.Target) []byte {
+ switch target.(type) {
+ case iptables.UnconditionalAcceptTarget:
+ return marshalUnconditionalAcceptTarget()
+ default:
+ panic(fmt.Errorf("unknown target of type %T", target))
+ }
+}
+
+func marshalUnconditionalAcceptTarget() []byte {
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ Verdict: translateStandardVerdict(iptables.Accept),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+// translateStandardVerdict translates verdicts the same way as the iptables
+// tool.
+func translateStandardVerdict(verdict iptables.Verdict) int32 {
+ switch verdict {
+ case iptables.Accept:
+ return -linux.NF_ACCEPT - 1
+ case iptables.Drop:
+ return -linux.NF_DROP - 1
+ case iptables.Queue:
+ return -linux.NF_QUEUE - 1
+ case iptables.Return:
+ return linux.NF_RETURN
+ case iptables.Jump:
+ // TODO(gvisor.dev/issue/170): Support Jump.
+ panic("Jump isn't supported yet")
+ default:
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+ }
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 8a3f65236..0d0cb68df 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
return "", err
}