diff options
Diffstat (limited to 'pkg/seccomp/seccomp.go')
-rw-r--r-- | pkg/seccomp/seccomp.go | 217 |
1 files changed, 135 insertions, 82 deletions
diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index 7ee63140c..cd6b0b4bc 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -12,24 +12,28 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package seccomp provides basic seccomp filters. +// Package seccomp provides basic seccomp filters for x86_64 (little endian). package seccomp import ( "fmt" + "reflect" "sort" + "gvisor.googlesource.com/gvisor/pkg/abi" "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/bpf" "gvisor.googlesource.com/gvisor/pkg/log" + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/strace" ) const ( // violationLabel is added to the program to take action on a violation. violationLabel = "violation" - // allowLabel is added to the program to allow the syscall to take place. - allowLabel = "allow" + // skipOneInst is the offset to take for skipping one instruction. + skipOneInst = 1 ) // Install generates BPF code based on the set of syscalls provided. It only @@ -38,27 +42,19 @@ const ( // // (*) The current implementation only checks the syscall number. It does NOT // validate any of the arguments. -func Install(syscalls []uintptr, kill bool) error { - // Sort syscalls and remove duplicates to build the BST. - sort.Slice(syscalls, func(i, j int) bool { return syscalls[i] < syscalls[j] }) - syscalls = filterUnique(syscalls) - - log.Infof("Installing seccomp filters for %d syscalls (kill=%t)", len(syscalls), kill) - for _, s := range syscalls { - log.Infof("syscall filter: %v", s) - } - - instrs, err := buildProgram(syscalls, kill) - if err != nil { - return err - } +func Install(rules SyscallRules, kill bool) error { + log.Infof("Installing seccomp filters for %d syscalls (kill=%t)", len(rules), kill) + instrs, err := buildProgram(rules, kill) if log.IsLogging(log.Debug) { - programStr, err := bpf.DecodeProgram(instrs) - if err != nil { - programStr = fmt.Sprintf("Error: %v\n%s", err, programStr) + programStr, errDecode := bpf.DecodeProgram(instrs) + if errDecode != nil { + programStr = fmt.Sprintf("Error: %v\n%s", errDecode, programStr) } log.Debugf("Seccomp program dump:\n%s", programStr) } + if err != nil { + return err + } if err := seccomp(instrs); err != nil { return err @@ -68,11 +64,8 @@ func Install(syscalls []uintptr, kill bool) error { return nil } -// buildProgram builds a BPF program that whitelists all given syscalls. -// -// Precondition: syscalls must be sorted and unique. -func buildProgram(syscalls []uintptr, kill bool) ([]linux.BPFInstruction, error) { - const archOffset = 4 // offsetof(seccomp_data, arch) +// buildProgram builds a BPF program that whitelists all given syscall rules. +func buildProgram(rules SyscallRules, kill bool) ([]linux.BPFInstruction, error) { program := bpf.NewProgramBuilder() violationAction := uint32(linux.SECCOMP_RET_KILL) if !kill { @@ -83,10 +76,13 @@ func buildProgram(syscalls []uintptr, kill bool) ([]linux.BPFInstruction, error) // // A = seccomp_data.arch // if (A != AUDIT_ARCH_X86_64) goto violation - program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, archOffset) - program.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, linux.AUDIT_ARCH_X86_64, 0, violationLabel) + program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArch) + // violationLabel is at the bottom of the program. The size of program + // may exceeds 255 lines, which is the limit of a condition jump. + program.AddJump(bpf.Jmp|bpf.Jeq|bpf.K, linux.AUDIT_ARCH_X86_64, skipOneInst, 0) + program.AddDirectJumpLabel(violationLabel) - if err := buildIndex(syscalls, program); err != nil { + if err := buildIndex(rules, program); err != nil { return nil, err } @@ -96,41 +92,34 @@ func buildProgram(syscalls []uintptr, kill bool) ([]linux.BPFInstruction, error) } program.AddStmt(bpf.Ret|bpf.K, violationAction) - // allow: return SECCOMP_RET_ALLOW - if err := program.AddLabel(allowLabel); err != nil { - return nil, err - } - program.AddStmt(bpf.Ret|bpf.K, linux.SECCOMP_RET_ALLOW) - return program.Instructions() } -// filterUnique filters unique system calls. -// -// Precondition: syscalls must be sorted. -func filterUnique(syscalls []uintptr) []uintptr { - filtered := make([]uintptr, 0, len(syscalls)) - for i := 0; i < len(syscalls); i++ { - if len(filtered) > 0 && syscalls[i] == filtered[len(filtered)-1] { - // This call has already been inserted, skip. - continue - } - filtered = append(filtered, syscalls[i]) +// buildIndex builds a BST to quickly search through all syscalls that are whitelisted. +func buildIndex(rules SyscallRules, program *bpf.ProgramBuilder) error { + syscalls := []uintptr{} + for sysno, _ := range rules { + syscalls = append(syscalls, sysno) + } + + t, ok := strace.Lookup(abi.Linux, arch.AMD64) + if !ok { + panic("Can't find amd64 Linux syscall table") + } + + sort.Slice(syscalls, func(i, j int) bool { return syscalls[i] < syscalls[j] }) + for _, s := range syscalls { + log.Infof("syscall filter: %v (%v): %s", s, t.Name(s), rules[s]) } - return filtered -} -// buildIndex builds a BST to quickly search through all syscalls that are whitelisted. -// -// Precondition: syscalls must be sorted and unique. -func buildIndex(syscalls []uintptr, program *bpf.ProgramBuilder) error { root := createBST(syscalls) + root.root = true // Load syscall number into A and run through BST. // // A = seccomp_data.nr - program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, 0) - return root.buildBSTProgram(program, true) + program.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetNR) + return root.traverse(buildBSTProgram, program, rules) } // createBST converts sorted syscall slice into a balanced BST. @@ -147,64 +136,128 @@ func createBST(syscalls []uintptr) *node { return &parent } -// node represents a tree node. -type node struct { - value uintptr - left *node - right *node +func ruleViolationLabel(sysno uintptr, idx int) string { + return fmt.Sprintf("ruleViolation_%v_%v", sysno, idx) } -// label returns the label corresponding to this node. If node is nil (syscall not present), -// violationLabel is returned for convenience. -func (n *node) label() string { - if n == nil { - return violationLabel +func checkArgsLabel(sysno uintptr) string { + return fmt.Sprintf("checkArgs_%v", sysno) +} + +func addSyscallArgsCheck(p *bpf.ProgramBuilder, rules []Rule, sysno uintptr) error { + for ruleidx, rule := range rules { + labelled := false + for i, arg := range rule { + if arg != nil { + switch a := arg.(type) { + case AllowAny: + case AllowValue: + high, low := uint32(a>>32), uint32(a) + // assert arg_low == low + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgLow(i)) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, low, 0, ruleViolationLabel(sysno, ruleidx)) + // assert arg_high == high + p.AddStmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArgHigh(i)) + p.AddJumpFalseLabel(bpf.Jmp|bpf.Jeq|bpf.K, high, 0, ruleViolationLabel(sysno, ruleidx)) + labelled = true + + default: + return fmt.Errorf("unknown syscall rule type: %v", reflect.TypeOf(a)) + } + } + } + // Matched, allow the syscall. + p.AddStmt(bpf.Ret|bpf.K, linux.SECCOMP_RET_ALLOW) + // Label the end of the rule if necessary. + if labelled { + if err := p.AddLabel(ruleViolationLabel(sysno, ruleidx)); err != nil { + return err + } + } } - return fmt.Sprintf("index_%v", n.value) + // Not matched? + p.AddDirectJumpLabel(violationLabel) + return nil } // buildBSTProgram converts a binary tree started in 'root' into BPF code. The ouline of the code // is as follows: // // // SYS_PIPE(22), root -// (A == 22) ? goto allow : continue +// (A == 22) ? goto argument check : continue // (A > 22) ? goto index_35 : goto index_9 // // index_9: // SYS_MMAP(9), leaf -// (A == 9) ? goto allow : goto violation +// A == 9) ? goto argument check : violation // // index_35: // SYS_NANOSLEEP(35), single child -// (A == 35) ? goto allow : continue +// (A == 35) ? goto argument check : continue // (A > 35) ? goto index_50 : goto violation // // index_50: // SYS_LISTEN(50), leaf -// (A == 50) ? goto allow : goto violation +// (A == 50) ? goto argument check : goto violation // -func (n *node) buildBSTProgram(program *bpf.ProgramBuilder, root bool) error { - if n == nil { - return nil - } - +func buildBSTProgram(program *bpf.ProgramBuilder, rules SyscallRules, n *node) error { // Root node is never referenced by label, skip it. - if !root { + if !n.root { if err := program.AddLabel(n.label()); err != nil { return err } } - // Leaf nodes don't require extra check, they either allow or violate! + sysno := n.value + program.AddJumpTrueLabel(bpf.Jmp|bpf.Jeq|bpf.K, uint32(sysno), checkArgsLabel(sysno), 0) if n.left == nil && n.right == nil { - program.AddJumpLabels(bpf.Jmp|bpf.Jeq|bpf.K, uint32(n.value), allowLabel, violationLabel) + // Leaf nodes don't require extra check. + program.AddDirectJumpLabel(violationLabel) + } else { + // Non-leaf node. Check which turn to take otherwise. Using direct jumps + // in case that the offset may exceed the limit of a conditional jump (255) + // Note that 'violationLabel' is returned for nil children. + program.AddJump(bpf.Jmp|bpf.Jgt|bpf.K, uint32(sysno), 0, skipOneInst) + program.AddDirectJumpLabel(n.right.label()) + program.AddDirectJumpLabel(n.left.label()) + } + + if err := program.AddLabel(checkArgsLabel(sysno)); err != nil { + return err + } + // No rules, just allow it and save one jmp. + if len(rules[sysno]) == 0 { + program.AddStmt(bpf.Ret|bpf.K, linux.SECCOMP_RET_ALLOW) return nil } + return addSyscallArgsCheck(program, rules[sysno], sysno) +} - // Non-leaf node. Allows syscall if it matches, check which turn to take otherwise. Note - // that 'violationLabel' is returned for nil children. - program.AddJumpTrueLabel(bpf.Jmp|bpf.Jeq|bpf.K, uint32(n.value), allowLabel, 0) - program.AddJumpLabels(bpf.Jmp|bpf.Jgt|bpf.K, uint32(n.value), n.right.label(), n.left.label()) +// node represents a tree node. +type node struct { + value uintptr + left *node + right *node + root bool +} + +// label returns the label corresponding to this node. If node is nil (syscall not present), +// violationLabel is returned for convenience. +func (n *node) label() string { + if n == nil { + return violationLabel + } + return fmt.Sprintf("index_%v", n.value) +} - if err := n.left.buildBSTProgram(program, false); err != nil { +type traverseFunc func(*bpf.ProgramBuilder, SyscallRules, *node) error + +func (n *node) traverse(fn traverseFunc, p *bpf.ProgramBuilder, rules SyscallRules) error { + if n == nil { + return nil + } + if err := fn(p, rules, n); err != nil { + return err + } + if err := n.left.traverse(fn, p, rules); err != nil { return err } - return n.right.buildBSTProgram(program, false) + return n.right.traverse(fn, p, rules) } |