Revamp of format validation in 'string.format'

When calling 'sprintf', not all conversion specifiers accept all
flags; some combinations are undefined behavior.
This commit is contained in:
Roberto Ierusalimschy 2021-09-03 13:14:56 -03:00
parent 91673a8ec0
commit 9db4bfed6b
3 changed files with 118 additions and 36 deletions

112
lstrlib.c
View File

@ -1090,13 +1090,31 @@ static int lua_number2strx (lua_State *L, char *buff, int sz,
/* valid flags in a format specification */ /* valid flags in a format specification */
#if !defined(L_FMTFLAGS) #if !defined(L_FMTFLAGSF)
#define L_FMTFLAGS "-+ #0"
/* valid flags for a, A, e, E, f, F, g, and G conversions */
#define L_FMTFLAGSF "-+#0 "
/* valid flags for o, x, and X conversions */
#define L_FMTFLAGSX "-#0"
/* valid flags for d and i conversions */
#define L_FMTFLAGSI "-+0 "
/* valid flags for u conversions */
#define L_FMTFLAGSU "-0"
/* valid flags for c, p, and s conversions */
#define L_FMTFLAGSC "-"
#endif #endif
/* /*
** maximum size of each format specification (such as "%-099.99d") ** Maximum size of each format specification (such as "%-099.99d"):
** Initial '%', flags (up to 5), width (2), period, precision (2),
** length modifier (8), conversion specifier, and final '\0', plus some
** extra.
*/ */
#define MAX_FORMAT 32 #define MAX_FORMAT 32
@ -1189,25 +1207,53 @@ static void addliteral (lua_State *L, luaL_Buffer *b, int arg) {
} }
static const char *scanformat (lua_State *L, const char *strfrmt, char *form) { static const char *get2digits (const char *s) {
const char *p = strfrmt; if (isdigit(uchar(*s))) {
while (*p != '\0' && strchr(L_FMTFLAGS, *p) != NULL) p++; /* skip flags */ s++;
if ((size_t)(p - strfrmt) >= sizeof(L_FMTFLAGS)/sizeof(char)) if (isdigit(uchar(*s))) s++; /* (2 digits at most) */
luaL_error(L, "invalid format (repeated flags)");
if (isdigit(uchar(*p))) p++; /* skip width */
if (isdigit(uchar(*p))) p++; /* (2 digits at most) */
if (*p == '.') {
p++;
if (isdigit(uchar(*p))) p++; /* skip precision */
if (isdigit(uchar(*p))) p++; /* (2 digits at most) */
} }
if (isdigit(uchar(*p))) return s;
luaL_error(L, "invalid format (width or precision too long)"); }
/*
** Chech whether a conversion specification is valid. When called,
** first character in 'form' must be '%' and last character must
** be a valid conversion specifier. 'flags' are the accepted flags;
** 'precision' signals whether to accept a precision.
*/
static void checkformat (lua_State *L, const char *form, const char *flags,
int precision) {
const char *spec = form + 1; /* skip '%' */
spec += strspn(spec, flags); /* skip flags */
if (*spec != '0') { /* a width cannot start with '0' */
spec = get2digits(spec); /* skip width */
if (*spec == '.' && precision) {
spec++;
spec = get2digits(spec); /* skip precision */
}
}
if (!isalpha(uchar(*spec))) /* did not go to the end? */
luaL_error(L, "invalid conversion specification: '%s'", form);
}
/*
** Get a conversion specification and copy it to 'form'.
** Return the address of its last character.
*/
static const char *getformat (lua_State *L, const char *strfrmt,
char *form) {
/* spans flags, width, and precision ('0' is included as a flag) */
size_t len = strspn(strfrmt, L_FMTFLAGSF "123456789.");
len++; /* adds following character (should be the specifier) */
/* still needs space for '%', '\0', plus a length modifier */
if (len >= MAX_FORMAT - 10)
luaL_error(L, "invalid format (too long)");
*(form++) = '%'; *(form++) = '%';
memcpy(form, strfrmt, ((p - strfrmt) + 1) * sizeof(char)); memcpy(form, strfrmt, len * sizeof(char));
form += (p - strfrmt) + 1; *(form + len) = '\0';
*form = '\0'; return strfrmt + len - 1;
return p;
} }
@ -1230,6 +1276,7 @@ static int str_format (lua_State *L) {
size_t sfl; size_t sfl;
const char *strfrmt = luaL_checklstring(L, arg, &sfl); const char *strfrmt = luaL_checklstring(L, arg, &sfl);
const char *strfrmt_end = strfrmt+sfl; const char *strfrmt_end = strfrmt+sfl;
const char *flags;
luaL_Buffer b; luaL_Buffer b;
luaL_buffinit(L, &b); luaL_buffinit(L, &b);
while (strfrmt < strfrmt_end) { while (strfrmt < strfrmt_end) {
@ -1239,25 +1286,35 @@ static int str_format (lua_State *L) {
luaL_addchar(&b, *strfrmt++); /* %% */ luaL_addchar(&b, *strfrmt++); /* %% */
else { /* format item */ else { /* format item */
char form[MAX_FORMAT]; /* to store the format ('%...') */ char form[MAX_FORMAT]; /* to store the format ('%...') */
int maxitem = MAX_ITEM; int maxitem = MAX_ITEM; /* maximum length for the result */
char *buff = luaL_prepbuffsize(&b, maxitem); /* to put formatted item */ char *buff = luaL_prepbuffsize(&b, maxitem); /* to put result */
int nb = 0; /* number of bytes in added item */ int nb = 0; /* number of bytes in result */
if (++arg > top) if (++arg > top)
return luaL_argerror(L, arg, "no value"); return luaL_argerror(L, arg, "no value");
strfrmt = scanformat(L, strfrmt, form); strfrmt = getformat(L, strfrmt, form);
switch (*strfrmt++) { switch (*strfrmt++) {
case 'c': { case 'c': {
checkformat(L, form, L_FMTFLAGSC, 0);
nb = l_sprintf(buff, maxitem, form, (int)luaL_checkinteger(L, arg)); nb = l_sprintf(buff, maxitem, form, (int)luaL_checkinteger(L, arg));
break; break;
} }
case 'd': case 'i': case 'd': case 'i':
case 'o': case 'u': case 'x': case 'X': { flags = L_FMTFLAGSI;
goto intcase;
case 'u':
flags = L_FMTFLAGSU;
goto intcase;
case 'o': case 'x': case 'X':
flags = L_FMTFLAGSX;
intcase: {
lua_Integer n = luaL_checkinteger(L, arg); lua_Integer n = luaL_checkinteger(L, arg);
checkformat(L, form, flags, 1);
addlenmod(form, LUA_INTEGER_FRMLEN); addlenmod(form, LUA_INTEGER_FRMLEN);
nb = l_sprintf(buff, maxitem, form, (LUAI_UACINT)n); nb = l_sprintf(buff, maxitem, form, (LUAI_UACINT)n);
break; break;
} }
case 'a': case 'A': case 'a': case 'A':
checkformat(L, form, L_FMTFLAGSF, 1);
addlenmod(form, LUA_NUMBER_FRMLEN); addlenmod(form, LUA_NUMBER_FRMLEN);
nb = lua_number2strx(L, buff, maxitem, form, nb = lua_number2strx(L, buff, maxitem, form,
luaL_checknumber(L, arg)); luaL_checknumber(L, arg));
@ -1268,12 +1325,14 @@ static int str_format (lua_State *L) {
/* FALLTHROUGH */ /* FALLTHROUGH */
case 'e': case 'E': case 'g': case 'G': { case 'e': case 'E': case 'g': case 'G': {
lua_Number n = luaL_checknumber(L, arg); lua_Number n = luaL_checknumber(L, arg);
checkformat(L, form, L_FMTFLAGSF, 1);
addlenmod(form, LUA_NUMBER_FRMLEN); addlenmod(form, LUA_NUMBER_FRMLEN);
nb = l_sprintf(buff, maxitem, form, (LUAI_UACNUMBER)n); nb = l_sprintf(buff, maxitem, form, (LUAI_UACNUMBER)n);
break; break;
} }
case 'p': { case 'p': {
const void *p = lua_topointer(L, arg); const void *p = lua_topointer(L, arg);
checkformat(L, form, L_FMTFLAGSC, 0);
if (p == NULL) { /* avoid calling 'printf' with argument NULL */ if (p == NULL) { /* avoid calling 'printf' with argument NULL */
p = "(null)"; /* result */ p = "(null)"; /* result */
form[strlen(form) - 1] = 's'; /* format it as a string */ form[strlen(form) - 1] = 's'; /* format it as a string */
@ -1294,7 +1353,8 @@ static int str_format (lua_State *L) {
luaL_addvalue(&b); /* keep entire string */ luaL_addvalue(&b); /* keep entire string */
else { else {
luaL_argcheck(L, l == strlen(s), arg, "string contains zeros"); luaL_argcheck(L, l == strlen(s), arg, "string contains zeros");
if (!strchr(form, '.') && l >= 100) { checkformat(L, form, L_FMTFLAGSC, 1);
if (strchr(form, '.') == NULL && l >= 100) {
/* no precision and string is too long to be formatted */ /* no precision and string is too long to be formatted */
luaL_addvalue(&b); /* keep entire string */ luaL_addvalue(&b); /* keep entire string */
} }

View File

@ -7078,8 +7078,10 @@ following the description given in its first argument,
which must be a string. which must be a string.
The format string follows the same rules as the @ANSI{sprintf}. The format string follows the same rules as the @ANSI{sprintf}.
The only differences are that the conversion specifiers and modifiers The only differences are that the conversion specifiers and modifiers
@T{*}, @id{h}, @id{L}, @id{l}, and @id{n} are not supported @id{F}, @id{n}, @T{*}, @id{h}, @id{L}, and @id{l} are not supported
and that there is an extra specifier, @id{q}. and that there is an extra specifier, @id{q}.
Both width and precision, when present,
are limited to two digits.
The specifier @id{q} formats booleans, nil, numbers, and strings The specifier @id{q} formats booleans, nil, numbers, and strings
in a way that the result is a valid constant in Lua source code. in a way that the result is a valid constant in Lua source code.
@ -7099,7 +7101,7 @@ may produce the string:
"a string with \"quotes\" and \ "a string with \"quotes\" and \
new line" new line"
} }
This specifier does not support modifiers (flags, width, length). This specifier does not support modifiers (flags, width, precision).
The conversion specifiers The conversion specifiers
@id{A}, @id{a}, @id{E}, @id{e}, @id{f}, @id{A}, @id{a}, @id{E}, @id{e}, @id{f},

View File

@ -202,13 +202,11 @@ assert(string.format("\0%c\0%c%x\0", string.byte("\xe4"), string.byte("b"), 140)
"\0\xe4\0b8c\0") "\0\xe4\0b8c\0")
assert(string.format('') == "") assert(string.format('') == "")
assert(string.format("%c",34)..string.format("%c",48)..string.format("%c",90)..string.format("%c",100) == assert(string.format("%c",34)..string.format("%c",48)..string.format("%c",90)..string.format("%c",100) ==
string.format("%c%c%c%c", 34, 48, 90, 100)) string.format("%1c%-c%-1c%c", 34, 48, 90, 100))
assert(string.format("%s\0 is not \0%s", 'not be', 'be') == 'not be\0 is not \0be') assert(string.format("%s\0 is not \0%s", 'not be', 'be') == 'not be\0 is not \0be')
assert(string.format("%%%d %010d", 10, 23) == "%10 0000000023") assert(string.format("%%%d %010d", 10, 23) == "%10 0000000023")
assert(tonumber(string.format("%f", 10.3)) == 10.3) assert(tonumber(string.format("%f", 10.3)) == 10.3)
x = string.format('"%-50s"', 'a') assert(string.format('"%-50s"', 'a') == '"a' .. string.rep(' ', 49) .. '"')
assert(#x == 52)
assert(string.sub(x, 1, 4) == '"a ')
assert(string.format("-%.20s.20s", string.rep("%", 2000)) == assert(string.format("-%.20s.20s", string.rep("%", 2000)) ==
"-"..string.rep("%", 20)..".20s") "-"..string.rep("%", 20)..".20s")
@ -237,7 +235,6 @@ end
assert(string.format("\0%s\0", "\0\0\1") == "\0\0\0\1\0") assert(string.format("\0%s\0", "\0\0\1") == "\0\0\0\1\0")
checkerror("contains zeros", string.format, "%10s", "\0") checkerror("contains zeros", string.format, "%10s", "\0")
checkerror("cannot have modifiers", string.format, "%10q", "1")
-- format x tostring -- format x tostring
assert(string.format("%s %s", nil, true) == "nil true") assert(string.format("%s %s", nil, true) == "nil true")
@ -341,6 +338,21 @@ do print("testing 'format %a %A'")
end end
-- testing some flags (all these results are required by ISO C)
assert(string.format("%#12o", 10) == " 012")
assert(string.format("%#10x", 100) == " 0x64")
assert(string.format("%#-17X", 100) == "0X64 ")
assert(string.format("%013i", -100) == "-000000000100")
assert(string.format("%2.5d", -100) == "-00100")
assert(string.format("%.u", 0) == "")
assert(string.format("%+#014.0f", 100) == "+000000000100.")
assert(string.format("% 1.0E", 100) == " 1E+02")
assert(string.format("%-16c", 97) == "a ")
assert(string.format("%+.3G", 1.5) == "+1.5")
assert(string.format("% .1g", 2^10) == " 1e+03")
assert(string.format("%.0s", "alo") == "")
assert(string.format("%.s", "alo") == "")
-- errors in format -- errors in format
local function check (fmt, msg) local function check (fmt, msg)
@ -348,13 +360,21 @@ local function check (fmt, msg)
end end
local aux = string.rep('0', 600) local aux = string.rep('0', 600)
check("%100.3d", "too long") check("%100.3d", "invalid conversion")
check("%1"..aux..".3d", "too long") check("%1"..aux..".3d", "too long")
check("%1.100d", "too long") check("%1.100d", "invalid conversion")
check("%10.1"..aux.."004d", "too long") check("%10.1"..aux.."004d", "too long")
check("%t", "invalid conversion") check("%t", "invalid conversion")
check("%"..aux.."d", "repeated flags") check("%"..aux.."d", "too long")
check("%d %d", "no value") check("%d %d", "no value")
check("%010c", "invalid conversion")
check("%.10c", "invalid conversion")
check("%0.34s", "invalid conversion")
check("%#i", "invalid conversion")
check("%3.1p", "invalid conversion")
check("%0.s", "invalid conversion")
check("%10q", "cannot have modifiers")
check("%F", "invalid conversion") -- useless and not in C89
assert(load("return 1\n--comment without ending EOL")() == 1) assert(load("return 1\n--comment without ending EOL")() == 1)