diff options
-rw-r--r-- | Makefile.in | 6 | ||||
-rw-r--r-- | aclocal.m4 | 480 | ||||
-rw-r--r-- | conf/cf-lex.l | 14 | ||||
-rw-r--r-- | configure.ac | 6 | ||||
-rw-r--r-- | filter/config.Y | 10 | ||||
-rw-r--r-- | filter/filter.c | 26 | ||||
-rw-r--r-- | filter/filter.h | 17 | ||||
-rw-r--r-- | lib/printf.c | 1 | ||||
-rw-r--r-- | lua/Makefile | 4 | ||||
-rw-r--r-- | lua/common.c | 323 | ||||
-rw-r--r-- | lua/filter.c | 187 | ||||
-rw-r--r-- | lua/lua.h | 18 | ||||
-rw-r--r-- | nest/config.Y | 3 | ||||
-rw-r--r-- | nest/route.h | 2 | ||||
-rw-r--r-- | sysdep/linux/Makefile | 2 | ||||
-rw-r--r-- | sysdep/linux/wireguard.c | 1766 | ||||
-rw-r--r-- | sysdep/linux/wireguard.h | 103 |
17 files changed, 2957 insertions, 11 deletions
diff --git a/Makefile.in b/Makefile.in index 0ecd6811..b858e898 100644 --- a/Makefile.in +++ b/Makefile.in @@ -7,12 +7,12 @@ MAKEFLAGS += -r # Variable definitions CPPFLAGS=-I$(objdir) -I$(srcdir) @CPPFLAGS@ -CFLAGS=$(CPPFLAGS) @CFLAGS@ +CFLAGS=$(CPPFLAGS) @CFLAGS@ @LUA_INCLUDE@ LDFLAGS=@LDFLAGS@ M4FLAGS=@M4FLAGS@ BISONFLAGS=@BISONFLAGS@ LIBS=@LIBS@ -DAEMON_LIBS=@DAEMON_LIBS@ +DAEMON_LIBS=@DAEMON_LIBS@ @LUA_LIB@ CLIENT_LIBS=@CLIENT_LIBS@ CC=@CC@ M4=@M4@ @@ -74,7 +74,7 @@ cli: $(client) $(daemon): LIBS += $(DAEMON_LIBS) # Include directories -dirs := client conf doc filter lib nest test $(addprefix proto/,$(protocols)) @sysdep_dirs@ +dirs := client conf doc filter lib lua nest test $(addprefix proto/,$(protocols)) @sysdep_dirs@ conf-y-targets := $(addprefix $(objdir)/conf/,cf-parse.y keywords.h commands.h) cf-local = $(conf-y-targets): $(s)config.Y @@ -196,3 +196,483 @@ AC_DEFUN([BIRD_CHECK_BISON_VERSION], ;; esac ]) + +dnl ========================================================================= +dnl AX_PROG_LUA([MINIMUM-VERSION], [TOO-BIG-VERSION], +dnl [ACTION-IF-FOUND], [ACTION-IF-NOT-FOUND]) +dnl ========================================================================= +AC_DEFUN([AX_PROG_LUA], +[ + dnl Check for required tools. + AC_REQUIRE([AC_PROG_GREP]) + AC_REQUIRE([AC_PROG_SED]) + + dnl Make LUA a precious variable. + AC_ARG_VAR([LUA], [The Lua interpreter, e.g. /usr/bin/lua5.1]) + + dnl Find a Lua interpreter. + m4_define_default([_AX_LUA_INTERPRETER_LIST], + [lua lua5.3 lua53 lua5.2 lua52 lua5.1 lua51 lua50]) + + m4_if([$1], [], + [ dnl No version check is needed. Find any Lua interpreter. + AS_IF([test "x$LUA" = 'x'], + [AC_PATH_PROGS([LUA], [_AX_LUA_INTERPRETER_LIST], [:])]) + ax_display_LUA='lua' + + AS_IF([test "x$LUA" != 'x:'], + [ dnl At least check if this is a Lua interpreter. + AC_MSG_CHECKING([if $LUA is a Lua interpreter]) + _AX_LUA_CHK_IS_INTRP([$LUA], + [AC_MSG_RESULT([yes])], + [ AC_MSG_RESULT([no]) + AC_MSG_ERROR([not a Lua interpreter]) + ]) + ]) + ], + [ dnl A version check is needed. + AS_IF([test "x$LUA" != 'x'], + [ dnl Check if this is a Lua interpreter. + AC_MSG_CHECKING([if $LUA is a Lua interpreter]) + _AX_LUA_CHK_IS_INTRP([$LUA], + [AC_MSG_RESULT([yes])], + [ AC_MSG_RESULT([no]) + AC_MSG_ERROR([not a Lua interpreter]) + ]) + dnl Check the version. + m4_if([$2], [], + [_ax_check_text="whether $LUA version >= $1"], + [_ax_check_text="whether $LUA version >= $1, < $2"]) + AC_MSG_CHECKING([$_ax_check_text]) + _AX_LUA_CHK_VER([$LUA], [$1], [$2], + [AC_MSG_RESULT([yes])], + [ AC_MSG_RESULT([no]) + AC_MSG_ERROR([version is out of range for specified LUA])]) + ax_display_LUA=$LUA + ], + [ dnl Try each interpreter until we find one that satisfies VERSION. + m4_if([$2], [], + [_ax_check_text="for a Lua interpreter with version >= $1"], + [_ax_check_text="for a Lua interpreter with version >= $1, < $2"]) + AC_CACHE_CHECK([$_ax_check_text], + [ax_cv_pathless_LUA], + [ for ax_cv_pathless_LUA in _AX_LUA_INTERPRETER_LIST none; do + test "x$ax_cv_pathless_LUA" = 'xnone' && break + _AX_LUA_CHK_IS_INTRP([$ax_cv_pathless_LUA], [], [continue]) + _AX_LUA_CHK_VER([$ax_cv_pathless_LUA], [$1], [$2], [break]) + done + ]) + dnl Set $LUA to the absolute path of $ax_cv_pathless_LUA. + AS_IF([test "x$ax_cv_pathless_LUA" = 'xnone'], + [LUA=':'], + [AC_PATH_PROG([LUA], [$ax_cv_pathless_LUA])]) + ax_display_LUA=$ax_cv_pathless_LUA + ]) + ]) + + AS_IF([test "x$LUA" = 'x:'], + [ dnl Run any user-specified action, or abort. + m4_default([$4], [AC_MSG_ERROR([cannot find suitable Lua interpreter])]) + ], + [ dnl Query Lua for its version number. + AC_CACHE_CHECK([for $ax_display_LUA version], + [ax_cv_lua_version], + [ dnl Get the interpreter version in X.Y format. This should work for + dnl interpreters version 5.0 and beyond. + ax_cv_lua_version=[`$LUA -e ' + -- return a version number in X.Y format + local _, _, ver = string.find(_VERSION, "^Lua (%d+%.%d+)") + print(ver)'`] + ]) + AS_IF([test "x$ax_cv_lua_version" = 'x'], + [AC_MSG_ERROR([invalid Lua version number])]) + AC_SUBST([LUA_VERSION], [$ax_cv_lua_version]) + AC_SUBST([LUA_SHORT_VERSION], [`echo "$LUA_VERSION" | $SED 's|\.||'`]) + + dnl The following check is not supported: + dnl At times (like when building shared libraries) you may want to know + dnl which OS platform Lua thinks this is. + AC_CACHE_CHECK([for $ax_display_LUA platform], + [ax_cv_lua_platform], + [ax_cv_lua_platform=[`$LUA -e 'print("unknown")'`]]) + AC_SUBST([LUA_PLATFORM], [$ax_cv_lua_platform]) + + dnl Use the values of $prefix and $exec_prefix for the corresponding + dnl values of LUA_PREFIX and LUA_EXEC_PREFIX. These are made distinct + dnl variables so they can be overridden if need be. However, the general + dnl consensus is that you shouldn't need this ability. + AC_SUBST([LUA_PREFIX], ['${prefix}']) + AC_SUBST([LUA_EXEC_PREFIX], ['${exec_prefix}']) + + dnl Lua provides no way to query the script directory, and instead + dnl provides LUA_PATH. However, we should be able to make a safe educated + dnl guess. If the built-in search path contains a directory which is + dnl prefixed by $prefix, then we can store scripts there. The first + dnl matching path will be used. + AC_CACHE_CHECK([for $ax_display_LUA script directory], + [ax_cv_lua_luadir], + [ AS_IF([test "x$prefix" = 'xNONE'], + [ax_lua_prefix=$ac_default_prefix], + [ax_lua_prefix=$prefix]) + + dnl Initialize to the default path. + ax_cv_lua_luadir="$LUA_PREFIX/share/lua/$LUA_VERSION" + + dnl Try to find a path with the prefix. + _AX_LUA_FND_PRFX_PTH([$LUA], [$ax_lua_prefix], [script]) + AS_IF([test "x$ax_lua_prefixed_path" != 'x'], + [ dnl Fix the prefix. + _ax_strip_prefix=`echo "$ax_lua_prefix" | $SED 's|.|.|g'` + ax_cv_lua_luadir=`echo "$ax_lua_prefixed_path" | \ + $SED "s|^$_ax_strip_prefix|$LUA_PREFIX|"` + ]) + ]) + AC_SUBST([luadir], [$ax_cv_lua_luadir]) + AC_SUBST([pkgluadir], [\${luadir}/$PACKAGE]) + + dnl Lua provides no way to query the module directory, and instead + dnl provides LUA_PATH. However, we should be able to make a safe educated + dnl guess. If the built-in search path contains a directory which is + dnl prefixed by $exec_prefix, then we can store modules there. The first + dnl matching path will be used. + AC_CACHE_CHECK([for $ax_display_LUA module directory], + [ax_cv_lua_luaexecdir], + [ AS_IF([test "x$exec_prefix" = 'xNONE'], + [ax_lua_exec_prefix=$ax_lua_prefix], + [ax_lua_exec_prefix=$exec_prefix]) + + dnl Initialize to the default path. + ax_cv_lua_luaexecdir="$LUA_EXEC_PREFIX/lib/lua/$LUA_VERSION" + + dnl Try to find a path with the prefix. + _AX_LUA_FND_PRFX_PTH([$LUA], + [$ax_lua_exec_prefix], [module]) + AS_IF([test "x$ax_lua_prefixed_path" != 'x'], + [ dnl Fix the prefix. + _ax_strip_prefix=`echo "$ax_lua_exec_prefix" | $SED 's|.|.|g'` + ax_cv_lua_luaexecdir=`echo "$ax_lua_prefixed_path" | \ + $SED "s|^$_ax_strip_prefix|$LUA_EXEC_PREFIX|"` + ]) + ]) + AC_SUBST([luaexecdir], [$ax_cv_lua_luaexecdir]) + AC_SUBST([pkgluaexecdir], [\${luaexecdir}/$PACKAGE]) + + dnl Run any user specified action. + $3 + ]) +]) + +dnl AX_WITH_LUA is now the same thing as AX_PROG_LUA. +AC_DEFUN([AX_WITH_LUA], +[ + AC_MSG_WARN([[$0 is deprecated, please use AX_PROG_LUA instead]]) + AX_PROG_LUA +]) + + +dnl ========================================================================= +dnl _AX_LUA_CHK_IS_INTRP(PROG, [ACTION-IF-TRUE], [ACTION-IF-FALSE]) +dnl ========================================================================= +AC_DEFUN([_AX_LUA_CHK_IS_INTRP], +[ + dnl A minimal Lua factorial to prove this is an interpreter. This should work + dnl for Lua interpreters version 5.0 and beyond. + _ax_lua_factorial=[`$1 2>/dev/null -e ' + -- a simple factorial + function fact (n) + if n == 0 then + return 1 + else + return n * fact(n-1) + end + end + print("fact(5) is " .. fact(5))'`] + AS_IF([test "$_ax_lua_factorial" = 'fact(5) is 120'], + [$2], [$3]) +]) + + +dnl ========================================================================= +dnl _AX_LUA_CHK_VER(PROG, MINIMUM-VERSION, [TOO-BIG-VERSION], +dnl [ACTION-IF-TRUE], [ACTION-IF-FALSE]) +dnl ========================================================================= +AC_DEFUN([_AX_LUA_CHK_VER], +[ + dnl Check that the Lua version is within the bounds. Only the major and minor + dnl version numbers are considered. This should work for Lua interpreters + dnl version 5.0 and beyond. + _ax_lua_good_version=[`$1 -e ' + -- a script to compare versions + function verstr2num(verstr) + local _, _, majorver, minorver = string.find(verstr, "^(%d+)%.(%d+)") + if majorver and minorver then + return tonumber(majorver) * 100 + tonumber(minorver) + end + end + local minver = verstr2num("$2") + local _, _, trimver = string.find(_VERSION, "^Lua (.*)") + local ver = verstr2num(trimver) + local maxver = verstr2num("$3") or 1e9 + if minver <= ver and ver < maxver then + print("yes") + else + print("no") + end'`] + AS_IF([test "x$_ax_lua_good_version" = "xyes"], + [$4], [$5]) +]) + + +dnl ========================================================================= +dnl _AX_LUA_FND_PRFX_PTH(PROG, PREFIX, SCRIPT-OR-MODULE-DIR) +dnl ========================================================================= +AC_DEFUN([_AX_LUA_FND_PRFX_PTH], +[ + dnl Get the script or module directory by querying the Lua interpreter, + dnl filtering on the given prefix, and selecting the shallowest path. If no + dnl path is found matching the prefix, the result will be an empty string. + dnl The third argument determines the type of search, it can be 'script' or + dnl 'module'. Supplying 'script' will perform the search with package.path + dnl and LUA_PATH, and supplying 'module' will search with package.cpath and + dnl LUA_CPATH. This is done for compatibility with Lua 5.0. + + ax_lua_prefixed_path=[`$1 -e ' + -- get the path based on search type + local searchtype = "$3" + local paths = "" + if searchtype == "script" then + paths = (package and package.path) or LUA_PATH + elseif searchtype == "module" then + paths = (package and package.cpath) or LUA_CPATH + end + -- search for the prefix + local prefix = "'$2'" + local minpath = "" + local mindepth = 1e9 + string.gsub(paths, "(@<:@^;@:>@+)", + function (path) + path = string.gsub(path, "%?.*$", "") + path = string.gsub(path, "/@<:@^/@:>@*$", "") + if string.find(path, prefix) then + local depth = string.len(string.gsub(path, "@<:@^/@:>@", "")) + if depth < mindepth then + minpath = path + mindepth = depth + end + end + end) + print(minpath)'`] +]) + + +dnl ========================================================================= +dnl AX_LUA_HEADERS([ACTION-IF-FOUND], [ACTION-IF-NOT-FOUND]) +dnl ========================================================================= +AC_DEFUN([AX_LUA_HEADERS], +[ + dnl Check for LUA_VERSION. + AC_MSG_CHECKING([if LUA_VERSION is defined]) + AS_IF([test "x$LUA_VERSION" != 'x'], + [AC_MSG_RESULT([yes])], + [ AC_MSG_RESULT([no]) + AC_MSG_ERROR([cannot check Lua headers without knowing LUA_VERSION]) + ]) + + dnl Make LUA_INCLUDE a precious variable. + AC_ARG_VAR([LUA_INCLUDE], [The Lua includes, e.g. -I/usr/include/lua5.1]) + + dnl Some default directories to search. + LUA_SHORT_VERSION=`echo "$LUA_VERSION" | $SED 's|\.||'` + m4_define_default([_AX_LUA_INCLUDE_LIST], + [ /usr/include/lua$LUA_VERSION \ + /usr/include/lua-$LUA_VERSION \ + /usr/include/lua/$LUA_VERSION \ + /usr/include/lua$LUA_SHORT_VERSION \ + /usr/local/include/lua$LUA_VERSION \ + /usr/local/include/lua-$LUA_VERSION \ + /usr/local/include/lua/$LUA_VERSION \ + /usr/local/include/lua$LUA_SHORT_VERSION \ + ]) + + dnl Try to find the headers. + _ax_lua_saved_cppflags=$CPPFLAGS + CPPFLAGS="$CPPFLAGS $LUA_INCLUDE" + AC_CHECK_HEADERS([lua.h lualib.h lauxlib.h luaconf.h]) + CPPFLAGS=$_ax_lua_saved_cppflags + + dnl Try some other directories if LUA_INCLUDE was not set. + AS_IF([test "x$LUA_INCLUDE" = 'x' && + test "x$ac_cv_header_lua_h" != 'xyes'], + [ dnl Try some common include paths. + for _ax_include_path in _AX_LUA_INCLUDE_LIST; do + test ! -d "$_ax_include_path" && continue + + AC_MSG_CHECKING([for Lua headers in]) + AC_MSG_RESULT([$_ax_include_path]) + + AS_UNSET([ac_cv_header_lua_h]) + AS_UNSET([ac_cv_header_lualib_h]) + AS_UNSET([ac_cv_header_lauxlib_h]) + AS_UNSET([ac_cv_header_luaconf_h]) + + _ax_lua_saved_cppflags=$CPPFLAGS + CPPFLAGS="$CPPFLAGS -I$_ax_include_path" + AC_CHECK_HEADERS([lua.h lualib.h lauxlib.h luaconf.h]) + CPPFLAGS=$_ax_lua_saved_cppflags + + AS_IF([test "x$ac_cv_header_lua_h" = 'xyes'], + [ LUA_INCLUDE="-I$_ax_include_path" + break + ]) + done + ]) + + AS_IF([test "x$ac_cv_header_lua_h" = 'xyes'], + [ dnl Make a program to print LUA_VERSION defined in the header. + dnl TODO It would be really nice if we could do this without compiling a + dnl program, then it would work when cross compiling. But I'm not sure how + dnl to do this reliably. For now, assume versions match when cross compiling. + + AS_IF([test "x$cross_compiling" != 'xyes'], + [ AC_CACHE_CHECK([for Lua header version], + [ax_cv_lua_header_version], + [ _ax_lua_saved_cppflags=$CPPFLAGS + CPPFLAGS="$CPPFLAGS $LUA_INCLUDE" + AC_RUN_IFELSE( + [ AC_LANG_SOURCE([[ +#include <lua.h> +#include <stdlib.h> +#include <stdio.h> +int main(int argc, char ** argv) +{ + if(argc > 1) printf("%s", LUA_VERSION); + exit(EXIT_SUCCESS); +} +]]) + ], + [ ax_cv_lua_header_version=`./conftest$EXEEXT p | \ + $SED -n "s|^Lua \(@<:@0-9@:>@\{1,\}\.@<:@0-9@:>@\{1,\}\).\{0,\}|\1|p"` + ], + [ax_cv_lua_header_version='unknown']) + CPPFLAGS=$_ax_lua_saved_cppflags + ]) + + dnl Compare this to the previously found LUA_VERSION. + AC_MSG_CHECKING([if Lua header version matches $LUA_VERSION]) + AS_IF([test "x$ax_cv_lua_header_version" = "x$LUA_VERSION"], + [ AC_MSG_RESULT([yes]) + ax_header_version_match='yes' + ], + [ AC_MSG_RESULT([no]) + ax_header_version_match='no' + ]) + ], + [ AC_MSG_WARN([cross compiling so assuming header version number matches]) + ax_header_version_match='yes' + ]) + ]) + + dnl Was LUA_INCLUDE specified? + AS_IF([test "x$ax_header_version_match" != 'xyes' && + test "x$LUA_INCLUDE" != 'x'], + [AC_MSG_ERROR([cannot find headers for specified LUA_INCLUDE])]) + + dnl Test the final result and run user code. + AS_IF([test "x$ax_header_version_match" = 'xyes'], [$1], + [m4_default([$2], [AC_MSG_ERROR([cannot find Lua includes])])]) +]) + +dnl AX_LUA_HEADERS_VERSION no longer exists, use AX_LUA_HEADERS. +AC_DEFUN([AX_LUA_HEADERS_VERSION], +[ + AC_MSG_WARN([[$0 is deprecated, please use AX_LUA_HEADERS instead]]) +]) + + +dnl ========================================================================= +dnl AX_LUA_LIBS([ACTION-IF-FOUND], [ACTION-IF-NOT-FOUND]) +dnl ========================================================================= +AC_DEFUN([AX_LUA_LIBS], +[ + dnl TODO Should this macro also check various -L flags? + + dnl Check for LUA_VERSION. + AC_MSG_CHECKING([if LUA_VERSION is defined]) + AS_IF([test "x$LUA_VERSION" != 'x'], + [AC_MSG_RESULT([yes])], + [ AC_MSG_RESULT([no]) + AC_MSG_ERROR([cannot check Lua libs without knowing LUA_VERSION]) + ]) + + dnl Make LUA_LIB a precious variable. + AC_ARG_VAR([LUA_LIB], [The Lua library, e.g. -llua5.1]) + + AS_IF([test "x$LUA_LIB" != 'x'], + [ dnl Check that LUA_LIBS works. + _ax_lua_saved_libs=$LIBS + LIBS="$LIBS $LUA_LIB" + AC_SEARCH_LIBS([lua_load], [], + [_ax_found_lua_libs='yes'], + [_ax_found_lua_libs='no']) + LIBS=$_ax_lua_saved_libs + + dnl Check the result. + AS_IF([test "x$_ax_found_lua_libs" != 'xyes'], + [AC_MSG_ERROR([cannot find libs for specified LUA_LIB])]) + ], + [ dnl First search for extra libs. + _ax_lua_extra_libs='' + + _ax_lua_saved_libs=$LIBS + LIBS="$LIBS $LUA_LIB" + AC_SEARCH_LIBS([exp], [m]) + AC_SEARCH_LIBS([dlopen], [dl]) + LIBS=$_ax_lua_saved_libs + + AS_IF([test "x$ac_cv_search_exp" != 'xno' && + test "x$ac_cv_search_exp" != 'xnone required'], + [_ax_lua_extra_libs="$_ax_lua_extra_libs $ac_cv_search_exp"]) + + AS_IF([test "x$ac_cv_search_dlopen" != 'xno' && + test "x$ac_cv_search_dlopen" != 'xnone required'], + [_ax_lua_extra_libs="$_ax_lua_extra_libs $ac_cv_search_dlopen"]) + + dnl Try to find the Lua libs. + _ax_lua_saved_libs=$LIBS + LIBS="$LIBS $LUA_LIB" + AC_SEARCH_LIBS([lua_load], + [ lua$LUA_VERSION \ + lua$LUA_SHORT_VERSION \ + lua-$LUA_VERSION \ + lua-$LUA_SHORT_VERSION \ + lua \ + ], + [_ax_found_lua_libs='yes'], + [_ax_found_lua_libs='no'], + [$_ax_lua_extra_libs]) + LIBS=$_ax_lua_saved_libs + + AS_IF([test "x$ac_cv_search_lua_load" != 'xno' && + test "x$ac_cv_search_lua_load" != 'xnone required'], + [LUA_LIB="$ac_cv_search_lua_load $_ax_lua_extra_libs"]) + ]) + + dnl Test the result and run user code. + AS_IF([test "x$_ax_found_lua_libs" = 'xyes'], [$1], + [m4_default([$2], [AC_MSG_ERROR([cannot find Lua libs])])]) +]) + + +dnl ========================================================================= +dnl AX_LUA_READLINE([ACTION-IF-FOUND], [ACTION-IF-NOT-FOUND]) +dnl ========================================================================= +AC_DEFUN([AX_LUA_READLINE], +[ + AX_LIB_READLINE + AS_IF([test "x$ac_cv_header_readline_readline_h" != 'x' && + test "x$ac_cv_header_readline_history_h" != 'x'], + [ LUA_LIBS_CFLAGS="-DLUA_USE_READLINE $LUA_LIBS_CFLAGS" + $1 + ], + [$2]) +]) diff --git a/conf/cf-lex.l b/conf/cf-lex.l index 9bbb3660..920d258e 100644 --- a/conf/cf-lex.l +++ b/conf/cf-lex.l @@ -112,7 +112,7 @@ static int check_eof(void); %option nounput %option noreject -%x COMMENT CCOMM CLI +%x COMMENT CCOMM CLI MULTISTRING ALPHA [a-zA-Z_] DIGIT [0-9] @@ -308,6 +308,18 @@ else: { return TEXT; } +\"\"\" BEGIN(MULTISTRING); +<MULTISTRING>\"\"\" { + yytext[yyleng-3] = 0; + cf_lval.t = cfg_strdup(yytext); + yytext[yyleng-3] = '\"'; + BEGIN(INITIAL); + return TEXT; +} +<MULTISTRING><<EOF>> cf_error("Unterminated multi-line string"); +<MULTISTRING>\n ifs->lino++; ifs->chno = 0; yymore(); +<MULTISTRING>. yymore(); + ["][^"\n]*\n cf_error("Unterminated string"); <INITIAL,COMMENT><<EOF>> { if (check_eof()) return END; } diff --git a/configure.ac b/configure.ac index d219b274..2b48c226 100644 --- a/configure.ac +++ b/configure.ac @@ -341,6 +341,12 @@ elif test "$bird_cv_lib_log" != yes ; then LIBS="$LIBS $bird_cv_lib_log" fi +AX_PROG_LUA(5.3) +AX_LUA_HEADERS +AX_LUA_LIBS +AC_SUBST(LUA_INCLUDE) +AC_SUBST(LUA_LIBS) + if test "$enable_debug" = yes ; then AC_DEFINE([DEBUGGING], [1], [Define to 1 if debugging is enabled]) LDFLAGS="$LDFLAGS -rdynamic" diff --git a/filter/config.Y b/filter/config.Y index c1e74531..6004c961 100644 --- a/filter/config.Y +++ b/filter/config.Y @@ -426,7 +426,7 @@ CF_KEYWORDS(FUNCTION, PRINT, PRINTN, UNSET, RETURN, %type <x> term block cmds cmds_int cmd function_body constant constructor print_one print_list var_list var_listn function_call symbol bgp_path_expr %type <fda> dynamic_attr %type <fsa> static_attr -%type <f> filter filter_body where_filter +%type <f> filter filter_body where_filter lua_call %type <i> type break_command ec_kind %type <i32> cnum %type <e> pair_item ec_item lc_item set_item switch_item set_items switch_items switch_body @@ -453,6 +453,7 @@ filter_def: conf: filter_eval ; filter_eval: EVAL term { f_eval_int($2); } + | EVAL LUA constant { lua_eval($3); } ; conf: custom_attr ; @@ -543,6 +544,7 @@ declsn: one_decl { $$ = $1; } filter_body: function_body { struct filter *f = cfg_alloc(sizeof(struct filter)); + f->type = FILTER_INTERNAL; f->name = NULL; f->root = $1; $$ = f; @@ -561,6 +563,7 @@ where_filter: WHERE term { /* Construct 'IF term THEN ACCEPT; REJECT;' */ struct filter *f = cfg_alloc(sizeof(struct filter)); + f->type = FILTER_INTERNAL; struct f_inst *i, *acc, *rej; acc = f_new_inst(FI_PRINT_AND_DIE); /* ACCEPT */ acc->a1.p = NULL; @@ -1064,6 +1067,11 @@ cmd: | BT_ASSERT '(' get_cf_position term get_cf_position ')' ';' { $$ = assert_done($4, $3 + 1, $5 - 1); } ; +lua_call: + LUA constant { + $$ = lua_new_filter($2); + } + get_cf_position: { $$ = cf_text; diff --git a/filter/filter.c b/filter/filter.c index 37cf16a3..d047b814 100644 --- a/filter/filter.c +++ b/filter/filter.c @@ -1773,7 +1773,18 @@ f_run(struct filter *filter, struct rte **rte, struct linpool *tmp_pool, int fla LOG_BUFFER_INIT(f_buf); - struct f_val res = interpret(filter->root); + struct f_val res; + switch (filter->type) { + case FILTER_INTERNAL: + res = interpret(filter->root); + break; + case FILTER_LUA: + ACCESS_EATTRS; + res = lua_interpret(filter->lua_chunk, rte, &f_old_rta, f_eattrs, tmp_pool, flags); + break; + default: + bug("filter type not set"); + } if (f_old_rta) { /* @@ -1867,5 +1878,16 @@ filter_same(struct filter *new, struct filter *old) if (old == FILTER_ACCEPT || old == FILTER_REJECT || new == FILTER_ACCEPT || new == FILTER_REJECT) return 0; - return i_same(new->root, old->root); + if (new->type != old->type) + return 0; + switch(new->type) { + case FILTER_INTERNAL: + return i_same(new->root, old->root); + break; + case FILTER_LUA: + return lua_filter_same(new->lua_chunk, old->lua_chunk); + break; + default: + bug("Unknown filter type"); + } } diff --git a/filter/filter.h b/filter/filter.h index a8c33287..2b1176dc 100644 --- a/filter/filter.h +++ b/filter/filter.h @@ -145,8 +145,15 @@ struct f_static_attr { }; struct filter { + enum filter_type { + FILTER_INTERNAL = 1, + FILTER_LUA = 2, + } type; char *name; - struct f_inst *root; + union { + struct f_inst *root; + struct lua_filter_chunk *lua_chunk; + }; }; struct f_inst *f_new_inst(enum f_instruction_code fi_code); @@ -284,6 +291,8 @@ struct f_trie }; #define NEW_F_VAL struct f_val * val; val = cfg_alloc(sizeof(struct f_val)); +#define F_VAL(_type, where, value) ((struct f_val) { .type = (_type), .val.where = (value) }) +#define F_VAL_VOID ((struct f_val) { .type = T_VOID }) #define FF_SILENT 2 /* Silent filter execution */ @@ -307,4 +316,10 @@ struct f_bt_test_suite { /* Hook for call bt_assert() function in configuration */ extern void (*bt_assert_hook)(int result, struct f_inst *assert); +/* Lua */ +struct filter * lua_new_filter(struct f_inst *inst); +struct f_val lua_interpret(struct lua_filter_chunk *chunk, struct rte **e, struct rta **a, struct ea_list **ea, struct linpool *lp, int flags); +int lua_filter_same(struct lua_filter_chunk *new, struct lua_filter_chunk *old); +uint lua_eval(struct f_inst *inst); + #endif diff --git a/lib/printf.c b/lib/printf.c index c2065d9a..130bc61c 100644 --- a/lib/printf.c +++ b/lib/printf.c @@ -8,6 +8,7 @@ */ #include "nest/bird.h" +#include "conf/conf.h" #include "string.h" #include <errno.h> diff --git a/lua/Makefile b/lua/Makefile new file mode 100644 index 00000000..b74309de --- /dev/null +++ b/lua/Makefile @@ -0,0 +1,4 @@ +src := common.c filter.c +obj := $(src-o-files) +$(all-daemon) +#$(cf-local) diff --git a/lua/common.c b/lua/common.c new file mode 100644 index 00000000..9f617775 --- /dev/null +++ b/lua/common.c @@ -0,0 +1,323 @@ +#include "nest/bird.h" +#include "nest/protocol.h" +#include "nest/route.h" +#include "conf/conf.h" +#include "filter/filter.h" +#include "lua.h" + +#include <lua.h> +#include <lualib.h> +#include <lauxlib.h> + +static linpool *lua_lp; + +static int luaB_err(lua_State *L) { + int n = lua_gettop(L); + if (n != 1) + log(L_WARN "bird.err() accepts exactly 1 argument"); + + if (n < 1) + return 0; + + log(L_ERR "%s", lua_tostring(L, 1)); + return 0; +} + +static int luaB_warn(lua_State *L) { + int n = lua_gettop(L); + if (n != 1) + log(L_WARN "bird.warn() accepts exactly 1 argument"); + + if (n < 1) + return 0; + + log(L_WARN "%s", lua_tostring(L, 1)); + return 0; +} + +static int luaB_info(lua_State *L) { + int n = lua_gettop(L); + if (n != 1) + log(L_WARN "bird.info() accepts exactly 1 argument"); + + if (n < 1) + return 0; + + log(L_INFO "%s", lua_tostring(L, 1)); + return 0; +} + +static int luaB_trace(lua_State *L) { + int n = lua_gettop(L); + if (n != 1) + log(L_WARN "bird.trace() accepts exactly 1 argument"); + + if (n < 1) + return 0; + + log(L_TRACE "%s", lua_tostring(L, 1)); + return 0; +} + +#define lua_sett(L, idx, val, what) do { \ + lua_pushstring(L, idx); \ + lua_push##what(L, val); \ + lua_settable(L, -3); \ +} while (0) + +#define lua_settableaddr(L, idx, val) lua_sett(L, idx, val, addr) +#define lua_settablecfunction(L, idx, val) lua_sett(L, idx, val, cfunction) +#define lua_settableinteger(L, idx, val) lua_sett(L, idx, val, integer) +#define lua_settableip4(L, idx, val) lua_sett(L, idx, val, ip4) +#define lua_settablelightuserdata(L, idx, val) lua_sett(L, idx, val, lightuserdata) +#define lua_settableeattr(L, idx, val) lua_sett(L, idx, val, eattr) +#define lua_settablevalue(L, idx, val) lua_sett(L, idx, val, value) + +#define lua_setglobalcfunction(L, n, val) do { \ + lua_pushcfunction(L, val); \ + lua_setglobal(L, n); \ +} while (0) + +static int luaB_generic_concat(lua_State *L) { + int n = lua_gettop(L); + if (n != 2) { + log(L_WARN "__concat needs exactly 2 arguments"); + return 0; + } + + const char *a, *b; + size_t la, lb; + + a = luaL_tolstring(L, 1, &la); + b = luaL_tolstring(L, 2, &lb); + + if (a == NULL) { + a = ""; + la = 0; + } + + if (b == NULL) { + b = ""; + lb = 0; + } + + char *c = alloca(la + lb + 1); + memcpy(c, a, la); + memcpy(c + la, b, lb); + c[la + lb] = 0; + + lua_pushlstring(L, c, la + lb); + + return 1; +} + +static int luaB_ip4_tostring(lua_State *L) { + int n = lua_gettop(L); + if (n != 1) { + log(L_WARN "__tostring needs exactly 1 argument"); + return 0; + } + + lua_pushliteral(L, "addr"); + lua_gettable(L, 1); + lua_Integer a = lua_tointeger(L, -1); + char c[IP4_MAX_TEXT_LENGTH]; + bsnprintf(c, IP4_MAX_TEXT_LENGTH, "%I4", a); + + lua_pushstring(L, c); + return 1; +} + +static void lua_puship4(lua_State *L, ip4_addr a) { + lua_newtable(L); + lua_settableinteger(L, "addr", ip4_to_u32(a)); + + lua_newtable(L); + lua_settablecfunction(L, "__tostring", luaB_ip4_tostring); + lua_settablecfunction(L, "__concat", luaB_generic_concat); + lua_setmetatable(L, -2); +} + +static int luaB_addr_tostring(lua_State *L) { + int n = lua_gettop(L); + if (n != 1) { + log(L_WARN "__tostring needs exactly 1 argument"); + return 0; + } + + lua_pushliteral(L, "_internal"); + lua_gettable(L, 1); + if (!lua_isuserdata(L, -1)) + luaL_error(L, "fatal: bird internal state not found, type %d", lua_type(L, -1)); + + net_addr *addr = lua_touserdata(L, -1); + lua_pop(L, 1); + + char c[NET_MAX_TEXT_LENGTH+1]; + net_format(addr, c, sizeof(c)); + lua_pushstring(L, c); + return 1; +} + +static void lua_pushaddr(lua_State *L, net_addr *addr) { + lua_newtable(L); + lua_settablelightuserdata(L, "_internal", addr); + + lua_newtable(L); + lua_settablecfunction(L, "__tostring", luaB_addr_tostring); + lua_settablecfunction(L, "__concat", luaB_generic_concat); + lua_setmetatable(L, -2); +} + +static void lua_pusheattr(lua_State *L, eattr *ea) { + /* if (ea->type == EAF_TYPE_IP_ADDRESS) { */ + /* lua_settableinteger(L, "data", 17); */ + /* /\* lua_pushaddr(L, "addr", (net_addr*)ea->u.ptr->data); *\/ */ + /* } */ + lua_newtable(L); + lua_settableinteger(L, "id", ea->id); + lua_settableinteger(L, "type", ea->type); + if (ea->u.ptr && ea->u.ptr->data) { + lua_pushliteral(L, "data"); + lua_pushlstring(L, ea->u.ptr->data, ea->u.ptr->length); + lua_settable(L, -3); + } + + /* lua_settablecfunction(L, "__tostring", luaB_addr_tostring); */ + /* lua_settablecfunction(L, "__concat", luaB_generic_concat); */ + /* lua_setmetatable(L, -2); */ +} + +static lua_bird_state *luaB_getinternalstate(lua_State *L) { + lua_getglobal(L, "bird"); + lua_pushstring(L, "_internal_state"); + lua_gettable(L, -2); + if (!lua_isuserdata(L, -1)) + luaL_error(L, "fatal: bird internal state not found, type %d", lua_type(L, -1)); + + lua_bird_state *lbs = lua_touserdata(L, -1); + lua_pop(L, 2); /* Pop the user data and then the table. The string is consumed by gettable(). */ + return lbs; +} + +static int luaB_global_exception(lua_State *L, int value) { + int n = lua_gettop(L); + if (n > 1) + log(L_WARN "Called exception with too many arguments."); + + lua_bird_state *lbs = luaB_getinternalstate(L); + lbs->exception = value; + + lua_error(L); + return 0; +} + +static inline int luaB_accept(lua_State *L) { return luaB_global_exception(L, F_ACCEPT); } +static inline int luaB_reject(lua_State *L) { return luaB_global_exception(L, F_REJECT); } + +lua_bird_state *luaB_init(lua_State *L, struct linpool *lp) { + lua_newtable(L); + + lua_settablecfunction(L, "err", luaB_err); + lua_settablecfunction(L, "warn", luaB_warn); + lua_settablecfunction(L, "info", luaB_info); + lua_settablecfunction(L, "trace", luaB_trace); + + lua_bird_state *lbs = lp_allocz(lp, sizeof(lua_bird_state)); + + lua_settablelightuserdata(L, "_internal_state", lbs); + + lua_settableip4(L, "router_id", ip4_from_u32(config->router_id)); + + lua_setglobal(L, "bird"); + + lua_pushcfunction(L, luaB_accept); + lua_setglobal(L, "accept"); + + lua_pushcfunction(L, luaB_reject); + lua_setglobal(L, "reject"); + + return lbs; +} + +static int luaB_route_ea_find(lua_State *L) { + int n = lua_gettop(L); + if (n != 2) { + log(L_WARN "ea_find needs exactly 1 argument"); + return 0; + } + + lua_pushliteral(L, "_internal"); + lua_gettable(L, 1); + if (!lua_isuserdata(L, -1)) + luaL_error(L, "fatal: bird internal state not found, type %d", lua_type(L, -1)); + + struct rte *e = lua_touserdata(L, -1); + int ea = lua_tointeger(L, 2); + lua_pop(L, 2); + + struct ea_list *eattrs = e->attrs->eattrs; + eattr *t = ea_find(eattrs, ea); + + if (t) { + lua_pusheattr(L, t); + return 1; + } else { + log(L_ERR "eattr not found"); + return 0; + } +} + +/* ea_set_attr_data(id, flags, type, data(string) */ +static int luaB_route_ea_set_attr_data(lua_State *L) { + int n = lua_gettop(L); + if (n != 5) { + log(L_WARN "ea_set_attr_data needs exactly 4 argument"); + return 0; + } + + lua_pushliteral(L, "_internal"); + lua_gettable(L, 1); + if (!lua_isuserdata(L, -1)) + luaL_error(L, "fatal: bird internal state not found, type %d", lua_type(L, -1)); + + struct rte *e = lua_touserdata(L, -1); + uint id = lua_tointeger(L, 2); + uint flags = lua_tointeger(L, 3); + uint type = lua_tointeger(L, 4); + size_t len = 0; + const char *data = lua_tolstring(L, 5, &len); + lua_pop(L, 5); + + struct ea_list **eattrs = &e->attrs->eattrs; + if (!lua_lp) + lua_lp = lp_new_default(&root_pool); + ea_set_attr_data(eattrs, lua_lp, id, flags, type, data, len); + lua_pushboolean(L, 1); + return 0; +} + +void luaB_push_route(lua_State *L, struct rte *e) { + lua_newtable(L); + lua_settablelightuserdata(L, "_internal", e); + lua_settableaddr(L, "prefix", e->net->n.addr); + lua_settablecfunction(L, "ea_find", luaB_route_ea_find); + lua_settablecfunction(L, "ea_set_attr_data", luaB_route_ea_set_attr_data); + + lua_newtable(L); + lua_settablevalue(L, "__index", -2-1); + lua_setmetatable(L, -2); + + lua_setglobal(L, "route"); +} + +void luaB_push_eattrs(lua_State *L, struct ea_list *ea) { + lua_newtable(L); + + if (!ea) { + log(L_ERR "Null eattrs"); + } + + lua_settablecfunction(L, "__tostring", luaB_addr_tostring); + lua_setmetatable(L, -2); +} diff --git a/lua/filter.c b/lua/filter.c new file mode 100644 index 00000000..f4de8f1f --- /dev/null +++ b/lua/filter.c @@ -0,0 +1,187 @@ +#include "nest/bird.h" +#include "conf/conf.h" +#include "filter/filter.h" +#include "lua.h" + +#include <lua.h> +#include <lualib.h> +#include <lauxlib.h> + +/* Docs: http://pgl.yoyo.org/luai/i/luaL_dostring */ + +static lua_State *global_lua_state = NULL; + +static inline lua_State * luaB_getstate(void) { + if (!global_lua_state) { + lua_State *L = luaL_newstate(); + luaL_openlibs(L); + global_lua_state = L; + } + + return lua_newthread(global_lua_state); +} + +static inline void luaB_close(lua_State *L UNUSED) { + lua_pop(global_lua_state, 1); +} + +struct lua_new_filter_writer_data { + struct lua_filter_chunk *first, *last; +}; + +static int lua_new_filter_writer(lua_State *L UNUSED, const void *p, size_t sz, void *ud) { + struct lua_new_filter_writer_data *d = ud; + struct lua_filter_chunk *cur = cfg_allocz(sizeof(struct lua_filter_chunk)); + + cur->size = sz; + cur->chunk = cfg_alloc(sz); + memcpy(cur->chunk, p, sz); + + if (d->last) + d->last = d->last->next = cur; + else + d->last = d->first = cur; + + return 0; +} + +struct filter * lua_new_filter(struct f_inst *inst) { + struct filter *f = cfg_alloc(sizeof(struct filter)); + f->name = NULL; + f->type = FILTER_LUA; + + struct f_val string = f_eval(inst, cfg_mem); + if (string.type != T_STRING) { + cf_error("Lua filter must be a string"); + return NULL; + } + + lua_State *L = luaB_getstate(); + int loadres = luaL_loadstring(L, string.val.s); + switch (loadres) { + case LUA_ERRMEM: + luaB_close(L); + cf_error("Memory allocation error occured when loading lua chunk"); + return NULL; + case LUA_ERRSYNTAX: + { + const char *e = lua_tostring(L, -1); + char *ec = cfg_alloc(strlen(e) + 1); + strcpy(ec, e); + luaB_close(L); + cf_error("Lua syntax error: %s", ec); + return NULL; + } + case 0: /* Everything OK */ + break; + } + + struct lua_new_filter_writer_data lnfwd = {}; + lua_dump(L, lua_new_filter_writer, &lnfwd, 0); /* No error to handle */ + luaB_close(L); + + f->lua_chunk = lnfwd.first; + return f; +} + +static const char *lua_interpret_reader(lua_State *L UNUSED, void *ud, size_t *sz) { + struct lua_filter_chunk **cptr = ud; + if ((*cptr) == NULL) + return NULL; + + *sz = (*cptr)->size; + void *out = (*cptr)->chunk; + *cptr = (*cptr)->next; + return out; +} + +struct f_val lua_interpret(struct lua_filter_chunk *chunk, struct rte **e, struct rta **a UNUSED, struct ea_list **ea UNUSED, struct linpool *lp, int flags UNUSED) { + lua_State *L = luaB_getstate(); + + lua_bird_state *lbs = luaB_init(L, lp); + luaB_push_route(L, *e); + luaB_push_eattrs(L, *ea); + + struct lua_filter_chunk **rptr = &chunk; + lua_load(L, lua_interpret_reader, rptr, "", "b"); + int le = lua_pcall(L, 0, LUA_MULTRET, 0); + struct f_val out = F_VAL_VOID; + if (le && lbs->exception) { + out = F_VAL(T_RETURN, i, lbs->exception); + } else if (le) { + log(L_ERR "bad lua: %s", lua_tostring(L, -1)); + out = F_VAL(T_RETURN, i, F_ERROR); + } else if (lua_isnumber(L, -1)) { + out = F_VAL(T_INT, i, lua_tonumber(L, -1)); + } else { + log(L_WARN "lua return value is not a number (unimplemented): %s", lua_tostring(L, -1)); + out = F_VAL(T_RETURN, i, F_ERROR); + } + + *ea = (*e)->attrs->eattrs; + + luaB_close(L); + return out; +} + +int lua_filter_same(struct lua_filter_chunk *new, struct lua_filter_chunk *old) { + size_t npos = 0, opos = 0; + while (new && old) { + size_t nrem = new->size - npos; + size_t orem = old->size - opos; + size_t rem = MIN(nrem, orem); + if (memcmp(new->chunk + npos, old->chunk + opos, rem)) + return 0; + + npos += rem; + opos += rem; + + if (npos == new->size) { + new = new->next; + npos = 0; + } + + if (opos == old->size) { + old = old->next; + opos = 0; + } + } + + if (!new && !old) + return 1; + else + return 0; +} + +uint lua_eval(struct f_inst *inst) +{ + struct f_val string = f_eval(inst, cfg_mem); + if (string.type != T_STRING) { + cf_error("Lua filter must be a string"); + return -1; + } + + lua_State *L = luaB_getstate(); + int dores = luaL_dostring(L, string.val.s); + log(L_WARN "lua_eval dores '%s' %d", string.val.s, dores); + switch (dores) { + case LUA_ERRMEM: + luaB_close(L); + cf_error("Memory allocation error occured when loading lua chunk"); + return -1; + case LUA_ERRSYNTAX: + { + const char *e = lua_tostring(L, -1); + char *ec = cfg_alloc(strlen(e) + 1); + strcpy(ec, e); + luaB_close(L); + cf_error("Lua syntax error: %s", ec); + return -1; + } + case 0: /* Everything OK */ + break; + } + luaB_close(L); + + return 0; +} diff --git a/lua/lua.h b/lua/lua.h new file mode 100644 index 00000000..fc9a52d2 --- /dev/null +++ b/lua/lua.h @@ -0,0 +1,18 @@ +#include "nest/bird.h" + +#include <lua.h> + +struct lua_filter_chunk { + size_t size; + void *chunk; + struct lua_filter_chunk *next; +}; + +typedef struct lua_bird_state { + int exception; +} lua_bird_state; + +lua_bird_state *luaB_init(lua_State *L, struct linpool *lp); +void luaB_push_route(lua_State *L, rte *e); +void luaB_push_eattrs(lua_State *L, struct ea_list *ea); + diff --git a/nest/config.Y b/nest/config.Y index aef5ed46..3242f96e 100644 --- a/nest/config.Y +++ b/nest/config.Y @@ -65,7 +65,7 @@ proto_postconfig(void) CF_DECLS CF_KEYWORDS(ROUTER, ID, PROTOCOL, TEMPLATE, PREFERENCE, DISABLED, DEBUG, ALL, OFF, DIRECT) -CF_KEYWORDS(INTERFACE, IMPORT, EXPORT, FILTER, NONE, VRF, TABLE, STATES, ROUTES, FILTERS) +CF_KEYWORDS(INTERFACE, IMPORT, EXPORT, FILTER, LUA, NONE, VRF, TABLE, STATES, ROUTES, FILTERS) CF_KEYWORDS(IPV4, IPV6, VPN4, VPN6, ROA4, ROA6, FLOW4, FLOW6, SADR, MPLS) CF_KEYWORDS(RECEIVE, LIMIT, ACTION, WARN, BLOCK, RESTART, DISABLE, KEEP, FILTERED) CF_KEYWORDS(PASSWORD, FROM, PASSIVE, TO, ID, EVENTS, PACKETS, PROTOCOLS, INTERFACES) @@ -263,6 +263,7 @@ rtable: imexport: FILTER filter { $$ = $2; } + | lua_call | where_filter | ALL { $$ = FILTER_ACCEPT; } | NONE { $$ = FILTER_REJECT; } diff --git a/nest/route.h b/nest/route.h index ad89e4b2..e37ef0ce 100644 --- a/nest/route.h +++ b/nest/route.h @@ -616,7 +616,7 @@ ea_set_attr_ptr(ea_list **to, struct linpool *pool, uint id, uint flags, uint ty { ea_set_attr(to, pool, id, flags, type, (uintptr_t) val); } static inline void -ea_set_attr_data(ea_list **to, struct linpool *pool, uint id, uint flags, uint type, void *data, uint len) +ea_set_attr_data(ea_list **to, struct linpool *pool, uint id, uint flags, uint type, const void *data, uint len) { struct adata *a = lp_alloc_adata(pool, len); memcpy(a->data, data, len); diff --git a/sysdep/linux/Makefile b/sysdep/linux/Makefile index 188ac8de..12bb26c1 100644 --- a/sysdep/linux/Makefile +++ b/sysdep/linux/Makefile @@ -1,4 +1,4 @@ -src := netlink.c +src := netlink.c wireguard.c obj := $(src-o-files) $(all-daemon) $(conf-y-targets): $(s)netlink.Y diff --git a/sysdep/linux/wireguard.c b/sysdep/linux/wireguard.c new file mode 100644 index 00000000..51da8ece --- /dev/null +++ b/sysdep/linux/wireguard.c @@ -0,0 +1,1766 @@ +// SPDX-License-Identifier: LGPL-2.1+ +/* + * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + * Copyright (C) 2008-2012 Pablo Neira Ayuso <pablo@netfilter.org>. + */ + +#define _GNU_SOURCE + +#include <errno.h> +#include <linux/genetlink.h> +#include <linux/if_link.h> +#include <linux/netlink.h> +#include <linux/rtnetlink.h> +#include <netinet/in.h> +#include <stdbool.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/socket.h> +#include <time.h> +#include <unistd.h> +#include <fcntl.h> +#include <assert.h> + +#include "wireguard.h" + +/* wireguard.h netlink uapi: */ + +#define WG_GENL_NAME "wireguard" +#define WG_GENL_VERSION 1 + +enum wg_cmd { + WG_CMD_GET_DEVICE, + WG_CMD_SET_DEVICE, + __WG_CMD_MAX +}; + +enum wgdevice_flag { + WGDEVICE_F_REPLACE_PEERS = 1U << 0 +}; +enum wgdevice_attribute { + WGDEVICE_A_UNSPEC, + WGDEVICE_A_IFINDEX, + WGDEVICE_A_IFNAME, + WGDEVICE_A_PRIVATE_KEY, + WGDEVICE_A_PUBLIC_KEY, + WGDEVICE_A_FLAGS, + WGDEVICE_A_LISTEN_PORT, + WGDEVICE_A_FWMARK, + WGDEVICE_A_PEERS, + __WGDEVICE_A_LAST +}; + +enum wgpeer_flag { + WGPEER_F_REMOVE_ME = 1U << 0, + WGPEER_F_REPLACE_ALLOWEDIPS = 1U << 1 +}; +enum wgpeer_attribute { + WGPEER_A_UNSPEC, + WGPEER_A_PUBLIC_KEY, + WGPEER_A_PRESHARED_KEY, + WGPEER_A_FLAGS, + WGPEER_A_ENDPOINT, + WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, + WGPEER_A_LAST_HANDSHAKE_TIME, + WGPEER_A_RX_BYTES, + WGPEER_A_TX_BYTES, + WGPEER_A_ALLOWEDIPS, + WGPEER_A_PROTOCOL_VERSION, + __WGPEER_A_LAST +}; + +enum wgallowedip_attribute { + WGALLOWEDIP_A_UNSPEC, + WGALLOWEDIP_A_FAMILY, + WGALLOWEDIP_A_IPADDR, + WGALLOWEDIP_A_CIDR_MASK, + __WGALLOWEDIP_A_LAST +}; + +/* libmnl mini library: */ + +#define MNL_SOCKET_AUTOPID 0 +#define MNL_SOCKET_BUFFER_SIZE (sysconf(_SC_PAGESIZE) < 8192L ? sysconf(_SC_PAGESIZE) : 8192L) +#define MNL_ALIGNTO 4 +#define MNL_ALIGN(len) (((len)+MNL_ALIGNTO-1) & ~(MNL_ALIGNTO-1)) +#define MNL_NLMSG_HDRLEN MNL_ALIGN(sizeof(struct nlmsghdr)) +#define MNL_ATTR_HDRLEN MNL_ALIGN(sizeof(struct nlattr)) + +enum mnl_attr_data_type { + MNL_TYPE_UNSPEC, + MNL_TYPE_U8, + MNL_TYPE_U16, + MNL_TYPE_U32, + MNL_TYPE_U64, + MNL_TYPE_STRING, + MNL_TYPE_FLAG, + MNL_TYPE_MSECS, + MNL_TYPE_NESTED, + MNL_TYPE_NESTED_COMPAT, + MNL_TYPE_NUL_STRING, + MNL_TYPE_BINARY, + MNL_TYPE_MAX, +}; + +#define mnl_attr_for_each(attr, nlh, offset) \ + for ((attr) = mnl_nlmsg_get_payload_offset((nlh), (offset)); \ + mnl_attr_ok((attr), (char *)mnl_nlmsg_get_payload_tail(nlh) - (char *)(attr)); \ + (attr) = mnl_attr_next(attr)) + +#define mnl_attr_for_each_nested(attr, nest) \ + for ((attr) = mnl_attr_get_payload(nest); \ + mnl_attr_ok((attr), (char *)mnl_attr_get_payload(nest) + mnl_attr_get_payload_len(nest) - (char *)(attr)); \ + (attr) = mnl_attr_next(attr)) + +#define mnl_attr_for_each_payload(payload, payload_size) \ + for ((attr) = (payload); \ + mnl_attr_ok((attr), (char *)(payload) + payload_size - (char *)(attr)); \ + (attr) = mnl_attr_next(attr)) + +#define MNL_CB_ERROR -1 +#define MNL_CB_STOP 0 +#define MNL_CB_OK 1 + +typedef int (*mnl_attr_cb_t)(const struct nlattr *attr, void *data); +typedef int (*mnl_cb_t)(const struct nlmsghdr *nlh, void *data); + +#ifndef MNL_ARRAY_SIZE +#define MNL_ARRAY_SIZE(a) (sizeof(a)/sizeof((a)[0])) +#endif + +static size_t mnl_nlmsg_size(size_t len) +{ + return len + MNL_NLMSG_HDRLEN; +} + +static struct nlmsghdr *mnl_nlmsg_put_header(void *buf) +{ + int len = MNL_ALIGN(sizeof(struct nlmsghdr)); + struct nlmsghdr *nlh = buf; + + memset(buf, 0, len); + nlh->nlmsg_len = len; + return nlh; +} + +static void *mnl_nlmsg_put_extra_header(struct nlmsghdr *nlh, size_t size) +{ + char *ptr = (char *)nlh + nlh->nlmsg_len; + size_t len = MNL_ALIGN(size); + nlh->nlmsg_len += len; + memset(ptr, 0, len); + return ptr; +} + +static void *mnl_nlmsg_get_payload(const struct nlmsghdr *nlh) +{ + return (void *)nlh + MNL_NLMSG_HDRLEN; +} + +static void *mnl_nlmsg_get_payload_offset(const struct nlmsghdr *nlh, size_t offset) +{ + return (void *)nlh + MNL_NLMSG_HDRLEN + MNL_ALIGN(offset); +} + + +static bool mnl_nlmsg_ok(const struct nlmsghdr *nlh, int len) +{ + return len >= (int)sizeof(struct nlmsghdr) && + nlh->nlmsg_len >= sizeof(struct nlmsghdr) && + (int)nlh->nlmsg_len <= len; +} + +static struct nlmsghdr *mnl_nlmsg_next(const struct nlmsghdr *nlh, int *len) +{ + *len -= MNL_ALIGN(nlh->nlmsg_len); + return (struct nlmsghdr *)((void *)nlh + MNL_ALIGN(nlh->nlmsg_len)); +} + +static void *mnl_nlmsg_get_payload_tail(const struct nlmsghdr *nlh) +{ + return (void *)nlh + MNL_ALIGN(nlh->nlmsg_len); +} + +static bool mnl_nlmsg_seq_ok(const struct nlmsghdr *nlh, unsigned int seq) +{ + return nlh->nlmsg_seq && seq ? nlh->nlmsg_seq == seq : true; +} + +static bool mnl_nlmsg_portid_ok(const struct nlmsghdr *nlh, unsigned int portid) +{ + return nlh->nlmsg_pid && portid ? nlh->nlmsg_pid == portid : true; +} + +static uint16_t mnl_attr_get_type(const struct nlattr *attr) +{ + return attr->nla_type & NLA_TYPE_MASK; +} + +static uint16_t mnl_attr_get_payload_len(const struct nlattr *attr) +{ + return attr->nla_len - MNL_ATTR_HDRLEN; +} + +static void *mnl_attr_get_payload(const struct nlattr *attr) +{ + return (void *)attr + MNL_ATTR_HDRLEN; +} + +static bool mnl_attr_ok(const struct nlattr *attr, int len) +{ + return len >= (int)sizeof(struct nlattr) && + attr->nla_len >= sizeof(struct nlattr) && + (int)attr->nla_len <= len; +} + +static struct nlattr *mnl_attr_next(const struct nlattr *attr) +{ + return (struct nlattr *)((void *)attr + MNL_ALIGN(attr->nla_len)); +} + +static int mnl_attr_type_valid(const struct nlattr *attr, uint16_t max) +{ + if (mnl_attr_get_type(attr) > max) { + errno = EOPNOTSUPP; + return -1; + } + return 1; +} + +static int __mnl_attr_validate(const struct nlattr *attr, + enum mnl_attr_data_type type, size_t exp_len) +{ + uint16_t attr_len = mnl_attr_get_payload_len(attr); + const char *attr_data = mnl_attr_get_payload(attr); + + if (attr_len < exp_len) { + errno = ERANGE; + return -1; + } + switch(type) { + case MNL_TYPE_FLAG: + if (attr_len > 0) { + errno = ERANGE; + return -1; + } + break; + case MNL_TYPE_NUL_STRING: + if (attr_len == 0) { + errno = ERANGE; + return -1; + } + if (attr_data[attr_len-1] != '\0') { + errno = EINVAL; + return -1; + } + break; + case MNL_TYPE_STRING: + if (attr_len == 0) { + errno = ERANGE; + return -1; + } + break; + case MNL_TYPE_NESTED: + + if (attr_len == 0) + break; + + if (attr_len < MNL_ATTR_HDRLEN) { + errno = ERANGE; + return -1; + } + break; + default: + + break; + } + if (exp_len && attr_len > exp_len) { + errno = ERANGE; + return -1; + } + return 0; +} + +static const size_t mnl_attr_data_type_len[MNL_TYPE_MAX] = { + [MNL_TYPE_U8] = sizeof(uint8_t), + [MNL_TYPE_U16] = sizeof(uint16_t), + [MNL_TYPE_U32] = sizeof(uint32_t), + [MNL_TYPE_U64] = sizeof(uint64_t), + [MNL_TYPE_MSECS] = sizeof(uint64_t), +}; + +static int mnl_attr_validate(const struct nlattr *attr, enum mnl_attr_data_type type) +{ + int exp_len; + + if (type >= MNL_TYPE_MAX) { + errno = EINVAL; + return -1; + } + exp_len = mnl_attr_data_type_len[type]; + return __mnl_attr_validate(attr, type, exp_len); +} + +static int mnl_attr_parse(const struct nlmsghdr *nlh, unsigned int offset, + mnl_attr_cb_t cb, void *data) +{ + int ret = MNL_CB_OK; + const struct nlattr *attr; + + mnl_attr_for_each(attr, nlh, offset) + if ((ret = cb(attr, data)) <= MNL_CB_STOP) + return ret; + return ret; +} + +static int mnl_attr_parse_nested(const struct nlattr *nested, mnl_attr_cb_t cb, + void *data) +{ + int ret = MNL_CB_OK; + const struct nlattr *attr; + + mnl_attr_for_each_nested(attr, nested) + if ((ret = cb(attr, data)) <= MNL_CB_STOP) + return ret; + return ret; +} + +static uint8_t mnl_attr_get_u8(const struct nlattr *attr) +{ + return *((uint8_t *)mnl_attr_get_payload(attr)); +} + +static uint16_t mnl_attr_get_u16(const struct nlattr *attr) +{ + return *((uint16_t *)mnl_attr_get_payload(attr)); +} + +static uint32_t mnl_attr_get_u32(const struct nlattr *attr) +{ + return *((uint32_t *)mnl_attr_get_payload(attr)); +} + + +static uint64_t mnl_attr_get_u64(const struct nlattr *attr) +{ + uint64_t tmp; + memcpy(&tmp, mnl_attr_get_payload(attr), sizeof(tmp)); + return tmp; +} + +static const char *mnl_attr_get_str(const struct nlattr *attr) +{ + return mnl_attr_get_payload(attr); +} + +static void mnl_attr_put(struct nlmsghdr *nlh, uint16_t type, size_t len, + const void *data) +{ + struct nlattr *attr = mnl_nlmsg_get_payload_tail(nlh); + uint16_t payload_len = MNL_ALIGN(sizeof(struct nlattr)) + len; + int pad; + + attr->nla_type = type; + attr->nla_len = payload_len; + memcpy(mnl_attr_get_payload(attr), data, len); + nlh->nlmsg_len += MNL_ALIGN(payload_len); + pad = MNL_ALIGN(len) - len; + if (pad > 0) + memset(mnl_attr_get_payload(attr) + len, 0, pad); +} + +static void mnl_attr_put_u16(struct nlmsghdr *nlh, uint16_t type, uint16_t data) +{ + mnl_attr_put(nlh, type, sizeof(uint16_t), &data); +} + +static void mnl_attr_put_u32(struct nlmsghdr *nlh, uint16_t type, uint32_t data) +{ + mnl_attr_put(nlh, type, sizeof(uint32_t), &data); +} + +static void mnl_attr_put_strz(struct nlmsghdr *nlh, uint16_t type, const char *data) +{ + mnl_attr_put(nlh, type, strlen(data)+1, data); +} + +static struct nlattr *mnl_attr_nest_start(struct nlmsghdr *nlh, uint16_t type) +{ + struct nlattr *start = mnl_nlmsg_get_payload_tail(nlh); + + start->nla_type = NLA_F_NESTED | type; + nlh->nlmsg_len += MNL_ALIGN(sizeof(struct nlattr)); + return start; +} + +static bool mnl_attr_put_check(struct nlmsghdr *nlh, size_t buflen, + uint16_t type, size_t len, const void *data) +{ + if (nlh->nlmsg_len + MNL_ATTR_HDRLEN + MNL_ALIGN(len) > buflen) + return false; + mnl_attr_put(nlh, type, len, data); + return true; +} + +static bool mnl_attr_put_u8_check(struct nlmsghdr *nlh, size_t buflen, + uint16_t type, uint8_t data) +{ + return mnl_attr_put_check(nlh, buflen, type, sizeof(uint8_t), &data); +} + + +static bool mnl_attr_put_u16_check(struct nlmsghdr *nlh, size_t buflen, + uint16_t type, uint16_t data) +{ + return mnl_attr_put_check(nlh, buflen, type, sizeof(uint16_t), &data); +} + + +static bool mnl_attr_put_u32_check(struct nlmsghdr *nlh, size_t buflen, + uint16_t type, uint32_t data) +{ + return mnl_attr_put_check(nlh, buflen, type, sizeof(uint32_t), &data); +} + +static struct nlattr *mnl_attr_nest_start_check(struct nlmsghdr *nlh, size_t buflen, + uint16_t type) +{ + if (nlh->nlmsg_len + MNL_ATTR_HDRLEN > buflen) + return NULL; + return mnl_attr_nest_start(nlh, type); +} + +static void mnl_attr_nest_end(struct nlmsghdr *nlh, struct nlattr *start) +{ + start->nla_len = mnl_nlmsg_get_payload_tail(nlh) - (void *)start; +} + +static void mnl_attr_nest_cancel(struct nlmsghdr *nlh, struct nlattr *start) +{ + nlh->nlmsg_len -= mnl_nlmsg_get_payload_tail(nlh) - (void *)start; +} + +static int mnl_cb_noop(const struct nlmsghdr *nlh, void *data) +{ + return MNL_CB_OK; +} + +static int mnl_cb_error(const struct nlmsghdr *nlh, void *data) +{ + const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh); + + if (nlh->nlmsg_len < mnl_nlmsg_size(sizeof(struct nlmsgerr))) { + errno = EBADMSG; + return MNL_CB_ERROR; + } + + if (err->error < 0) + errno = -err->error; + else + errno = err->error; + + return err->error == 0 ? MNL_CB_STOP : MNL_CB_ERROR; +} + +static int mnl_cb_stop(const struct nlmsghdr *nlh, void *data) +{ + return MNL_CB_STOP; +} + +static const mnl_cb_t default_cb_array[NLMSG_MIN_TYPE] = { + [NLMSG_NOOP] = mnl_cb_noop, + [NLMSG_ERROR] = mnl_cb_error, + [NLMSG_DONE] = mnl_cb_stop, + [NLMSG_OVERRUN] = mnl_cb_noop, +}; + +static int __mnl_cb_run(const void *buf, size_t numbytes, + unsigned int seq, unsigned int portid, + mnl_cb_t cb_data, void *data, + const mnl_cb_t *cb_ctl_array, + unsigned int cb_ctl_array_len) +{ + int ret = MNL_CB_OK, len = numbytes; + const struct nlmsghdr *nlh = buf; + + while (mnl_nlmsg_ok(nlh, len)) { + + if (!mnl_nlmsg_portid_ok(nlh, portid)) { + errno = ESRCH; + return -1; + } + + if (!mnl_nlmsg_seq_ok(nlh, seq)) { + errno = EPROTO; + return -1; + } + + + if (nlh->nlmsg_flags & NLM_F_DUMP_INTR) { + errno = EINTR; + return -1; + } + + + if (nlh->nlmsg_type >= NLMSG_MIN_TYPE) { + if (cb_data){ + ret = cb_data(nlh, data); + if (ret <= MNL_CB_STOP) + goto out; + } + } else if (nlh->nlmsg_type < cb_ctl_array_len) { + if (cb_ctl_array && cb_ctl_array[nlh->nlmsg_type]) { + ret = cb_ctl_array[nlh->nlmsg_type](nlh, data); + if (ret <= MNL_CB_STOP) + goto out; + } + } else if (default_cb_array[nlh->nlmsg_type]) { + ret = default_cb_array[nlh->nlmsg_type](nlh, data); + if (ret <= MNL_CB_STOP) + goto out; + } + nlh = mnl_nlmsg_next(nlh, &len); + } +out: + return ret; +} + +static int mnl_cb_run2(const void *buf, size_t numbytes, unsigned int seq, + unsigned int portid, mnl_cb_t cb_data, void *data, + const mnl_cb_t *cb_ctl_array, unsigned int cb_ctl_array_len) +{ + return __mnl_cb_run(buf, numbytes, seq, portid, cb_data, data, + cb_ctl_array, cb_ctl_array_len); +} + +static int mnl_cb_run(const void *buf, size_t numbytes, unsigned int seq, + unsigned int portid, mnl_cb_t cb_data, void *data) +{ + return __mnl_cb_run(buf, numbytes, seq, portid, cb_data, data, NULL, 0); +} + +struct mnl_socket { + int fd; + struct sockaddr_nl addr; +}; + +static unsigned int mnl_socket_get_portid(const struct mnl_socket *nl) +{ + return nl->addr.nl_pid; +} + +static struct mnl_socket *__mnl_socket_open(int bus, int flags) +{ + struct mnl_socket *nl; + + nl = calloc(1, sizeof(struct mnl_socket)); + if (nl == NULL) + return NULL; + + nl->fd = socket(AF_NETLINK, SOCK_RAW | flags, bus); + if (nl->fd == -1) { + free(nl); + return NULL; + } + + return nl; +} + +static struct mnl_socket *mnl_socket_open(int bus) +{ + return __mnl_socket_open(bus, 0); +} + + +static int mnl_socket_bind(struct mnl_socket *nl, unsigned int groups, pid_t pid) +{ + int ret; + socklen_t addr_len; + + nl->addr.nl_family = AF_NETLINK; + nl->addr.nl_groups = groups; + nl->addr.nl_pid = pid; + + ret = bind(nl->fd, (struct sockaddr *) &nl->addr, sizeof (nl->addr)); + if (ret < 0) + return ret; + + addr_len = sizeof(nl->addr); + ret = getsockname(nl->fd, (struct sockaddr *) &nl->addr, &addr_len); + if (ret < 0) + return ret; + + if (addr_len != sizeof(nl->addr)) { + errno = EINVAL; + return -1; + } + if (nl->addr.nl_family != AF_NETLINK) { + errno = EINVAL; + return -1; + } + return 0; +} + + +static ssize_t mnl_socket_sendto(const struct mnl_socket *nl, const void *buf, + size_t len) +{ + static const struct sockaddr_nl snl = { + .nl_family = AF_NETLINK + }; + return sendto(nl->fd, buf, len, 0, + (struct sockaddr *) &snl, sizeof(snl)); +} + + +static ssize_t mnl_socket_recvfrom(const struct mnl_socket *nl, void *buf, + size_t bufsiz) +{ + ssize_t ret; + struct sockaddr_nl addr; + struct iovec iov = { + .iov_base = buf, + .iov_len = bufsiz, + }; + struct msghdr msg = { + .msg_name = &addr, + .msg_namelen = sizeof(struct sockaddr_nl), + .msg_iov = &iov, + .msg_iovlen = 1, + .msg_control = NULL, + .msg_controllen = 0, + .msg_flags = 0, + }; + ret = recvmsg(nl->fd, &msg, 0); + if (ret == -1) + return ret; + + if (msg.msg_flags & MSG_TRUNC) { + errno = ENOSPC; + return -1; + } + if (msg.msg_namelen != sizeof(struct sockaddr_nl)) { + errno = EINVAL; + return -1; + } + return ret; +} + +static int mnl_socket_close(struct mnl_socket *nl) +{ + int ret = close(nl->fd); + free(nl); + return ret; +} + +/* mnlg mini library: */ + +struct mnlg_socket { + struct mnl_socket *nl; + char *buf; + uint16_t id; + uint8_t version; + unsigned int seq; + unsigned int portid; +}; + +static struct nlmsghdr *__mnlg_msg_prepare(struct mnlg_socket *nlg, uint8_t cmd, + uint16_t flags, uint16_t id, + uint8_t version) +{ + struct nlmsghdr *nlh; + struct genlmsghdr *genl; + + nlh = mnl_nlmsg_put_header(nlg->buf); + nlh->nlmsg_type = id; + nlh->nlmsg_flags = flags; + nlg->seq = time(NULL); + nlh->nlmsg_seq = nlg->seq; + + genl = mnl_nlmsg_put_extra_header(nlh, sizeof(struct genlmsghdr)); + genl->cmd = cmd; + genl->version = version; + + return nlh; +} + +static struct nlmsghdr *mnlg_msg_prepare(struct mnlg_socket *nlg, uint8_t cmd, + uint16_t flags) +{ + return __mnlg_msg_prepare(nlg, cmd, flags, nlg->id, nlg->version); +} + +static int mnlg_socket_send(struct mnlg_socket *nlg, const struct nlmsghdr *nlh) +{ + return mnl_socket_sendto(nlg->nl, nlh, nlh->nlmsg_len); +} + +static int mnlg_cb_noop(const struct nlmsghdr *nlh, void *data) +{ + (void)nlh; + (void)data; + return MNL_CB_OK; +} + +static int mnlg_cb_error(const struct nlmsghdr *nlh, void *data) +{ + const struct nlmsgerr *err = mnl_nlmsg_get_payload(nlh); + (void)data; + + if (nlh->nlmsg_len < mnl_nlmsg_size(sizeof(struct nlmsgerr))) { + errno = EBADMSG; + return MNL_CB_ERROR; + } + /* Netlink subsystems returns the errno value with different signess */ + if (err->error < 0) + errno = -err->error; + else + errno = err->error; + + return err->error == 0 ? MNL_CB_STOP : MNL_CB_ERROR; +} + +static int mnlg_cb_stop(const struct nlmsghdr *nlh, void *data) +{ + (void)data; + if (nlh->nlmsg_flags & NLM_F_MULTI && nlh->nlmsg_len == mnl_nlmsg_size(sizeof(int))) { + int error = *(int *)mnl_nlmsg_get_payload(nlh); + /* Netlink subsystems returns the errno value with different signess */ + if (error < 0) + errno = -error; + else + errno = error; + + return error == 0 ? MNL_CB_STOP : MNL_CB_ERROR; + } + return MNL_CB_STOP; +} + +static const mnl_cb_t mnlg_cb_array[] = { + [NLMSG_NOOP] = mnlg_cb_noop, + [NLMSG_ERROR] = mnlg_cb_error, + [NLMSG_DONE] = mnlg_cb_stop, + [NLMSG_OVERRUN] = mnlg_cb_noop, +}; + +static int mnlg_socket_recv_run(struct mnlg_socket *nlg, mnl_cb_t data_cb, void *data) +{ + int err; + + do { + err = mnl_socket_recvfrom(nlg->nl, nlg->buf, + MNL_SOCKET_BUFFER_SIZE); + if (err <= 0) + break; + err = mnl_cb_run2(nlg->buf, err, nlg->seq, nlg->portid, + data_cb, data, mnlg_cb_array, MNL_ARRAY_SIZE(mnlg_cb_array)); + } while (err > 0); + + return err; +} + +static int get_family_id_attr_cb(const struct nlattr *attr, void *data) +{ + const struct nlattr **tb = data; + int type = mnl_attr_get_type(attr); + + if (mnl_attr_type_valid(attr, CTRL_ATTR_MAX) < 0) + return MNL_CB_ERROR; + + if (type == CTRL_ATTR_FAMILY_ID && + mnl_attr_validate(attr, MNL_TYPE_U16) < 0) + return MNL_CB_ERROR; + tb[type] = attr; + return MNL_CB_OK; +} + +static int get_family_id_cb(const struct nlmsghdr *nlh, void *data) +{ + uint16_t *p_id = data; + struct nlattr *tb[CTRL_ATTR_MAX + 1] = { 0 }; + + mnl_attr_parse(nlh, sizeof(struct genlmsghdr), get_family_id_attr_cb, tb); + if (!tb[CTRL_ATTR_FAMILY_ID]) + return MNL_CB_ERROR; + *p_id = mnl_attr_get_u16(tb[CTRL_ATTR_FAMILY_ID]); + return MNL_CB_OK; +} + +static struct mnlg_socket *mnlg_socket_open(const char *family_name, uint8_t version) +{ + struct mnlg_socket *nlg; + struct nlmsghdr *nlh; + int err; + + nlg = malloc(sizeof(*nlg)); + if (!nlg) + return NULL; + + err = -ENOMEM; + nlg->buf = malloc(MNL_SOCKET_BUFFER_SIZE); + if (!nlg->buf) + goto err_buf_alloc; + + nlg->nl = mnl_socket_open(NETLINK_GENERIC); + if (!nlg->nl) { + err = -errno; + goto err_mnl_socket_open; + } + + if (mnl_socket_bind(nlg->nl, 0, MNL_SOCKET_AUTOPID) < 0) { + err = -errno; + goto err_mnl_socket_bind; + } + + nlg->portid = mnl_socket_get_portid(nlg->nl); + + nlh = __mnlg_msg_prepare(nlg, CTRL_CMD_GETFAMILY, + NLM_F_REQUEST | NLM_F_ACK, GENL_ID_CTRL, 1); + mnl_attr_put_strz(nlh, CTRL_ATTR_FAMILY_NAME, family_name); + + if (mnlg_socket_send(nlg, nlh) < 0) { + err = -errno; + goto err_mnlg_socket_send; + } + + errno = 0; + if (mnlg_socket_recv_run(nlg, get_family_id_cb, &nlg->id) < 0) { + errno = errno == ENOENT ? EPROTONOSUPPORT : errno; + err = errno ? -errno : -ENOSYS; + goto err_mnlg_socket_recv_run; + } + + nlg->version = version; + errno = 0; + return nlg; + +err_mnlg_socket_recv_run: +err_mnlg_socket_send: +err_mnl_socket_bind: + mnl_socket_close(nlg->nl); +err_mnl_socket_open: + free(nlg->buf); +err_buf_alloc: + free(nlg); + errno = -err; + return NULL; +} + +static void mnlg_socket_close(struct mnlg_socket *nlg) +{ + mnl_socket_close(nlg->nl); + free(nlg->buf); + free(nlg); +} + +/* wireguard-specific parts: */ + +struct inflatable_buffer { + char *buffer; + char *next; + bool good; + size_t len; + size_t pos; +}; + +#define max(a, b) ((a) > (b) ? (a) : (b)) + +static int add_next_to_inflatable_buffer(struct inflatable_buffer *buffer) +{ + size_t len, expand_to; + char *new_buffer; + + if (!buffer->good || !buffer->next) { + free(buffer->next); + buffer->good = false; + return 0; + } + + len = strlen(buffer->next) + 1; + + if (len == 1) { + free(buffer->next); + buffer->good = false; + return 0; + } + + if (buffer->len - buffer->pos <= len) { + expand_to = max(buffer->len * 2, buffer->len + len + 1); + new_buffer = realloc(buffer->buffer, expand_to); + if (!new_buffer) { + free(buffer->next); + buffer->good = false; + return -errno; + } + memset(&new_buffer[buffer->len], 0, expand_to - buffer->len); + buffer->buffer = new_buffer; + buffer->len = expand_to; + } + memcpy(&buffer->buffer[buffer->pos], buffer->next, len); + free(buffer->next); + buffer->good = false; + buffer->pos += len; + return 0; +} + +static int parse_linkinfo(const struct nlattr *attr, void *data) +{ + struct inflatable_buffer *buffer = data; + + if (mnl_attr_get_type(attr) == IFLA_INFO_KIND && !strcmp(WG_GENL_NAME, mnl_attr_get_str(attr))) + buffer->good = true; + return MNL_CB_OK; +} + +static int parse_infomsg(const struct nlattr *attr, void *data) +{ + struct inflatable_buffer *buffer = data; + + if (mnl_attr_get_type(attr) == IFLA_LINKINFO) + return mnl_attr_parse_nested(attr, parse_linkinfo, data); + else if (mnl_attr_get_type(attr) == IFLA_IFNAME) + buffer->next = strdup(mnl_attr_get_str(attr)); + return MNL_CB_OK; +} + +static int read_devices_cb(const struct nlmsghdr *nlh, void *data) +{ + struct inflatable_buffer *buffer = data; + int ret; + + buffer->good = false; + buffer->next = NULL; + ret = mnl_attr_parse(nlh, sizeof(struct ifinfomsg), parse_infomsg, data); + if (ret != MNL_CB_OK) + return ret; + ret = add_next_to_inflatable_buffer(buffer); + if (ret < 0) + return ret; + if (nlh->nlmsg_type != NLMSG_DONE) + return MNL_CB_OK + 1; + return MNL_CB_OK; +} + +static int fetch_device_names(struct inflatable_buffer *buffer) +{ + struct mnl_socket *nl = NULL; + char *rtnl_buffer = NULL; + size_t message_len; + unsigned int portid, seq; + ssize_t len; + int ret = 0; + struct nlmsghdr *nlh; + struct ifinfomsg *ifm; + + ret = -ENOMEM; + rtnl_buffer = calloc(MNL_SOCKET_BUFFER_SIZE, 1); + if (!rtnl_buffer) + goto cleanup; + + nl = mnl_socket_open(NETLINK_ROUTE); + if (!nl) { + ret = -errno; + goto cleanup; + } + + if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) { + ret = -errno; + goto cleanup; + } + + seq = time(NULL); + portid = mnl_socket_get_portid(nl); + nlh = mnl_nlmsg_put_header(rtnl_buffer); + nlh->nlmsg_type = RTM_GETLINK; + nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP; + nlh->nlmsg_seq = seq; + ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm)); + ifm->ifi_family = AF_UNSPEC; + message_len = nlh->nlmsg_len; + + if (mnl_socket_sendto(nl, rtnl_buffer, message_len) < 0) { + ret = -errno; + goto cleanup; + } + +another: + if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, MNL_SOCKET_BUFFER_SIZE)) < 0) { + ret = -errno; + goto cleanup; + } + if ((len = mnl_cb_run(rtnl_buffer, len, seq, portid, read_devices_cb, buffer)) < 0) { + /* Netlink returns NLM_F_DUMP_INTR if the set of all tunnels changed + * during the dump. That's unfortunate, but is pretty common on busy + * systems that are adding and removing tunnels all the time. Rather + * than retrying, potentially indefinitely, we just work with the + * partial results. */ + if (errno != EINTR) { + ret = -errno; + goto cleanup; + } + } + if (len == MNL_CB_OK + 1) + goto another; + ret = 0; + +cleanup: + free(rtnl_buffer); + if (nl) + mnl_socket_close(nl); + return ret; +} + +static int add_del_iface(const char *ifname, bool add) +{ + struct mnl_socket *nl = NULL; + char *rtnl_buffer; + ssize_t len; + int ret; + struct nlmsghdr *nlh; + struct ifinfomsg *ifm; + struct nlattr *nest; + + rtnl_buffer = calloc(MNL_SOCKET_BUFFER_SIZE, 1); + if (!rtnl_buffer) { + ret = -ENOMEM; + goto cleanup; + } + + nl = mnl_socket_open(NETLINK_ROUTE); + if (!nl) { + ret = -errno; + goto cleanup; + } + + if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) { + ret = -errno; + goto cleanup; + } + + nlh = mnl_nlmsg_put_header(rtnl_buffer); + nlh->nlmsg_type = add ? RTM_NEWLINK : RTM_DELLINK; + nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | (add ? NLM_F_CREATE | NLM_F_EXCL : 0); + nlh->nlmsg_seq = time(NULL); + ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm)); + ifm->ifi_family = AF_UNSPEC; + mnl_attr_put_strz(nlh, IFLA_IFNAME, ifname); + nest = mnl_attr_nest_start(nlh, IFLA_LINKINFO); + mnl_attr_put_strz(nlh, IFLA_INFO_KIND, WG_GENL_NAME); + mnl_attr_nest_end(nlh, nest); + + if (mnl_socket_sendto(nl, rtnl_buffer, nlh->nlmsg_len) < 0) { + ret = -errno; + goto cleanup; + } + if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, MNL_SOCKET_BUFFER_SIZE)) < 0) { + ret = -errno; + goto cleanup; + } + if (mnl_cb_run(rtnl_buffer, len, nlh->nlmsg_seq, mnl_socket_get_portid(nl), NULL, NULL) < 0) { + ret = -errno; + goto cleanup; + } + ret = 0; + +cleanup: + free(rtnl_buffer); + if (nl) + mnl_socket_close(nl); + return ret; +} + +int wg_set_device(wg_device *dev) +{ + int ret = 0; + wg_peer *peer = NULL; + wg_allowedip *allowedip = NULL; + struct nlattr *peers_nest, *peer_nest, *allowedips_nest, *allowedip_nest; + struct nlmsghdr *nlh; + struct mnlg_socket *nlg; + + nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION); + if (!nlg) + return -errno; + +again: + nlh = mnlg_msg_prepare(nlg, WG_CMD_SET_DEVICE, NLM_F_REQUEST | NLM_F_ACK); + mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, dev->name); + + if (!peer) { + uint32_t flags = 0; + + if (dev->flags & WGDEVICE_HAS_PRIVATE_KEY) + mnl_attr_put(nlh, WGDEVICE_A_PRIVATE_KEY, sizeof(dev->private_key), dev->private_key); + if (dev->flags & WGDEVICE_HAS_LISTEN_PORT) + mnl_attr_put_u16(nlh, WGDEVICE_A_LISTEN_PORT, dev->listen_port); + if (dev->flags & WGDEVICE_HAS_FWMARK) + mnl_attr_put_u32(nlh, WGDEVICE_A_FWMARK, dev->fwmark); + if (dev->flags & WGDEVICE_REPLACE_PEERS) + flags |= WGDEVICE_F_REPLACE_PEERS; + if (flags) + mnl_attr_put_u32(nlh, WGDEVICE_A_FLAGS, flags); + } + if (!dev->first_peer) + goto send; + peers_nest = peer_nest = allowedips_nest = allowedip_nest = NULL; + peers_nest = mnl_attr_nest_start(nlh, WGDEVICE_A_PEERS); + for (peer = peer ? peer : dev->first_peer; peer; peer = peer->next_peer) { + uint32_t flags = 0; + + peer_nest = mnl_attr_nest_start_check(nlh, MNL_SOCKET_BUFFER_SIZE, 0); + if (!peer_nest) + goto toobig_peers; + if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_PUBLIC_KEY, sizeof(peer->public_key), peer->public_key)) + goto toobig_peers; + if (peer->flags & WGPEER_REMOVE_ME) + flags |= WGPEER_F_REMOVE_ME; + if (!allowedip) { + if (peer->flags & WGPEER_REPLACE_ALLOWEDIPS) + flags |= WGPEER_F_REPLACE_ALLOWEDIPS; + if (peer->flags & WGPEER_HAS_PRESHARED_KEY) { + if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_PRESHARED_KEY, sizeof(peer->preshared_key), peer->preshared_key)) + goto toobig_peers; + } + if (peer->endpoint.addr.sa_family == AF_INET) { + if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr4), &peer->endpoint.addr4)) + goto toobig_peers; + } else if (peer->endpoint.addr.sa_family == AF_INET6) { + if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr6), &peer->endpoint.addr6)) + goto toobig_peers; + } + if (peer->flags & WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL) { + if (!mnl_attr_put_u16_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval)) + goto toobig_peers; + } + } + if (flags) { + if (!mnl_attr_put_u32_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_FLAGS, flags)) + goto toobig_peers; + } + if (peer->first_allowedip) { + if (!allowedip) + allowedip = peer->first_allowedip; + allowedips_nest = mnl_attr_nest_start_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_ALLOWEDIPS); + if (!allowedips_nest) + goto toobig_allowedips; + for (; allowedip; allowedip = allowedip->next_allowedip) { + allowedip_nest = mnl_attr_nest_start_check(nlh, MNL_SOCKET_BUFFER_SIZE, 0); + if (!allowedip_nest) + goto toobig_allowedips; + if (!mnl_attr_put_u16_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGALLOWEDIP_A_FAMILY, allowedip->family)) + goto toobig_allowedips; + if (allowedip->family == AF_INET) { + if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip4), &allowedip->ip4)) + goto toobig_allowedips; + } else if (allowedip->family == AF_INET6) { + if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip6), &allowedip->ip6)) + goto toobig_allowedips; + } + if (!mnl_attr_put_u8_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGALLOWEDIP_A_CIDR_MASK, allowedip->cidr)) + goto toobig_allowedips; + mnl_attr_nest_end(nlh, allowedip_nest); + allowedip_nest = NULL; + } + mnl_attr_nest_end(nlh, allowedips_nest); + allowedips_nest = NULL; + } + + mnl_attr_nest_end(nlh, peer_nest); + peer_nest = NULL; + } + mnl_attr_nest_end(nlh, peers_nest); + peers_nest = NULL; + goto send; +toobig_allowedips: + if (allowedip_nest) + mnl_attr_nest_cancel(nlh, allowedip_nest); + if (allowedips_nest) + mnl_attr_nest_end(nlh, allowedips_nest); + mnl_attr_nest_end(nlh, peer_nest); + mnl_attr_nest_end(nlh, peers_nest); + goto send; +toobig_peers: + if (peer_nest) + mnl_attr_nest_cancel(nlh, peer_nest); + mnl_attr_nest_end(nlh, peers_nest); + goto send; +send: + if (mnlg_socket_send(nlg, nlh) < 0) { + ret = -errno; + goto out; + } + errno = 0; + if (mnlg_socket_recv_run(nlg, NULL, NULL) < 0) { + ret = errno ? -errno : -EINVAL; + goto out; + } + if (peer) + goto again; + +out: + mnlg_socket_close(nlg); + errno = -ret; + return ret; +} + +static int parse_allowedip(const struct nlattr *attr, void *data) +{ + wg_allowedip *allowedip = data; + + switch (mnl_attr_get_type(attr)) { + case WGALLOWEDIP_A_UNSPEC: + break; + case WGALLOWEDIP_A_FAMILY: + if (!mnl_attr_validate(attr, MNL_TYPE_U16)) + allowedip->family = mnl_attr_get_u16(attr); + break; + case WGALLOWEDIP_A_IPADDR: + if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip4)) + memcpy(&allowedip->ip4, mnl_attr_get_payload(attr), sizeof(allowedip->ip4)); + else if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip6)) + memcpy(&allowedip->ip6, mnl_attr_get_payload(attr), sizeof(allowedip->ip6)); + break; + case WGALLOWEDIP_A_CIDR_MASK: + if (!mnl_attr_validate(attr, MNL_TYPE_U8)) + allowedip->cidr = mnl_attr_get_u8(attr); + break; + } + + return MNL_CB_OK; +} + +static int parse_allowedips(const struct nlattr *attr, void *data) +{ + wg_peer *peer = data; + wg_allowedip *new_allowedip = calloc(1, sizeof(wg_allowedip)); + int ret; + + if (!new_allowedip) + return MNL_CB_ERROR; + if (!peer->first_allowedip) + peer->first_allowedip = peer->last_allowedip = new_allowedip; + else { + peer->last_allowedip->next_allowedip = new_allowedip; + peer->last_allowedip = new_allowedip; + } + ret = mnl_attr_parse_nested(attr, parse_allowedip, new_allowedip); + if (!ret) + return ret; + if (!((new_allowedip->family == AF_INET && new_allowedip->cidr <= 32) || (new_allowedip->family == AF_INET6 && new_allowedip->cidr <= 128))) { + errno = EAFNOSUPPORT; + return MNL_CB_ERROR; + } + return MNL_CB_OK; +} + +bool wg_key_is_zero(const wg_key key) +{ + volatile uint8_t acc = 0; + unsigned int i; + + for (i = 0; i < sizeof(wg_key); ++i) { + acc |= key[i]; + __asm__ ("" : "=r" (acc) : "0" (acc)); + } + return 1 & ((acc - 1) >> 8); +} + +static int parse_peer(const struct nlattr *attr, void *data) +{ + wg_peer *peer = data; + + switch (mnl_attr_get_type(attr)) { + case WGPEER_A_UNSPEC: + break; + case WGPEER_A_PUBLIC_KEY: + if (mnl_attr_get_payload_len(attr) == sizeof(peer->public_key)) { + memcpy(peer->public_key, mnl_attr_get_payload(attr), sizeof(peer->public_key)); + peer->flags |= WGPEER_HAS_PUBLIC_KEY; + } + break; + case WGPEER_A_PRESHARED_KEY: + if (mnl_attr_get_payload_len(attr) == sizeof(peer->preshared_key)) { + memcpy(peer->preshared_key, mnl_attr_get_payload(attr), sizeof(peer->preshared_key)); + if (!wg_key_is_zero(peer->preshared_key)) + peer->flags |= WGPEER_HAS_PRESHARED_KEY; + } + break; + case WGPEER_A_ENDPOINT: { + struct sockaddr *addr; + + if (mnl_attr_get_payload_len(attr) < sizeof(*addr)) + break; + addr = mnl_attr_get_payload(attr); + if (addr->sa_family == AF_INET && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr4)) + memcpy(&peer->endpoint.addr4, addr, sizeof(peer->endpoint.addr4)); + else if (addr->sa_family == AF_INET6 && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr6)) + memcpy(&peer->endpoint.addr6, addr, sizeof(peer->endpoint.addr6)); + break; + } + case WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL: + if (!mnl_attr_validate(attr, MNL_TYPE_U16)) + peer->persistent_keepalive_interval = mnl_attr_get_u16(attr); + break; + case WGPEER_A_LAST_HANDSHAKE_TIME: + if (mnl_attr_get_payload_len(attr) == sizeof(peer->last_handshake_time)) + memcpy(&peer->last_handshake_time, mnl_attr_get_payload(attr), sizeof(peer->last_handshake_time)); + break; + case WGPEER_A_RX_BYTES: + if (!mnl_attr_validate(attr, MNL_TYPE_U64)) + peer->rx_bytes = mnl_attr_get_u64(attr); + break; + case WGPEER_A_TX_BYTES: + if (!mnl_attr_validate(attr, MNL_TYPE_U64)) + peer->tx_bytes = mnl_attr_get_u64(attr); + break; + case WGPEER_A_ALLOWEDIPS: + return mnl_attr_parse_nested(attr, parse_allowedips, peer); + } + + return MNL_CB_OK; +} + +static int parse_peers(const struct nlattr *attr, void *data) +{ + wg_device *device = data; + wg_peer *new_peer = calloc(1, sizeof(wg_peer)); + int ret; + + if (!new_peer) + return MNL_CB_ERROR; + if (!device->first_peer) + device->first_peer = device->last_peer = new_peer; + else { + device->last_peer->next_peer = new_peer; + device->last_peer = new_peer; + } + ret = mnl_attr_parse_nested(attr, parse_peer, new_peer); + if (!ret) + return ret; + if (!(new_peer->flags & WGPEER_HAS_PUBLIC_KEY)) { + errno = ENXIO; + return MNL_CB_ERROR; + } + return MNL_CB_OK; +} + +static int parse_device(const struct nlattr *attr, void *data) +{ + wg_device *device = data; + + switch (mnl_attr_get_type(attr)) { + case WGDEVICE_A_UNSPEC: + break; + case WGDEVICE_A_IFINDEX: + if (!mnl_attr_validate(attr, MNL_TYPE_U32)) + device->ifindex = mnl_attr_get_u32(attr); + break; + case WGDEVICE_A_IFNAME: + if (!mnl_attr_validate(attr, MNL_TYPE_STRING)) { + strncpy(device->name, mnl_attr_get_str(attr), sizeof(device->name) - 1); + device->name[sizeof(device->name) - 1] = '\0'; + } + break; + case WGDEVICE_A_PRIVATE_KEY: + if (mnl_attr_get_payload_len(attr) == sizeof(device->private_key)) { + memcpy(device->private_key, mnl_attr_get_payload(attr), sizeof(device->private_key)); + device->flags |= WGDEVICE_HAS_PRIVATE_KEY; + } + break; + case WGDEVICE_A_PUBLIC_KEY: + if (mnl_attr_get_payload_len(attr) == sizeof(device->public_key)) { + memcpy(device->public_key, mnl_attr_get_payload(attr), sizeof(device->public_key)); + device->flags |= WGDEVICE_HAS_PUBLIC_KEY; + } + break; + case WGDEVICE_A_LISTEN_PORT: + if (!mnl_attr_validate(attr, MNL_TYPE_U16)) + device->listen_port = mnl_attr_get_u16(attr); + break; + case WGDEVICE_A_FWMARK: + if (!mnl_attr_validate(attr, MNL_TYPE_U32)) + device->fwmark = mnl_attr_get_u32(attr); + break; + case WGDEVICE_A_PEERS: + return mnl_attr_parse_nested(attr, parse_peers, device); + } + + return MNL_CB_OK; +} + +static int read_device_cb(const struct nlmsghdr *nlh, void *data) +{ + return mnl_attr_parse(nlh, sizeof(struct genlmsghdr), parse_device, data); +} + +static void coalesce_peers(wg_device *device) +{ + wg_peer *old_next_peer, *peer = device->first_peer; + + while (peer && peer->next_peer) { + if (memcmp(peer->public_key, peer->next_peer->public_key, sizeof(wg_key))) { + peer = peer->next_peer; + continue; + } + if (!peer->first_allowedip) { + peer->first_allowedip = peer->next_peer->first_allowedip; + peer->last_allowedip = peer->next_peer->last_allowedip; + } else { + peer->last_allowedip->next_allowedip = peer->next_peer->first_allowedip; + peer->last_allowedip = peer->next_peer->last_allowedip; + } + old_next_peer = peer->next_peer; + peer->next_peer = old_next_peer->next_peer; + free(old_next_peer); + } +} + +int wg_get_device(wg_device **device, const char *device_name) +{ + int ret = 0; + struct nlmsghdr *nlh; + struct mnlg_socket *nlg; + +try_again: + *device = calloc(1, sizeof(wg_device)); + if (!*device) + return -errno; + + nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION); + if (!nlg) { + wg_free_device(*device); + *device = NULL; + return -errno; + } + + nlh = mnlg_msg_prepare(nlg, WG_CMD_GET_DEVICE, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP); + mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, device_name); + if (mnlg_socket_send(nlg, nlh) < 0) { + ret = -errno; + goto out; + } + errno = 0; + if (mnlg_socket_recv_run(nlg, read_device_cb, *device) < 0) { + ret = errno ? -errno : -EINVAL; + goto out; + } + coalesce_peers(*device); + +out: + if (nlg) + mnlg_socket_close(nlg); + if (ret) { + wg_free_device(*device); + if (ret == -EINTR) + goto try_again; + *device = NULL; + } + errno = -ret; + return ret; +} + +/* first\0second\0third\0forth\0last\0\0 */ +char *wg_list_device_names(void) +{ + struct inflatable_buffer buffer = { .len = MNL_SOCKET_BUFFER_SIZE }; + int ret; + + ret = -ENOMEM; + buffer.buffer = calloc(1, buffer.len); + if (!buffer.buffer) + goto err; + + ret = fetch_device_names(&buffer); +err: + errno = -ret; + if (errno) { + free(buffer.buffer); + return NULL; + } + return buffer.buffer; +} + +int wg_add_device(const char *device_name) +{ + return add_del_iface(device_name, true); +} + +int wg_del_device(const char *device_name) +{ + return add_del_iface(device_name, false); +} + +void wg_free_device(wg_device *dev) +{ + wg_peer *peer, *np; + wg_allowedip *allowedip, *na; + + if (!dev) + return; + for (peer = dev->first_peer, np = peer ? peer->next_peer : NULL; peer; peer = np, np = peer ? peer->next_peer : NULL) { + for (allowedip = peer->first_allowedip, na = allowedip ? allowedip->next_allowedip : NULL; allowedip; allowedip = na, na = allowedip ? allowedip->next_allowedip : NULL) + free(allowedip); + free(peer); + } + free(dev); +} + +static void encode_base64(char dest[static 4], const uint8_t src[static 3]) +{ + const uint8_t input[] = { (src[0] >> 2) & 63, ((src[0] << 4) | (src[1] >> 4)) & 63, ((src[1] << 2) | (src[2] >> 6)) & 63, src[2] & 63 }; + unsigned int i; + + for (i = 0; i < 4; ++i) + dest[i] = input[i] + 'A' + + (((25 - input[i]) >> 8) & 6) + - (((51 - input[i]) >> 8) & 75) + - (((61 - input[i]) >> 8) & 15) + + (((62 - input[i]) >> 8) & 3); + +} + +void wg_key_to_base64(wg_key_b64_string base64, const wg_key key) +{ + unsigned int i; + + for (i = 0; i < 32 / 3; ++i) + encode_base64(&base64[i * 4], &key[i * 3]); + encode_base64(&base64[i * 4], (const uint8_t[]){ key[i * 3 + 0], key[i * 3 + 1], 0 }); + base64[sizeof(wg_key_b64_string) - 2] = '='; + base64[sizeof(wg_key_b64_string) - 1] = '\0'; +} + +static int decode_base64(const char src[static 4]) +{ + int val = 0; + unsigned int i; + + for (i = 0; i < 4; ++i) + val |= (-1 + + ((((('A' - 1) - src[i]) & (src[i] - ('Z' + 1))) >> 8) & (src[i] - 64)) + + ((((('a' - 1) - src[i]) & (src[i] - ('z' + 1))) >> 8) & (src[i] - 70)) + + ((((('0' - 1) - src[i]) & (src[i] - ('9' + 1))) >> 8) & (src[i] + 5)) + + ((((('+' - 1) - src[i]) & (src[i] - ('+' + 1))) >> 8) & 63) + + ((((('/' - 1) - src[i]) & (src[i] - ('/' + 1))) >> 8) & 64) + ) << (18 - 6 * i); + return val; +} + +int wg_key_from_base64(wg_key key, const wg_key_b64_string base64) +{ + unsigned int i; + int val; + volatile uint8_t ret = 0; + + if (strlen(base64) != sizeof(wg_key_b64_string) - 1 || base64[sizeof(wg_key_b64_string) - 2] != '=') { + errno = EINVAL; + goto out; + } + + for (i = 0; i < 32 / 3; ++i) { + val = decode_base64(&base64[i * 4]); + ret |= (uint32_t)val >> 31; + key[i * 3 + 0] = (val >> 16) & 0xff; + key[i * 3 + 1] = (val >> 8) & 0xff; + key[i * 3 + 2] = val & 0xff; + } + val = decode_base64((const char[]){ base64[i * 4 + 0], base64[i * 4 + 1], base64[i * 4 + 2], 'A' }); + ret |= ((uint32_t)val >> 31) | (val & 0xff); + key[i * 3 + 0] = (val >> 16) & 0xff; + key[i * 3 + 1] = (val >> 8) & 0xff; + errno = EINVAL & ~((ret - 1) >> 8); +out: + return -errno; +} + +typedef int64_t fe[16]; + +static __attribute__((noinline)) void memzero_explicit(void *s, size_t count) +{ + memset(s, 0, count); + __asm__ __volatile__("": :"r"(s) :"memory"); +} + +static void carry(fe o) +{ + int i; + + for (i = 0; i < 16; ++i) { + o[(i + 1) % 16] += (i == 15 ? 38 : 1) * (o[i] >> 16); + o[i] &= 0xffff; + } +} + +static void cswap(fe p, fe q, int b) +{ + int i; + int64_t t, c = ~(b - 1); + + for (i = 0; i < 16; ++i) { + t = c & (p[i] ^ q[i]); + p[i] ^= t; + q[i] ^= t; + } + + memzero_explicit(&t, sizeof(t)); + memzero_explicit(&c, sizeof(c)); + memzero_explicit(&b, sizeof(b)); +} + +static void pack(uint8_t *o, const fe n) +{ + int i, j, b; + fe m, t; + + memcpy(t, n, sizeof(t)); + carry(t); + carry(t); + carry(t); + for (j = 0; j < 2; ++j) { + m[0] = t[0] - 0xffed; + for (i = 1; i < 15; ++i) { + m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1); + m[i - 1] &= 0xffff; + } + m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1); + b = (m[15] >> 16) & 1; + m[14] &= 0xffff; + cswap(t, m, 1 - b); + } + for (i = 0; i < 16; ++i) { + o[2 * i] = t[i] & 0xff; + o[2 * i + 1] = t[i] >> 8; + } + + memzero_explicit(m, sizeof(m)); + memzero_explicit(t, sizeof(t)); + memzero_explicit(&b, sizeof(b)); +} + +static void add(fe o, const fe a, const fe b) +{ + int i; + + for (i = 0; i < 16; ++i) + o[i] = a[i] + b[i]; +} + +static void subtract(fe o, const fe a, const fe b) +{ + int i; + + for (i = 0; i < 16; ++i) + o[i] = a[i] - b[i]; +} + +static void multmod(fe o, const fe a, const fe b) +{ + int i, j; + int64_t t[31] = { 0 }; + + for (i = 0; i < 16; ++i) { + for (j = 0; j < 16; ++j) + t[i + j] += a[i] * b[j]; + } + for (i = 0; i < 15; ++i) + t[i] += 38 * t[i + 16]; + memcpy(o, t, sizeof(fe)); + carry(o); + carry(o); + + memzero_explicit(t, sizeof(t)); +} + +static void invert(fe o, const fe i) +{ + fe c; + int a; + + memcpy(c, i, sizeof(c)); + for (a = 253; a >= 0; --a) { + multmod(c, c, c); + if (a != 2 && a != 4) + multmod(c, c, i); + } + memcpy(o, c, sizeof(fe)); + + memzero_explicit(c, sizeof(c)); +} + +static void clamp_key(uint8_t *z) +{ + z[31] = (z[31] & 127) | 64; + z[0] &= 248; +} + +void wg_generate_public_key(wg_key public_key, const wg_key private_key) +{ + int i, r; + uint8_t z[32]; + fe a = { 1 }, b = { 9 }, c = { 0 }, d = { 1 }, e, f; + + memcpy(z, private_key, sizeof(z)); + clamp_key(z); + + for (i = 254; i >= 0; --i) { + r = (z[i >> 3] >> (i & 7)) & 1; + cswap(a, b, r); + cswap(c, d, r); + add(e, a, c); + subtract(a, a, c); + add(c, b, d); + subtract(b, b, d); + multmod(d, e, e); + multmod(f, a, a); + multmod(a, c, a); + multmod(c, b, e); + add(e, a, c); + subtract(a, a, c); + multmod(b, a, a); + subtract(c, d, f); + multmod(a, c, (const fe){ 0xdb41, 1 }); + add(a, a, d); + multmod(c, c, a); + multmod(a, d, f); + multmod(d, b, (const fe){ 9 }); + multmod(b, e, e); + cswap(a, b, r); + cswap(c, d, r); + } + invert(c, c); + multmod(a, a, c); + pack(public_key, a); + + memzero_explicit(&r, sizeof(r)); + memzero_explicit(z, sizeof(z)); + memzero_explicit(a, sizeof(a)); + memzero_explicit(b, sizeof(b)); + memzero_explicit(c, sizeof(c)); + memzero_explicit(d, sizeof(d)); + memzero_explicit(e, sizeof(e)); + memzero_explicit(f, sizeof(f)); +} + +void wg_generate_private_key(wg_key private_key) +{ + wg_generate_preshared_key(private_key); + clamp_key(private_key); +} + +void wg_generate_preshared_key(wg_key preshared_key) +{ + ssize_t ret; + int fd; + +#if defined(__NR_getrandom) + ret = syscall(__NR_getrandom, preshared_key, sizeof(wg_key), 0); + if (ret == sizeof(wg_key)) + return; +#endif + fd = open("/dev/urandom", O_RDONLY); + assert(fd >= 0); + ret = read(fd, preshared_key, sizeof(wg_key)); + close(fd); + assert(ret == sizeof(wg_key)); +} diff --git a/sysdep/linux/wireguard.h b/sysdep/linux/wireguard.h new file mode 100644 index 00000000..d8965f75 --- /dev/null +++ b/sysdep/linux/wireguard.h @@ -0,0 +1,103 @@ +/* SPDX-License-Identifier: LGPL-2.1+ */ +/* + * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. + */ + +#ifndef WIREGUARD_H +#define WIREGUARD_H + +#include <linux/if.h> +#include <netinet/in.h> +#include <sys/socket.h> +#include <time.h> +#include <stdint.h> +#include <stdbool.h> + +typedef uint8_t wg_key[32]; +typedef char wg_key_b64_string[((sizeof(wg_key) + 2) / 3) * 4 + 1]; + +/* Cross platform __kernel_timespec */ +struct timespec64 { + int64_t tv_sec; + int64_t tv_nsec; +}; + +typedef struct wg_allowedip { + uint16_t family; + union { + struct in_addr ip4; + struct in6_addr ip6; + }; + uint8_t cidr; + struct wg_allowedip *next_allowedip; +} wg_allowedip; + +enum wg_peer_flags { + WGPEER_REMOVE_ME = 1U << 0, + WGPEER_REPLACE_ALLOWEDIPS = 1U << 1, + WGPEER_HAS_PUBLIC_KEY = 1U << 2, + WGPEER_HAS_PRESHARED_KEY = 1U << 3, + WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL = 1U << 4 +}; + +typedef struct wg_peer { + enum wg_peer_flags flags; + + wg_key public_key; + wg_key preshared_key; + + union { + struct sockaddr addr; + struct sockaddr_in addr4; + struct sockaddr_in6 addr6; + } endpoint; + + struct timespec64 last_handshake_time; + uint64_t rx_bytes, tx_bytes; + uint16_t persistent_keepalive_interval; + + struct wg_allowedip *first_allowedip, *last_allowedip; + struct wg_peer *next_peer; +} wg_peer; + +enum wg_device_flags { + WGDEVICE_REPLACE_PEERS = 1U << 0, + WGDEVICE_HAS_PRIVATE_KEY = 1U << 1, + WGDEVICE_HAS_PUBLIC_KEY = 1U << 2, + WGDEVICE_HAS_LISTEN_PORT = 1U << 3, + WGDEVICE_HAS_FWMARK = 1U << 4 +}; + +typedef struct wg_device { + char name[IFNAMSIZ]; + uint32_t ifindex; + + enum wg_device_flags flags; + + wg_key public_key; + wg_key private_key; + + uint32_t fwmark; + uint16_t listen_port; + + struct wg_peer *first_peer, *last_peer; +} wg_device; + +#define wg_for_each_device_name(__names, __name, __len) for ((__name) = (__names), (__len) = 0; ((__len) = strlen(__name)); (__name) += (__len) + 1) +#define wg_for_each_peer(__dev, __peer) for ((__peer) = (__dev)->first_peer; (__peer); (__peer) = (__peer)->next_peer) +#define wg_for_each_allowedip(__peer, __allowedip) for ((__allowedip) = (__peer)->first_allowedip; (__allowedip); (__allowedip) = (__allowedip)->next_allowedip) + +int wg_set_device(wg_device *dev); +int wg_get_device(wg_device **dev, const char *device_name); +int wg_add_device(const char *device_name); +int wg_del_device(const char *device_name); +void wg_free_device(wg_device *dev); +char *wg_list_device_names(void); /* first\0second\0third\0forth\0last\0\0 */ +void wg_key_to_base64(wg_key_b64_string base64, const wg_key key); +int wg_key_from_base64(wg_key key, const wg_key_b64_string base64); +bool wg_key_is_zero(const wg_key key); +void wg_generate_public_key(wg_key public_key, const wg_key private_key); +void wg_generate_private_key(wg_key private_key); +void wg_generate_preshared_key(wg_key preshared_key); + +#endif |