lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


A while back I posted a commercial job offer for someone to help add
in-place mutation operators to Lua (+=, -=, *=, /=, %= and ^=). My
domain is in a scripting language for games, and a good deal of the
lua scripts consist of changing variable values, such as:

target.health -= 2
player.ammo.shells -= 1

so this saves a *lot* of typing.


Robert Jakabosky helped me out with the project (for the record, he's
good, hire him ;)), and I want to contribute the code he built for me
back to the community. The patch is below, a patch off 5.1.4 (it
applies clean with "git apply" for me). I have added this to the Lua
Power Patches page at

http://lua-users.org/wiki/LuaPowerPatches

But I didn't quite know how to fill out the "Lua authors' position" field :)


The patch can also be retrieved by forking or cloning my fork of James
Snyder's github of Lua and checking out the mutation-operators branch:

https://github.com/idmillington/lua/tree/mutation-operators

As the owner of this code, I hereby release it under the same MIT
license as Lua 5.0. If anyone has any need of a different license, let
me know. Robert also built a whole bunch of regression and feature
tests for this code. Which I am also happy to release under the same
license. They aren't attached here, so get in touch if you'd like
them.

Thanks to everyone who offered to help with this job. There were so
many of you with excellent skills. Sorry I had to pick one.

Ian.

---


diff --git a/src/lcode.c b/src/lcode.c
index e9aa88d..89e113f 100644
--- a/src/lcode.c
+++ b/src/lcode.c
@@ -684,6 +684,73 @@ static void codearith (FuncState *fs, OpCode op,
expdesc *e1, expdesc *e2) {
 }


