diff --git a/lstrlib.c b/lstrlib.c index 4d8ff6f9..a257e872 100644 --- a/lstrlib.c +++ b/lstrlib.c @@ -1,5 +1,5 @@ /* -** $Id: lstrlib.c,v 1.231 2015/06/24 18:25:10 roberto Exp roberto $ +** $Id: lstrlib.c,v 1.232 2015/07/20 16:30:22 roberto Exp roberto $ ** Standard library for string operations and pattern-matching ** See Copyright Notice in lua.h */ @@ -41,8 +41,10 @@ ** Some sizes are better limited to fit in 'int', but must also fit in ** 'size_t'. (We assume that 'lua_Integer' cannot be smaller than 'int'.) */ +#define MAX_SIZET ((size_t)(~(size_t)0)) + #define MAXSIZE \ - (sizeof(size_t) < sizeof(int) ? (~(size_t)0) : (size_t)(INT_MAX)) + (sizeof(size_t) < sizeof(int) ? MAX_SIZET : (size_t)(INT_MAX)) @@ -208,11 +210,12 @@ static int str_dump (lua_State *L) { typedef struct MatchState { - int matchdepth; /* control for recursive depth (to avoid C stack overflow) */ const char *src_init; /* init of source string */ const char *src_end; /* end ('\0') of source string */ const char *p_end; /* end ('\0') of pattern */ lua_State *L; + size_t nrep; /* limit to avoid non-linear complexity */ + int matchdepth; /* control for recursive depth (to avoid C stack overflow) */ int level; /* total number of captures (finished or unfinished) */ struct { const char *init; @@ -231,6 +234,17 @@ static const char *match (MatchState *ms, const char *s, const char *p); #endif +/* +** parameters to control the maximum number of operators handled in +** a match (to avoid non-linear complexity). The maximum will be: +** (subject length) * A_REPS + B_REPS +*/ +#if !defined(A_REPS) +#define A_REPS 4 +#define B_REPS 100000 +#endif + + #define L_ESC '%' #define SPECIALS "^$*+?.([%-" @@ -488,6 +502,8 @@ static const char *match (MatchState *ms, const char *s, const char *p) { s = NULL; /* fail */ } else { /* matched once */ + if (ms->nrep-- == 0) + luaL_error(ms->L, "pattern too complex"); switch (*ep) { /* handle optional suffix */ case '?': { /* optional */ const char *res; @@ -584,6 +600,26 @@ static int nospecials (const char *p, size_t l) { } +static void prepstate (MatchState *ms, lua_State *L, + const char *s, size_t ls, const char *p, size_t lp) { + ms->L = L; + ms->matchdepth = MAXCCALLS; + ms->src_init = s; + ms->src_end = s + ls; + ms->p_end = p + lp; + if (ls < (MAX_SIZET - B_REPS) / A_REPS) + ms->nrep = A_REPS * ls + B_REPS; + else /* overflow (very long subject) */ + ms->nrep = MAX_SIZET; /* no limit */ +} + + +static void reprepstate (MatchState *ms) { + ms->level = 0; + lua_assert(ms->matchdepth == MAXCCALLS); +} + + static int str_find_aux (lua_State *L, int find) { size_t ls, lp; const char *s = luaL_checklstring(L, 1, &ls); @@ -611,15 +647,10 @@ static int str_find_aux (lua_State *L, int find) { if (anchor) { p++; lp--; /* skip anchor character */ } - ms.L = L; - ms.matchdepth = MAXCCALLS; - ms.src_init = s; - ms.src_end = s + ls; - ms.p_end = p + lp; + prepstate(&ms, L, s, ls, p, lp); do { const char *res; - ms.level = 0; - lua_assert(ms.matchdepth == MAXCCALLS); + reprepstate(&ms); if ((res=match(&ms, s1, p)) != NULL) { if (find) { lua_pushinteger(L, (s1 - s) + 1); /* start */ @@ -652,17 +683,12 @@ static int gmatch_aux (lua_State *L) { const char *s = lua_tolstring(L, lua_upvalueindex(1), &ls); const char *p = lua_tolstring(L, lua_upvalueindex(2), &lp); const char *src; - ms.L = L; - ms.matchdepth = MAXCCALLS; - ms.src_init = s; - ms.src_end = s+ls; - ms.p_end = p + lp; + prepstate(&ms, L, s, ls, p, lp); for (src = s + (size_t)lua_tointeger(L, lua_upvalueindex(3)); src <= ms.src_end; src++) { const char *e; - ms.level = 0; - lua_assert(ms.matchdepth == MAXCCALLS); + reprepstate(&ms); if ((e = match(&ms, src, p)) != NULL) { lua_Integer newstart = e-s; if (e == src) newstart++; /* empty match? go at least one position */ @@ -761,17 +787,11 @@ static int str_gsub (lua_State *L) { if (anchor) { p++; lp--; /* skip anchor character */ } - ms.L = L; - ms.matchdepth = MAXCCALLS; - ms.src_init = src; - ms.src_end = src+srcl; - ms.p_end = p + lp; + prepstate(&ms, L, src, srcl, p, lp); while (n < max_s) { const char *e; - ms.level = 0; - lua_assert(ms.matchdepth == MAXCCALLS); - e = match(&ms, src, p); - if (e) { + reprepstate(&ms); + if ((e = match(&ms, src, p)) != NULL) { n++; add_value(&ms, &b, src, e, tr); }