implemented counter to abort non-linear behavior in pattern matching

This commit is contained in:
Roberto Ierusalimschy 2015-09-26 15:45:03 -03:00
parent 9fae7b6d3f
commit 8264dbc2bb
1 changed files with 46 additions and 26 deletions

View File

@ -1,5 +1,5 @@
/*
** $Id: lstrlib.c,v 1.231 2015/06/24 18:25:10 roberto Exp roberto $
** $Id: lstrlib.c,v 1.232 2015/07/20 16:30:22 roberto Exp roberto $
** Standard library for string operations and pattern-matching
** See Copyright Notice in lua.h
*/
@ -41,8 +41,10 @@
** Some sizes are better limited to fit in 'int', but must also fit in
** 'size_t'. (We assume that 'lua_Integer' cannot be smaller than 'int'.)
*/
#define MAX_SIZET ((size_t)(~(size_t)0))
#define MAXSIZE \
(sizeof(size_t) < sizeof(int) ? (~(size_t)0) : (size_t)(INT_MAX))
(sizeof(size_t) < sizeof(int) ? MAX_SIZET : (size_t)(INT_MAX))
@ -208,11 +210,12 @@ static int str_dump (lua_State *L) {
typedef struct MatchState {
int matchdepth; /* control for recursive depth (to avoid C stack overflow) */
const char *src_init; /* init of source string */
const char *src_end; /* end ('\0') of source string */
const char *p_end; /* end ('\0') of pattern */
lua_State *L;
size_t nrep; /* limit to avoid non-linear complexity */
int matchdepth; /* control for recursive depth (to avoid C stack overflow) */
int level; /* total number of captures (finished or unfinished) */
struct {
const char *init;
@ -231,6 +234,17 @@ static const char *match (MatchState *ms, const char *s, const char *p);
#endif
/*
** parameters to control the maximum number of operators handled in
** a match (to avoid non-linear complexity). The maximum will be:
** (subject length) * A_REPS + B_REPS
*/
#if !defined(A_REPS)
#define A_REPS 4
#define B_REPS 100000
#endif
#define L_ESC '%'
#define SPECIALS "^$*+?.([%-"
@ -488,6 +502,8 @@ static const char *match (MatchState *ms, const char *s, const char *p) {
s = NULL; /* fail */
}
else { /* matched once */
if (ms->nrep-- == 0)
luaL_error(ms->L, "pattern too complex");
switch (*ep) { /* handle optional suffix */
case '?': { /* optional */
const char *res;
@ -584,6 +600,26 @@ static int nospecials (const char *p, size_t l) {
}
static void prepstate (MatchState *ms, lua_State *L,
const char *s, size_t ls, const char *p, size_t lp) {
ms->L = L;
ms->matchdepth = MAXCCALLS;
ms->src_init = s;
ms->src_end = s + ls;
ms->p_end = p + lp;
if (ls < (MAX_SIZET - B_REPS) / A_REPS)
ms->nrep = A_REPS * ls + B_REPS;
else /* overflow (very long subject) */
ms->nrep = MAX_SIZET; /* no limit */
}
static void reprepstate (MatchState *ms) {
ms->level = 0;
lua_assert(ms->matchdepth == MAXCCALLS);
}
static int str_find_aux (lua_State *L, int find) {
size_t ls, lp;
const char *s = luaL_checklstring(L, 1, &ls);
@ -611,15 +647,10 @@ static int str_find_aux (lua_State *L, int find) {
if (anchor) {
p++; lp--; /* skip anchor character */
}
ms.L = L;
ms.matchdepth = MAXCCALLS;
ms.src_init = s;
ms.src_end = s + ls;
ms.p_end = p + lp;
prepstate(&ms, L, s, ls, p, lp);
do {
const char *res;
ms.level = 0;
lua_assert(ms.matchdepth == MAXCCALLS);
reprepstate(&ms);
if ((res=match(&ms, s1, p)) != NULL) {
if (find) {
lua_pushinteger(L, (s1 - s) + 1); /* start */
@ -652,17 +683,12 @@ static int gmatch_aux (lua_State *L) {
const char *s = lua_tolstring(L, lua_upvalueindex(1), &ls);
const char *p = lua_tolstring(L, lua_upvalueindex(2), &lp);
const char *src;
ms.L = L;
ms.matchdepth = MAXCCALLS;
ms.src_init = s;
ms.src_end = s+ls;
ms.p_end = p + lp;
prepstate(&ms, L, s, ls, p, lp);
for (src = s + (size_t)lua_tointeger(L, lua_upvalueindex(3));
src <= ms.src_end;
src++) {
const char *e;
ms.level = 0;
lua_assert(ms.matchdepth == MAXCCALLS);
reprepstate(&ms);
if ((e = match(&ms, src, p)) != NULL) {
lua_Integer newstart = e-s;
if (e == src) newstart++; /* empty match? go at least one position */
@ -761,17 +787,11 @@ static int str_gsub (lua_State *L) {
if (anchor) {
p++; lp--; /* skip anchor character */
}
ms.L = L;
ms.matchdepth = MAXCCALLS;
ms.src_init = src;
ms.src_end = src+srcl;
ms.p_end = p + lp;
prepstate(&ms, L, src, srcl, p, lp);
while (n < max_s) {
const char *e;
ms.level = 0;
lua_assert(ms.matchdepth == MAXCCALLS);
e = match(&ms, src, p);
if (e) {
reprepstate(&ms);
if ((e = match(&ms, src, p)) != NULL) {
n++;
add_value(&ms, &b, src, e, tr);
}