+static void codecompound (FuncState *fs, OpCode op, expdesc *e1, expdesc *e2) {
+  int o1;
+  int o2;
+
+  /* load expresion 2 into a register. */
+  o2 = luaK_exp2RK(fs, e2);
+
+  switch (e1->k) {
+    case VLOCAL: {
+      // compound opcode
+      luaK_codeABC(fs, op, e1->u.s.info, o2, 0);
+      return;
+    }
+    case VUPVAL: {
+      // allocate temp. register
+      o1 = fs->freereg;
+      luaK_reserveregs(fs, 1);
+      // load upval into temp. register
+      luaK_codeABC(fs, OP_GETUPVAL, o1, e1->u.s.info, 0);
+      // compound opcode
+      luaK_codeABC(fs, op, o1, o2, 0);
+      // store results back to upval
+      luaK_codeABC(fs, OP_SETUPVAL, o1, e1->u.s.info, 0);
+      // free temp. register
+      freereg(fs, o1);
+      break;
+    }
+    case VGLOBAL: {
+      // allocate temp. register
+      o1 = fs->freereg;
+      luaK_reserveregs(fs, 1);
+      // load global into temp. register
+      luaK_codeABx(fs, OP_GETGLOBAL, o1, e1->u.s.info);
+      // compound opcode
+      luaK_codeABC(fs, op, o1, o2, 0);
+      // store results back to global
+      luaK_codeABx(fs, OP_SETGLOBAL, o1, e1->u.s.info);
+      // free temp. register
+      freereg(fs, o1);
+      break;
+    }
+    case VINDEXED: {
+      // allocate temp. register
+      o1 = fs->freereg;
+      luaK_reserveregs(fs, 1);
+      // load indexed value into temp. register
+      luaK_codeABC(fs, OP_GETTABLE, o1, e1->u.s.info, e1->u.s.aux);
+      // compound opcode
+      luaK_codeABC(fs, op, o1, o2, 0);
+      // store results back to indexed value
+      luaK_codeABC(fs, OP_SETTABLE, e1->u.s.info, e1->u.s.aux, o1);
+      // free temp. register
+      freereg(fs, o1);
+      freereg(fs, e1->u.s.aux);
+      freereg(fs, e1->u.s.info);
+      break;
+    }
+    default: {
+      lua_assert(0);  /* invalid var kind to store */
+      break;
+    }
+  }
+  /* free register for expression 2 */
+  freeexp(fs, e2);
+}
+
+
 static void codecomp (FuncState *fs, OpCode op, int cond, expdesc *e1,
                                                           expdesc *e2) {
   int o1 = luaK_exp2RK(fs, e1);
@@ -784,6 +851,12 @@ void luaK_posfix (FuncState *fs, BinOpr op,
expdesc *e1, expdesc *e2) {
     case OPR_DIV: codearith(fs, OP_DIV, e1, e2); break;
     case OPR_MOD: codearith(fs, OP_MOD, e1, e2); break;
     case OPR_POW: codearith(fs, OP_POW, e1, e2); break;
+    case OPR_ADD_EQ: codecompound(fs, OP_ADD_EQ, e1, e2); break;
+    case OPR_SUB_EQ: codecompound(fs, OP_SUB_EQ, e1, e2); break;
+    case OPR_MUL_EQ: codecompound(fs, OP_MUL_EQ, e1, e2); break;
+    case OPR_DIV_EQ: codecompound(fs, OP_DIV_EQ, e1, e2); break;
+    case OPR_MOD_EQ: codecompound(fs, OP_MOD_EQ, e1, e2); break;
+    case OPR_POW_EQ: codecompound(fs, OP_POW_EQ, e1, e2); break;
     case OPR_EQ: codecomp(fs, OP_EQ, 1, e1, e2); break;
     case OPR_NE: codecomp(fs, OP_EQ, 0, e1, e2); break;
     case OPR_LT: codecomp(fs, OP_LT, 1, e1, e2); break;
diff --git a/src/lcode.h b/src/lcode.h
index b941c60..e851bca 100644
--- a/src/lcode.h
+++ b/src/lcode.h
@@ -29,6 +29,7 @@ typedef enum BinOpr {
   OPR_NE, OPR_EQ,
   OPR_LT, OPR_LE, OPR_GT, OPR_GE,
   OPR_AND, OPR_OR,
+  OPR_ADD_EQ, OPR_SUB_EQ, OPR_MUL_EQ, OPR_DIV_EQ, OPR_MOD_EQ, OPR_POW_EQ,
   OPR_NOBINOPR
 } BinOpr;

diff --git a/src/llex.c b/src/llex.c
index 88c6790..9248bda 100644
--- a/src/llex.c
+++ b/src/llex.c
@@ -40,6 +40,7 @@ const char *const luaX_tokens [] = {
     "in", "local", "nil", "not", "or", "repeat",
     "return", "then", "true", "until", "while",
     "..", "...", "==", ">=", "<=", "~=",
+    "+=", "-=", "*=", "/=", "%=", "^=",
     "<number>", "<name>", "<string>", "<eof>",
     NULL
 };
@@ -342,6 +343,10 @@ static int llex (LexState *ls, SemInfo *seminfo) {
       }
       case '-': {
         next(ls);
+        if (ls->current == '=') {
+          next(ls);
+          return TK_SUB_EQ;
+        }
         if (ls->current != '-') return '-';
         /* else is a comment */
         next(ls);
@@ -388,6 +393,31 @@ static int llex (LexState *ls, SemInfo *seminfo) {
         if (ls->current != '=') return '~';
         else { next(ls); return TK_NE; }
       }
+      case '+': {
+        next(ls);
+        if (ls->current != '=') return '+';
+        else { next(ls); return TK_ADD_EQ; }
+      }
+      case '*': {
+        next(ls);
+        if (ls->current != '=') return '*';
+        else { next(ls); return TK_MUL_EQ; }
+      }
+      case '/': {
+        next(ls);
+        if (ls->current != '=') return '/';
+        else { next(ls); return TK_DIV_EQ; }
+      }
+      case '%': {
+        next(ls);
+        if (ls->current != '=') return '%';
+        else { next(ls); return TK_MOD_EQ; }
+      }
+      case '^': {
+        next(ls);
+        if (ls->current != '=') return '^';
+        else { next(ls); return TK_POW_EQ; }
+      }
       case '"':
       case '\'': {
         read_string(ls, ls->current, seminfo);
diff --git a/src/llex.h b/src/llex.h
index a9201ce..cdcb213 100644
--- a/src/llex.h
+++ b/src/llex.h
@@ -28,8 +28,9 @@ enum RESERVED {
   TK_IF, TK_IN, TK_LOCAL, TK_NIL, TK_NOT, TK_OR, TK_REPEAT,
   TK_RETURN, TK_THEN, TK_TRUE, TK_UNTIL, TK_WHILE,
   /* other terminal symbols */
-  TK_CONCAT, TK_DOTS, TK_EQ, TK_GE, TK_LE, TK_NE, TK_NUMBER,
-  TK_NAME, TK_STRING, TK_EOS
+  TK_CONCAT, TK_DOTS, TK_EQ, TK_GE, TK_LE, TK_NE,
+  TK_ADD_EQ, TK_SUB_EQ, TK_MUL_EQ, TK_DIV_EQ, TK_MOD_EQ, TK_POW_EQ,
+  TK_NUMBER, TK_NAME, TK_STRING, TK_EOS
 };

 /* number of reserved words */
diff --git a/src/lopcodes.c b/src/lopcodes.c
index 4cc7452..f5f8331 100644
--- a/src/lopcodes.c
+++ b/src/lopcodes.c
@@ -52,6 +52,12 @@ const char *const luaP_opnames[NUM_OPCODES+1] = {
   "CLOSE",
   "CLOSURE",
   "VARARG",
+  "ADD_EQ",
+  "SUB_EQ",
+  "MUL_EQ",
+  "DIV_EQ",
+  "MOD_EQ",
+  "POW_EQ",
   NULL
 };

@@ -98,5 +104,12 @@ const lu_byte luaP_opmodes[NUM_OPCODES] = {
  ,opmode(0, 0, OpArgN, OpArgN, iABC)		/* OP_CLOSE */
  ,opmode(0, 1, OpArgU, OpArgN, iABx)		/* OP_CLOSURE */
  ,opmode(0, 1, OpArgU, OpArgN, iABC)		/* OP_VARARG */
+/* NEW: opcodes */
+ ,opmode(0, 1, OpArgK, OpArgN, iABC)		/* OP_ADD_EQ */
+ ,opmode(0, 1, OpArgK, OpArgN, iABC)		/* OP_SUB_EQ */
+ ,opmode(0, 1, OpArgK, OpArgN, iABC)		/* OP_MUL_EQ */
+ ,opmode(0, 1, OpArgK, OpArgN, iABC)		/* OP_DIV_EQ */
+ ,opmode(0, 1, OpArgK, OpArgN, iABC)		/* OP_MOD_EQ */
+ ,opmode(0, 1, OpArgK, OpArgN, iABC)		/* OP_POW_EQ */
 };

diff --git a/src/lopcodes.h b/src/lopcodes.h
index 41224d6..7340dc6 100644
--- a/src/lopcodes.h
+++ b/src/lopcodes.h
@@ -204,11 +204,18 @@ OP_SETLIST,/*	A B C	R(A)[(C-1)*FPF+i] := R(A+i),
1 <= i <= B	*/
 OP_CLOSE,/*	A 	close all variables in the stack up to (>=) R(A)*/
 OP_CLOSURE,/*	A Bx	R(A) := closure(KPROTO[Bx], R(A), ... ,R(A+n))	*/

-OP_VARARG/*	A B	R(A), R(A+1), ..., R(A+B-1) = vararg		*/
+OP_VARARG,/*	A B	R(A), R(A+1), ..., R(A+B-1) = vararg		*/
+
+OP_ADD_EQ,/*	A B R(A) := R(A) + RK(B)				*/
+OP_SUB_EQ,/*	A B R(A) := R(A) - RK(B)				*/
+OP_MUL_EQ,/*	A B R(A) := R(A) * RK(B)				*/
+OP_DIV_EQ,/*	A B R(A) := R(A) / RK(B)				*/
+OP_MOD_EQ,/*	A B R(A) := R(A) % RK(B)				*/
+OP_POW_EQ,/*	A B R(A) := R(A) ^ RK(B)				*/
 } OpCode;


-#define NUM_OPCODES	(cast(int, OP_VARARG) + 1)
+#define NUM_OPCODES	(cast(int, OP_POW_EQ) + 1)



diff --git a/src/lparser.c b/src/lparser.c
index 1e2a9a8..8bb66e4 100644
--- a/src/lparser.c
+++ b/src/lparser.c
@@ -944,7 +944,7 @@ static void assignment (LexState *ls, struct
LHS_assign *lh, int nvars) {
   }
   else {  /* assignment -> `=' explist1 */
     int nexps;
-    checknext(ls, '=');
+    luaX_next(ls); /* consume `=' token. */
     nexps = explist1(ls, &e);
     if (nexps != nvars) {
       adjust_assign(ls, nvars, nexps, &e);
@@ -962,6 +962,38 @@ static void assignment (LexState *ls, struct
LHS_assign *lh, int nvars) {
 }


+static BinOpr getcompopr (int op) {
+  switch (op) {
+    case TK_ADD_EQ: return OPR_ADD_EQ;
+    case TK_SUB_EQ: return OPR_SUB_EQ;
+    case TK_MUL_EQ: return OPR_MUL_EQ;
+    case TK_DIV_EQ: return OPR_DIV_EQ;
+    case TK_MOD_EQ: return OPR_MOD_EQ;
+    case TK_POW_EQ: return OPR_POW_EQ;
+    default: return OPR_NOBINOPR;
+  }
+}
+
+
+static void compound (LexState *ls, struct LHS_assign *lh) {
+  expdesc rh;
+  int nexps;
+  BinOpr op;
+
+  check_condition(ls, VLOCAL <= lh->v.k && lh->v.k <= VINDEXED,
+                      "syntax error");
+  /* parse Compound operation. */
+  op = getcompopr(ls->t.token);
+  luaX_next(ls);
+
+  /* parse right-hand expression */
+  nexps = explist1(ls, &rh);
+  check_condition(ls, nexps == 1, "syntax error");
+
+  luaK_posfix(ls->fs, op, &(lh->v), &rh);
+}
+
+
 static int cond (LexState *ls) {
   /* cond -> exp */
   expdesc v;
@@ -1222,15 +1254,33 @@ static void funcstat (LexState *ls, int line) {


 static void exprstat (LexState *ls) {
-  /* stat -> func | assignment */
+  /* stat -> func | compound | assignment */
   FuncState *fs = ls->fs;
   struct LHS_assign v;
   primaryexp(ls, &v.v);
   if (v.v.k == VCALL)  /* stat -> func */
     SETARG_C(getcode(fs, &v.v), 1);  /* call statement uses no results */
-  else {  /* stat -> assignment */
+  else {  /* stat -> compound | assignment */
     v.prev = NULL;
-    assignment(ls, &v, 1);
+    switch(ls->t.token) {
+      case TK_ADD_EQ:
+      case TK_SUB_EQ:
+      case TK_MUL_EQ:
+      case TK_DIV_EQ:
+      case TK_MOD_EQ:
+      case TK_POW_EQ:
+        compound(ls, &v);
+        break;
+      case ',':
+      case '=':
+        assignment(ls, &v, 1);
+        break;
+      default:
+        luaX_syntaxerror(ls,
+          luaO_pushfstring(ls->L,
+                           "'+=','-=','*=', '/=', '%=', '^=', '=' expected"));
+        break;
+    }
   }
 }

diff --git a/src/ltm.c b/src/ltm.c
index c27f0f6..322ece8 100644
--- a/src/ltm.c
+++ b/src/ltm.c
@@ -33,7 +33,9 @@ void luaT_init (lua_State *L) {
     "__gc", "__mode", "__eq",
     "__add", "__sub", "__mul", "__div", "__mod",
     "__pow", "__unm", "__len", "__lt", "__le",
-    "__concat", "__call"
+    "__concat", "__call",
+    "__add_eq", "__sub_eq", "__mul_eq", "__div_eq", "__mod_eq", "__pow_eq"
+
   };
   int i;
   for (i=0; i<TM_N; i++) {
diff --git a/src/ltm.h b/src/ltm.h
index 64343b7..0ab68d9 100644
--- a/src/ltm.h
+++ b/src/ltm.h
@@ -33,6 +33,12 @@ typedef enum {
   TM_LE,
   TM_CONCAT,
   TM_CALL,
+  TM_ADD_EQ,
+  TM_SUB_EQ,
+  TM_MUL_EQ,
+  TM_DIV_EQ,
+  TM_MOD_EQ,
+  TM_POW_EQ,
   TM_N		/* number of elements in the enum */
 } TMS;

diff --git a/src/lvm.c b/src/lvm.c
index 8aeafda..3a4a624 100644
--- a/src/lvm.c
+++ b/src/lvm.c
@@ -172,6 +172,15 @@ static int call_binTM (lua_State *L, const TValue
*p1, const TValue *p2,
 }


+static int call_compTM (lua_State *L, const TValue *p1, const TValue *p2,
+                       StkId res, TMS event) {
+  const TValue *tm = luaT_gettmbyobj(L, p1, event);  /* try first operand */
+  if (ttisnil(tm)) return 0;
+  callTMres(L, res, tm, p1, p2);
+  return 1;
+}
+
+
 static const TValue *get_compTM (lua_State *L, Table *mt1, Table *mt2,
                                   TMS event) {
   const TValue *tm1 = fasttm(L, mt1, event);
@@ -336,6 +345,29 @@ static void Arith (lua_State *L, StkId ra, const
TValue *rb,
 }


+static void Compound (lua_State *L, StkId ra, const TValue *rb, TMS op) {
+  TValue tempa;
+  TValue tempb;
+  const TValue *a, *b;
+  if ((a = luaV_tonumber(ra, &tempa)) != NULL &&
+      (b = luaV_tonumber(rb, &tempb)) != NULL) {
+    lua_Number na = nvalue(a), nb = nvalue(b);
+    switch (op) {
+      case TM_ADD_EQ: setnvalue(ra, luai_numadd(na, nb)); break;
+      case TM_SUB_EQ: setnvalue(ra, luai_numsub(na, nb)); break;
+      case TM_MUL_EQ: setnvalue(ra, luai_nummul(na, nb)); break;
+      case TM_DIV_EQ: setnvalue(ra, luai_numdiv(na, nb)); break;
+      case TM_MOD_EQ: setnvalue(ra, luai_nummod(na, nb)); break;
+      case TM_POW_EQ: setnvalue(ra, luai_numpow(na, nb)); break;
+      default: lua_assert(0); break;
+    }
+  }
+  else if (!call_compTM(L, ra, rb, ra, op))
+    luaG_aritherror(L, ra, rb);
+}
+
+
+

 /*
 ** some macros for common tasks in `luaV_execute'
@@ -371,6 +403,17 @@ static void Arith (lua_State *L, StkId ra, const
TValue *rb,
           Protect(Arith(L, ra, rb, rc, tm)); \
       }

+#define compound_op(op,tm) { \
+        TValue *rb = RKB(i); \
+        if (ttisnumber(ra) && ttisnumber(rb)) { \
+          lua_Number na = nvalue(ra), nb = nvalue(rb); \
+          setnvalue(ra, op(na, nb)); \
+        } \
+        else \
+          Protect(Compound(L, ra, rb, tm)); \
+      }
+
+


 void luaV_execute (lua_State *L, int nexeccalls) {
@@ -760,6 +803,30 @@ void luaV_execute (lua_State *L, int nexeccalls) {
         }
         continue;
       }
+      case OP_ADD_EQ: {
+        compound_op(luai_numadd, TM_ADD_EQ);
+        continue;
+      }
+      case OP_SUB_EQ: {
+        compound_op(luai_numsub, TM_SUB_EQ);
+        continue;
+      }
+      case OP_MUL_EQ: {
+        compound_op(luai_nummul, TM_MUL_EQ);
+        continue;
+      }
+      case OP_DIV_EQ: {
+        compound_op(luai_numdiv, TM_DIV_EQ);
+        continue;
+      }
+      case OP_MOD_EQ: {
+        compound_op(luai_nummod, TM_MOD_EQ);
+        continue;
+      }
+      case OP_POW_EQ: {
+        compound_op(luai_numpow, TM_POW_EQ);
+        continue;
+      }
     }
   }
 }