diff --git a/.gitignore b/.gitignore
index b77f87e..885f96e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,6 @@
web/
-/data
-/.settings
+/data/programs.json
+/data/sd.json
+/data/snames.txt
*.pyc
-.project
-.pydevproject
+/static/log/
diff --git a/.project b/.project
new file mode 100644
index 0000000..a77e25c
--- /dev/null
+++ b/.project
@@ -0,0 +1,17 @@
+
+
+ OSPi
+
+
+
+
+
+ org.python.pydev.PyDevBuilder
+
+
+
+
+
+ org.python.pydev.pythonNature
+
+
diff --git a/.pydevproject b/.pydevproject
new file mode 100644
index 0000000..2783106
--- /dev/null
+++ b/.pydevproject
@@ -0,0 +1,10 @@
+
+
+
+
+
+/${PROJECT_DIR_NAME}
+
+python 2.7
+Default
+
diff --git a/README.md b/README.md
index bf2866f..bf666fc 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,14 @@ June 2013, http://rayshobby.net
UPDATES
===========
***********
-
+
+October 16 2013
+--------------
+(Dan)
+Additions, bug fixes:
+1. Fixed a bug that would cause an error in program preview when a master was enabled.
+2. Changing to manual mode would clear rain delay setting, Setting rain delay in manual mode would switch to program mode - fixed.
+
October 11 2013
--------------
(Dan)
diff --git a/data/meta.txt b/data/meta.txt
index 4d19ddf..e11637c 100644
--- a/data/meta.txt
+++ b/data/meta.txt
@@ -1,3 +1,3 @@
-
-
-
+
+
+
diff --git a/ospi.py b/ospi.py
index 2b88482..ea398c1 100644
--- a/ospi.py
+++ b/ospi.py
@@ -11,10 +11,10 @@ except ImportError:
#### Revision information ####
gv.ver = 183
-gv.rev = 138
-gv.rev_date = '11/October/2013'
+gv.rev = 139
+gv.rev_date = '16/October/2013'
- #### urls is a feature of web.py. When a GET request is recieved , the corrisponding class is executed.
+ #### urls is a feature of web.py. When a GET request is received , the corresponding class is executed.
urls = [
'/', 'home',
'/cv', 'change_values',
@@ -193,8 +193,8 @@ def stop_stations():
return
def main_loop(): # Runs in a separate thread
- """ ***** Main algorithm.***** """
- print 'Starting main loop \n'
+ """ ***** Main timing algorithm.***** """
+ print 'Starting timing loop \n'
last_min = 0
while True: # infinite loop
gv.now = time.time()+((gv.sd['tz']/4)-12)*3600 # Current time based on UTC time from the Pi adjusted by the Time Zone setting from options. updated once per second.
@@ -309,7 +309,7 @@ def main_loop(): # Runs in a separate thread
gv.sd['rdst'] = 0 # Rain delay stop time
jsave(gv.sd, 'sd')
time.sleep(1)
- #### End of main loop ####
+ #### End of timing loop ####
def data(dataf):
"""Return contents of requested text file as string or create file if a missing config file."""
@@ -556,7 +556,7 @@ class change_values:
gv.srvals = [0]*(gv.sd['nst']) # turn off all stations
set_output()
if qdict.has_key('mm') and qdict['mm'] == '0': clear_mm()
- if qdict.has_key('rd') and qdict['rd'] != '0':
+ if qdict.has_key('rd') and qdict['rd'] != '0' and qdict['rd'] != '':
gv.sd['rdst'] = (gv.now+(int(qdict['rd'])*3600))
stop_onrain()
elif qdict.has_key('rd') and qdict['rd'] == '0': gv.sd['rdst'] = 0
diff --git a/ospi.sh b/ospi.sh
index 57dbe2b..a4904a6 100644
--- a/ospi.sh
+++ b/ospi.sh
@@ -1,170 +1,170 @@
-#! /bin/sh
-### BEGIN INIT INFO
-# Provides: ospi
-# Required-Start: $remote_fs $syslog
-# Required-Stop: $remote_fs $syslog
-# Default-Start: 2 3 4 5
-# Default-Stop: 0 1 6
-# Short-Description: OpenSprinkler + Raspberry Pi
-# Description: OpenSprinkler + Raspberry Pi - Raspberry Pi with
-# OpenSprinkler Pi board from Ray's Hobby
-### END INIT INFO
-
-#
-# To auto start on boot execute (once) as root
-#
-# update-rc.d ospi defaults
-#
-# To stop auto start on boot execute
-#
-# update-rc.d ospi remove
-#
-
-# Author: Denny Fox
-#
-# Please remove the "Author" lines above and replace them
-# with your own name if you copy and modify this script.
-
-# Do NOT "set -e"
-
-# PATH should only include /usr/* if it runs after the mountnfs.sh script
-PATH=/sbin:/usr/sbin:/bin:/usr/bin
-DESC="OpenSprinkler Raspberry Pi"
-NAME=ospi.py
-DAEMON=/usr/bin/python
-DAEMON_ARGS="ospi.py"
-HOMEDIR=/home/pi/OSPi/ # Edit if different on your Raspberry Pi
-PIDFILE=/var/run/$NAME.pid
-SCRIPTNAME=/etc/init.d/$NAME
-
-# Exit if the package is not installed
-[ -x "$DAEMON" ] || exit 0
-
-# Read configuration variable file if it is present
-[ -r /etc/default/$NAME ] && . /etc/default/$NAME
-
-# Load the VERBOSE setting and other rcS variables
-. /lib/init/vars.sh
-
-# Define LSB log_* functions.
-# Depend on lsb-base (>= 3.2-14) to ensure that this file is present
-# and status_of_proc is working.
-. /lib/lsb/init-functions
-
-#
-# Function that starts the daemon/service
-#
-do_start()
-{
- # Return
- # 0 if daemon has been started
- # 1 if daemon was already running
- # 2 if daemon could not be started
- start-stop-daemon --start --quiet --chdir $HOMEDIR --pidfile $PIDFILE --make-pidfile --background --exec $DAEMON --test > /dev/null \
- || return 1
- start-stop-daemon --start --quiet --chdir $HOMEDIR --pidfile $PIDFILE --make-pidfile --background --exec $DAEMON -- \
- $DAEMON_ARGS \
- || return 2
- # Add code here, if necessary, that waits for the process to be ready
- # to handle requests from services started subsequently which depend
- # on this one. As a last resort, sleep for some time.
-}
-
-#
-# Function that stops the daemon/service
-#
-do_stop()
-{
- # Return
- # 0 if daemon has been stopped
- # 1 if daemon was already stopped
- # 2 if daemon could not be stopped
- # other if a failure occurred
- start-stop-daemon --stop --quiet --retry=TERM/30/KILL/5 --pidfile $PIDFILE
- RETVAL="$?"
- [ "$RETVAL" = 2 ] && return 2
- # Wait for children to finish too if this is a daemon that forks
- # and if the daemon is only ever run from this initscript.
- # If the above conditions are not satisfied then add some other code
- # that waits for the process to drop all resources that could be
- # needed by services started subsequently. A last resort is to
- # sleep for some time.
- start-stop-daemon --stop --quiet --oknodo --retry=0/30/KILL/5 --exec $DAEMON
- [ "$?" = 2 ] && return 2
- # Many daemons don't delete their pidfiles when they exit.
- rm -f $PIDFILE
- return "$RETVAL"
-}
-
-#
-# Function that sends a SIGHUP to the daemon/service
-#
-do_reload() {
- #
- # If the daemon can reload its configuration without
- # restarting (for example, when it is sent a SIGHUP),
- # then implement that here.
- #
- start-stop-daemon --stop --signal 1 --quiet --pidfile $PIDFILE --name $NAME
- return 0
-}
-
-case "$1" in
- start)
- [ "$VERBOSE" != no ] && log_daemon_msg "Starting $DESC" "$NAME"
- do_start
- case "$?" in
- 0|1) [ "$VERBOSE" != no ] && log_end_msg 0 ;;
- 2) [ "$VERBOSE" != no ] && log_end_msg 1 ;;
- esac
- ;;
- stop)
- [ "$VERBOSE" != no ] && log_daemon_msg "Stopping $DESC" "$NAME"
- do_stop
- case "$?" in
- 0|1) [ "$VERBOSE" != no ] && log_end_msg 0 ;;
- 2) [ "$VERBOSE" != no ] && log_end_msg 1 ;;
- esac
- ;;
- status)
- status_of_proc "$DAEMON" "$NAME" && exit 0 || exit $?
- ;;
- #reload|force-reload)
- #
- # If do_reload() is not implemented then leave this commented out
- # and leave 'force-reload' as an alias for 'restart'.
- #
- #log_daemon_msg "Reloading $DESC" "$NAME"
- #do_reload
- #log_end_msg $?
- #;;
- restart|force-reload)
- #
- # If the "reload" option is implemented then remove the
- # 'force-reload' alias
- #
- log_daemon_msg "Restarting $DESC" "$NAME"
- do_stop
- case "$?" in
- 0|1)
- do_start
- case "$?" in
- 0) log_end_msg 0 ;;
- 1) log_end_msg 1 ;; # Old process is still running
- *) log_end_msg 1 ;; # Failed to start
- esac
- ;;
- *)
- # Failed to stop
- log_end_msg 1
- ;;
- esac
- ;;
- *)
- #echo "Usage: $SCRIPTNAME {start|stop|restart|reload|force-reload}" >&2
- echo "Usage: $SCRIPTNAME {start|stop|status|restart|force-reload}" >&2
- exit 3
- ;;
-esac
-
-:
+#! /bin/sh
+### BEGIN INIT INFO
+# Provides: ospi
+# Required-Start: $remote_fs $syslog
+# Required-Stop: $remote_fs $syslog
+# Default-Start: 2 3 4 5
+# Default-Stop: 0 1 6
+# Short-Description: OpenSprinkler + Raspberry Pi
+# Description: OpenSprinkler + Raspberry Pi - Raspberry Pi with
+# OpenSprinkler Pi board from Ray's Hobby
+### END INIT INFO
+
+#
+# To auto start on boot execute (once) as root
+#
+# update-rc.d ospi defaults
+#
+# To stop auto start on boot execute
+#
+# update-rc.d ospi remove
+#
+
+# Author: Denny Fox
+#
+# Please remove the "Author" lines above and replace them
+# with your own name if you copy and modify this script.
+
+# Do NOT "set -e"
+
+# PATH should only include /usr/* if it runs after the mountnfs.sh script
+PATH=/sbin:/usr/sbin:/bin:/usr/bin
+DESC="OpenSprinkler Raspberry Pi"
+NAME=ospi.py
+DAEMON=/usr/bin/python
+DAEMON_ARGS="ospi.py"
+HOMEDIR=/home/pi/OSPi/ # Edit if different on your Raspberry Pi
+PIDFILE=/var/run/$NAME.pid
+SCRIPTNAME=/etc/init.d/$NAME
+
+# Exit if the package is not installed
+[ -x "$DAEMON" ] || exit 0
+
+# Read configuration variable file if it is present
+[ -r /etc/default/$NAME ] && . /etc/default/$NAME
+
+# Load the VERBOSE setting and other rcS variables
+. /lib/init/vars.sh
+
+# Define LSB log_* functions.
+# Depend on lsb-base (>= 3.2-14) to ensure that this file is present
+# and status_of_proc is working.
+. /lib/lsb/init-functions
+
+#
+# Function that starts the daemon/service
+#
+do_start()
+{
+ # Return
+ # 0 if daemon has been started
+ # 1 if daemon was already running
+ # 2 if daemon could not be started
+ start-stop-daemon --start --quiet --chdir $HOMEDIR --pidfile $PIDFILE --make-pidfile --background --exec $DAEMON --test > /dev/null \
+ || return 1
+ start-stop-daemon --start --quiet --chdir $HOMEDIR --pidfile $PIDFILE --make-pidfile --background --exec $DAEMON -- \
+ $DAEMON_ARGS \
+ || return 2
+ # Add code here, if necessary, that waits for the process to be ready
+ # to handle requests from services started subsequently which depend
+ # on this one. As a last resort, sleep for some time.
+}
+
+#
+# Function that stops the daemon/service
+#
+do_stop()
+{
+ # Return
+ # 0 if daemon has been stopped
+ # 1 if daemon was already stopped
+ # 2 if daemon could not be stopped
+ # other if a failure occurred
+ start-stop-daemon --stop --quiet --retry=TERM/30/KILL/5 --pidfile $PIDFILE
+ RETVAL="$?"
+ [ "$RETVAL" = 2 ] && return 2
+ # Wait for children to finish too if this is a daemon that forks
+ # and if the daemon is only ever run from this initscript.
+ # If the above conditions are not satisfied then add some other code
+ # that waits for the process to drop all resources that could be
+ # needed by services started subsequently. A last resort is to
+ # sleep for some time.
+ start-stop-daemon --stop --quiet --oknodo --retry=0/30/KILL/5 --exec $DAEMON
+ [ "$?" = 2 ] && return 2
+ # Many daemons don't delete their pidfiles when they exit.
+ rm -f $PIDFILE
+ return "$RETVAL"
+}
+
+#
+# Function that sends a SIGHUP to the daemon/service
+#
+do_reload() {
+ #
+ # If the daemon can reload its configuration without
+ # restarting (for example, when it is sent a SIGHUP),
+ # then implement that here.
+ #
+ start-stop-daemon --stop --signal 1 --quiet --pidfile $PIDFILE --name $NAME
+ return 0
+}
+
+case "$1" in
+ start)
+ [ "$VERBOSE" != no ] && log_daemon_msg "Starting $DESC" "$NAME"
+ do_start
+ case "$?" in
+ 0|1) [ "$VERBOSE" != no ] && log_end_msg 0 ;;
+ 2) [ "$VERBOSE" != no ] && log_end_msg 1 ;;
+ esac
+ ;;
+ stop)
+ [ "$VERBOSE" != no ] && log_daemon_msg "Stopping $DESC" "$NAME"
+ do_stop
+ case "$?" in
+ 0|1) [ "$VERBOSE" != no ] && log_end_msg 0 ;;
+ 2) [ "$VERBOSE" != no ] && log_end_msg 1 ;;
+ esac
+ ;;
+ status)
+ status_of_proc "$DAEMON" "$NAME" && exit 0 || exit $?
+ ;;
+ #reload|force-reload)
+ #
+ # If do_reload() is not implemented then leave this commented out
+ # and leave 'force-reload' as an alias for 'restart'.
+ #
+ #log_daemon_msg "Reloading $DESC" "$NAME"
+ #do_reload
+ #log_end_msg $?
+ #;;
+ restart|force-reload)
+ #
+ # If the "reload" option is implemented then remove the
+ # 'force-reload' alias
+ #
+ log_daemon_msg "Restarting $DESC" "$NAME"
+ do_stop
+ case "$?" in
+ 0|1)
+ do_start
+ case "$?" in
+ 0) log_end_msg 0 ;;
+ 1) log_end_msg 1 ;; # Old process is still running
+ *) log_end_msg 1 ;; # Failed to start
+ esac
+ ;;
+ *)
+ # Failed to stop
+ log_end_msg 1
+ ;;
+ esac
+ ;;
+ *)
+ #echo "Usage: $SCRIPTNAME {start|stop|restart|reload|force-reload}" >&2
+ echo "Usage: $SCRIPTNAME {start|stop|status|restart|force-reload}" >&2
+ exit 3
+ ;;
+esac
+
+:
diff --git a/ospi_addon.py b/ospi_addon.py
index 2b36cde..6e3d47b 100644
--- a/ospi_addon.py
+++ b/ospi_addon.py
@@ -1,19 +1,19 @@
-#!/usr/bin/python
-import ospi
-
- #### Add any new page urls here ####
-ospi.urls.extend(['/c1', 'ospi_addon.custom_page_1']) # example: (['/c1', 'ospi_addon.custom_page_1', '/c2', 'ospi_addon.custom_page_2', '/c3', 'ospi_addon.custom_page_3'])
-
- #### add new functions and classes here ####
- ### Example custom class ###
-class custom_page_1:
- """Add description here"""
- def GET(self):
- custpg = '\n'
- #Insert Custom Code here.
- custpg += 'Hello form an ospi_addon program!'
- return custpg
-
-
-
-
+#!/usr/bin/python
+import ospi
+
+ #### Add any new page urls here ####
+ospi.urls.extend(['/c1', 'ospi_addon.custom_page_1']) # example: (['/c1', 'ospi_addon.custom_page_1', '/c2', 'ospi_addon.custom_page_2', '/c3', 'ospi_addon.custom_page_3'])
+
+ #### add new functions and classes here ####
+ ### Example custom class ###
+class custom_page_1:
+ """Add description here"""
+ def GET(self):
+ custpg = '\n'
+ #Insert Custom Code here.
+ custpg += 'Hello form an ospi_addon program!'
+ return custpg
+
+
+
+
diff --git a/static/log/water_log.csv b/static/log/water_log.csv
index 23d8f7c..4a8781c 100644
--- a/static/log/water_log.csv
+++ b/static/log/water_log.csv
@@ -1 +1,4 @@
-Program, Zone, Duration, Finish Time, Date
+Program, Zone, Duration, Finish Time, Date
+Manual, S02, 0m5s, 12:06:53, Wed. 16 Oct 2013
+Manual, S01, 0m6s, 20:51:03, Tue. 15 Oct 2013
+Manual, S02, 0m8s, 20:50:54, Tue. 15 Oct 2013
diff --git a/static/scripts/java/svc1.8.3/home.js b/static/scripts/java/svc1.8.3/home.js
index c51a415..02dcb9f 100644
--- a/static/scripts/java/svc1.8.3/home.js
+++ b/static/scripts/java/svc1.8.3/home.js
@@ -51,7 +51,7 @@ else w(" Log: n/a");
w("");
// print html form
w("");
+w("");
w("");
w("");
w("");
diff --git a/static/scripts/java/svc1.8.3/manualmode.js b/static/scripts/java/svc1.8.3/manualmode.js
index 96d3a9a..0523fa3 100644
--- a/static/scripts/java/svc1.8.3/manualmode.js
+++ b/static/scripts/java/svc1.8.3/manualmode.js
@@ -1,42 +1,42 @@
-// Javascript for printing OpenSprinkler homepage (manual mode)
-// Firmware v1.8
-// All content is published under:
-// Creative Commons Attribution ShareAlike 3.0 License
-// Sep 2012, Rayshobby.net
-
-function id(s) {return document.getElementById(s);}
-function snf(sid,sbit) {
- if(sbit==1) window.location="/sn"+(sid+1)+"=0"; // turn off station
- else {
- var strmm=id("mm"+sid).value, strss=id("ss"+sid).value;
- var mm=(strmm=="")?0:parseInt(strmm);
- var ss=(strss=="")?0:parseInt(strss);
- if(!(mm>=0&&ss>=0&&ss<60)) {alert("Timer values wrong: "+strmm+":"+strss);return;}
- window.location="/sn"+(sid+1)+"=1"+"&t="+(mm*60+ss); // turn it off with timer
- }
-}
-w("Manual Control: (timer is optional)");
-w("
");
diff --git a/static/scripts/java/svc1.8.3/modprog.js b/static/scripts/java/svc1.8.3/modprog.js
index 0d12045..4873091 100644
--- a/static/scripts/java/svc1.8.3/modprog.js
+++ b/static/scripts/java/svc1.8.3/modprog.js
@@ -1,142 +1,142 @@
-// Javascript for printing OpenSprinkler modify program page
-// Firmware v1.8
-// All content is published under:
-// Creative Commons Attribution ShareAlike 3.0 License
-// Sep 2012, Rayshobby.net
-
-function w(s) {document.writeln(s);}
-function id(s){return document.getElementById(s);}
-function imgstr(s) {return " ";}
-// parse time
-function parse_time(prefix) {
- var h=parseInt(id(prefix+"h").value,10);
- var m=parseInt(id(prefix+"m").value,10);
- if(!(h>=0&&h<24&&m>=0&&m<60)) {alert("Error: Incorrect time input "+prefix+".");return -1;}
- return h*60+m;
-}
-// fill time
-function fill_time(prefix,idx) {
- var t=prog[idx];
- id(prefix+"h").value=""+((t/60>>0)/10>>0)+((t/60>>0)%10);
- id(prefix+"m").value=""+((t%60)/10>>0)+((t%60)%10);
-}
-// check/uncheck all days
-function seldays(v) {
- var i;
- for(i=0;i<7;i++) id("d"+i).checked=(v>0)?true:false;
-}
-// handle form submit
-function fsubmit(f) {
- var errmsg = "",days=[0,0],i,s,sid;
- var en=0;
- if(id("en_on").checked) en=1;
- // process days
- if(id("days_week").checked) {
- for(i=0;i<7;i++) {if(id("d"+i).checked) {days[0] |= (1<=2&&days[1]<=128)) {alert("Error: interval days must be between 2 and 128.");return;}
- days[0]=parseInt(id("drem").value,10);
- if(!(days[0]>=0&&days[0]=0&&ds>=0&&ds<60&&duration>0)) {alert("Error: Incorrect duration.");return;}
- // password
- var p="";
- if(!sd['ipas']) p=prompt("Please enter your password:","");
- if(p!=null){
- f.elements[0].value=p;
- f.elements[1].value=pid;
- f.elements[2].value="["+en+","+days[0]+","+days[1]+","+start_time+","+end_time+","+interval+","+duration;
- for(i=0;i"+((pid>-1)?"Modify Program "+(pid+1):"Add a New Program")+"");
-w("");
-w("");
-w("");
-// default values
-id("en_on").checked=true;
-id("days_week").checked=true;id("days_norst").checked=true;
-id("dn").value="3";id("drem").value="0";
-id("tsh").value="06";id("tsm").value="00";id("teh").value="18";id("tem").value="00";
-id("tih").value="04";id("tim").value="00";id("tdm").value="15";id("tds").value="00";
-// fill in existing program values
-if(pid>-1) {
- if(prog[0]==0) id("en_off").checked=true;
- // process days
- var _days=[prog[1],prog[2]];
- if((_days[0]&0x80)&&(_days[1]>1)) {
- id("days_n").checked=true;
- id("dn").value=_days[1];id("drem").value=_days[0]&0x7f;
- } else {
- id("days_week").checked=true;
- for(i=0;i<7;i++) {if(_days[0]&(1<>0)/10>>0)+((t/60>>0)%10);
- id("tds").value=""+((t%60)/10>>0)+((t%60)%10);
- // process stations
- var bits;
- for(bid=0;bid ";}
+// parse time
+function parse_time(prefix) {
+ var h=parseInt(id(prefix+"h").value,10);
+ var m=parseInt(id(prefix+"m").value,10);
+ if(!(h>=0&&h<24&&m>=0&&m<60)) {alert("Error: Incorrect time input "+prefix+".");return -1;}
+ return h*60+m;
+}
+// fill time
+function fill_time(prefix,idx) {
+ var t=prog[idx];
+ id(prefix+"h").value=""+((t/60>>0)/10>>0)+((t/60>>0)%10);
+ id(prefix+"m").value=""+((t%60)/10>>0)+((t%60)%10);
+}
+// check/uncheck all days
+function seldays(v) {
+ var i;
+ for(i=0;i<7;i++) id("d"+i).checked=(v>0)?true:false;
+}
+// handle form submit
+function fsubmit(f) {
+ var errmsg = "",days=[0,0],i,s,sid;
+ var en=0;
+ if(id("en_on").checked) en=1;
+ // process days
+ if(id("days_week").checked) {
+ for(i=0;i<7;i++) {if(id("d"+i).checked) {days[0] |= (1<=2&&days[1]<=128)) {alert("Error: interval days must be between 2 and 128.");return;}
+ days[0]=parseInt(id("drem").value,10);
+ if(!(days[0]>=0&&days[0]=0&&ds>=0&&ds<60&&duration>0)) {alert("Error: Incorrect duration.");return;}
+ // password
+ var p="";
+ if(!sd['ipas']) p=prompt("Please enter your password:","");
+ if(p!=null){
+ f.elements[0].value=p;
+ f.elements[1].value=pid;
+ f.elements[2].value="["+en+","+days[0]+","+days[1]+","+start_time+","+end_time+","+interval+","+duration;
+ for(i=0;i"+((pid>-1)?"Modify Program "+(pid+1):"Add a New Program")+"");
+w("");
+w("");
+w("");
+// default values
+id("en_on").checked=true;
+id("days_week").checked=true;id("days_norst").checked=true;
+id("dn").value="3";id("drem").value="0";
+id("tsh").value="06";id("tsm").value="00";id("teh").value="18";id("tem").value="00";
+id("tih").value="04";id("tim").value="00";id("tdm").value="15";id("tds").value="00";
+// fill in existing program values
+if(pid>-1) {
+ if(prog[0]==0) id("en_off").checked=true;
+ // process days
+ var _days=[prog[1],prog[2]];
+ if((_days[0]&0x80)&&(_days[1]>1)) {
+ id("days_n").checked=true;
+ id("dn").value=_days[1];id("drem").value=_days[0]&0x7f;
+ } else {
+ id("days_week").checked=true;
+ for(i=0;i<7;i++) {if(_days[0]&(1<>0)/10>>0)+((t/60>>0)%10);
+ id("tds").value=""+((t%60)/10>>0)+((t%60)%10);
+ // process stations
+ var bits;
+ for(bid=0;bid>0;
-function w(s) {document.writeln(s);}
-function check_match(prog,simminutes,simdate,simday) {
- // simdate is Java date object, simday is the #days since 1970 01-01
- var wd,dn,drem;
- if(prog[0]==0) return 0;
- if ((prog[1]&0x80)&&(prog[2]>1)) { // inverval checking
- dn=prog[2];drem=prog[1]&0x7f;
- if((simday%dn)!=((devday+drem)%dn)) return 0; // remainder checking
- } else {
- wd=(simdate.getUTCDay()+6)%7; // getDay assumes sunday is 0, converts to Monday 0
- if((prog[1]&(1<prog[4]) return 0; // start and end time checking
- if(prog[5]==0) return 0;
- if(((simminutes-prog[3])/prog[5]>>0)*prog[5] == (simminutes-prog[3])) { // interval checking
- return 1;
- }
- return 0; // no match found
-}
-function getx(sid) {return xstart+sid*stwidth-stwidth/2;} // x coordinate given a station
-function gety(t) {return ystart+t*stheight/60;} // y coordinate given a time
-function getrunstr(start,end){ // run time string
- var h,m,s,str;
- h=start/3600>>0;m=(start/60>>0)%60;s=start%60;
- str=""+(h/10>>0)+(h%10)+":"+(m/10>>0)+(m%10)+":"+(s/10>>0)+(s%10);
- h=end/3600>>0;m=(end/60>>0)%60;s=end%60;
- str+="->"+(h/10>>0)+(h%10)+":"+(m/10>>0)+(m%10)+":"+(s/10>>0)+(s%10);
- return str;
-}
-function plot_bar(sid,start,pid,end) { // plot program bar
- w("
P"+pid+"
");
-}
-function plot_master(start,end) { // plot master station
- w("");
- //if(mas==0||start==end) return;
- //ctx.fillStyle="rgba(64,64,64,0.5)";
- //ctx.fillRect(getx(mas-1),gety(start/60),stwidth,(end-start)/60*stheight/60);
-}
-function plot_currtime() {
- w("");
-}
-function run_sched(simseconds,st_array,pid_array,et_array) { // run and plot schedule stored in array data
- var sid,endtime=simseconds;
- for(sid=0;sid0)&&(sd['mas']!=sid+1)&&(sd['mo'][sid>>3]&(1<<(sid%8))))
- plot_master(st_array[sid]+sd['mton'], et_array[sid]+sd['mtoff']);
- endtime=et_array[sid];
- } else { // concurrent
- plot_bar(sid,simseconds,pid_array[sid],et_array[sid]);
- // check if this station activates master
- if((sd['mas']>0)&&(sd['mas']!=sid+1)&&(sd['mo'][sid>>3]&(1<<(sid%8))))
- endtime=(endtime>et_array[sid])?endtime:et_array[sid];
- }
- }
- }
- if(sd['seq']==0&&sd['mas']>0) plot_master(simseconds,endtime);
- return endtime;
-}
-function draw_title() {
- w("
Program Preview of ");
- w(days_str[simdate.getUTCDay()]+" "+(simdate.getUTCMonth()+1)+"/"+(simdate.getUTCDate())+" "+(simdate.getUTCFullYear()));
- w(" (Hover over each colored bar to see tooltip)");
- w("
");
- }
- plot_currtime();
-}
-function draw_program() {
- // plot program data by a full simulation
- var simminutes=0,busy=0,match_found=0,bid,s,sid,pid,match=[0,0];
- var st_array=new Array(sd['nbrd']*8),pid_array=new Array(sd['nbrd']*8);
- var et_array=new Array(sd['nbrd']*8);
- for(sid=0;sid>3;s=sid%8;
- if(sd['mas']==(sid+1)) continue; // skip master station
- if(prog[7+bid]&(1<>0;pid_array[sid]=pid+1;
- match_found=1;
- }//if
- }//for_sid
- }//if_match
- }//for_pid
- if(match_found) {
- var acctime=simminutes*60;
- if(sd['seq']) { // sequential
- for(sid=0;sid>0;
- if(sd['seq']&&simminutes!=endminutes) simminutes=endminutes;
- else simminutes++;
- for(sid=0;sid>0)*60)); // scroll to the hour line cloest to the current time
-}
-
-draw_title();
-draw_grid();
-draw_program();
+// Javascript for printing OpenSprinkler schedule page
+// Firmware v1.8
+// All content is published under:
+// Creative Commons Attribution ShareAlike 3.0 License
+// Sep 2012, Rayshobby.net
+
+// colors to draw different programs
+var prog_color=["rgba(0,0,200,0.5)","rgba(0,200,0,0.5)","rgba(200,0,0,0.5)","rgba(0,200,200,0.5)"];
+var days_str=["Sun","Mon","Tue","Wed","Thur","Fri","Sat"];
+var xstart=80,ystart=80,stwidth=40,stheight=180;
+var winwidth=stwidth*sd['nbrd']*8+xstart, winheight=26*stheight+ystart;
+var sid,sn,t;
+var simt=Date.UTC(yy,mm-1,dd,0,0,0,0);
+var simdate=new Date(simt);
+var simday = (simt/1000/3600/24)>>0;
+function w(s) {document.writeln(s);}
+function check_match(prog,simminutes,simdate,simday) {
+ // simdate is Java date object, simday is the #days since 1970 01-01
+ var wd,dn,drem;
+ if(prog[0]==0) return 0;
+ if ((prog[1]&0x80)&&(prog[2]>1)) { // inverval checking
+ dn=prog[2];drem=prog[1]&0x7f;
+ if((simday%dn)!=((devday+drem)%dn)) return 0; // remainder checking
+ } else {
+ wd=(simdate.getUTCDay()+6)%7; // getDay assumes sunday is 0, converts to Monday 0
+ if((prog[1]&(1<prog[4]) return 0; // start and end time checking
+ if(prog[5]==0) return 0;
+ if(((simminutes-prog[3])/prog[5]>>0)*prog[5] == (simminutes-prog[3])) { // interval checking
+ return 1;
+ }
+ return 0; // no match found
+}
+function getx(sid) {return xstart+sid*stwidth-stwidth/2;} // x coordinate given a station
+function gety(t) {return ystart+t*stheight/60;} // y coordinate given a time
+function getrunstr(start,end){ // run time string
+ var h,m,s,str;
+ h=start/3600>>0;m=(start/60>>0)%60;s=start%60;
+ str=""+(h/10>>0)+(h%10)+":"+(m/10>>0)+(m%10)+":"+(s/10>>0)+(s%10);
+ h=end/3600>>0;m=(end/60>>0)%60;s=end%60;
+ str+="->"+(h/10>>0)+(h%10)+":"+(m/10>>0)+(m%10)+":"+(s/10>>0)+(s%10);
+ return str;
+}
+function plot_bar(sid,start,pid,end) { // plot program bar
+ w("
P"+pid+"
");
+}
+function plot_master(start,end) { // plot master station
+ w("");
+ //if(sd['mas']==0||start==end) return;
+ //ctx.fillStyle="rgba(64,64,64,0.5)";
+ //ctx.fillRect(getx(mas-1),gety(start/60),stwidth,(end-start)/60*stheight/60);
+}
+function plot_currtime() {
+ w("");
+}
+function run_sched(simseconds,st_array,pid_array,et_array) { // run and plot schedule stored in array data
+ var sid,endtime=simseconds;
+ for(sid=0;sid0)&&(sd['mas']!=sid+1)&&(sd['mo'][sid>>3]&(1<<(sid%8))))
+ plot_master(st_array[sid]+sd['mton'], et_array[sid]+sd['mtoff']);
+ endtime=et_array[sid];
+ } else { // concurrent
+ plot_bar(sid,simseconds,pid_array[sid],et_array[sid]);
+ // check if this station activates master
+ if((sd['mas']>0)&&(sd['mas']!=sid+1)&&(sd['mo'][sid>>3]&(1<<(sid%8))))
+ endtime=(endtime>et_array[sid])?endtime:et_array[sid];
+ }
+ }
+ }
+ if(sd['seq']==0&&sd['mas']>0) plot_master(simseconds,endtime);
+ return endtime;
+}
+function draw_title() {
+ w("
Program Preview of ");
+ w(days_str[simdate.getUTCDay()]+" "+(simdate.getUTCMonth()+1)+"/"+(simdate.getUTCDate())+" "+(simdate.getUTCFullYear()));
+ w(" (Hover over each colored bar to see tooltip)");
+ w("
+$ newctx = [(k, v) for (k, v) in ctx.iteritems() if not k.startswith('_') and not isinstance(v, dict)]
+$:dicttable(dict(newctx))
+
+
ENVIRONMENT
+$:dicttable(ctx.env)
+
+
+
+
+ You're seeing this error because you have web.config.debug
+ set to True. Set that to False if you don't want to see this.
+
+
+
+
+
+"""
+
+djangoerror_r = None
+
+def djangoerror():
+ def _get_lines_from_file(filename, lineno, context_lines):
+ """
+ Returns context_lines before and after lineno from file.
+ Returns (pre_context_lineno, pre_context, context_line, post_context).
+ """
+ try:
+ source = open(filename).readlines()
+ lower_bound = max(0, lineno - context_lines)
+ upper_bound = lineno + context_lines
+
+ pre_context = \
+ [line.strip('\n') for line in source[lower_bound:lineno]]
+ context_line = source[lineno].strip('\n')
+ post_context = \
+ [line.strip('\n') for line in source[lineno + 1:upper_bound]]
+
+ return lower_bound, pre_context, context_line, post_context
+ except (OSError, IOError, IndexError):
+ return None, [], None, []
+
+ exception_type, exception_value, tback = sys.exc_info()
+ frames = []
+ while tback is not None:
+ filename = tback.tb_frame.f_code.co_filename
+ function = tback.tb_frame.f_code.co_name
+ lineno = tback.tb_lineno - 1
+
+ # hack to get correct line number for templates
+ lineno += tback.tb_frame.f_locals.get("__lineoffset__", 0)
+
+ pre_context_lineno, pre_context, context_line, post_context = \
+ _get_lines_from_file(filename, lineno, 7)
+
+ if '__hidetraceback__' not in tback.tb_frame.f_locals:
+ frames.append(web.storage({
+ 'tback': tback,
+ 'filename': filename,
+ 'function': function,
+ 'lineno': lineno,
+ 'vars': tback.tb_frame.f_locals,
+ 'id': id(tback),
+ 'pre_context': pre_context,
+ 'context_line': context_line,
+ 'post_context': post_context,
+ 'pre_context_lineno': pre_context_lineno,
+ }))
+ tback = tback.tb_next
+ frames.reverse()
+ urljoin = urlparse.urljoin
+ def prettify(x):
+ try:
+ out = pprint.pformat(x)
+ except Exception, e:
+ out = '[could not display: <' + e.__class__.__name__ + \
+ ': '+str(e)+'>]'
+ return out
+
+ global djangoerror_r
+ if djangoerror_r is None:
+ djangoerror_r = Template(djangoerror_t, filename=__file__, filter=websafe)
+
+ t = djangoerror_r
+ globals = {'ctx': web.ctx, 'web':web, 'dict':dict, 'str':str, 'prettify': prettify}
+ t.t.func_globals.update(globals)
+ return t(exception_type, exception_value, frames)
+
+def debugerror():
+ """
+ A replacement for `internalerror` that presents a nice page with lots
+ of debug information for the programmer.
+
+ (Based on the beautiful 500 page from [Django](http://djangoproject.com/),
+ designed by [Wilson Miner](http://wilsonminer.com/).)
+ """
+ return web._InternalError(djangoerror())
+
+def emailerrors(to_address, olderror, from_address=None):
+ """
+ Wraps the old `internalerror` handler (pass as `olderror`) to
+ additionally email all errors to `to_address`, to aid in
+ debugging production websites.
+
+ Emails contain a normal text traceback as well as an
+ attachment containing the nice `debugerror` page.
+ """
+ from_address = from_address or to_address
+
+ def emailerrors_internal():
+ error = olderror()
+ tb = sys.exc_info()
+ error_name = tb[0]
+ error_value = tb[1]
+ tb_txt = ''.join(traceback.format_exception(*tb))
+ path = web.ctx.path
+ request = web.ctx.method + ' ' + web.ctx.home + web.ctx.fullpath
+
+ message = "\n%s\n\n%s\n\n" % (request, tb_txt)
+
+ sendmail(
+ "your buggy site <%s>" % from_address,
+ "the bugfixer <%s>" % to_address,
+ "bug: %(error_name)s: %(error_value)s (%(path)s)" % locals(),
+ message,
+ attachments=[
+ dict(filename="bug.html", content=safestr(djangoerror()))
+ ],
+ )
+ return error
+
+ return emailerrors_internal
+
+if __name__ == "__main__":
+ urls = (
+ '/', 'index'
+ )
+ from application import application
+ app = application(urls, globals())
+ app.internalerror = debugerror
+
+ class index:
+ def GET(self):
+ thisdoesnotexist
+
+ app.run()
diff --git a/web/form.py b/web/form.py
index 8099c38..3f615e0 100644
--- a/web/form.py
+++ b/web/form.py
@@ -1,410 +1,410 @@
-"""
-HTML forms
-(part of web.py)
-"""
-
-import copy, re
-import webapi as web
-import utils, net
-
-def attrget(obj, attr, value=None):
- try:
- if hasattr(obj, 'has_key') and obj.has_key(attr):
- return obj[attr]
- except TypeError:
- # Handle the case where has_key takes different number of arguments.
- # This is the case with Model objects on appengine. See #134
- pass
- if hasattr(obj, attr):
- return getattr(obj, attr)
- return value
-
-class Form(object):
- r"""
- HTML form.
-
- >>> f = Form(Textbox("x"))
- >>> f.render()
- u'
\n
\n
'
- """
- def __init__(self, *inputs, **kw):
- self.inputs = inputs
- self.valid = True
- self.note = None
- self.validators = kw.pop('validators', [])
-
- def __call__(self, x=None):
- o = copy.deepcopy(self)
- if x: o.validates(x)
- return o
-
- def render(self):
- out = ''
- out += self.rendernote(self.note)
- out += '
\n'
-
- for i in self.inputs:
- html = utils.safeunicode(i.pre) + i.render() + self.rendernote(i.note) + utils.safeunicode(i.post)
- if i.is_hidden():
- out += '
%s
\n' % (html)
- else:
- out += '
%s
\n' % (i.id, net.websafe(i.description), html)
- out += "
"
- return out
-
- def render_css(self):
- out = []
- out.append(self.rendernote(self.note))
- for i in self.inputs:
- if not i.is_hidden():
- out.append('' % (i.id, net.websafe(i.description)))
- out.append(i.pre)
- out.append(i.render())
- out.append(self.rendernote(i.note))
- out.append(i.post)
- out.append('\n')
- return ''.join(out)
-
- def rendernote(self, note):
- if note: return '%s' % net.websafe(note)
- else: return ""
-
- def validates(self, source=None, _validate=True, **kw):
- source = source or kw or web.input()
- out = True
- for i in self.inputs:
- v = attrget(source, i.name)
- if _validate:
- out = i.validate(v) and out
- else:
- i.set_value(v)
- if _validate:
- out = out and self._validate(source)
- self.valid = out
- return out
-
- def _validate(self, value):
- self.value = value
- for v in self.validators:
- if not v.valid(value):
- self.note = v.msg
- return False
- return True
-
- def fill(self, source=None, **kw):
- return self.validates(source, _validate=False, **kw)
-
- def __getitem__(self, i):
- for x in self.inputs:
- if x.name == i: return x
- raise KeyError, i
-
- def __getattr__(self, name):
- # don't interfere with deepcopy
- inputs = self.__dict__.get('inputs') or []
- for x in inputs:
- if x.name == name: return x
- raise AttributeError, name
-
- def get(self, i, default=None):
- try:
- return self[i]
- except KeyError:
- return default
-
- def _get_d(self): #@@ should really be form.attr, no?
- return utils.storage([(i.name, i.get_value()) for i in self.inputs])
- d = property(_get_d)
-
-class Input(object):
- def __init__(self, name, *validators, **attrs):
- self.name = name
- self.validators = validators
- self.attrs = attrs = AttributeList(attrs)
-
- self.description = attrs.pop('description', name)
- self.value = attrs.pop('value', None)
- self.pre = attrs.pop('pre', "")
- self.post = attrs.pop('post', "")
- self.note = None
-
- self.id = attrs.setdefault('id', self.get_default_id())
-
- if 'class_' in attrs:
- attrs['class'] = attrs['class_']
- del attrs['class_']
-
- def is_hidden(self):
- return False
-
- def get_type(self):
- raise NotImplementedError
-
- def get_default_id(self):
- return self.name
-
- def validate(self, value):
- self.set_value(value)
-
- for v in self.validators:
- if not v.valid(value):
- self.note = v.msg
- return False
- return True
-
- def set_value(self, value):
- self.value = value
-
- def get_value(self):
- return self.value
-
- def render(self):
- attrs = self.attrs.copy()
- attrs['type'] = self.get_type()
- if self.value is not None:
- attrs['value'] = self.value
- attrs['name'] = self.name
- return '' % attrs
-
- def rendernote(self, note):
- if note: return '%s' % net.websafe(note)
- else: return ""
-
- def addatts(self):
- # add leading space for backward-compatibility
- return " " + str(self.attrs)
-
-class AttributeList(dict):
- """List of atributes of input.
-
- >>> a = AttributeList(type='text', name='x', value=20)
- >>> a
-
- """
- def copy(self):
- return AttributeList(self)
-
- def __str__(self):
- return " ".join(['%s="%s"' % (k, net.websafe(v)) for k, v in self.items()])
-
- def __repr__(self):
- return '' % repr(str(self))
-
-class Textbox(Input):
- """Textbox input.
-
- >>> Textbox(name='foo', value='bar').render()
- u''
- >>> Textbox(name='foo', value=0).render()
- u''
- """
- def get_type(self):
- return 'text'
-
-class Password(Input):
- """Password input.
-
- >>> Password(name='password', value='secret').render()
- u''
- """
-
- def get_type(self):
- return 'password'
-
-class Textarea(Input):
- """Textarea input.
-
- >>> Textarea(name='foo', value='bar').render()
- u''
- """
- def render(self):
- attrs = self.attrs.copy()
- attrs['name'] = self.name
- value = net.websafe(self.value or '')
- return '' % (attrs, value)
-
-class Dropdown(Input):
- r"""Dropdown/select input.
-
- >>> Dropdown(name='foo', args=['a', 'b', 'c'], value='b').render()
- u'\n'
- >>> Dropdown(name='foo', args=[('a', 'aa'), ('b', 'bb'), ('c', 'cc')], value='b').render()
- u'\n'
- """
- def __init__(self, name, args, *validators, **attrs):
- self.args = args
- super(Dropdown, self).__init__(name, *validators, **attrs)
-
- def render(self):
- attrs = self.attrs.copy()
- attrs['name'] = self.name
-
- x = '\n'
- return x
-
- def _render_option(self, arg, indent=' '):
- if isinstance(arg, (tuple, list)):
- value, desc= arg
- else:
- value, desc = arg, arg
-
- if self.value == value or (isinstance(self.value, list) and value in self.value):
- select_p = ' selected="selected"'
- else:
- select_p = ''
- return indent + '\n' % (select_p, net.websafe(value), net.websafe(desc))
-
-
-class GroupedDropdown(Dropdown):
- r"""Grouped Dropdown/select input.
-
- >>> GroupedDropdown(name='car_type', args=(('Swedish Cars', ('Volvo', 'Saab')), ('German Cars', ('Mercedes', 'Audi'))), value='Audi').render()
- u'\n'
- >>> GroupedDropdown(name='car_type', args=(('Swedish Cars', (('v', 'Volvo'), ('s', 'Saab'))), ('German Cars', (('m', 'Mercedes'), ('a', 'Audi')))), value='a').render()
- u'\n'
-
- """
- def __init__(self, name, args, *validators, **attrs):
- self.args = args
- super(Dropdown, self).__init__(name, *validators, **attrs)
-
- def render(self):
- attrs = self.attrs.copy()
- attrs['name'] = self.name
-
- x = '\n'
- return x
-
-class Radio(Input):
- def __init__(self, name, args, *validators, **attrs):
- self.args = args
- super(Radio, self).__init__(name, *validators, **attrs)
-
- def render(self):
- x = ''
- for arg in self.args:
- if isinstance(arg, (tuple, list)):
- value, desc= arg
- else:
- value, desc = arg, arg
- attrs = self.attrs.copy()
- attrs['name'] = self.name
- attrs['type'] = 'radio'
- attrs['value'] = value
- if self.value == value:
- attrs['checked'] = 'checked'
- x += ' %s' % (attrs, net.websafe(desc))
- x += ''
- return x
-
-class Checkbox(Input):
- """Checkbox input.
-
- >>> Checkbox('foo', value='bar', checked=True).render()
- u''
- >>> Checkbox('foo', value='bar').render()
- u''
- >>> c = Checkbox('foo', value='bar')
- >>> c.validate('on')
- True
- >>> c.render()
- u''
- """
- def __init__(self, name, *validators, **attrs):
- self.checked = attrs.pop('checked', False)
- Input.__init__(self, name, *validators, **attrs)
-
- def get_default_id(self):
- value = utils.safestr(self.value or "")
- return self.name + '_' + value.replace(' ', '_')
-
- def render(self):
- attrs = self.attrs.copy()
- attrs['type'] = 'checkbox'
- attrs['name'] = self.name
- attrs['value'] = self.value
-
- if self.checked:
- attrs['checked'] = 'checked'
- return '' % attrs
-
- def set_value(self, value):
- self.checked = bool(value)
-
- def get_value(self):
- return self.checked
-
-class Button(Input):
- """HTML Button.
-
- >>> Button("save").render()
- u''
- >>> Button("action", value="save", html="Save Changes").render()
- u''
- """
- def __init__(self, name, *validators, **attrs):
- super(Button, self).__init__(name, *validators, **attrs)
- self.description = ""
-
- def render(self):
- attrs = self.attrs.copy()
- attrs['name'] = self.name
- if self.value is not None:
- attrs['value'] = self.value
- html = attrs.pop('html', None) or net.websafe(self.name)
- return '' % (attrs, html)
-
-class Hidden(Input):
- """Hidden Input.
-
- >>> Hidden(name='foo', value='bar').render()
- u''
- """
- def is_hidden(self):
- return True
-
- def get_type(self):
- return 'hidden'
-
-class File(Input):
- """File input.
-
- >>> File(name='f').render()
- u''
- """
- def get_type(self):
- return 'file'
-
-class Validator:
- def __deepcopy__(self, memo): return copy.copy(self)
- def __init__(self, msg, test, jstest=None): utils.autoassign(self, locals())
- def valid(self, value):
- try: return self.test(value)
- except: return False
-
-notnull = Validator("Required", bool)
-
-class regexp(Validator):
- def __init__(self, rexp, msg):
- self.rexp = re.compile(rexp)
- self.msg = msg
-
- def valid(self, value):
- return bool(self.rexp.match(value))
-
-if __name__ == "__main__":
- import doctest
- doctest.testmod()
+"""
+HTML forms
+(part of web.py)
+"""
+
+import copy, re
+import webapi as web
+import utils, net
+
+def attrget(obj, attr, value=None):
+ try:
+ if hasattr(obj, 'has_key') and obj.has_key(attr):
+ return obj[attr]
+ except TypeError:
+ # Handle the case where has_key takes different number of arguments.
+ # This is the case with Model objects on appengine. See #134
+ pass
+ if hasattr(obj, attr):
+ return getattr(obj, attr)
+ return value
+
+class Form(object):
+ r"""
+ HTML form.
+
+ >>> f = Form(Textbox("x"))
+ >>> f.render()
+ u'
\n
\n
'
+ """
+ def __init__(self, *inputs, **kw):
+ self.inputs = inputs
+ self.valid = True
+ self.note = None
+ self.validators = kw.pop('validators', [])
+
+ def __call__(self, x=None):
+ o = copy.deepcopy(self)
+ if x: o.validates(x)
+ return o
+
+ def render(self):
+ out = ''
+ out += self.rendernote(self.note)
+ out += '
\n'
+
+ for i in self.inputs:
+ html = utils.safeunicode(i.pre) + i.render() + self.rendernote(i.note) + utils.safeunicode(i.post)
+ if i.is_hidden():
+ out += '
%s
\n' % (html)
+ else:
+ out += '
%s
\n' % (i.id, net.websafe(i.description), html)
+ out += "
"
+ return out
+
+ def render_css(self):
+ out = []
+ out.append(self.rendernote(self.note))
+ for i in self.inputs:
+ if not i.is_hidden():
+ out.append('' % (i.id, net.websafe(i.description)))
+ out.append(i.pre)
+ out.append(i.render())
+ out.append(self.rendernote(i.note))
+ out.append(i.post)
+ out.append('\n')
+ return ''.join(out)
+
+ def rendernote(self, note):
+ if note: return '%s' % net.websafe(note)
+ else: return ""
+
+ def validates(self, source=None, _validate=True, **kw):
+ source = source or kw or web.input()
+ out = True
+ for i in self.inputs:
+ v = attrget(source, i.name)
+ if _validate:
+ out = i.validate(v) and out
+ else:
+ i.set_value(v)
+ if _validate:
+ out = out and self._validate(source)
+ self.valid = out
+ return out
+
+ def _validate(self, value):
+ self.value = value
+ for v in self.validators:
+ if not v.valid(value):
+ self.note = v.msg
+ return False
+ return True
+
+ def fill(self, source=None, **kw):
+ return self.validates(source, _validate=False, **kw)
+
+ def __getitem__(self, i):
+ for x in self.inputs:
+ if x.name == i: return x
+ raise KeyError, i
+
+ def __getattr__(self, name):
+ # don't interfere with deepcopy
+ inputs = self.__dict__.get('inputs') or []
+ for x in inputs:
+ if x.name == name: return x
+ raise AttributeError, name
+
+ def get(self, i, default=None):
+ try:
+ return self[i]
+ except KeyError:
+ return default
+
+ def _get_d(self): #@@ should really be form.attr, no?
+ return utils.storage([(i.name, i.get_value()) for i in self.inputs])
+ d = property(_get_d)
+
+class Input(object):
+ def __init__(self, name, *validators, **attrs):
+ self.name = name
+ self.validators = validators
+ self.attrs = attrs = AttributeList(attrs)
+
+ self.description = attrs.pop('description', name)
+ self.value = attrs.pop('value', None)
+ self.pre = attrs.pop('pre', "")
+ self.post = attrs.pop('post', "")
+ self.note = None
+
+ self.id = attrs.setdefault('id', self.get_default_id())
+
+ if 'class_' in attrs:
+ attrs['class'] = attrs['class_']
+ del attrs['class_']
+
+ def is_hidden(self):
+ return False
+
+ def get_type(self):
+ raise NotImplementedError
+
+ def get_default_id(self):
+ return self.name
+
+ def validate(self, value):
+ self.set_value(value)
+
+ for v in self.validators:
+ if not v.valid(value):
+ self.note = v.msg
+ return False
+ return True
+
+ def set_value(self, value):
+ self.value = value
+
+ def get_value(self):
+ return self.value
+
+ def render(self):
+ attrs = self.attrs.copy()
+ attrs['type'] = self.get_type()
+ if self.value is not None:
+ attrs['value'] = self.value
+ attrs['name'] = self.name
+ return '' % attrs
+
+ def rendernote(self, note):
+ if note: return '%s' % net.websafe(note)
+ else: return ""
+
+ def addatts(self):
+ # add leading space for backward-compatibility
+ return " " + str(self.attrs)
+
+class AttributeList(dict):
+ """List of atributes of input.
+
+ >>> a = AttributeList(type='text', name='x', value=20)
+ >>> a
+
+ """
+ def copy(self):
+ return AttributeList(self)
+
+ def __str__(self):
+ return " ".join(['%s="%s"' % (k, net.websafe(v)) for k, v in self.items()])
+
+ def __repr__(self):
+ return '' % repr(str(self))
+
+class Textbox(Input):
+ """Textbox input.
+
+ >>> Textbox(name='foo', value='bar').render()
+ u''
+ >>> Textbox(name='foo', value=0).render()
+ u''
+ """
+ def get_type(self):
+ return 'text'
+
+class Password(Input):
+ """Password input.
+
+ >>> Password(name='password', value='secret').render()
+ u''
+ """
+
+ def get_type(self):
+ return 'password'
+
+class Textarea(Input):
+ """Textarea input.
+
+ >>> Textarea(name='foo', value='bar').render()
+ u''
+ """
+ def render(self):
+ attrs = self.attrs.copy()
+ attrs['name'] = self.name
+ value = net.websafe(self.value or '')
+ return '' % (attrs, value)
+
+class Dropdown(Input):
+ r"""Dropdown/select input.
+
+ >>> Dropdown(name='foo', args=['a', 'b', 'c'], value='b').render()
+ u'\n'
+ >>> Dropdown(name='foo', args=[('a', 'aa'), ('b', 'bb'), ('c', 'cc')], value='b').render()
+ u'\n'
+ """
+ def __init__(self, name, args, *validators, **attrs):
+ self.args = args
+ super(Dropdown, self).__init__(name, *validators, **attrs)
+
+ def render(self):
+ attrs = self.attrs.copy()
+ attrs['name'] = self.name
+
+ x = '\n'
+ return x
+
+ def _render_option(self, arg, indent=' '):
+ if isinstance(arg, (tuple, list)):
+ value, desc= arg
+ else:
+ value, desc = arg, arg
+
+ if self.value == value or (isinstance(self.value, list) and value in self.value):
+ select_p = ' selected="selected"'
+ else:
+ select_p = ''
+ return indent + '\n' % (select_p, net.websafe(value), net.websafe(desc))
+
+
+class GroupedDropdown(Dropdown):
+ r"""Grouped Dropdown/select input.
+
+ >>> GroupedDropdown(name='car_type', args=(('Swedish Cars', ('Volvo', 'Saab')), ('German Cars', ('Mercedes', 'Audi'))), value='Audi').render()
+ u'\n'
+ >>> GroupedDropdown(name='car_type', args=(('Swedish Cars', (('v', 'Volvo'), ('s', 'Saab'))), ('German Cars', (('m', 'Mercedes'), ('a', 'Audi')))), value='a').render()
+ u'\n'
+
+ """
+ def __init__(self, name, args, *validators, **attrs):
+ self.args = args
+ super(Dropdown, self).__init__(name, *validators, **attrs)
+
+ def render(self):
+ attrs = self.attrs.copy()
+ attrs['name'] = self.name
+
+ x = '\n'
+ return x
+
+class Radio(Input):
+ def __init__(self, name, args, *validators, **attrs):
+ self.args = args
+ super(Radio, self).__init__(name, *validators, **attrs)
+
+ def render(self):
+ x = ''
+ for arg in self.args:
+ if isinstance(arg, (tuple, list)):
+ value, desc= arg
+ else:
+ value, desc = arg, arg
+ attrs = self.attrs.copy()
+ attrs['name'] = self.name
+ attrs['type'] = 'radio'
+ attrs['value'] = value
+ if self.value == value:
+ attrs['checked'] = 'checked'
+ x += ' %s' % (attrs, net.websafe(desc))
+ x += ''
+ return x
+
+class Checkbox(Input):
+ """Checkbox input.
+
+ >>> Checkbox('foo', value='bar', checked=True).render()
+ u''
+ >>> Checkbox('foo', value='bar').render()
+ u''
+ >>> c = Checkbox('foo', value='bar')
+ >>> c.validate('on')
+ True
+ >>> c.render()
+ u''
+ """
+ def __init__(self, name, *validators, **attrs):
+ self.checked = attrs.pop('checked', False)
+ Input.__init__(self, name, *validators, **attrs)
+
+ def get_default_id(self):
+ value = utils.safestr(self.value or "")
+ return self.name + '_' + value.replace(' ', '_')
+
+ def render(self):
+ attrs = self.attrs.copy()
+ attrs['type'] = 'checkbox'
+ attrs['name'] = self.name
+ attrs['value'] = self.value
+
+ if self.checked:
+ attrs['checked'] = 'checked'
+ return '' % attrs
+
+ def set_value(self, value):
+ self.checked = bool(value)
+
+ def get_value(self):
+ return self.checked
+
+class Button(Input):
+ """HTML Button.
+
+ >>> Button("save").render()
+ u''
+ >>> Button("action", value="save", html="Save Changes").render()
+ u''
+ """
+ def __init__(self, name, *validators, **attrs):
+ super(Button, self).__init__(name, *validators, **attrs)
+ self.description = ""
+
+ def render(self):
+ attrs = self.attrs.copy()
+ attrs['name'] = self.name
+ if self.value is not None:
+ attrs['value'] = self.value
+ html = attrs.pop('html', None) or net.websafe(self.name)
+ return '' % (attrs, html)
+
+class Hidden(Input):
+ """Hidden Input.
+
+ >>> Hidden(name='foo', value='bar').render()
+ u''
+ """
+ def is_hidden(self):
+ return True
+
+ def get_type(self):
+ return 'hidden'
+
+class File(Input):
+ """File input.
+
+ >>> File(name='f').render()
+ u''
+ """
+ def get_type(self):
+ return 'file'
+
+class Validator:
+ def __deepcopy__(self, memo): return copy.copy(self)
+ def __init__(self, msg, test, jstest=None): utils.autoassign(self, locals())
+ def valid(self, value):
+ try: return self.test(value)
+ except: return False
+
+notnull = Validator("Required", bool)
+
+class regexp(Validator):
+ def __init__(self, rexp, msg):
+ self.rexp = re.compile(rexp)
+ self.msg = msg
+
+ def valid(self, value):
+ return bool(self.rexp.match(value))
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/web/http.py b/web/http.py
index da67eba..9644ceb 100644
--- a/web/http.py
+++ b/web/http.py
@@ -1,150 +1,150 @@
-"""
-HTTP Utilities
-(from web.py)
-"""
-
-__all__ = [
- "expires", "lastmodified",
- "prefixurl", "modified",
- "changequery", "url",
- "profiler",
-]
-
-import sys, os, threading, urllib, urlparse
-try: import datetime
-except ImportError: pass
-import net, utils, webapi as web
-
-def prefixurl(base=''):
- """
- Sorry, this function is really difficult to explain.
- Maybe some other time.
- """
- url = web.ctx.path.lstrip('/')
- for i in xrange(url.count('/')):
- base += '../'
- if not base:
- base = './'
- return base
-
-def expires(delta):
- """
- Outputs an `Expires` header for `delta` from now.
- `delta` is a `timedelta` object or a number of seconds.
- """
- if isinstance(delta, (int, long)):
- delta = datetime.timedelta(seconds=delta)
- date_obj = datetime.datetime.utcnow() + delta
- web.header('Expires', net.httpdate(date_obj))
-
-def lastmodified(date_obj):
- """Outputs a `Last-Modified` header for `datetime`."""
- web.header('Last-Modified', net.httpdate(date_obj))
-
-def modified(date=None, etag=None):
- """
- Checks to see if the page has been modified since the version in the
- requester's cache.
-
- When you publish pages, you can include `Last-Modified` and `ETag`
- with the date the page was last modified and an opaque token for
- the particular version, respectively. When readers reload the page,
- the browser sends along the modification date and etag value for
- the version it has in its cache. If the page hasn't changed,
- the server can just return `304 Not Modified` and not have to
- send the whole page again.
-
- This function takes the last-modified date `date` and the ETag `etag`
- and checks the headers to see if they match. If they do, it returns
- `True`, or otherwise it raises NotModified error. It also sets
- `Last-Modified` and `ETag` output headers.
- """
- try:
- from __builtin__ import set
- except ImportError:
- # for python 2.3
- from sets import Set as set
-
- n = set([x.strip('" ') for x in web.ctx.env.get('HTTP_IF_NONE_MATCH', '').split(',')])
- m = net.parsehttpdate(web.ctx.env.get('HTTP_IF_MODIFIED_SINCE', '').split(';')[0])
- validate = False
- if etag:
- if '*' in n or etag in n:
- validate = True
- if date and m:
- # we subtract a second because
- # HTTP dates don't have sub-second precision
- if date-datetime.timedelta(seconds=1) <= m:
- validate = True
-
- if date: lastmodified(date)
- if etag: web.header('ETag', '"' + etag + '"')
- if validate:
- raise web.notmodified()
- else:
- return True
-
-def urlencode(query, doseq=0):
- """
- Same as urllib.urlencode, but supports unicode strings.
-
- >>> urlencode({'text':'foo bar'})
- 'text=foo+bar'
- >>> urlencode({'x': [1, 2]}, doseq=True)
- 'x=1&x=2'
- """
- def convert(value, doseq=False):
- if doseq and isinstance(value, list):
- return [convert(v) for v in value]
- else:
- return utils.safestr(value)
-
- query = dict([(k, convert(v, doseq)) for k, v in query.items()])
- return urllib.urlencode(query, doseq=doseq)
-
-def changequery(query=None, **kw):
- """
- Imagine you're at `/foo?a=1&b=2`. Then `changequery(a=3)` will return
- `/foo?a=3&b=2` -- the same URL but with the arguments you requested
- changed.
- """
- if query is None:
- query = web.rawinput(method='get')
- for k, v in kw.iteritems():
- if v is None:
- query.pop(k, None)
- else:
- query[k] = v
- out = web.ctx.path
- if query:
- out += '?' + urlencode(query, doseq=True)
- return out
-
-def url(path=None, doseq=False, **kw):
- """
- Makes url by concatenating web.ctx.homepath and path and the
- query string created using the arguments.
- """
- if path is None:
- path = web.ctx.path
- if path.startswith("/"):
- out = web.ctx.homepath + path
- else:
- out = path
-
- if kw:
- out += '?' + urlencode(kw, doseq=doseq)
-
- return out
-
-def profiler(app):
- """Outputs basic profiling information at the bottom of each response."""
- from utils import profile
- def profile_internal(e, o):
- out, result = profile(app)(e, o)
- return list(out) + ['
' + net.websafe(result) + '
']
- return profile_internal
-
-if __name__ == "__main__":
- import doctest
- doctest.testmod()
+"""
+HTTP Utilities
+(from web.py)
+"""
+
+__all__ = [
+ "expires", "lastmodified",
+ "prefixurl", "modified",
+ "changequery", "url",
+ "profiler",
+]
+
+import sys, os, threading, urllib, urlparse
+try: import datetime
+except ImportError: pass
+import net, utils, webapi as web
+
+def prefixurl(base=''):
+ """
+ Sorry, this function is really difficult to explain.
+ Maybe some other time.
+ """
+ url = web.ctx.path.lstrip('/')
+ for i in xrange(url.count('/')):
+ base += '../'
+ if not base:
+ base = './'
+ return base
+
+def expires(delta):
+ """
+ Outputs an `Expires` header for `delta` from now.
+ `delta` is a `timedelta` object or a number of seconds.
+ """
+ if isinstance(delta, (int, long)):
+ delta = datetime.timedelta(seconds=delta)
+ date_obj = datetime.datetime.utcnow() + delta
+ web.header('Expires', net.httpdate(date_obj))
+
+def lastmodified(date_obj):
+ """Outputs a `Last-Modified` header for `datetime`."""
+ web.header('Last-Modified', net.httpdate(date_obj))
+
+def modified(date=None, etag=None):
+ """
+ Checks to see if the page has been modified since the version in the
+ requester's cache.
+
+ When you publish pages, you can include `Last-Modified` and `ETag`
+ with the date the page was last modified and an opaque token for
+ the particular version, respectively. When readers reload the page,
+ the browser sends along the modification date and etag value for
+ the version it has in its cache. If the page hasn't changed,
+ the server can just return `304 Not Modified` and not have to
+ send the whole page again.
+
+ This function takes the last-modified date `date` and the ETag `etag`
+ and checks the headers to see if they match. If they do, it returns
+ `True`, or otherwise it raises NotModified error. It also sets
+ `Last-Modified` and `ETag` output headers.
+ """
+ try:
+ from __builtin__ import set
+ except ImportError:
+ # for python 2.3
+ from sets import Set as set
+
+ n = set([x.strip('" ') for x in web.ctx.env.get('HTTP_IF_NONE_MATCH', '').split(',')])
+ m = net.parsehttpdate(web.ctx.env.get('HTTP_IF_MODIFIED_SINCE', '').split(';')[0])
+ validate = False
+ if etag:
+ if '*' in n or etag in n:
+ validate = True
+ if date and m:
+ # we subtract a second because
+ # HTTP dates don't have sub-second precision
+ if date-datetime.timedelta(seconds=1) <= m:
+ validate = True
+
+ if date: lastmodified(date)
+ if etag: web.header('ETag', '"' + etag + '"')
+ if validate:
+ raise web.notmodified()
+ else:
+ return True
+
+def urlencode(query, doseq=0):
+ """
+ Same as urllib.urlencode, but supports unicode strings.
+
+ >>> urlencode({'text':'foo bar'})
+ 'text=foo+bar'
+ >>> urlencode({'x': [1, 2]}, doseq=True)
+ 'x=1&x=2'
+ """
+ def convert(value, doseq=False):
+ if doseq and isinstance(value, list):
+ return [convert(v) for v in value]
+ else:
+ return utils.safestr(value)
+
+ query = dict([(k, convert(v, doseq)) for k, v in query.items()])
+ return urllib.urlencode(query, doseq=doseq)
+
+def changequery(query=None, **kw):
+ """
+ Imagine you're at `/foo?a=1&b=2`. Then `changequery(a=3)` will return
+ `/foo?a=3&b=2` -- the same URL but with the arguments you requested
+ changed.
+ """
+ if query is None:
+ query = web.rawinput(method='get')
+ for k, v in kw.iteritems():
+ if v is None:
+ query.pop(k, None)
+ else:
+ query[k] = v
+ out = web.ctx.path
+ if query:
+ out += '?' + urlencode(query, doseq=True)
+ return out
+
+def url(path=None, doseq=False, **kw):
+ """
+ Makes url by concatenating web.ctx.homepath and path and the
+ query string created using the arguments.
+ """
+ if path is None:
+ path = web.ctx.path
+ if path.startswith("/"):
+ out = web.ctx.homepath + path
+ else:
+ out = path
+
+ if kw:
+ out += '?' + urlencode(kw, doseq=doseq)
+
+ return out
+
+def profiler(app):
+ """Outputs basic profiling information at the bottom of each response."""
+ from utils import profile
+ def profile_internal(e, o):
+ out, result = profile(app)(e, o)
+ return list(out) + ['
' + net.websafe(result) + '
']
+ return profile_internal
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/web/httpserver.py b/web/httpserver.py
index 9c0909e..3644f98 100644
--- a/web/httpserver.py
+++ b/web/httpserver.py
@@ -1,319 +1,319 @@
-__all__ = ["runsimple"]
-
-import sys, os
-from SimpleHTTPServer import SimpleHTTPRequestHandler
-import urllib
-import posixpath
-
-import webapi as web
-import net
-import utils
-
-def runbasic(func, server_address=("0.0.0.0", 8080)):
- """
- Runs a simple HTTP server hosting WSGI app `func`. The directory `static/`
- is hosted statically.
-
- Based on [WsgiServer][ws] from [Colin Stewart][cs].
-
- [ws]: http://www.owlfish.com/software/wsgiutils/documentation/wsgi-server-api.html
- [cs]: http://www.owlfish.com/
- """
- # Copyright (c) 2004 Colin Stewart (http://www.owlfish.com/)
- # Modified somewhat for simplicity
- # Used under the modified BSD license:
- # http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5
-
- import SimpleHTTPServer, SocketServer, BaseHTTPServer, urlparse
- import socket, errno
- import traceback
-
- class WSGIHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
- def run_wsgi_app(self):
- protocol, host, path, parameters, query, fragment = \
- urlparse.urlparse('http://dummyhost%s' % self.path)
-
- # we only use path, query
- env = {'wsgi.version': (1, 0)
- ,'wsgi.url_scheme': 'http'
- ,'wsgi.input': self.rfile
- ,'wsgi.errors': sys.stderr
- ,'wsgi.multithread': 1
- ,'wsgi.multiprocess': 0
- ,'wsgi.run_once': 0
- ,'REQUEST_METHOD': self.command
- ,'REQUEST_URI': self.path
- ,'PATH_INFO': path
- ,'QUERY_STRING': query
- ,'CONTENT_TYPE': self.headers.get('Content-Type', '')
- ,'CONTENT_LENGTH': self.headers.get('Content-Length', '')
- ,'REMOTE_ADDR': self.client_address[0]
- ,'SERVER_NAME': self.server.server_address[0]
- ,'SERVER_PORT': str(self.server.server_address[1])
- ,'SERVER_PROTOCOL': self.request_version
- }
-
- for http_header, http_value in self.headers.items():
- env ['HTTP_%s' % http_header.replace('-', '_').upper()] = \
- http_value
-
- # Setup the state
- self.wsgi_sent_headers = 0
- self.wsgi_headers = []
-
- try:
- # We have there environment, now invoke the application
- result = self.server.app(env, self.wsgi_start_response)
- try:
- try:
- for data in result:
- if data:
- self.wsgi_write_data(data)
- finally:
- if hasattr(result, 'close'):
- result.close()
- except socket.error, socket_err:
- # Catch common network errors and suppress them
- if (socket_err.args[0] in \
- (errno.ECONNABORTED, errno.EPIPE)):
- return
- except socket.timeout, socket_timeout:
- return
- except:
- print >> web.debug, traceback.format_exc(),
-
- if (not self.wsgi_sent_headers):
- # We must write out something!
- self.wsgi_write_data(" ")
- return
-
- do_POST = run_wsgi_app
- do_PUT = run_wsgi_app
- do_DELETE = run_wsgi_app
-
- def do_GET(self):
- if self.path.startswith('/static/'):
- SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
- else:
- self.run_wsgi_app()
-
- def wsgi_start_response(self, response_status, response_headers,
- exc_info=None):
- if (self.wsgi_sent_headers):
- raise Exception \
- ("Headers already sent and start_response called again!")
- # Should really take a copy to avoid changes in the application....
- self.wsgi_headers = (response_status, response_headers)
- return self.wsgi_write_data
-
- def wsgi_write_data(self, data):
- if (not self.wsgi_sent_headers):
- status, headers = self.wsgi_headers
- # Need to send header prior to data
- status_code = status[:status.find(' ')]
- status_msg = status[status.find(' ') + 1:]
- self.send_response(int(status_code), status_msg)
- for header, value in headers:
- self.send_header(header, value)
- self.end_headers()
- self.wsgi_sent_headers = 1
- # Send the data
- self.wfile.write(data)
-
- class WSGIServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
- def __init__(self, func, server_address):
- BaseHTTPServer.HTTPServer.__init__(self,
- server_address,
- WSGIHandler)
- self.app = func
- self.serverShuttingDown = 0
-
- print "http://%s:%d/" % server_address
- WSGIServer(func, server_address).serve_forever()
-
-# The WSGIServer instance.
-# Made global so that it can be stopped in embedded mode.
-server = None
-
-def runsimple(func, server_address=("0.0.0.0", 8080)):
- """
- Runs [CherryPy][cp] WSGI server hosting WSGI app `func`.
- The directory `static/` is hosted statically.
-
- [cp]: http://www.cherrypy.org
- """
- global server
- func = StaticMiddleware(func)
- func = LogMiddleware(func)
-
- server = WSGIServer(server_address, func)
-
- if server.ssl_adapter:
- print "https://%s:%d/" % server_address
- else:
- print "http://%s:%d/" % server_address
-
- try:
- server.start()
- except (KeyboardInterrupt, SystemExit):
- server.stop()
- server = None
-
-def WSGIServer(server_address, wsgi_app):
- """Creates CherryPy WSGI server listening at `server_address` to serve `wsgi_app`.
- This function can be overwritten to customize the webserver or use a different webserver.
- """
- import wsgiserver
-
- # Default values of wsgiserver.ssl_adapters uses cherrypy.wsgiserver
- # prefix. Overwriting it make it work with web.wsgiserver.
- wsgiserver.ssl_adapters = {
- 'builtin': 'web.wsgiserver.ssl_builtin.BuiltinSSLAdapter',
- 'pyopenssl': 'web.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter',
- }
-
- server = wsgiserver.CherryPyWSGIServer(server_address, wsgi_app, server_name="localhost")
-
- def create_ssl_adapter(cert, key):
- # wsgiserver tries to import submodules as cherrypy.wsgiserver.foo.
- # That doesn't work as not it is web.wsgiserver.
- # Patching sys.modules temporarily to make it work.
- import types
- cherrypy = types.ModuleType('cherrypy')
- cherrypy.wsgiserver = wsgiserver
- sys.modules['cherrypy'] = cherrypy
- sys.modules['cherrypy.wsgiserver'] = wsgiserver
-
- from wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter
- adapter = pyOpenSSLAdapter(cert, key)
-
- # We are done with our work. Cleanup the patches.
- del sys.modules['cherrypy']
- del sys.modules['cherrypy.wsgiserver']
-
- return adapter
-
- # SSL backward compatibility
- if (server.ssl_adapter is None and
- getattr(server, 'ssl_certificate', None) and
- getattr(server, 'ssl_private_key', None)):
- server.ssl_adapter = create_ssl_adapter(server.ssl_certificate, server.ssl_private_key)
-
- server.nodelay = not sys.platform.startswith('java') # TCP_NODELAY isn't supported on the JVM
- return server
-
-class StaticApp(SimpleHTTPRequestHandler):
- """WSGI application for serving static files."""
- def __init__(self, environ, start_response):
- self.headers = []
- self.environ = environ
- self.start_response = start_response
-
- def send_response(self, status, msg=""):
- self.status = str(status) + " " + msg
-
- def send_header(self, name, value):
- self.headers.append((name, value))
-
- def end_headers(self):
- pass
-
- def log_message(*a): pass
-
- def __iter__(self):
- environ = self.environ
-
- self.path = environ.get('PATH_INFO', '')
- self.client_address = environ.get('REMOTE_ADDR','-'), \
- environ.get('REMOTE_PORT','-')
- self.command = environ.get('REQUEST_METHOD', '-')
-
- from cStringIO import StringIO
- self.wfile = StringIO() # for capturing error
-
- try:
- path = self.translate_path(self.path)
- etag = '"%s"' % os.path.getmtime(path)
- client_etag = environ.get('HTTP_IF_NONE_MATCH')
- self.send_header('ETag', etag)
- if etag == client_etag:
- self.send_response(304, "Not Modified")
- self.start_response(self.status, self.headers)
- raise StopIteration
- except OSError:
- pass # Probably a 404
-
- f = self.send_head()
- self.start_response(self.status, self.headers)
-
- if f:
- block_size = 16 * 1024
- while True:
- buf = f.read(block_size)
- if not buf:
- break
- yield buf
- f.close()
- else:
- value = self.wfile.getvalue()
- yield value
-
-class StaticMiddleware:
- """WSGI middleware for serving static files."""
- def __init__(self, app, prefix='/static/'):
- self.app = app
- self.prefix = prefix
-
- def __call__(self, environ, start_response):
- path = environ.get('PATH_INFO', '')
- path = self.normpath(path)
-
- if path.startswith(self.prefix):
- return StaticApp(environ, start_response)
- else:
- return self.app(environ, start_response)
-
- def normpath(self, path):
- path2 = posixpath.normpath(urllib.unquote(path))
- if path.endswith("/"):
- path2 += "/"
- return path2
-
-
-class LogMiddleware:
- """WSGI middleware for logging the status."""
- def __init__(self, app):
- self.app = app
- self.format = '%s - - [%s] "%s %s %s" - %s'
-
- from BaseHTTPServer import BaseHTTPRequestHandler
- import StringIO
- f = StringIO.StringIO()
-
- class FakeSocket:
- def makefile(self, *a):
- return f
-
- # take log_date_time_string method from BaseHTTPRequestHandler
- self.log_date_time_string = BaseHTTPRequestHandler(FakeSocket(), None, None).log_date_time_string
-
- def __call__(self, environ, start_response):
- def xstart_response(status, response_headers, *args):
- out = start_response(status, response_headers, *args)
- self.log(status, environ)
- return out
-
- return self.app(environ, xstart_response)
-
- def log(self, status, environ):
- outfile = environ.get('wsgi.errors', web.debug)
- req = environ.get('PATH_INFO', '_')
- protocol = environ.get('ACTUAL_SERVER_PROTOCOL', '-')
- method = environ.get('REQUEST_METHOD', '-')
- host = "%s:%s" % (environ.get('REMOTE_ADDR','-'),
- environ.get('REMOTE_PORT','-'))
-
- time = self.log_date_time_string()
-
- msg = self.format % (host, time, protocol, method, req, status)
- print >> outfile, utils.safestr(msg)
+__all__ = ["runsimple"]
+
+import sys, os
+from SimpleHTTPServer import SimpleHTTPRequestHandler
+import urllib
+import posixpath
+
+import webapi as web
+import net
+import utils
+
+def runbasic(func, server_address=("0.0.0.0", 8080)):
+ """
+ Runs a simple HTTP server hosting WSGI app `func`. The directory `static/`
+ is hosted statically.
+
+ Based on [WsgiServer][ws] from [Colin Stewart][cs].
+
+ [ws]: http://www.owlfish.com/software/wsgiutils/documentation/wsgi-server-api.html
+ [cs]: http://www.owlfish.com/
+ """
+ # Copyright (c) 2004 Colin Stewart (http://www.owlfish.com/)
+ # Modified somewhat for simplicity
+ # Used under the modified BSD license:
+ # http://www.xfree86.org/3.3.6/COPYRIGHT2.html#5
+
+ import SimpleHTTPServer, SocketServer, BaseHTTPServer, urlparse
+ import socket, errno
+ import traceback
+
+ class WSGIHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
+ def run_wsgi_app(self):
+ protocol, host, path, parameters, query, fragment = \
+ urlparse.urlparse('http://dummyhost%s' % self.path)
+
+ # we only use path, query
+ env = {'wsgi.version': (1, 0)
+ ,'wsgi.url_scheme': 'http'
+ ,'wsgi.input': self.rfile
+ ,'wsgi.errors': sys.stderr
+ ,'wsgi.multithread': 1
+ ,'wsgi.multiprocess': 0
+ ,'wsgi.run_once': 0
+ ,'REQUEST_METHOD': self.command
+ ,'REQUEST_URI': self.path
+ ,'PATH_INFO': path
+ ,'QUERY_STRING': query
+ ,'CONTENT_TYPE': self.headers.get('Content-Type', '')
+ ,'CONTENT_LENGTH': self.headers.get('Content-Length', '')
+ ,'REMOTE_ADDR': self.client_address[0]
+ ,'SERVER_NAME': self.server.server_address[0]
+ ,'SERVER_PORT': str(self.server.server_address[1])
+ ,'SERVER_PROTOCOL': self.request_version
+ }
+
+ for http_header, http_value in self.headers.items():
+ env ['HTTP_%s' % http_header.replace('-', '_').upper()] = \
+ http_value
+
+ # Setup the state
+ self.wsgi_sent_headers = 0
+ self.wsgi_headers = []
+
+ try:
+ # We have there environment, now invoke the application
+ result = self.server.app(env, self.wsgi_start_response)
+ try:
+ try:
+ for data in result:
+ if data:
+ self.wsgi_write_data(data)
+ finally:
+ if hasattr(result, 'close'):
+ result.close()
+ except socket.error, socket_err:
+ # Catch common network errors and suppress them
+ if (socket_err.args[0] in \
+ (errno.ECONNABORTED, errno.EPIPE)):
+ return
+ except socket.timeout, socket_timeout:
+ return
+ except:
+ print >> web.debug, traceback.format_exc(),
+
+ if (not self.wsgi_sent_headers):
+ # We must write out something!
+ self.wsgi_write_data(" ")
+ return
+
+ do_POST = run_wsgi_app
+ do_PUT = run_wsgi_app
+ do_DELETE = run_wsgi_app
+
+ def do_GET(self):
+ if self.path.startswith('/static/'):
+ SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
+ else:
+ self.run_wsgi_app()
+
+ def wsgi_start_response(self, response_status, response_headers,
+ exc_info=None):
+ if (self.wsgi_sent_headers):
+ raise Exception \
+ ("Headers already sent and start_response called again!")
+ # Should really take a copy to avoid changes in the application....
+ self.wsgi_headers = (response_status, response_headers)
+ return self.wsgi_write_data
+
+ def wsgi_write_data(self, data):
+ if (not self.wsgi_sent_headers):
+ status, headers = self.wsgi_headers
+ # Need to send header prior to data
+ status_code = status[:status.find(' ')]
+ status_msg = status[status.find(' ') + 1:]
+ self.send_response(int(status_code), status_msg)
+ for header, value in headers:
+ self.send_header(header, value)
+ self.end_headers()
+ self.wsgi_sent_headers = 1
+ # Send the data
+ self.wfile.write(data)
+
+ class WSGIServer(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer):
+ def __init__(self, func, server_address):
+ BaseHTTPServer.HTTPServer.__init__(self,
+ server_address,
+ WSGIHandler)
+ self.app = func
+ self.serverShuttingDown = 0
+
+ print "http://%s:%d/" % server_address
+ WSGIServer(func, server_address).serve_forever()
+
+# The WSGIServer instance.
+# Made global so that it can be stopped in embedded mode.
+server = None
+
+def runsimple(func, server_address=("0.0.0.0", 8080)):
+ """
+ Runs [CherryPy][cp] WSGI server hosting WSGI app `func`.
+ The directory `static/` is hosted statically.
+
+ [cp]: http://www.cherrypy.org
+ """
+ global server
+ func = StaticMiddleware(func)
+ func = LogMiddleware(func)
+
+ server = WSGIServer(server_address, func)
+
+ if server.ssl_adapter:
+ print "https://%s:%d/" % server_address
+ else:
+ print "http://%s:%d/" % server_address
+
+ try:
+ server.start()
+ except (KeyboardInterrupt, SystemExit):
+ server.stop()
+ server = None
+
+def WSGIServer(server_address, wsgi_app):
+ """Creates CherryPy WSGI server listening at `server_address` to serve `wsgi_app`.
+ This function can be overwritten to customize the webserver or use a different webserver.
+ """
+ import wsgiserver
+
+ # Default values of wsgiserver.ssl_adapters uses cherrypy.wsgiserver
+ # prefix. Overwriting it make it work with web.wsgiserver.
+ wsgiserver.ssl_adapters = {
+ 'builtin': 'web.wsgiserver.ssl_builtin.BuiltinSSLAdapter',
+ 'pyopenssl': 'web.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter',
+ }
+
+ server = wsgiserver.CherryPyWSGIServer(server_address, wsgi_app, server_name="localhost")
+
+ def create_ssl_adapter(cert, key):
+ # wsgiserver tries to import submodules as cherrypy.wsgiserver.foo.
+ # That doesn't work as not it is web.wsgiserver.
+ # Patching sys.modules temporarily to make it work.
+ import types
+ cherrypy = types.ModuleType('cherrypy')
+ cherrypy.wsgiserver = wsgiserver
+ sys.modules['cherrypy'] = cherrypy
+ sys.modules['cherrypy.wsgiserver'] = wsgiserver
+
+ from wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter
+ adapter = pyOpenSSLAdapter(cert, key)
+
+ # We are done with our work. Cleanup the patches.
+ del sys.modules['cherrypy']
+ del sys.modules['cherrypy.wsgiserver']
+
+ return adapter
+
+ # SSL backward compatibility
+ if (server.ssl_adapter is None and
+ getattr(server, 'ssl_certificate', None) and
+ getattr(server, 'ssl_private_key', None)):
+ server.ssl_adapter = create_ssl_adapter(server.ssl_certificate, server.ssl_private_key)
+
+ server.nodelay = not sys.platform.startswith('java') # TCP_NODELAY isn't supported on the JVM
+ return server
+
+class StaticApp(SimpleHTTPRequestHandler):
+ """WSGI application for serving static files."""
+ def __init__(self, environ, start_response):
+ self.headers = []
+ self.environ = environ
+ self.start_response = start_response
+
+ def send_response(self, status, msg=""):
+ self.status = str(status) + " " + msg
+
+ def send_header(self, name, value):
+ self.headers.append((name, value))
+
+ def end_headers(self):
+ pass
+
+ def log_message(*a): pass
+
+ def __iter__(self):
+ environ = self.environ
+
+ self.path = environ.get('PATH_INFO', '')
+ self.client_address = environ.get('REMOTE_ADDR','-'), \
+ environ.get('REMOTE_PORT','-')
+ self.command = environ.get('REQUEST_METHOD', '-')
+
+ from cStringIO import StringIO
+ self.wfile = StringIO() # for capturing error
+
+ try:
+ path = self.translate_path(self.path)
+ etag = '"%s"' % os.path.getmtime(path)
+ client_etag = environ.get('HTTP_IF_NONE_MATCH')
+ self.send_header('ETag', etag)
+ if etag == client_etag:
+ self.send_response(304, "Not Modified")
+ self.start_response(self.status, self.headers)
+ raise StopIteration
+ except OSError:
+ pass # Probably a 404
+
+ f = self.send_head()
+ self.start_response(self.status, self.headers)
+
+ if f:
+ block_size = 16 * 1024
+ while True:
+ buf = f.read(block_size)
+ if not buf:
+ break
+ yield buf
+ f.close()
+ else:
+ value = self.wfile.getvalue()
+ yield value
+
+class StaticMiddleware:
+ """WSGI middleware for serving static files."""
+ def __init__(self, app, prefix='/static/'):
+ self.app = app
+ self.prefix = prefix
+
+ def __call__(self, environ, start_response):
+ path = environ.get('PATH_INFO', '')
+ path = self.normpath(path)
+
+ if path.startswith(self.prefix):
+ return StaticApp(environ, start_response)
+ else:
+ return self.app(environ, start_response)
+
+ def normpath(self, path):
+ path2 = posixpath.normpath(urllib.unquote(path))
+ if path.endswith("/"):
+ path2 += "/"
+ return path2
+
+
+class LogMiddleware:
+ """WSGI middleware for logging the status."""
+ def __init__(self, app):
+ self.app = app
+ self.format = '%s - - [%s] "%s %s %s" - %s'
+
+ from BaseHTTPServer import BaseHTTPRequestHandler
+ import StringIO
+ f = StringIO.StringIO()
+
+ class FakeSocket:
+ def makefile(self, *a):
+ return f
+
+ # take log_date_time_string method from BaseHTTPRequestHandler
+ self.log_date_time_string = BaseHTTPRequestHandler(FakeSocket(), None, None).log_date_time_string
+
+ def __call__(self, environ, start_response):
+ def xstart_response(status, response_headers, *args):
+ out = start_response(status, response_headers, *args)
+ self.log(status, environ)
+ return out
+
+ return self.app(environ, xstart_response)
+
+ def log(self, status, environ):
+ outfile = environ.get('wsgi.errors', web.debug)
+ req = environ.get('PATH_INFO', '_')
+ protocol = environ.get('ACTUAL_SERVER_PROTOCOL', '-')
+ method = environ.get('REQUEST_METHOD', '-')
+ host = "%s:%s" % (environ.get('REMOTE_ADDR','-'),
+ environ.get('REMOTE_PORT','-'))
+
+ time = self.log_date_time_string()
+
+ msg = self.format % (host, time, protocol, method, req, status)
+ print >> outfile, utils.safestr(msg)
diff --git a/web/net.py b/web/net.py
index 40ff197..3e228a1 100644
--- a/web/net.py
+++ b/web/net.py
@@ -1,193 +1,193 @@
-"""
-Network Utilities
-(from web.py)
-"""
-
-__all__ = [
- "validipaddr", "validipport", "validip", "validaddr",
- "urlquote",
- "httpdate", "parsehttpdate",
- "htmlquote", "htmlunquote", "websafe",
-]
-
-import urllib, time
-try: import datetime
-except ImportError: pass
-
-def validipaddr(address):
- """
- Returns True if `address` is a valid IPv4 address.
-
- >>> validipaddr('192.168.1.1')
- True
- >>> validipaddr('192.168.1.800')
- False
- >>> validipaddr('192.168.1')
- False
- """
- try:
- octets = address.split('.')
- if len(octets) != 4:
- return False
- for x in octets:
- if not (0 <= int(x) <= 255):
- return False
- except ValueError:
- return False
- return True
-
-def validipport(port):
- """
- Returns True if `port` is a valid IPv4 port.
-
- >>> validipport('9000')
- True
- >>> validipport('foo')
- False
- >>> validipport('1000000')
- False
- """
- try:
- if not (0 <= int(port) <= 65535):
- return False
- except ValueError:
- return False
- return True
-
-def validip(ip, defaultaddr="0.0.0.0", defaultport=8080):
- """Returns `(ip_address, port)` from string `ip_addr_port`"""
- addr = defaultaddr
- port = defaultport
-
- ip = ip.split(":", 1)
- if len(ip) == 1:
- if not ip[0]:
- pass
- elif validipaddr(ip[0]):
- addr = ip[0]
- elif validipport(ip[0]):
- port = int(ip[0])
- else:
- raise ValueError, ':'.join(ip) + ' is not a valid IP address/port'
- elif len(ip) == 2:
- addr, port = ip
- if not validipaddr(addr) and validipport(port):
- raise ValueError, ':'.join(ip) + ' is not a valid IP address/port'
- port = int(port)
- else:
- raise ValueError, ':'.join(ip) + ' is not a valid IP address/port'
- return (addr, port)
-
-def validaddr(string_):
- """
- Returns either (ip_address, port) or "/path/to/socket" from string_
-
- >>> validaddr('/path/to/socket')
- '/path/to/socket'
- >>> validaddr('8000')
- ('0.0.0.0', 8000)
- >>> validaddr('127.0.0.1')
- ('127.0.0.1', 8080)
- >>> validaddr('127.0.0.1:8000')
- ('127.0.0.1', 8000)
- >>> validaddr('fff')
- Traceback (most recent call last):
- ...
- ValueError: fff is not a valid IP address/port
- """
- if '/' in string_:
- return string_
- else:
- return validip(string_)
-
-def urlquote(val):
- """
- Quotes a string for use in a URL.
-
- >>> urlquote('://?f=1&j=1')
- '%3A//%3Ff%3D1%26j%3D1'
- >>> urlquote(None)
- ''
- >>> urlquote(u'\u203d')
- '%E2%80%BD'
- """
- if val is None: return ''
- if not isinstance(val, unicode): val = str(val)
- else: val = val.encode('utf-8')
- return urllib.quote(val)
-
-def httpdate(date_obj):
- """
- Formats a datetime object for use in HTTP headers.
-
- >>> import datetime
- >>> httpdate(datetime.datetime(1970, 1, 1, 1, 1, 1))
- 'Thu, 01 Jan 1970 01:01:01 GMT'
- """
- return date_obj.strftime("%a, %d %b %Y %H:%M:%S GMT")
-
-def parsehttpdate(string_):
- """
- Parses an HTTP date into a datetime object.
-
- >>> parsehttpdate('Thu, 01 Jan 1970 01:01:01 GMT')
- datetime.datetime(1970, 1, 1, 1, 1, 1)
- """
- try:
- t = time.strptime(string_, "%a, %d %b %Y %H:%M:%S %Z")
- except ValueError:
- return None
- return datetime.datetime(*t[:6])
-
-def htmlquote(text):
- r"""
- Encodes `text` for raw use in HTML.
-
- >>> htmlquote(u"<'&\">")
- u'<'&">'
- """
- text = text.replace(u"&", u"&") # Must be done first!
- text = text.replace(u"<", u"<")
- text = text.replace(u">", u">")
- text = text.replace(u"'", u"'")
- text = text.replace(u'"', u""")
- return text
-
-def htmlunquote(text):
- r"""
- Decodes `text` that's HTML quoted.
-
- >>> htmlunquote(u'<'&">')
- u'<\'&">'
- """
- text = text.replace(u""", u'"')
- text = text.replace(u"'", u"'")
- text = text.replace(u">", u">")
- text = text.replace(u"<", u"<")
- text = text.replace(u"&", u"&") # Must be done last!
- return text
-
-def websafe(val):
- r"""Converts `val` so that it is safe for use in Unicode HTML.
-
- >>> websafe("<'&\">")
- u'<'&">'
- >>> websafe(None)
- u''
- >>> websafe(u'\u203d')
- u'\u203d'
- >>> websafe('\xe2\x80\xbd')
- u'\u203d'
- """
- if val is None:
- return u''
- elif isinstance(val, str):
- val = val.decode('utf-8')
- elif not isinstance(val, unicode):
- val = unicode(val)
-
- return htmlquote(val)
-
-if __name__ == "__main__":
- import doctest
- doctest.testmod()
+"""
+Network Utilities
+(from web.py)
+"""
+
+__all__ = [
+ "validipaddr", "validipport", "validip", "validaddr",
+ "urlquote",
+ "httpdate", "parsehttpdate",
+ "htmlquote", "htmlunquote", "websafe",
+]
+
+import urllib, time
+try: import datetime
+except ImportError: pass
+
+def validipaddr(address):
+ """
+ Returns True if `address` is a valid IPv4 address.
+
+ >>> validipaddr('192.168.1.1')
+ True
+ >>> validipaddr('192.168.1.800')
+ False
+ >>> validipaddr('192.168.1')
+ False
+ """
+ try:
+ octets = address.split('.')
+ if len(octets) != 4:
+ return False
+ for x in octets:
+ if not (0 <= int(x) <= 255):
+ return False
+ except ValueError:
+ return False
+ return True
+
+def validipport(port):
+ """
+ Returns True if `port` is a valid IPv4 port.
+
+ >>> validipport('9000')
+ True
+ >>> validipport('foo')
+ False
+ >>> validipport('1000000')
+ False
+ """
+ try:
+ if not (0 <= int(port) <= 65535):
+ return False
+ except ValueError:
+ return False
+ return True
+
+def validip(ip, defaultaddr="0.0.0.0", defaultport=8080):
+ """Returns `(ip_address, port)` from string `ip_addr_port`"""
+ addr = defaultaddr
+ port = defaultport
+
+ ip = ip.split(":", 1)
+ if len(ip) == 1:
+ if not ip[0]:
+ pass
+ elif validipaddr(ip[0]):
+ addr = ip[0]
+ elif validipport(ip[0]):
+ port = int(ip[0])
+ else:
+ raise ValueError, ':'.join(ip) + ' is not a valid IP address/port'
+ elif len(ip) == 2:
+ addr, port = ip
+ if not validipaddr(addr) and validipport(port):
+ raise ValueError, ':'.join(ip) + ' is not a valid IP address/port'
+ port = int(port)
+ else:
+ raise ValueError, ':'.join(ip) + ' is not a valid IP address/port'
+ return (addr, port)
+
+def validaddr(string_):
+ """
+ Returns either (ip_address, port) or "/path/to/socket" from string_
+
+ >>> validaddr('/path/to/socket')
+ '/path/to/socket'
+ >>> validaddr('8000')
+ ('0.0.0.0', 8000)
+ >>> validaddr('127.0.0.1')
+ ('127.0.0.1', 8080)
+ >>> validaddr('127.0.0.1:8000')
+ ('127.0.0.1', 8000)
+ >>> validaddr('fff')
+ Traceback (most recent call last):
+ ...
+ ValueError: fff is not a valid IP address/port
+ """
+ if '/' in string_:
+ return string_
+ else:
+ return validip(string_)
+
+def urlquote(val):
+ """
+ Quotes a string for use in a URL.
+
+ >>> urlquote('://?f=1&j=1')
+ '%3A//%3Ff%3D1%26j%3D1'
+ >>> urlquote(None)
+ ''
+ >>> urlquote(u'\u203d')
+ '%E2%80%BD'
+ """
+ if val is None: return ''
+ if not isinstance(val, unicode): val = str(val)
+ else: val = val.encode('utf-8')
+ return urllib.quote(val)
+
+def httpdate(date_obj):
+ """
+ Formats a datetime object for use in HTTP headers.
+
+ >>> import datetime
+ >>> httpdate(datetime.datetime(1970, 1, 1, 1, 1, 1))
+ 'Thu, 01 Jan 1970 01:01:01 GMT'
+ """
+ return date_obj.strftime("%a, %d %b %Y %H:%M:%S GMT")
+
+def parsehttpdate(string_):
+ """
+ Parses an HTTP date into a datetime object.
+
+ >>> parsehttpdate('Thu, 01 Jan 1970 01:01:01 GMT')
+ datetime.datetime(1970, 1, 1, 1, 1, 1)
+ """
+ try:
+ t = time.strptime(string_, "%a, %d %b %Y %H:%M:%S %Z")
+ except ValueError:
+ return None
+ return datetime.datetime(*t[:6])
+
+def htmlquote(text):
+ r"""
+ Encodes `text` for raw use in HTML.
+
+ >>> htmlquote(u"<'&\">")
+ u'<'&">'
+ """
+ text = text.replace(u"&", u"&") # Must be done first!
+ text = text.replace(u"<", u"<")
+ text = text.replace(u">", u">")
+ text = text.replace(u"'", u"'")
+ text = text.replace(u'"', u""")
+ return text
+
+def htmlunquote(text):
+ r"""
+ Decodes `text` that's HTML quoted.
+
+ >>> htmlunquote(u'<'&">')
+ u'<\'&">'
+ """
+ text = text.replace(u""", u'"')
+ text = text.replace(u"'", u"'")
+ text = text.replace(u">", u">")
+ text = text.replace(u"<", u"<")
+ text = text.replace(u"&", u"&") # Must be done last!
+ return text
+
+def websafe(val):
+ r"""Converts `val` so that it is safe for use in Unicode HTML.
+
+ >>> websafe("<'&\">")
+ u'<'&">'
+ >>> websafe(None)
+ u''
+ >>> websafe(u'\u203d')
+ u'\u203d'
+ >>> websafe('\xe2\x80\xbd')
+ u'\u203d'
+ """
+ if val is None:
+ return u''
+ elif isinstance(val, str):
+ val = val.decode('utf-8')
+ elif not isinstance(val, unicode):
+ val = unicode(val)
+
+ return htmlquote(val)
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/web/python23.py b/web/python23.py
index dfb331a..0361672 100644
--- a/web/python23.py
+++ b/web/python23.py
@@ -1,46 +1,46 @@
-"""Python 2.3 compatabilty"""
-import threading
-
-class threadlocal(object):
- """Implementation of threading.local for python2.3.
- """
- def __getattribute__(self, name):
- if name == "__dict__":
- return threadlocal._getd(self)
- else:
- try:
- return object.__getattribute__(self, name)
- except AttributeError:
- try:
- return self.__dict__[name]
- except KeyError:
- raise AttributeError, name
-
- def __setattr__(self, name, value):
- self.__dict__[name] = value
-
- def __delattr__(self, name):
- try:
- del self.__dict__[name]
- except KeyError:
- raise AttributeError, name
-
- def _getd(self):
- t = threading.currentThread()
- if not hasattr(t, '_d'):
- # using __dict__ of thread as thread local storage
- t._d = {}
-
- _id = id(self)
- # there could be multiple instances of threadlocal.
- # use id(self) as key
- if _id not in t._d:
- t._d[_id] = {}
- return t._d[_id]
-
-if __name__ == '__main__':
- d = threadlocal()
- d.x = 1
- print d.__dict__
- print d.x
+"""Python 2.3 compatabilty"""
+import threading
+
+class threadlocal(object):
+ """Implementation of threading.local for python2.3.
+ """
+ def __getattribute__(self, name):
+ if name == "__dict__":
+ return threadlocal._getd(self)
+ else:
+ try:
+ return object.__getattribute__(self, name)
+ except AttributeError:
+ try:
+ return self.__dict__[name]
+ except KeyError:
+ raise AttributeError, name
+
+ def __setattr__(self, name, value):
+ self.__dict__[name] = value
+
+ def __delattr__(self, name):
+ try:
+ del self.__dict__[name]
+ except KeyError:
+ raise AttributeError, name
+
+ def _getd(self):
+ t = threading.currentThread()
+ if not hasattr(t, '_d'):
+ # using __dict__ of thread as thread local storage
+ t._d = {}
+
+ _id = id(self)
+ # there could be multiple instances of threadlocal.
+ # use id(self) as key
+ if _id not in t._d:
+ t._d[_id] = {}
+ return t._d[_id]
+
+if __name__ == '__main__':
+ d = threadlocal()
+ d.x = 1
+ print d.__dict__
+ print d.x
\ No newline at end of file
diff --git a/web/session.py b/web/session.py
index 02d6908..b1a63ff 100644
--- a/web/session.py
+++ b/web/session.py
@@ -1,358 +1,358 @@
-"""
-Session Management
-(from web.py)
-"""
-
-import os, time, datetime, random, base64
-import os.path
-from copy import deepcopy
-try:
- import cPickle as pickle
-except ImportError:
- import pickle
-try:
- import hashlib
- sha1 = hashlib.sha1
-except ImportError:
- import sha
- sha1 = sha.new
-
-import utils
-import webapi as web
-
-__all__ = [
- 'Session', 'SessionExpired',
- 'Store', 'DiskStore', 'DBStore',
-]
-
-web.config.session_parameters = utils.storage({
- 'cookie_name': 'webpy_session_id',
- 'cookie_domain': None,
- 'cookie_path' : None,
- 'timeout': 86400, #24 * 60 * 60, # 24 hours in seconds
- 'ignore_expiry': True,
- 'ignore_change_ip': True,
- 'secret_key': 'fLjUfxqXtfNoIldA0A0J',
- 'expired_message': 'Session expired',
- 'httponly': True,
- 'secure': False
-})
-
-class SessionExpired(web.HTTPError):
- def __init__(self, message):
- web.HTTPError.__init__(self, '200 OK', {}, data=message)
-
-class Session(object):
- """Session management for web.py
- """
- __slots__ = [
- "store", "_initializer", "_last_cleanup_time", "_config", "_data",
- "__getitem__", "__setitem__", "__delitem__"
- ]
-
- def __init__(self, app, store, initializer=None):
- self.store = store
- self._initializer = initializer
- self._last_cleanup_time = 0
- self._config = utils.storage(web.config.session_parameters)
- self._data = utils.threadeddict()
-
- self.__getitem__ = self._data.__getitem__
- self.__setitem__ = self._data.__setitem__
- self.__delitem__ = self._data.__delitem__
-
- if app:
- app.add_processor(self._processor)
-
- def __contains__(self, name):
- return name in self._data
-
- def __getattr__(self, name):
- return getattr(self._data, name)
-
- def __setattr__(self, name, value):
- if name in self.__slots__:
- object.__setattr__(self, name, value)
- else:
- setattr(self._data, name, value)
-
- def __delattr__(self, name):
- delattr(self._data, name)
-
- def _processor(self, handler):
- """Application processor to setup session for every request"""
- self._cleanup()
- self._load()
-
- try:
- return handler()
- finally:
- self._save()
-
- def _load(self):
- """Load the session from the store, by the id from cookie"""
- cookie_name = self._config.cookie_name
- cookie_domain = self._config.cookie_domain
- cookie_path = self._config.cookie_path
- httponly = self._config.httponly
- self.session_id = web.cookies().get(cookie_name)
-
- # protection against session_id tampering
- if self.session_id and not self._valid_session_id(self.session_id):
- self.session_id = None
-
- self._check_expiry()
- if self.session_id:
- d = self.store[self.session_id]
- self.update(d)
- self._validate_ip()
-
- if not self.session_id:
- self.session_id = self._generate_session_id()
-
- if self._initializer:
- if isinstance(self._initializer, dict):
- self.update(deepcopy(self._initializer))
- elif hasattr(self._initializer, '__call__'):
- self._initializer()
-
- self.ip = web.ctx.ip
-
- def _check_expiry(self):
- # check for expiry
- if self.session_id and self.session_id not in self.store:
- if self._config.ignore_expiry:
- self.session_id = None
- else:
- return self.expired()
-
- def _validate_ip(self):
- # check for change of IP
- if self.session_id and self.get('ip', None) != web.ctx.ip:
- if not self._config.ignore_change_ip:
- return self.expired()
-
- def _save(self):
- if not self.get('_killed'):
- self._setcookie(self.session_id)
- self.store[self.session_id] = dict(self._data)
- else:
- self._setcookie(self.session_id, expires=-1)
-
- def _setcookie(self, session_id, expires='', **kw):
- cookie_name = self._config.cookie_name
- cookie_domain = self._config.cookie_domain
- cookie_path = self._config.cookie_path
- httponly = self._config.httponly
- secure = self._config.secure
- web.setcookie(cookie_name, session_id, expires=expires, domain=cookie_domain, httponly=httponly, secure=secure, path=cookie_path)
-
- def _generate_session_id(self):
- """Generate a random id for session"""
-
- while True:
- rand = os.urandom(16)
- now = time.time()
- secret_key = self._config.secret_key
- session_id = sha1("%s%s%s%s" %(rand, now, utils.safestr(web.ctx.ip), secret_key))
- session_id = session_id.hexdigest()
- if session_id not in self.store:
- break
- return session_id
-
- def _valid_session_id(self, session_id):
- rx = utils.re_compile('^[0-9a-fA-F]+$')
- return rx.match(session_id)
-
- def _cleanup(self):
- """Cleanup the stored sessions"""
- current_time = time.time()
- timeout = self._config.timeout
- if current_time - self._last_cleanup_time > timeout:
- self.store.cleanup(timeout)
- self._last_cleanup_time = current_time
-
- def expired(self):
- """Called when an expired session is atime"""
- self._killed = True
- self._save()
- raise SessionExpired(self._config.expired_message)
-
- def kill(self):
- """Kill the session, make it no longer available"""
- del self.store[self.session_id]
- self._killed = True
-
-class Store:
- """Base class for session stores"""
-
- def __contains__(self, key):
- raise NotImplementedError
-
- def __getitem__(self, key):
- raise NotImplementedError
-
- def __setitem__(self, key, value):
- raise NotImplementedError
-
- def cleanup(self, timeout):
- """removes all the expired sessions"""
- raise NotImplementedError
-
- def encode(self, session_dict):
- """encodes session dict as a string"""
- pickled = pickle.dumps(session_dict)
- return base64.encodestring(pickled)
-
- def decode(self, session_data):
- """decodes the data to get back the session dict """
- pickled = base64.decodestring(session_data)
- return pickle.loads(pickled)
-
-class DiskStore(Store):
- """
- Store for saving a session on disk.
-
- >>> import tempfile
- >>> root = tempfile.mkdtemp()
- >>> s = DiskStore(root)
- >>> s['a'] = 'foo'
- >>> s['a']
- 'foo'
- >>> time.sleep(0.01)
- >>> s.cleanup(0.01)
- >>> s['a']
- Traceback (most recent call last):
- ...
- KeyError: 'a'
- """
- def __init__(self, root):
- # if the storage root doesn't exists, create it.
- if not os.path.exists(root):
- os.makedirs(
- os.path.abspath(root)
- )
- self.root = root
-
- def _get_path(self, key):
- if os.path.sep in key:
- raise ValueError, "Bad key: %s" % repr(key)
- return os.path.join(self.root, key)
-
- def __contains__(self, key):
- path = self._get_path(key)
- return os.path.exists(path)
-
- def __getitem__(self, key):
- path = self._get_path(key)
- if os.path.exists(path):
- pickled = open(path).read()
- return self.decode(pickled)
- else:
- raise KeyError, key
-
- def __setitem__(self, key, value):
- path = self._get_path(key)
- pickled = self.encode(value)
- try:
- f = open(path, 'w')
- try:
- f.write(pickled)
- finally:
- f.close()
- except IOError:
- pass
-
- def __delitem__(self, key):
- path = self._get_path(key)
- if os.path.exists(path):
- os.remove(path)
-
- def cleanup(self, timeout):
- now = time.time()
- for f in os.listdir(self.root):
- path = self._get_path(f)
- atime = os.stat(path).st_atime
- if now - atime > timeout :
- os.remove(path)
-
-class DBStore(Store):
- """Store for saving a session in database
- Needs a table with the following columns:
-
- session_id CHAR(128) UNIQUE NOT NULL,
- atime DATETIME NOT NULL default current_timestamp,
- data TEXT
- """
- def __init__(self, db, table_name):
- self.db = db
- self.table = table_name
-
- def __contains__(self, key):
- data = self.db.select(self.table, where="session_id=$key", vars=locals())
- return bool(list(data))
-
- def __getitem__(self, key):
- now = datetime.datetime.now()
- try:
- s = self.db.select(self.table, where="session_id=$key", vars=locals())[0]
- self.db.update(self.table, where="session_id=$key", atime=now, vars=locals())
- except IndexError:
- raise KeyError
- else:
- return self.decode(s.data)
-
- def __setitem__(self, key, value):
- pickled = self.encode(value)
- now = datetime.datetime.now()
- if key in self:
- self.db.update(self.table, where="session_id=$key", data=pickled, vars=locals())
- else:
- self.db.insert(self.table, False, session_id=key, data=pickled )
-
- def __delitem__(self, key):
- self.db.delete(self.table, where="session_id=$key", vars=locals())
-
- def cleanup(self, timeout):
- timeout = datetime.timedelta(timeout/(24.0*60*60)) #timedelta takes numdays as arg
- last_allowed_time = datetime.datetime.now() - timeout
- self.db.delete(self.table, where="$last_allowed_time > atime", vars=locals())
-
-class ShelfStore:
- """Store for saving session using `shelve` module.
-
- import shelve
- store = ShelfStore(shelve.open('session.shelf'))
-
- XXX: is shelve thread-safe?
- """
- def __init__(self, shelf):
- self.shelf = shelf
-
- def __contains__(self, key):
- return key in self.shelf
-
- def __getitem__(self, key):
- atime, v = self.shelf[key]
- self[key] = v # update atime
- return v
-
- def __setitem__(self, key, value):
- self.shelf[key] = time.time(), value
-
- def __delitem__(self, key):
- try:
- del self.shelf[key]
- except KeyError:
- pass
-
- def cleanup(self, timeout):
- now = time.time()
- for k in self.shelf.keys():
- atime, v = self.shelf[k]
- if now - atime > timeout :
- del self[k]
-
-if __name__ == '__main__' :
- import doctest
- doctest.testmod()
+"""
+Session Management
+(from web.py)
+"""
+
+import os, time, datetime, random, base64
+import os.path
+from copy import deepcopy
+try:
+ import cPickle as pickle
+except ImportError:
+ import pickle
+try:
+ import hashlib
+ sha1 = hashlib.sha1
+except ImportError:
+ import sha
+ sha1 = sha.new
+
+import utils
+import webapi as web
+
+__all__ = [
+ 'Session', 'SessionExpired',
+ 'Store', 'DiskStore', 'DBStore',
+]
+
+web.config.session_parameters = utils.storage({
+ 'cookie_name': 'webpy_session_id',
+ 'cookie_domain': None,
+ 'cookie_path' : None,
+ 'timeout': 86400, #24 * 60 * 60, # 24 hours in seconds
+ 'ignore_expiry': True,
+ 'ignore_change_ip': True,
+ 'secret_key': 'fLjUfxqXtfNoIldA0A0J',
+ 'expired_message': 'Session expired',
+ 'httponly': True,
+ 'secure': False
+})
+
+class SessionExpired(web.HTTPError):
+ def __init__(self, message):
+ web.HTTPError.__init__(self, '200 OK', {}, data=message)
+
+class Session(object):
+ """Session management for web.py
+ """
+ __slots__ = [
+ "store", "_initializer", "_last_cleanup_time", "_config", "_data",
+ "__getitem__", "__setitem__", "__delitem__"
+ ]
+
+ def __init__(self, app, store, initializer=None):
+ self.store = store
+ self._initializer = initializer
+ self._last_cleanup_time = 0
+ self._config = utils.storage(web.config.session_parameters)
+ self._data = utils.threadeddict()
+
+ self.__getitem__ = self._data.__getitem__
+ self.__setitem__ = self._data.__setitem__
+ self.__delitem__ = self._data.__delitem__
+
+ if app:
+ app.add_processor(self._processor)
+
+ def __contains__(self, name):
+ return name in self._data
+
+ def __getattr__(self, name):
+ return getattr(self._data, name)
+
+ def __setattr__(self, name, value):
+ if name in self.__slots__:
+ object.__setattr__(self, name, value)
+ else:
+ setattr(self._data, name, value)
+
+ def __delattr__(self, name):
+ delattr(self._data, name)
+
+ def _processor(self, handler):
+ """Application processor to setup session for every request"""
+ self._cleanup()
+ self._load()
+
+ try:
+ return handler()
+ finally:
+ self._save()
+
+ def _load(self):
+ """Load the session from the store, by the id from cookie"""
+ cookie_name = self._config.cookie_name
+ cookie_domain = self._config.cookie_domain
+ cookie_path = self._config.cookie_path
+ httponly = self._config.httponly
+ self.session_id = web.cookies().get(cookie_name)
+
+ # protection against session_id tampering
+ if self.session_id and not self._valid_session_id(self.session_id):
+ self.session_id = None
+
+ self._check_expiry()
+ if self.session_id:
+ d = self.store[self.session_id]
+ self.update(d)
+ self._validate_ip()
+
+ if not self.session_id:
+ self.session_id = self._generate_session_id()
+
+ if self._initializer:
+ if isinstance(self._initializer, dict):
+ self.update(deepcopy(self._initializer))
+ elif hasattr(self._initializer, '__call__'):
+ self._initializer()
+
+ self.ip = web.ctx.ip
+
+ def _check_expiry(self):
+ # check for expiry
+ if self.session_id and self.session_id not in self.store:
+ if self._config.ignore_expiry:
+ self.session_id = None
+ else:
+ return self.expired()
+
+ def _validate_ip(self):
+ # check for change of IP
+ if self.session_id and self.get('ip', None) != web.ctx.ip:
+ if not self._config.ignore_change_ip:
+ return self.expired()
+
+ def _save(self):
+ if not self.get('_killed'):
+ self._setcookie(self.session_id)
+ self.store[self.session_id] = dict(self._data)
+ else:
+ self._setcookie(self.session_id, expires=-1)
+
+ def _setcookie(self, session_id, expires='', **kw):
+ cookie_name = self._config.cookie_name
+ cookie_domain = self._config.cookie_domain
+ cookie_path = self._config.cookie_path
+ httponly = self._config.httponly
+ secure = self._config.secure
+ web.setcookie(cookie_name, session_id, expires=expires, domain=cookie_domain, httponly=httponly, secure=secure, path=cookie_path)
+
+ def _generate_session_id(self):
+ """Generate a random id for session"""
+
+ while True:
+ rand = os.urandom(16)
+ now = time.time()
+ secret_key = self._config.secret_key
+ session_id = sha1("%s%s%s%s" %(rand, now, utils.safestr(web.ctx.ip), secret_key))
+ session_id = session_id.hexdigest()
+ if session_id not in self.store:
+ break
+ return session_id
+
+ def _valid_session_id(self, session_id):
+ rx = utils.re_compile('^[0-9a-fA-F]+$')
+ return rx.match(session_id)
+
+ def _cleanup(self):
+ """Cleanup the stored sessions"""
+ current_time = time.time()
+ timeout = self._config.timeout
+ if current_time - self._last_cleanup_time > timeout:
+ self.store.cleanup(timeout)
+ self._last_cleanup_time = current_time
+
+ def expired(self):
+ """Called when an expired session is atime"""
+ self._killed = True
+ self._save()
+ raise SessionExpired(self._config.expired_message)
+
+ def kill(self):
+ """Kill the session, make it no longer available"""
+ del self.store[self.session_id]
+ self._killed = True
+
+class Store:
+ """Base class for session stores"""
+
+ def __contains__(self, key):
+ raise NotImplementedError
+
+ def __getitem__(self, key):
+ raise NotImplementedError
+
+ def __setitem__(self, key, value):
+ raise NotImplementedError
+
+ def cleanup(self, timeout):
+ """removes all the expired sessions"""
+ raise NotImplementedError
+
+ def encode(self, session_dict):
+ """encodes session dict as a string"""
+ pickled = pickle.dumps(session_dict)
+ return base64.encodestring(pickled)
+
+ def decode(self, session_data):
+ """decodes the data to get back the session dict """
+ pickled = base64.decodestring(session_data)
+ return pickle.loads(pickled)
+
+class DiskStore(Store):
+ """
+ Store for saving a session on disk.
+
+ >>> import tempfile
+ >>> root = tempfile.mkdtemp()
+ >>> s = DiskStore(root)
+ >>> s['a'] = 'foo'
+ >>> s['a']
+ 'foo'
+ >>> time.sleep(0.01)
+ >>> s.cleanup(0.01)
+ >>> s['a']
+ Traceback (most recent call last):
+ ...
+ KeyError: 'a'
+ """
+ def __init__(self, root):
+ # if the storage root doesn't exists, create it.
+ if not os.path.exists(root):
+ os.makedirs(
+ os.path.abspath(root)
+ )
+ self.root = root
+
+ def _get_path(self, key):
+ if os.path.sep in key:
+ raise ValueError, "Bad key: %s" % repr(key)
+ return os.path.join(self.root, key)
+
+ def __contains__(self, key):
+ path = self._get_path(key)
+ return os.path.exists(path)
+
+ def __getitem__(self, key):
+ path = self._get_path(key)
+ if os.path.exists(path):
+ pickled = open(path).read()
+ return self.decode(pickled)
+ else:
+ raise KeyError, key
+
+ def __setitem__(self, key, value):
+ path = self._get_path(key)
+ pickled = self.encode(value)
+ try:
+ f = open(path, 'w')
+ try:
+ f.write(pickled)
+ finally:
+ f.close()
+ except IOError:
+ pass
+
+ def __delitem__(self, key):
+ path = self._get_path(key)
+ if os.path.exists(path):
+ os.remove(path)
+
+ def cleanup(self, timeout):
+ now = time.time()
+ for f in os.listdir(self.root):
+ path = self._get_path(f)
+ atime = os.stat(path).st_atime
+ if now - atime > timeout :
+ os.remove(path)
+
+class DBStore(Store):
+ """Store for saving a session in database
+ Needs a table with the following columns:
+
+ session_id CHAR(128) UNIQUE NOT NULL,
+ atime DATETIME NOT NULL default current_timestamp,
+ data TEXT
+ """
+ def __init__(self, db, table_name):
+ self.db = db
+ self.table = table_name
+
+ def __contains__(self, key):
+ data = self.db.select(self.table, where="session_id=$key", vars=locals())
+ return bool(list(data))
+
+ def __getitem__(self, key):
+ now = datetime.datetime.now()
+ try:
+ s = self.db.select(self.table, where="session_id=$key", vars=locals())[0]
+ self.db.update(self.table, where="session_id=$key", atime=now, vars=locals())
+ except IndexError:
+ raise KeyError
+ else:
+ return self.decode(s.data)
+
+ def __setitem__(self, key, value):
+ pickled = self.encode(value)
+ now = datetime.datetime.now()
+ if key in self:
+ self.db.update(self.table, where="session_id=$key", data=pickled, vars=locals())
+ else:
+ self.db.insert(self.table, False, session_id=key, data=pickled )
+
+ def __delitem__(self, key):
+ self.db.delete(self.table, where="session_id=$key", vars=locals())
+
+ def cleanup(self, timeout):
+ timeout = datetime.timedelta(timeout/(24.0*60*60)) #timedelta takes numdays as arg
+ last_allowed_time = datetime.datetime.now() - timeout
+ self.db.delete(self.table, where="$last_allowed_time > atime", vars=locals())
+
+class ShelfStore:
+ """Store for saving session using `shelve` module.
+
+ import shelve
+ store = ShelfStore(shelve.open('session.shelf'))
+
+ XXX: is shelve thread-safe?
+ """
+ def __init__(self, shelf):
+ self.shelf = shelf
+
+ def __contains__(self, key):
+ return key in self.shelf
+
+ def __getitem__(self, key):
+ atime, v = self.shelf[key]
+ self[key] = v # update atime
+ return v
+
+ def __setitem__(self, key, value):
+ self.shelf[key] = time.time(), value
+
+ def __delitem__(self, key):
+ try:
+ del self.shelf[key]
+ except KeyError:
+ pass
+
+ def cleanup(self, timeout):
+ now = time.time()
+ for k in self.shelf.keys():
+ atime, v = self.shelf[k]
+ if now - atime > timeout :
+ del self[k]
+
+if __name__ == '__main__' :
+ import doctest
+ doctest.testmod()
diff --git a/web/template.py b/web/template.py
index 4a37e58..e886d35 100644
--- a/web/template.py
+++ b/web/template.py
@@ -1,1515 +1,1515 @@
-"""
-simple, elegant templating
-(part of web.py)
-
-Template design:
-
-Template string is split into tokens and the tokens are combined into nodes.
-Parse tree is a nodelist. TextNode and ExpressionNode are simple nodes and
-for-loop, if-loop etc are block nodes, which contain multiple child nodes.
-
-Each node can emit some python string. python string emitted by the
-root node is validated for safeeval and executed using python in the given environment.
-
-Enough care is taken to make sure the generated code and the template has line to line match,
-so that the error messages can point to exact line number in template. (It doesn't work in some cases still.)
-
-Grammar:
-
- template -> defwith sections
- defwith -> '$def with (' arguments ')' | ''
- sections -> section*
- section -> block | assignment | line
-
- assignment -> '$ '
- line -> (text|expr)*
- text ->
- expr -> '$' pyexpr | '$(' pyexpr ')' | '${' pyexpr '}'
- pyexpr ->
-"""
-
-__all__ = [
- "Template",
- "Render", "render", "frender",
- "ParseError", "SecurityError",
- "test"
-]
-
-import tokenize
-import os
-import sys
-import glob
-import re
-from UserDict import DictMixin
-import warnings
-
-from utils import storage, safeunicode, safestr, re_compile
-from webapi import config
-from net import websafe
-
-def splitline(text):
- r"""
- Splits the given text at newline.
-
- >>> splitline('foo\nbar')
- ('foo\n', 'bar')
- >>> splitline('foo')
- ('foo', '')
- >>> splitline('')
- ('', '')
- """
- index = text.find('\n') + 1
- if index:
- return text[:index], text[index:]
- else:
- return text, ''
-
-class Parser:
- """Parser Base.
- """
- def __init__(self):
- self.statement_nodes = STATEMENT_NODES
- self.keywords = KEYWORDS
-
- def parse(self, text, name=""):
- self.text = text
- self.name = name
-
- defwith, text = self.read_defwith(text)
- suite = self.read_suite(text)
- return DefwithNode(defwith, suite)
-
- def read_defwith(self, text):
- if text.startswith('$def with'):
- defwith, text = splitline(text)
- defwith = defwith[1:].strip() # strip $ and spaces
- return defwith, text
- else:
- return '', text
-
- def read_section(self, text):
- r"""Reads one section from the given text.
-
- section -> block | assignment | line
-
- >>> read_section = Parser().read_section
- >>> read_section('foo\nbar\n')
- (, 'bar\n')
- >>> read_section('$ a = b + 1\nfoo\n')
- (, 'foo\n')
-
- read_section('$for in range(10):\n hello $i\nfoo)
- """
- if text.lstrip(' ').startswith('$'):
- index = text.index('$')
- begin_indent, text2 = text[:index], text[index+1:]
- ahead = self.python_lookahead(text2)
-
- if ahead == 'var':
- return self.read_var(text2)
- elif ahead in self.statement_nodes:
- return self.read_block_section(text2, begin_indent)
- elif ahead in self.keywords:
- return self.read_keyword(text2)
- elif ahead.strip() == '':
- # assignments starts with a space after $
- # ex: $ a = b + 2
- return self.read_assignment(text2)
- return self.readline(text)
-
- def read_var(self, text):
- r"""Reads a var statement.
-
- >>> read_var = Parser().read_var
- >>> read_var('var x=10\nfoo')
- (, 'foo')
- >>> read_var('var x: hello $name\nfoo')
- (, 'foo')
- """
- line, text = splitline(text)
- tokens = self.python_tokens(line)
- if len(tokens) < 4:
- raise SyntaxError('Invalid var statement')
-
- name = tokens[1]
- sep = tokens[2]
- value = line.split(sep, 1)[1].strip()
-
- if sep == '=':
- pass # no need to process value
- elif sep == ':':
- #@@ Hack for backward-compatability
- if tokens[3] == '\n': # multi-line var statement
- block, text = self.read_indented_block(text, ' ')
- lines = [self.readline(x)[0] for x in block.splitlines()]
- nodes = []
- for x in lines:
- nodes.extend(x.nodes)
- nodes.append(TextNode('\n'))
- else: # single-line var statement
- linenode, _ = self.readline(value)
- nodes = linenode.nodes
- parts = [node.emit('') for node in nodes]
- value = "join_(%s)" % ", ".join(parts)
- else:
- raise SyntaxError('Invalid var statement')
- return VarNode(name, value), text
-
- def read_suite(self, text):
- r"""Reads section by section till end of text.
-
- >>> read_suite = Parser().read_suite
- >>> read_suite('hello $name\nfoo\n')
- [, ]
- """
- sections = []
- while text:
- section, text = self.read_section(text)
- sections.append(section)
- return SuiteNode(sections)
-
- def readline(self, text):
- r"""Reads one line from the text. Newline is supressed if the line ends with \.
-
- >>> readline = Parser().readline
- >>> readline('hello $name!\nbye!')
- (, 'bye!')
- >>> readline('hello $name!\\\nbye!')
- (, 'bye!')
- >>> readline('$f()\n\n')
- (, '\n')
- """
- line, text = splitline(text)
-
- # supress new line if line ends with \
- if line.endswith('\\\n'):
- line = line[:-2]
-
- nodes = []
- while line:
- node, line = self.read_node(line)
- nodes.append(node)
-
- return LineNode(nodes), text
-
- def read_node(self, text):
- r"""Reads a node from the given text and returns the node and remaining text.
-
- >>> read_node = Parser().read_node
- >>> read_node('hello $name')
- (t'hello ', '$name')
- >>> read_node('$name')
- ($name, '')
- """
- if text.startswith('$$'):
- return TextNode('$'), text[2:]
- elif text.startswith('$#'): # comment
- line, text = splitline(text)
- return TextNode('\n'), text
- elif text.startswith('$'):
- text = text[1:] # strip $
- if text.startswith(':'):
- escape = False
- text = text[1:] # strip :
- else:
- escape = True
- return self.read_expr(text, escape=escape)
- else:
- return self.read_text(text)
-
- def read_text(self, text):
- r"""Reads a text node from the given text.
-
- >>> read_text = Parser().read_text
- >>> read_text('hello $name')
- (t'hello ', '$name')
- """
- index = text.find('$')
- if index < 0:
- return TextNode(text), ''
- else:
- return TextNode(text[:index]), text[index:]
-
- def read_keyword(self, text):
- line, text = splitline(text)
- return StatementNode(line.strip() + "\n"), text
-
- def read_expr(self, text, escape=True):
- """Reads a python expression from the text and returns the expression and remaining text.
-
- expr -> simple_expr | paren_expr
- simple_expr -> id extended_expr
- extended_expr -> attr_access | paren_expr extended_expr | ''
- attr_access -> dot id extended_expr
- paren_expr -> [ tokens ] | ( tokens ) | { tokens }
-
- >>> read_expr = Parser().read_expr
- >>> read_expr("name")
- ($name, '')
- >>> read_expr("a.b and c")
- ($a.b, ' and c')
- >>> read_expr("a. b")
- ($a, '. b')
- >>> read_expr("name")
- ($name, '')
- >>> read_expr("(limit)ing")
- ($(limit), 'ing')
- >>> read_expr('a[1, 2][:3].f(1+2, "weird string[).", 3 + 4) done.')
- ($a[1, 2][:3].f(1+2, "weird string[).", 3 + 4), ' done.')
- """
- def simple_expr():
- identifier()
- extended_expr()
-
- def identifier():
- tokens.next()
-
- def extended_expr():
- lookahead = tokens.lookahead()
- if lookahead is None:
- return
- elif lookahead.value == '.':
- attr_access()
- elif lookahead.value in parens:
- paren_expr()
- extended_expr()
- else:
- return
-
- def attr_access():
- from token import NAME # python token constants
- dot = tokens.lookahead()
- if tokens.lookahead2().type == NAME:
- tokens.next() # consume dot
- identifier()
- extended_expr()
-
- def paren_expr():
- begin = tokens.next().value
- end = parens[begin]
- while True:
- if tokens.lookahead().value in parens:
- paren_expr()
- else:
- t = tokens.next()
- if t.value == end:
- break
- return
-
- parens = {
- "(": ")",
- "[": "]",
- "{": "}"
- }
-
- def get_tokens(text):
- """tokenize text using python tokenizer.
- Python tokenizer ignores spaces, but they might be important in some cases.
- This function introduces dummy space tokens when it identifies any ignored space.
- Each token is a storage object containing type, value, begin and end.
- """
- readline = iter([text]).next
- end = None
- for t in tokenize.generate_tokens(readline):
- t = storage(type=t[0], value=t[1], begin=t[2], end=t[3])
- if end is not None and end != t.begin:
- _, x1 = end
- _, x2 = t.begin
- yield storage(type=-1, value=text[x1:x2], begin=end, end=t.begin)
- end = t.end
- yield t
-
- class BetterIter:
- """Iterator like object with 2 support for 2 look aheads."""
- def __init__(self, items):
- self.iteritems = iter(items)
- self.items = []
- self.position = 0
- self.current_item = None
-
- def lookahead(self):
- if len(self.items) <= self.position:
- self.items.append(self._next())
- return self.items[self.position]
-
- def _next(self):
- try:
- return self.iteritems.next()
- except StopIteration:
- return None
-
- def lookahead2(self):
- if len(self.items) <= self.position+1:
- self.items.append(self._next())
- return self.items[self.position+1]
-
- def next(self):
- self.current_item = self.lookahead()
- self.position += 1
- return self.current_item
-
- tokens = BetterIter(get_tokens(text))
-
- if tokens.lookahead().value in parens:
- paren_expr()
- else:
- simple_expr()
- row, col = tokens.current_item.end
- return ExpressionNode(text[:col], escape=escape), text[col:]
-
- def read_assignment(self, text):
- r"""Reads assignment statement from text.
-
- >>> read_assignment = Parser().read_assignment
- >>> read_assignment('a = b + 1\nfoo')
- (, 'foo')
- """
- line, text = splitline(text)
- return AssignmentNode(line.strip()), text
-
- def python_lookahead(self, text):
- """Returns the first python token from the given text.
-
- >>> python_lookahead = Parser().python_lookahead
- >>> python_lookahead('for i in range(10):')
- 'for'
- >>> python_lookahead('else:')
- 'else'
- >>> python_lookahead(' x = 1')
- ' '
- """
- readline = iter([text]).next
- tokens = tokenize.generate_tokens(readline)
- return tokens.next()[1]
-
- def python_tokens(self, text):
- readline = iter([text]).next
- tokens = tokenize.generate_tokens(readline)
- return [t[1] for t in tokens]
-
- def read_indented_block(self, text, indent):
- r"""Read a block of text. A block is what typically follows a for or it statement.
- It can be in the same line as that of the statement or an indented block.
-
- >>> read_indented_block = Parser().read_indented_block
- >>> read_indented_block(' a\n b\nc', ' ')
- ('a\nb\n', 'c')
- >>> read_indented_block(' a\n b\n c\nd', ' ')
- ('a\n b\nc\n', 'd')
- >>> read_indented_block(' a\n\n b\nc', ' ')
- ('a\n\n b\n', 'c')
- """
- if indent == '':
- return '', text
-
- block = ""
- while text:
- line, text2 = splitline(text)
- if line.strip() == "":
- block += '\n'
- elif line.startswith(indent):
- block += line[len(indent):]
- else:
- break
- text = text2
- return block, text
-
- def read_statement(self, text):
- r"""Reads a python statement.
-
- >>> read_statement = Parser().read_statement
- >>> read_statement('for i in range(10): hello $name')
- ('for i in range(10):', ' hello $name')
- """
- tok = PythonTokenizer(text)
- tok.consume_till(':')
- return text[:tok.index], text[tok.index:]
-
- def read_block_section(self, text, begin_indent=''):
- r"""
- >>> read_block_section = Parser().read_block_section
- >>> read_block_section('for i in range(10): hello $i\nfoo')
- (]>, 'foo')
- >>> read_block_section('for i in range(10):\n hello $i\n foo', begin_indent=' ')
- (]>, ' foo')
- >>> read_block_section('for i in range(10):\n hello $i\nfoo')
- (]>, 'foo')
- """
- line, text = splitline(text)
- stmt, line = self.read_statement(line)
- keyword = self.python_lookahead(stmt)
-
- # if there is some thing left in the line
- if line.strip():
- block = line.lstrip()
- else:
- def find_indent(text):
- rx = re_compile(' +')
- match = rx.match(text)
- first_indent = match and match.group(0)
- return first_indent or ""
-
- # find the indentation of the block by looking at the first line
- first_indent = find_indent(text)[len(begin_indent):]
-
- #TODO: fix this special case
- if keyword == "code":
- indent = begin_indent + first_indent
- else:
- indent = begin_indent + min(first_indent, INDENT)
-
- block, text = self.read_indented_block(text, indent)
-
- return self.create_block_node(keyword, stmt, block, begin_indent), text
-
- def create_block_node(self, keyword, stmt, block, begin_indent):
- if keyword in self.statement_nodes:
- return self.statement_nodes[keyword](stmt, block, begin_indent)
- else:
- raise ParseError, 'Unknown statement: %s' % repr(keyword)
-
-class PythonTokenizer:
- """Utility wrapper over python tokenizer."""
- def __init__(self, text):
- self.text = text
- readline = iter([text]).next
- self.tokens = tokenize.generate_tokens(readline)
- self.index = 0
-
- def consume_till(self, delim):
- """Consumes tokens till colon.
-
- >>> tok = PythonTokenizer('for i in range(10): hello $i')
- >>> tok.consume_till(':')
- >>> tok.text[:tok.index]
- 'for i in range(10):'
- >>> tok.text[tok.index:]
- ' hello $i'
- """
- try:
- while True:
- t = self.next()
- if t.value == delim:
- break
- elif t.value == '(':
- self.consume_till(')')
- elif t.value == '[':
- self.consume_till(']')
- elif t.value == '{':
- self.consume_till('}')
-
- # if end of line is found, it is an exception.
- # Since there is no easy way to report the line number,
- # leave the error reporting to the python parser later
- #@@ This should be fixed.
- if t.value == '\n':
- break
- except:
- #raise ParseError, "Expected %s, found end of line." % repr(delim)
-
- # raising ParseError doesn't show the line number.
- # if this error is ignored, then it will be caught when compiling the python code.
- return
-
- def next(self):
- type, t, begin, end, line = self.tokens.next()
- row, col = end
- self.index = col
- return storage(type=type, value=t, begin=begin, end=end)
-
-class DefwithNode:
- def __init__(self, defwith, suite):
- if defwith:
- self.defwith = defwith.replace('with', '__template__') + ':'
- # offset 4 lines. for encoding, __lineoffset__, loop and self.
- self.defwith += "\n __lineoffset__ = -4"
- else:
- self.defwith = 'def __template__():'
- # offset 4 lines for encoding, __template__, __lineoffset__, loop and self.
- self.defwith += "\n __lineoffset__ = -5"
-
- self.defwith += "\n loop = ForLoop()"
- self.defwith += "\n self = TemplateResult(); extend_ = self.extend"
- self.suite = suite
- self.end = "\n return self"
-
- def emit(self, indent):
- encoding = "# coding: utf-8\n"
- return encoding + self.defwith + self.suite.emit(indent + INDENT) + self.end
-
- def __repr__(self):
- return "" % (self.defwith, self.suite)
-
-class TextNode:
- def __init__(self, value):
- self.value = value
-
- def emit(self, indent, begin_indent=''):
- return repr(safeunicode(self.value))
-
- def __repr__(self):
- return 't' + repr(self.value)
-
-class ExpressionNode:
- def __init__(self, value, escape=True):
- self.value = value.strip()
-
- # convert ${...} to $(...)
- if value.startswith('{') and value.endswith('}'):
- self.value = '(' + self.value[1:-1] + ')'
-
- self.escape = escape
-
- def emit(self, indent, begin_indent=''):
- return 'escape_(%s, %s)' % (self.value, bool(self.escape))
-
- def __repr__(self):
- if self.escape:
- escape = ''
- else:
- escape = ':'
- return "$%s%s" % (escape, self.value)
-
-class AssignmentNode:
- def __init__(self, code):
- self.code = code
-
- def emit(self, indent, begin_indent=''):
- return indent + self.code + "\n"
-
- def __repr__(self):
- return "" % repr(self.code)
-
-class LineNode:
- def __init__(self, nodes):
- self.nodes = nodes
-
- def emit(self, indent, text_indent='', name=''):
- text = [node.emit('') for node in self.nodes]
- if text_indent:
- text = [repr(text_indent)] + text
-
- return indent + "extend_([%s])\n" % ", ".join(text)
-
- def __repr__(self):
- return "" % repr(self.nodes)
-
-INDENT = ' ' # 4 spaces
-
-class BlockNode:
- def __init__(self, stmt, block, begin_indent=''):
- self.stmt = stmt
- self.suite = Parser().read_suite(block)
- self.begin_indent = begin_indent
-
- def emit(self, indent, text_indent=''):
- text_indent = self.begin_indent + text_indent
- out = indent + self.stmt + self.suite.emit(indent + INDENT, text_indent)
- return out
-
- def __repr__(self):
- return "" % (repr(self.stmt), repr(self.suite))
-
-class ForNode(BlockNode):
- def __init__(self, stmt, block, begin_indent=''):
- self.original_stmt = stmt
- tok = PythonTokenizer(stmt)
- tok.consume_till('in')
- a = stmt[:tok.index] # for i in
- b = stmt[tok.index:-1] # rest of for stmt excluding :
- stmt = a + ' loop.setup(' + b.strip() + '):'
- BlockNode.__init__(self, stmt, block, begin_indent)
-
- def __repr__(self):
- return "" % (repr(self.original_stmt), repr(self.suite))
-
-class CodeNode:
- def __init__(self, stmt, block, begin_indent=''):
- # compensate one line for $code:
- self.code = "\n" + block
-
- def emit(self, indent, text_indent=''):
- import re
- rx = re.compile('^', re.M)
- return rx.sub(indent, self.code).rstrip(' ')
-
- def __repr__(self):
- return "" % repr(self.code)
-
-class StatementNode:
- def __init__(self, stmt):
- self.stmt = stmt
-
- def emit(self, indent, begin_indent=''):
- return indent + self.stmt
-
- def __repr__(self):
- return "" % repr(self.stmt)
-
-class IfNode(BlockNode):
- pass
-
-class ElseNode(BlockNode):
- pass
-
-class ElifNode(BlockNode):
- pass
-
-class DefNode(BlockNode):
- def __init__(self, *a, **kw):
- BlockNode.__init__(self, *a, **kw)
-
- code = CodeNode("", "")
- code.code = "self = TemplateResult(); extend_ = self.extend\n"
- self.suite.sections.insert(0, code)
-
- code = CodeNode("", "")
- code.code = "return self\n"
- self.suite.sections.append(code)
-
- def emit(self, indent, text_indent=''):
- text_indent = self.begin_indent + text_indent
- out = indent + self.stmt + self.suite.emit(indent + INDENT, text_indent)
- return indent + "__lineoffset__ -= 3\n" + out
-
-class VarNode:
- def __init__(self, name, value):
- self.name = name
- self.value = value
-
- def emit(self, indent, text_indent):
- return indent + "self[%s] = %s\n" % (repr(self.name), self.value)
-
- def __repr__(self):
- return "" % (self.name, self.value)
-
-class SuiteNode:
- """Suite is a list of sections."""
- def __init__(self, sections):
- self.sections = sections
-
- def emit(self, indent, text_indent=''):
- return "\n" + "".join([s.emit(indent, text_indent) for s in self.sections])
-
- def __repr__(self):
- return repr(self.sections)
-
-STATEMENT_NODES = {
- 'for': ForNode,
- 'while': BlockNode,
- 'if': IfNode,
- 'elif': ElifNode,
- 'else': ElseNode,
- 'def': DefNode,
- 'code': CodeNode
-}
-
-KEYWORDS = [
- "pass",
- "break",
- "continue",
- "return"
-]
-
-TEMPLATE_BUILTIN_NAMES = [
- "dict", "enumerate", "float", "int", "bool", "list", "long", "reversed",
- "set", "slice", "tuple", "xrange",
- "abs", "all", "any", "callable", "chr", "cmp", "divmod", "filter", "hex",
- "id", "isinstance", "iter", "len", "max", "min", "oct", "ord", "pow", "range",
- "True", "False",
- "None",
- "__import__", # some c-libraries like datetime requires __import__ to present in the namespace
-]
-
-import __builtin__
-TEMPLATE_BUILTINS = dict([(name, getattr(__builtin__, name)) for name in TEMPLATE_BUILTIN_NAMES if name in __builtin__.__dict__])
-
-class ForLoop:
- """
- Wrapper for expression in for stament to support loop.xxx helpers.
-
- >>> loop = ForLoop()
- >>> for x in loop.setup(['a', 'b', 'c']):
- ... print loop.index, loop.revindex, loop.parity, x
- ...
- 1 3 odd a
- 2 2 even b
- 3 1 odd c
- >>> loop.index
- Traceback (most recent call last):
- ...
- AttributeError: index
- """
- def __init__(self):
- self._ctx = None
-
- def __getattr__(self, name):
- if self._ctx is None:
- raise AttributeError, name
- else:
- return getattr(self._ctx, name)
-
- def setup(self, seq):
- self._push()
- return self._ctx.setup(seq)
-
- def _push(self):
- self._ctx = ForLoopContext(self, self._ctx)
-
- def _pop(self):
- self._ctx = self._ctx.parent
-
-class ForLoopContext:
- """Stackable context for ForLoop to support nested for loops.
- """
- def __init__(self, forloop, parent):
- self._forloop = forloop
- self.parent = parent
-
- def setup(self, seq):
- try:
- self.length = len(seq)
- except:
- self.length = 0
-
- self.index = 0
- for a in seq:
- self.index += 1
- yield a
- self._forloop._pop()
-
- index0 = property(lambda self: self.index-1)
- first = property(lambda self: self.index == 1)
- last = property(lambda self: self.index == self.length)
- odd = property(lambda self: self.index % 2 == 1)
- even = property(lambda self: self.index % 2 == 0)
- parity = property(lambda self: ['odd', 'even'][self.even])
- revindex0 = property(lambda self: self.length - self.index)
- revindex = property(lambda self: self.length - self.index + 1)
-
-class BaseTemplate:
- def __init__(self, code, filename, filter, globals, builtins):
- self.filename = filename
- self.filter = filter
- self._globals = globals
- self._builtins = builtins
- if code:
- self.t = self._compile(code)
- else:
- self.t = lambda: ''
-
- def _compile(self, code):
- env = self.make_env(self._globals or {}, self._builtins)
- exec(code, env)
- return env['__template__']
-
- def __call__(self, *a, **kw):
- __hidetraceback__ = True
- return self.t(*a, **kw)
-
- def make_env(self, globals, builtins):
- return dict(globals,
- __builtins__=builtins,
- ForLoop=ForLoop,
- TemplateResult=TemplateResult,
- escape_=self._escape,
- join_=self._join
- )
- def _join(self, *items):
- return u"".join(items)
-
- def _escape(self, value, escape=False):
- if value is None:
- value = ''
-
- value = safeunicode(value)
- if escape and self.filter:
- value = self.filter(value)
- return value
-
-class Template(BaseTemplate):
- CONTENT_TYPES = {
- '.html' : 'text/html; charset=utf-8',
- '.xhtml' : 'application/xhtml+xml; charset=utf-8',
- '.txt' : 'text/plain',
- }
- FILTERS = {
- '.html': websafe,
- '.xhtml': websafe,
- '.xml': websafe
- }
- globals = {}
-
- def __init__(self, text, filename='', filter=None, globals=None, builtins=None, extensions=None):
- self.extensions = extensions or []
- text = Template.normalize_text(text)
- code = self.compile_template(text, filename)
-
- _, ext = os.path.splitext(filename)
- filter = filter or self.FILTERS.get(ext, None)
- self.content_type = self.CONTENT_TYPES.get(ext, None)
-
- if globals is None:
- globals = self.globals
- if builtins is None:
- builtins = TEMPLATE_BUILTINS
-
- BaseTemplate.__init__(self, code=code, filename=filename, filter=filter, globals=globals, builtins=builtins)
-
- def normalize_text(text):
- """Normalizes template text by correcting \r\n, tabs and BOM chars."""
- text = text.replace('\r\n', '\n').replace('\r', '\n').expandtabs()
- if not text.endswith('\n'):
- text += '\n'
-
- # ignore BOM chars at the begining of template
- BOM = '\xef\xbb\xbf'
- if isinstance(text, str) and text.startswith(BOM):
- text = text[len(BOM):]
-
- # support fort \$ for backward-compatibility
- text = text.replace(r'\$', '$$')
- return text
- normalize_text = staticmethod(normalize_text)
-
- def __call__(self, *a, **kw):
- __hidetraceback__ = True
- import webapi as web
- if 'headers' in web.ctx and self.content_type:
- web.header('Content-Type', self.content_type, unique=True)
-
- return BaseTemplate.__call__(self, *a, **kw)
-
- def generate_code(text, filename, parser=None):
- # parse the text
- parser = parser or Parser()
- rootnode = parser.parse(text, filename)
-
- # generate python code from the parse tree
- code = rootnode.emit(indent="").strip()
- return safestr(code)
-
- generate_code = staticmethod(generate_code)
-
- def create_parser(self):
- p = Parser()
- for ext in self.extensions:
- p = ext(p)
- return p
-
- def compile_template(self, template_string, filename):
- code = Template.generate_code(template_string, filename, parser=self.create_parser())
-
- def get_source_line(filename, lineno):
- try:
- lines = open(filename).read().splitlines()
- return lines[lineno]
- except:
- return None
-
- try:
- # compile the code first to report the errors, if any, with the filename
- compiled_code = compile(code, filename, 'exec')
- except SyntaxError, e:
- # display template line that caused the error along with the traceback.
- try:
- e.msg += '\n\nTemplate traceback:\n File %s, line %s\n %s' % \
- (repr(e.filename), e.lineno, get_source_line(e.filename, e.lineno-1))
- except:
- pass
- raise
-
- # make sure code is safe - but not with jython, it doesn't have a working compiler module
- if not sys.platform.startswith('java'):
- try:
- import compiler
- ast = compiler.parse(code)
- SafeVisitor().walk(ast, filename)
- except ImportError:
- warnings.warn("Unabled to import compiler module. Unable to check templates for safety.")
- else:
- warnings.warn("SECURITY ISSUE: You are using Jython, which does not support checking templates for safety. Your templates can execute arbitrary code.")
-
- return compiled_code
-
-class CompiledTemplate(Template):
- def __init__(self, f, filename):
- Template.__init__(self, '', filename)
- self.t = f
-
- def compile_template(self, *a):
- return None
-
- def _compile(self, *a):
- return None
-
-class Render:
- """The most preferred way of using templates.
-
- render = web.template.render('templates')
- print render.foo()
-
- Optional parameter can be `base` can be used to pass output of
- every template through the base template.
-
- render = web.template.render('templates', base='layout')
- """
- def __init__(self, loc='templates', cache=None, base=None, **keywords):
- self._loc = loc
- self._keywords = keywords
-
- if cache is None:
- cache = not config.get('debug', False)
-
- if cache:
- self._cache = {}
- else:
- self._cache = None
-
- if base and not hasattr(base, '__call__'):
- # make base a function, so that it can be passed to sub-renders
- self._base = lambda page: self._template(base)(page)
- else:
- self._base = base
-
- def _add_global(self, obj, name=None):
- """Add a global to this rendering instance."""
- if 'globals' not in self._keywords: self._keywords['globals'] = {}
- if not name:
- name = obj.__name__
- self._keywords['globals'][name] = obj
-
- def _lookup(self, name):
- path = os.path.join(self._loc, name)
- if os.path.isdir(path):
- return 'dir', path
- else:
- path = self._findfile(path)
- if path:
- return 'file', path
- else:
- return 'none', None
-
- def _load_template(self, name):
- kind, path = self._lookup(name)
-
- if kind == 'dir':
- return Render(path, cache=self._cache is not None, base=self._base, **self._keywords)
- elif kind == 'file':
- return Template(open(path).read(), filename=path, **self._keywords)
- else:
- raise AttributeError, "No template named " + name
-
- def _findfile(self, path_prefix):
- p = [f for f in glob.glob(path_prefix + '.*') if not f.endswith('~')] # skip backup files
- p.sort() # sort the matches for deterministic order
- return p and p[0]
-
- def _template(self, name):
- if self._cache is not None:
- if name not in self._cache:
- self._cache[name] = self._load_template(name)
- return self._cache[name]
- else:
- return self._load_template(name)
-
- def __getattr__(self, name):
- t = self._template(name)
- if self._base and isinstance(t, Template):
- def template(*a, **kw):
- return self._base(t(*a, **kw))
- return template
- else:
- return self._template(name)
-
-class GAE_Render(Render):
- # Render gets over-written. make a copy here.
- super = Render
- def __init__(self, loc, *a, **kw):
- GAE_Render.super.__init__(self, loc, *a, **kw)
-
- import types
- if isinstance(loc, types.ModuleType):
- self.mod = loc
- else:
- name = loc.rstrip('/').replace('/', '.')
- self.mod = __import__(name, None, None, ['x'])
-
- self.mod.__dict__.update(kw.get('builtins', TEMPLATE_BUILTINS))
- self.mod.__dict__.update(Template.globals)
- self.mod.__dict__.update(kw.get('globals', {}))
-
- def _load_template(self, name):
- t = getattr(self.mod, name)
- import types
- if isinstance(t, types.ModuleType):
- return GAE_Render(t, cache=self._cache is not None, base=self._base, **self._keywords)
- else:
- return t
-
-render = Render
-# setup render for Google App Engine.
-try:
- from google import appengine
- render = Render = GAE_Render
-except ImportError:
- pass
-
-def frender(path, **keywords):
- """Creates a template from the given file path.
- """
- return Template(open(path).read(), filename=path, **keywords)
-
-def compile_templates(root):
- """Compiles templates to python code."""
- re_start = re_compile('^', re.M)
-
- for dirpath, dirnames, filenames in os.walk(root):
- filenames = [f for f in filenames if not f.startswith('.') and not f.endswith('~') and not f.startswith('__init__.py')]
-
- for d in dirnames[:]:
- if d.startswith('.'):
- dirnames.remove(d) # don't visit this dir
-
- out = open(os.path.join(dirpath, '__init__.py'), 'w')
- out.write('from web.template import CompiledTemplate, ForLoop, TemplateResult\n\n')
- if dirnames:
- out.write("import " + ", ".join(dirnames))
- out.write("\n")
-
- for f in filenames:
- path = os.path.join(dirpath, f)
-
- if '.' in f:
- name, _ = f.split('.', 1)
- else:
- name = f
-
- text = open(path).read()
- text = Template.normalize_text(text)
- code = Template.generate_code(text, path)
-
- code = code.replace("__template__", name, 1)
-
- out.write(code)
-
- out.write('\n\n')
- out.write('%s = CompiledTemplate(%s, %s)\n' % (name, name, repr(path)))
- out.write("join_ = %s._join; escape_ = %s._escape\n\n" % (name, name))
-
- # create template to make sure it compiles
- t = Template(open(path).read(), path)
- out.close()
-
-class ParseError(Exception):
- pass
-
-class SecurityError(Exception):
- """The template seems to be trying to do something naughty."""
- pass
-
-# Enumerate all the allowed AST nodes
-ALLOWED_AST_NODES = [
- "Add", "And",
-# "AssAttr",
- "AssList", "AssName", "AssTuple",
-# "Assert",
- "Assign", "AugAssign",
-# "Backquote",
- "Bitand", "Bitor", "Bitxor", "Break",
- "CallFunc","Class", "Compare", "Const", "Continue",
- "Decorators", "Dict", "Discard", "Div",
- "Ellipsis", "EmptyNode",
-# "Exec",
- "Expression", "FloorDiv", "For",
-# "From",
- "Function",
- "GenExpr", "GenExprFor", "GenExprIf", "GenExprInner",
- "Getattr",
-# "Global",
- "If", "IfExp",
-# "Import",
- "Invert", "Keyword", "Lambda", "LeftShift",
- "List", "ListComp", "ListCompFor", "ListCompIf", "Mod",
- "Module",
- "Mul", "Name", "Not", "Or", "Pass", "Power",
-# "Print", "Printnl", "Raise",
- "Return", "RightShift", "Slice", "Sliceobj",
- "Stmt", "Sub", "Subscript",
-# "TryExcept", "TryFinally",
- "Tuple", "UnaryAdd", "UnarySub",
- "While", "With", "Yield",
-]
-
-class SafeVisitor(object):
- """
- Make sure code is safe by walking through the AST.
-
- Code considered unsafe if:
- * it has restricted AST nodes
- * it is trying to access resricted attributes
-
- Adopted from http://www.zafar.se/bkz/uploads/safe.txt (public domain, Babar K. Zafar)
- """
- def __init__(self):
- "Initialize visitor by generating callbacks for all AST node types."
- self.errors = []
-
- def walk(self, ast, filename):
- "Validate each node in AST and raise SecurityError if the code is not safe."
- self.filename = filename
- self.visit(ast)
-
- if self.errors:
- raise SecurityError, '\n'.join([str(err) for err in self.errors])
-
- def visit(self, node, *args):
- "Recursively validate node and all of its children."
- def classname(obj):
- return obj.__class__.__name__
- nodename = classname(node)
- fn = getattr(self, 'visit' + nodename, None)
-
- if fn:
- fn(node, *args)
- else:
- if nodename not in ALLOWED_AST_NODES:
- self.fail(node, *args)
-
- for child in node.getChildNodes():
- self.visit(child, *args)
-
- def visitName(self, node, *args):
- "Disallow any attempts to access a restricted attr."
- #self.assert_attr(node.getChildren()[0], node)
- pass
-
- def visitGetattr(self, node, *args):
- "Disallow any attempts to access a restricted attribute."
- self.assert_attr(node.attrname, node)
-
- def assert_attr(self, attrname, node):
- if self.is_unallowed_attr(attrname):
- lineno = self.get_node_lineno(node)
- e = SecurityError("%s:%d - access to attribute '%s' is denied" % (self.filename, lineno, attrname))
- self.errors.append(e)
-
- def is_unallowed_attr(self, name):
- return name.startswith('_') \
- or name.startswith('func_') \
- or name.startswith('im_')
-
- def get_node_lineno(self, node):
- return (node.lineno) and node.lineno or 0
-
- def fail(self, node, *args):
- "Default callback for unallowed AST nodes."
- lineno = self.get_node_lineno(node)
- nodename = node.__class__.__name__
- e = SecurityError("%s:%d - execution of '%s' statements is denied" % (self.filename, lineno, nodename))
- self.errors.append(e)
-
-class TemplateResult(object, DictMixin):
- """Dictionary like object for storing template output.
-
- The result of a template execution is usally a string, but sometimes it
- contains attributes set using $var. This class provides a simple
- dictionary like interface for storing the output of the template and the
- attributes. The output is stored with a special key __body__. Convering
- the the TemplateResult to string or unicode returns the value of __body__.
-
- When the template is in execution, the output is generated part by part
- and those parts are combined at the end. Parts are added to the
- TemplateResult by calling the `extend` method and the parts are combined
- seemlessly when __body__ is accessed.
-
- >>> d = TemplateResult(__body__='hello, world', x='foo')
- >>> d
-
- >>> print d
- hello, world
- >>> d.x
- 'foo'
- >>> d = TemplateResult()
- >>> d.extend([u'hello', u'world'])
- >>> d
-
- """
- def __init__(self, *a, **kw):
- self.__dict__["_d"] = dict(*a, **kw)
- self._d.setdefault("__body__", u'')
-
- self.__dict__['_parts'] = []
- self.__dict__["extend"] = self._parts.extend
-
- self._d.setdefault("__body__", None)
-
- def keys(self):
- return self._d.keys()
-
- def _prepare_body(self):
- """Prepare value of __body__ by joining parts.
- """
- if self._parts:
- value = u"".join(self._parts)
- self._parts[:] = []
- body = self._d.get('__body__')
- if body:
- self._d['__body__'] = body + value
- else:
- self._d['__body__'] = value
-
- def __getitem__(self, name):
- if name == "__body__":
- self._prepare_body()
- return self._d[name]
-
- def __setitem__(self, name, value):
- if name == "__body__":
- self._prepare_body()
- return self._d.__setitem__(name, value)
-
- def __delitem__(self, name):
- if name == "__body__":
- self._prepare_body()
- return self._d.__delitem__(name)
-
- def __getattr__(self, key):
- try:
- return self[key]
- except KeyError, k:
- raise AttributeError, k
-
- def __setattr__(self, key, value):
- self[key] = value
-
- def __delattr__(self, key):
- try:
- del self[key]
- except KeyError, k:
- raise AttributeError, k
-
- def __unicode__(self):
- self._prepare_body()
- return self["__body__"]
-
- def __str__(self):
- self._prepare_body()
- return self["__body__"].encode('utf-8')
-
- def __repr__(self):
- self._prepare_body()
- return "" % self._d
-
-def test():
- r"""Doctest for testing template module.
-
- Define a utility function to run template test.
-
- >>> class TestResult:
- ... def __init__(self, t): self.t = t
- ... def __getattr__(self, name): return getattr(self.t, name)
- ... def __repr__(self): return repr(unicode(self))
- ...
- >>> def t(code, **keywords):
- ... tmpl = Template(code, **keywords)
- ... return lambda *a, **kw: TestResult(tmpl(*a, **kw))
- ...
-
- Simple tests.
-
- >>> t('1')()
- u'1\n'
- >>> t('$def with ()\n1')()
- u'1\n'
- >>> t('$def with (a)\n$a')(1)
- u'1\n'
- >>> t('$def with (a=0)\n$a')(1)
- u'1\n'
- >>> t('$def with (a=0)\n$a')(a=1)
- u'1\n'
-
- Test complicated expressions.
-
- >>> t('$def with (x)\n$x.upper()')('hello')
- u'HELLO\n'
- >>> t('$(2 * 3 + 4 * 5)')()
- u'26\n'
- >>> t('${2 * 3 + 4 * 5}')()
- u'26\n'
- >>> t('$def with (limit)\nkeep $(limit)ing.')('go')
- u'keep going.\n'
- >>> t('$def with (a)\n$a.b[0]')(storage(b=[1]))
- u'1\n'
-
- Test html escaping.
-
- >>> t('$def with (x)\n$x', filename='a.html')('')
- u'<html>\n'
- >>> t('$def with (x)\n$x', filename='a.txt')('')
- u'\n'
-
- Test if, for and while.
-
- >>> t('$if 1: 1')()
- u'1\n'
- >>> t('$if 1:\n 1')()
- u'1\n'
- >>> t('$if 1:\n 1\\')()
- u'1'
- >>> t('$if 0: 0\n$elif 1: 1')()
- u'1\n'
- >>> t('$if 0: 0\n$elif None: 0\n$else: 1')()
- u'1\n'
- >>> t('$if 0 < 1 and 1 < 2: 1')()
- u'1\n'
- >>> t('$for x in [1, 2, 3]: $x')()
- u'1\n2\n3\n'
- >>> t('$def with (d)\n$for k, v in d.iteritems(): $k')({1: 1})
- u'1\n'
- >>> t('$for x in [1, 2, 3]:\n\t$x')()
- u' 1\n 2\n 3\n'
- >>> t('$def with (a)\n$while a and a.pop():1')([1, 2, 3])
- u'1\n1\n1\n'
-
- The space after : must be ignored.
-
- >>> t('$if True: foo')()
- u'foo\n'
-
- Test loop.xxx.
-
- >>> t("$for i in range(5):$loop.index, $loop.parity")()
- u'1, odd\n2, even\n3, odd\n4, even\n5, odd\n'
- >>> t("$for i in range(2):\n $for j in range(2):$loop.parent.parity $loop.parity")()
- u'odd odd\nodd even\neven odd\neven even\n'
-
- Test assignment.
-
- >>> t('$ a = 1\n$a')()
- u'1\n'
- >>> t('$ a = [1]\n$a[0]')()
- u'1\n'
- >>> t('$ a = {1: 1}\n$a.keys()[0]')()
- u'1\n'
- >>> t('$ a = []\n$if not a: 1')()
- u'1\n'
- >>> t('$ a = {}\n$if not a: 1')()
- u'1\n'
- >>> t('$ a = -1\n$a')()
- u'-1\n'
- >>> t('$ a = "1"\n$a')()
- u'1\n'
-
- Test comments.
-
- >>> t('$# 0')()
- u'\n'
- >>> t('hello$#comment1\nhello$#comment2')()
- u'hello\nhello\n'
- >>> t('$#comment0\nhello$#comment1\nhello$#comment2')()
- u'\nhello\nhello\n'
-
- Test unicode.
-
- >>> t('$def with (a)\n$a')(u'\u203d')
- u'\u203d\n'
- >>> t('$def with (a)\n$a')(u'\u203d'.encode('utf-8'))
- u'\u203d\n'
- >>> t(u'$def with (a)\n$a $:a')(u'\u203d')
- u'\u203d \u203d\n'
- >>> t(u'$def with ()\nfoo')()
- u'foo\n'
- >>> def f(x): return x
- ...
- >>> t(u'$def with (f)\n$:f("x")')(f)
- u'x\n'
- >>> t('$def with (f)\n$:f("x")')(f)
- u'x\n'
-
- Test dollar escaping.
-
- >>> t("Stop, $$money isn't evaluated.")()
- u"Stop, $money isn't evaluated.\n"
- >>> t("Stop, \$money isn't evaluated.")()
- u"Stop, $money isn't evaluated.\n"
-
- Test space sensitivity.
-
- >>> t('$def with (x)\n$x')(1)
- u'1\n'
- >>> t('$def with(x ,y)\n$x')(1, 1)
- u'1\n'
- >>> t('$(1 + 2*3 + 4)')()
- u'11\n'
-
- Make sure globals are working.
-
- >>> t('$x')()
- Traceback (most recent call last):
- ...
- NameError: global name 'x' is not defined
- >>> t('$x', globals={'x': 1})()
- u'1\n'
-
- Can't change globals.
-
- >>> t('$ x = 2\n$x', globals={'x': 1})()
- u'2\n'
- >>> t('$ x = x + 1\n$x', globals={'x': 1})()
- Traceback (most recent call last):
- ...
- UnboundLocalError: local variable 'x' referenced before assignment
-
- Make sure builtins are customizable.
-
- >>> t('$min(1, 2)')()
- u'1\n'
- >>> t('$min(1, 2)', builtins={})()
- Traceback (most recent call last):
- ...
- NameError: global name 'min' is not defined
-
- Test vars.
-
- >>> x = t('$var x: 1')()
- >>> x.x
- u'1'
- >>> x = t('$var x = 1')()
- >>> x.x
- 1
- >>> x = t('$var x: \n foo\n bar')()
- >>> x.x
- u'foo\nbar\n'
-
- Test BOM chars.
-
- >>> t('\xef\xbb\xbf$def with(x)\n$x')('foo')
- u'foo\n'
-
- Test for with weird cases.
-
- >>> t('$for i in range(10)[1:5]:\n $i')()
- u'1\n2\n3\n4\n'
- >>> t("$for k, v in {'a': 1, 'b': 2}.items():\n $k $v")()
- u'a 1\nb 2\n'
- >>> t("$for k, v in ({'a': 1, 'b': 2}.items():\n $k $v")()
- Traceback (most recent call last):
- ...
- SyntaxError: invalid syntax
-
- Test datetime.
-
- >>> import datetime
- >>> t("$def with (date)\n$date.strftime('%m %Y')")(datetime.datetime(2009, 1, 1))
- u'01 2009\n'
- """
- pass
-
-if __name__ == "__main__":
- import sys
- if '--compile' in sys.argv:
- compile_templates(sys.argv[2])
- else:
- import doctest
- doctest.testmod()
+"""
+simple, elegant templating
+(part of web.py)
+
+Template design:
+
+Template string is split into tokens and the tokens are combined into nodes.
+Parse tree is a nodelist. TextNode and ExpressionNode are simple nodes and
+for-loop, if-loop etc are block nodes, which contain multiple child nodes.
+
+Each node can emit some python string. python string emitted by the
+root node is validated for safeeval and executed using python in the given environment.
+
+Enough care is taken to make sure the generated code and the template has line to line match,
+so that the error messages can point to exact line number in template. (It doesn't work in some cases still.)
+
+Grammar:
+
+ template -> defwith sections
+ defwith -> '$def with (' arguments ')' | ''
+ sections -> section*
+ section -> block | assignment | line
+
+ assignment -> '$ '
+ line -> (text|expr)*
+ text ->
+ expr -> '$' pyexpr | '$(' pyexpr ')' | '${' pyexpr '}'
+ pyexpr ->
+"""
+
+__all__ = [
+ "Template",
+ "Render", "render", "frender",
+ "ParseError", "SecurityError",
+ "test"
+]
+
+import tokenize
+import os
+import sys
+import glob
+import re
+from UserDict import DictMixin
+import warnings
+
+from utils import storage, safeunicode, safestr, re_compile
+from webapi import config
+from net import websafe
+
+def splitline(text):
+ r"""
+ Splits the given text at newline.
+
+ >>> splitline('foo\nbar')
+ ('foo\n', 'bar')
+ >>> splitline('foo')
+ ('foo', '')
+ >>> splitline('')
+ ('', '')
+ """
+ index = text.find('\n') + 1
+ if index:
+ return text[:index], text[index:]
+ else:
+ return text, ''
+
+class Parser:
+ """Parser Base.
+ """
+ def __init__(self):
+ self.statement_nodes = STATEMENT_NODES
+ self.keywords = KEYWORDS
+
+ def parse(self, text, name=""):
+ self.text = text
+ self.name = name
+
+ defwith, text = self.read_defwith(text)
+ suite = self.read_suite(text)
+ return DefwithNode(defwith, suite)
+
+ def read_defwith(self, text):
+ if text.startswith('$def with'):
+ defwith, text = splitline(text)
+ defwith = defwith[1:].strip() # strip $ and spaces
+ return defwith, text
+ else:
+ return '', text
+
+ def read_section(self, text):
+ r"""Reads one section from the given text.
+
+ section -> block | assignment | line
+
+ >>> read_section = Parser().read_section
+ >>> read_section('foo\nbar\n')
+ (, 'bar\n')
+ >>> read_section('$ a = b + 1\nfoo\n')
+ (, 'foo\n')
+
+ read_section('$for in range(10):\n hello $i\nfoo)
+ """
+ if text.lstrip(' ').startswith('$'):
+ index = text.index('$')
+ begin_indent, text2 = text[:index], text[index+1:]
+ ahead = self.python_lookahead(text2)
+
+ if ahead == 'var':
+ return self.read_var(text2)
+ elif ahead in self.statement_nodes:
+ return self.read_block_section(text2, begin_indent)
+ elif ahead in self.keywords:
+ return self.read_keyword(text2)
+ elif ahead.strip() == '':
+ # assignments starts with a space after $
+ # ex: $ a = b + 2
+ return self.read_assignment(text2)
+ return self.readline(text)
+
+ def read_var(self, text):
+ r"""Reads a var statement.
+
+ >>> read_var = Parser().read_var
+ >>> read_var('var x=10\nfoo')
+ (, 'foo')
+ >>> read_var('var x: hello $name\nfoo')
+ (, 'foo')
+ """
+ line, text = splitline(text)
+ tokens = self.python_tokens(line)
+ if len(tokens) < 4:
+ raise SyntaxError('Invalid var statement')
+
+ name = tokens[1]
+ sep = tokens[2]
+ value = line.split(sep, 1)[1].strip()
+
+ if sep == '=':
+ pass # no need to process value
+ elif sep == ':':
+ #@@ Hack for backward-compatability
+ if tokens[3] == '\n': # multi-line var statement
+ block, text = self.read_indented_block(text, ' ')
+ lines = [self.readline(x)[0] for x in block.splitlines()]
+ nodes = []
+ for x in lines:
+ nodes.extend(x.nodes)
+ nodes.append(TextNode('\n'))
+ else: # single-line var statement
+ linenode, _ = self.readline(value)
+ nodes = linenode.nodes
+ parts = [node.emit('') for node in nodes]
+ value = "join_(%s)" % ", ".join(parts)
+ else:
+ raise SyntaxError('Invalid var statement')
+ return VarNode(name, value), text
+
+ def read_suite(self, text):
+ r"""Reads section by section till end of text.
+
+ >>> read_suite = Parser().read_suite
+ >>> read_suite('hello $name\nfoo\n')
+ [, ]
+ """
+ sections = []
+ while text:
+ section, text = self.read_section(text)
+ sections.append(section)
+ return SuiteNode(sections)
+
+ def readline(self, text):
+ r"""Reads one line from the text. Newline is supressed if the line ends with \.
+
+ >>> readline = Parser().readline
+ >>> readline('hello $name!\nbye!')
+ (, 'bye!')
+ >>> readline('hello $name!\\\nbye!')
+ (, 'bye!')
+ >>> readline('$f()\n\n')
+ (, '\n')
+ """
+ line, text = splitline(text)
+
+ # supress new line if line ends with \
+ if line.endswith('\\\n'):
+ line = line[:-2]
+
+ nodes = []
+ while line:
+ node, line = self.read_node(line)
+ nodes.append(node)
+
+ return LineNode(nodes), text
+
+ def read_node(self, text):
+ r"""Reads a node from the given text and returns the node and remaining text.
+
+ >>> read_node = Parser().read_node
+ >>> read_node('hello $name')
+ (t'hello ', '$name')
+ >>> read_node('$name')
+ ($name, '')
+ """
+ if text.startswith('$$'):
+ return TextNode('$'), text[2:]
+ elif text.startswith('$#'): # comment
+ line, text = splitline(text)
+ return TextNode('\n'), text
+ elif text.startswith('$'):
+ text = text[1:] # strip $
+ if text.startswith(':'):
+ escape = False
+ text = text[1:] # strip :
+ else:
+ escape = True
+ return self.read_expr(text, escape=escape)
+ else:
+ return self.read_text(text)
+
+ def read_text(self, text):
+ r"""Reads a text node from the given text.
+
+ >>> read_text = Parser().read_text
+ >>> read_text('hello $name')
+ (t'hello ', '$name')
+ """
+ index = text.find('$')
+ if index < 0:
+ return TextNode(text), ''
+ else:
+ return TextNode(text[:index]), text[index:]
+
+ def read_keyword(self, text):
+ line, text = splitline(text)
+ return StatementNode(line.strip() + "\n"), text
+
+ def read_expr(self, text, escape=True):
+ """Reads a python expression from the text and returns the expression and remaining text.
+
+ expr -> simple_expr | paren_expr
+ simple_expr -> id extended_expr
+ extended_expr -> attr_access | paren_expr extended_expr | ''
+ attr_access -> dot id extended_expr
+ paren_expr -> [ tokens ] | ( tokens ) | { tokens }
+
+ >>> read_expr = Parser().read_expr
+ >>> read_expr("name")
+ ($name, '')
+ >>> read_expr("a.b and c")
+ ($a.b, ' and c')
+ >>> read_expr("a. b")
+ ($a, '. b')
+ >>> read_expr("name")
+ ($name, '')
+ >>> read_expr("(limit)ing")
+ ($(limit), 'ing')
+ >>> read_expr('a[1, 2][:3].f(1+2, "weird string[).", 3 + 4) done.')
+ ($a[1, 2][:3].f(1+2, "weird string[).", 3 + 4), ' done.')
+ """
+ def simple_expr():
+ identifier()
+ extended_expr()
+
+ def identifier():
+ tokens.next()
+
+ def extended_expr():
+ lookahead = tokens.lookahead()
+ if lookahead is None:
+ return
+ elif lookahead.value == '.':
+ attr_access()
+ elif lookahead.value in parens:
+ paren_expr()
+ extended_expr()
+ else:
+ return
+
+ def attr_access():
+ from token import NAME # python token constants
+ dot = tokens.lookahead()
+ if tokens.lookahead2().type == NAME:
+ tokens.next() # consume dot
+ identifier()
+ extended_expr()
+
+ def paren_expr():
+ begin = tokens.next().value
+ end = parens[begin]
+ while True:
+ if tokens.lookahead().value in parens:
+ paren_expr()
+ else:
+ t = tokens.next()
+ if t.value == end:
+ break
+ return
+
+ parens = {
+ "(": ")",
+ "[": "]",
+ "{": "}"
+ }
+
+ def get_tokens(text):
+ """tokenize text using python tokenizer.
+ Python tokenizer ignores spaces, but they might be important in some cases.
+ This function introduces dummy space tokens when it identifies any ignored space.
+ Each token is a storage object containing type, value, begin and end.
+ """
+ readline = iter([text]).next
+ end = None
+ for t in tokenize.generate_tokens(readline):
+ t = storage(type=t[0], value=t[1], begin=t[2], end=t[3])
+ if end is not None and end != t.begin:
+ _, x1 = end
+ _, x2 = t.begin
+ yield storage(type=-1, value=text[x1:x2], begin=end, end=t.begin)
+ end = t.end
+ yield t
+
+ class BetterIter:
+ """Iterator like object with 2 support for 2 look aheads."""
+ def __init__(self, items):
+ self.iteritems = iter(items)
+ self.items = []
+ self.position = 0
+ self.current_item = None
+
+ def lookahead(self):
+ if len(self.items) <= self.position:
+ self.items.append(self._next())
+ return self.items[self.position]
+
+ def _next(self):
+ try:
+ return self.iteritems.next()
+ except StopIteration:
+ return None
+
+ def lookahead2(self):
+ if len(self.items) <= self.position+1:
+ self.items.append(self._next())
+ return self.items[self.position+1]
+
+ def next(self):
+ self.current_item = self.lookahead()
+ self.position += 1
+ return self.current_item
+
+ tokens = BetterIter(get_tokens(text))
+
+ if tokens.lookahead().value in parens:
+ paren_expr()
+ else:
+ simple_expr()
+ row, col = tokens.current_item.end
+ return ExpressionNode(text[:col], escape=escape), text[col:]
+
+ def read_assignment(self, text):
+ r"""Reads assignment statement from text.
+
+ >>> read_assignment = Parser().read_assignment
+ >>> read_assignment('a = b + 1\nfoo')
+ (, 'foo')
+ """
+ line, text = splitline(text)
+ return AssignmentNode(line.strip()), text
+
+ def python_lookahead(self, text):
+ """Returns the first python token from the given text.
+
+ >>> python_lookahead = Parser().python_lookahead
+ >>> python_lookahead('for i in range(10):')
+ 'for'
+ >>> python_lookahead('else:')
+ 'else'
+ >>> python_lookahead(' x = 1')
+ ' '
+ """
+ readline = iter([text]).next
+ tokens = tokenize.generate_tokens(readline)
+ return tokens.next()[1]
+
+ def python_tokens(self, text):
+ readline = iter([text]).next
+ tokens = tokenize.generate_tokens(readline)
+ return [t[1] for t in tokens]
+
+ def read_indented_block(self, text, indent):
+ r"""Read a block of text. A block is what typically follows a for or it statement.
+ It can be in the same line as that of the statement or an indented block.
+
+ >>> read_indented_block = Parser().read_indented_block
+ >>> read_indented_block(' a\n b\nc', ' ')
+ ('a\nb\n', 'c')
+ >>> read_indented_block(' a\n b\n c\nd', ' ')
+ ('a\n b\nc\n', 'd')
+ >>> read_indented_block(' a\n\n b\nc', ' ')
+ ('a\n\n b\n', 'c')
+ """
+ if indent == '':
+ return '', text
+
+ block = ""
+ while text:
+ line, text2 = splitline(text)
+ if line.strip() == "":
+ block += '\n'
+ elif line.startswith(indent):
+ block += line[len(indent):]
+ else:
+ break
+ text = text2
+ return block, text
+
+ def read_statement(self, text):
+ r"""Reads a python statement.
+
+ >>> read_statement = Parser().read_statement
+ >>> read_statement('for i in range(10): hello $name')
+ ('for i in range(10):', ' hello $name')
+ """
+ tok = PythonTokenizer(text)
+ tok.consume_till(':')
+ return text[:tok.index], text[tok.index:]
+
+ def read_block_section(self, text, begin_indent=''):
+ r"""
+ >>> read_block_section = Parser().read_block_section
+ >>> read_block_section('for i in range(10): hello $i\nfoo')
+ (]>, 'foo')
+ >>> read_block_section('for i in range(10):\n hello $i\n foo', begin_indent=' ')
+ (]>, ' foo')
+ >>> read_block_section('for i in range(10):\n hello $i\nfoo')
+ (]>, 'foo')
+ """
+ line, text = splitline(text)
+ stmt, line = self.read_statement(line)
+ keyword = self.python_lookahead(stmt)
+
+ # if there is some thing left in the line
+ if line.strip():
+ block = line.lstrip()
+ else:
+ def find_indent(text):
+ rx = re_compile(' +')
+ match = rx.match(text)
+ first_indent = match and match.group(0)
+ return first_indent or ""
+
+ # find the indentation of the block by looking at the first line
+ first_indent = find_indent(text)[len(begin_indent):]
+
+ #TODO: fix this special case
+ if keyword == "code":
+ indent = begin_indent + first_indent
+ else:
+ indent = begin_indent + min(first_indent, INDENT)
+
+ block, text = self.read_indented_block(text, indent)
+
+ return self.create_block_node(keyword, stmt, block, begin_indent), text
+
+ def create_block_node(self, keyword, stmt, block, begin_indent):
+ if keyword in self.statement_nodes:
+ return self.statement_nodes[keyword](stmt, block, begin_indent)
+ else:
+ raise ParseError, 'Unknown statement: %s' % repr(keyword)
+
+class PythonTokenizer:
+ """Utility wrapper over python tokenizer."""
+ def __init__(self, text):
+ self.text = text
+ readline = iter([text]).next
+ self.tokens = tokenize.generate_tokens(readline)
+ self.index = 0
+
+ def consume_till(self, delim):
+ """Consumes tokens till colon.
+
+ >>> tok = PythonTokenizer('for i in range(10): hello $i')
+ >>> tok.consume_till(':')
+ >>> tok.text[:tok.index]
+ 'for i in range(10):'
+ >>> tok.text[tok.index:]
+ ' hello $i'
+ """
+ try:
+ while True:
+ t = self.next()
+ if t.value == delim:
+ break
+ elif t.value == '(':
+ self.consume_till(')')
+ elif t.value == '[':
+ self.consume_till(']')
+ elif t.value == '{':
+ self.consume_till('}')
+
+ # if end of line is found, it is an exception.
+ # Since there is no easy way to report the line number,
+ # leave the error reporting to the python parser later
+ #@@ This should be fixed.
+ if t.value == '\n':
+ break
+ except:
+ #raise ParseError, "Expected %s, found end of line." % repr(delim)
+
+ # raising ParseError doesn't show the line number.
+ # if this error is ignored, then it will be caught when compiling the python code.
+ return
+
+ def next(self):
+ type, t, begin, end, line = self.tokens.next()
+ row, col = end
+ self.index = col
+ return storage(type=type, value=t, begin=begin, end=end)
+
+class DefwithNode:
+ def __init__(self, defwith, suite):
+ if defwith:
+ self.defwith = defwith.replace('with', '__template__') + ':'
+ # offset 4 lines. for encoding, __lineoffset__, loop and self.
+ self.defwith += "\n __lineoffset__ = -4"
+ else:
+ self.defwith = 'def __template__():'
+ # offset 4 lines for encoding, __template__, __lineoffset__, loop and self.
+ self.defwith += "\n __lineoffset__ = -5"
+
+ self.defwith += "\n loop = ForLoop()"
+ self.defwith += "\n self = TemplateResult(); extend_ = self.extend"
+ self.suite = suite
+ self.end = "\n return self"
+
+ def emit(self, indent):
+ encoding = "# coding: utf-8\n"
+ return encoding + self.defwith + self.suite.emit(indent + INDENT) + self.end
+
+ def __repr__(self):
+ return "" % (self.defwith, self.suite)
+
+class TextNode:
+ def __init__(self, value):
+ self.value = value
+
+ def emit(self, indent, begin_indent=''):
+ return repr(safeunicode(self.value))
+
+ def __repr__(self):
+ return 't' + repr(self.value)
+
+class ExpressionNode:
+ def __init__(self, value, escape=True):
+ self.value = value.strip()
+
+ # convert ${...} to $(...)
+ if value.startswith('{') and value.endswith('}'):
+ self.value = '(' + self.value[1:-1] + ')'
+
+ self.escape = escape
+
+ def emit(self, indent, begin_indent=''):
+ return 'escape_(%s, %s)' % (self.value, bool(self.escape))
+
+ def __repr__(self):
+ if self.escape:
+ escape = ''
+ else:
+ escape = ':'
+ return "$%s%s" % (escape, self.value)
+
+class AssignmentNode:
+ def __init__(self, code):
+ self.code = code
+
+ def emit(self, indent, begin_indent=''):
+ return indent + self.code + "\n"
+
+ def __repr__(self):
+ return "" % repr(self.code)
+
+class LineNode:
+ def __init__(self, nodes):
+ self.nodes = nodes
+
+ def emit(self, indent, text_indent='', name=''):
+ text = [node.emit('') for node in self.nodes]
+ if text_indent:
+ text = [repr(text_indent)] + text
+
+ return indent + "extend_([%s])\n" % ", ".join(text)
+
+ def __repr__(self):
+ return "" % repr(self.nodes)
+
+INDENT = ' ' # 4 spaces
+
+class BlockNode:
+ def __init__(self, stmt, block, begin_indent=''):
+ self.stmt = stmt
+ self.suite = Parser().read_suite(block)
+ self.begin_indent = begin_indent
+
+ def emit(self, indent, text_indent=''):
+ text_indent = self.begin_indent + text_indent
+ out = indent + self.stmt + self.suite.emit(indent + INDENT, text_indent)
+ return out
+
+ def __repr__(self):
+ return "" % (repr(self.stmt), repr(self.suite))
+
+class ForNode(BlockNode):
+ def __init__(self, stmt, block, begin_indent=''):
+ self.original_stmt = stmt
+ tok = PythonTokenizer(stmt)
+ tok.consume_till('in')
+ a = stmt[:tok.index] # for i in
+ b = stmt[tok.index:-1] # rest of for stmt excluding :
+ stmt = a + ' loop.setup(' + b.strip() + '):'
+ BlockNode.__init__(self, stmt, block, begin_indent)
+
+ def __repr__(self):
+ return "" % (repr(self.original_stmt), repr(self.suite))
+
+class CodeNode:
+ def __init__(self, stmt, block, begin_indent=''):
+ # compensate one line for $code:
+ self.code = "\n" + block
+
+ def emit(self, indent, text_indent=''):
+ import re
+ rx = re.compile('^', re.M)
+ return rx.sub(indent, self.code).rstrip(' ')
+
+ def __repr__(self):
+ return "" % repr(self.code)
+
+class StatementNode:
+ def __init__(self, stmt):
+ self.stmt = stmt
+
+ def emit(self, indent, begin_indent=''):
+ return indent + self.stmt
+
+ def __repr__(self):
+ return "" % repr(self.stmt)
+
+class IfNode(BlockNode):
+ pass
+
+class ElseNode(BlockNode):
+ pass
+
+class ElifNode(BlockNode):
+ pass
+
+class DefNode(BlockNode):
+ def __init__(self, *a, **kw):
+ BlockNode.__init__(self, *a, **kw)
+
+ code = CodeNode("", "")
+ code.code = "self = TemplateResult(); extend_ = self.extend\n"
+ self.suite.sections.insert(0, code)
+
+ code = CodeNode("", "")
+ code.code = "return self\n"
+ self.suite.sections.append(code)
+
+ def emit(self, indent, text_indent=''):
+ text_indent = self.begin_indent + text_indent
+ out = indent + self.stmt + self.suite.emit(indent + INDENT, text_indent)
+ return indent + "__lineoffset__ -= 3\n" + out
+
+class VarNode:
+ def __init__(self, name, value):
+ self.name = name
+ self.value = value
+
+ def emit(self, indent, text_indent):
+ return indent + "self[%s] = %s\n" % (repr(self.name), self.value)
+
+ def __repr__(self):
+ return "" % (self.name, self.value)
+
+class SuiteNode:
+ """Suite is a list of sections."""
+ def __init__(self, sections):
+ self.sections = sections
+
+ def emit(self, indent, text_indent=''):
+ return "\n" + "".join([s.emit(indent, text_indent) for s in self.sections])
+
+ def __repr__(self):
+ return repr(self.sections)
+
+STATEMENT_NODES = {
+ 'for': ForNode,
+ 'while': BlockNode,
+ 'if': IfNode,
+ 'elif': ElifNode,
+ 'else': ElseNode,
+ 'def': DefNode,
+ 'code': CodeNode
+}
+
+KEYWORDS = [
+ "pass",
+ "break",
+ "continue",
+ "return"
+]
+
+TEMPLATE_BUILTIN_NAMES = [
+ "dict", "enumerate", "float", "int", "bool", "list", "long", "reversed",
+ "set", "slice", "tuple", "xrange",
+ "abs", "all", "any", "callable", "chr", "cmp", "divmod", "filter", "hex",
+ "id", "isinstance", "iter", "len", "max", "min", "oct", "ord", "pow", "range",
+ "True", "False",
+ "None",
+ "__import__", # some c-libraries like datetime requires __import__ to present in the namespace
+]
+
+import __builtin__
+TEMPLATE_BUILTINS = dict([(name, getattr(__builtin__, name)) for name in TEMPLATE_BUILTIN_NAMES if name in __builtin__.__dict__])
+
+class ForLoop:
+ """
+ Wrapper for expression in for stament to support loop.xxx helpers.
+
+ >>> loop = ForLoop()
+ >>> for x in loop.setup(['a', 'b', 'c']):
+ ... print loop.index, loop.revindex, loop.parity, x
+ ...
+ 1 3 odd a
+ 2 2 even b
+ 3 1 odd c
+ >>> loop.index
+ Traceback (most recent call last):
+ ...
+ AttributeError: index
+ """
+ def __init__(self):
+ self._ctx = None
+
+ def __getattr__(self, name):
+ if self._ctx is None:
+ raise AttributeError, name
+ else:
+ return getattr(self._ctx, name)
+
+ def setup(self, seq):
+ self._push()
+ return self._ctx.setup(seq)
+
+ def _push(self):
+ self._ctx = ForLoopContext(self, self._ctx)
+
+ def _pop(self):
+ self._ctx = self._ctx.parent
+
+class ForLoopContext:
+ """Stackable context for ForLoop to support nested for loops.
+ """
+ def __init__(self, forloop, parent):
+ self._forloop = forloop
+ self.parent = parent
+
+ def setup(self, seq):
+ try:
+ self.length = len(seq)
+ except:
+ self.length = 0
+
+ self.index = 0
+ for a in seq:
+ self.index += 1
+ yield a
+ self._forloop._pop()
+
+ index0 = property(lambda self: self.index-1)
+ first = property(lambda self: self.index == 1)
+ last = property(lambda self: self.index == self.length)
+ odd = property(lambda self: self.index % 2 == 1)
+ even = property(lambda self: self.index % 2 == 0)
+ parity = property(lambda self: ['odd', 'even'][self.even])
+ revindex0 = property(lambda self: self.length - self.index)
+ revindex = property(lambda self: self.length - self.index + 1)
+
+class BaseTemplate:
+ def __init__(self, code, filename, filter, globals, builtins):
+ self.filename = filename
+ self.filter = filter
+ self._globals = globals
+ self._builtins = builtins
+ if code:
+ self.t = self._compile(code)
+ else:
+ self.t = lambda: ''
+
+ def _compile(self, code):
+ env = self.make_env(self._globals or {}, self._builtins)
+ exec(code, env)
+ return env['__template__']
+
+ def __call__(self, *a, **kw):
+ __hidetraceback__ = True
+ return self.t(*a, **kw)
+
+ def make_env(self, globals, builtins):
+ return dict(globals,
+ __builtins__=builtins,
+ ForLoop=ForLoop,
+ TemplateResult=TemplateResult,
+ escape_=self._escape,
+ join_=self._join
+ )
+ def _join(self, *items):
+ return u"".join(items)
+
+ def _escape(self, value, escape=False):
+ if value is None:
+ value = ''
+
+ value = safeunicode(value)
+ if escape and self.filter:
+ value = self.filter(value)
+ return value
+
+class Template(BaseTemplate):
+ CONTENT_TYPES = {
+ '.html' : 'text/html; charset=utf-8',
+ '.xhtml' : 'application/xhtml+xml; charset=utf-8',
+ '.txt' : 'text/plain',
+ }
+ FILTERS = {
+ '.html': websafe,
+ '.xhtml': websafe,
+ '.xml': websafe
+ }
+ globals = {}
+
+ def __init__(self, text, filename='', filter=None, globals=None, builtins=None, extensions=None):
+ self.extensions = extensions or []
+ text = Template.normalize_text(text)
+ code = self.compile_template(text, filename)
+
+ _, ext = os.path.splitext(filename)
+ filter = filter or self.FILTERS.get(ext, None)
+ self.content_type = self.CONTENT_TYPES.get(ext, None)
+
+ if globals is None:
+ globals = self.globals
+ if builtins is None:
+ builtins = TEMPLATE_BUILTINS
+
+ BaseTemplate.__init__(self, code=code, filename=filename, filter=filter, globals=globals, builtins=builtins)
+
+ def normalize_text(text):
+ """Normalizes template text by correcting \r\n, tabs and BOM chars."""
+ text = text.replace('\r\n', '\n').replace('\r', '\n').expandtabs()
+ if not text.endswith('\n'):
+ text += '\n'
+
+ # ignore BOM chars at the begining of template
+ BOM = '\xef\xbb\xbf'
+ if isinstance(text, str) and text.startswith(BOM):
+ text = text[len(BOM):]
+
+ # support fort \$ for backward-compatibility
+ text = text.replace(r'\$', '$$')
+ return text
+ normalize_text = staticmethod(normalize_text)
+
+ def __call__(self, *a, **kw):
+ __hidetraceback__ = True
+ import webapi as web
+ if 'headers' in web.ctx and self.content_type:
+ web.header('Content-Type', self.content_type, unique=True)
+
+ return BaseTemplate.__call__(self, *a, **kw)
+
+ def generate_code(text, filename, parser=None):
+ # parse the text
+ parser = parser or Parser()
+ rootnode = parser.parse(text, filename)
+
+ # generate python code from the parse tree
+ code = rootnode.emit(indent="").strip()
+ return safestr(code)
+
+ generate_code = staticmethod(generate_code)
+
+ def create_parser(self):
+ p = Parser()
+ for ext in self.extensions:
+ p = ext(p)
+ return p
+
+ def compile_template(self, template_string, filename):
+ code = Template.generate_code(template_string, filename, parser=self.create_parser())
+
+ def get_source_line(filename, lineno):
+ try:
+ lines = open(filename).read().splitlines()
+ return lines[lineno]
+ except:
+ return None
+
+ try:
+ # compile the code first to report the errors, if any, with the filename
+ compiled_code = compile(code, filename, 'exec')
+ except SyntaxError, e:
+ # display template line that caused the error along with the traceback.
+ try:
+ e.msg += '\n\nTemplate traceback:\n File %s, line %s\n %s' % \
+ (repr(e.filename), e.lineno, get_source_line(e.filename, e.lineno-1))
+ except:
+ pass
+ raise
+
+ # make sure code is safe - but not with jython, it doesn't have a working compiler module
+ if not sys.platform.startswith('java'):
+ try:
+ import compiler
+ ast = compiler.parse(code)
+ SafeVisitor().walk(ast, filename)
+ except ImportError:
+ warnings.warn("Unabled to import compiler module. Unable to check templates for safety.")
+ else:
+ warnings.warn("SECURITY ISSUE: You are using Jython, which does not support checking templates for safety. Your templates can execute arbitrary code.")
+
+ return compiled_code
+
+class CompiledTemplate(Template):
+ def __init__(self, f, filename):
+ Template.__init__(self, '', filename)
+ self.t = f
+
+ def compile_template(self, *a):
+ return None
+
+ def _compile(self, *a):
+ return None
+
+class Render:
+ """The most preferred way of using templates.
+
+ render = web.template.render('templates')
+ print render.foo()
+
+ Optional parameter can be `base` can be used to pass output of
+ every template through the base template.
+
+ render = web.template.render('templates', base='layout')
+ """
+ def __init__(self, loc='templates', cache=None, base=None, **keywords):
+ self._loc = loc
+ self._keywords = keywords
+
+ if cache is None:
+ cache = not config.get('debug', False)
+
+ if cache:
+ self._cache = {}
+ else:
+ self._cache = None
+
+ if base and not hasattr(base, '__call__'):
+ # make base a function, so that it can be passed to sub-renders
+ self._base = lambda page: self._template(base)(page)
+ else:
+ self._base = base
+
+ def _add_global(self, obj, name=None):
+ """Add a global to this rendering instance."""
+ if 'globals' not in self._keywords: self._keywords['globals'] = {}
+ if not name:
+ name = obj.__name__
+ self._keywords['globals'][name] = obj
+
+ def _lookup(self, name):
+ path = os.path.join(self._loc, name)
+ if os.path.isdir(path):
+ return 'dir', path
+ else:
+ path = self._findfile(path)
+ if path:
+ return 'file', path
+ else:
+ return 'none', None
+
+ def _load_template(self, name):
+ kind, path = self._lookup(name)
+
+ if kind == 'dir':
+ return Render(path, cache=self._cache is not None, base=self._base, **self._keywords)
+ elif kind == 'file':
+ return Template(open(path).read(), filename=path, **self._keywords)
+ else:
+ raise AttributeError, "No template named " + name
+
+ def _findfile(self, path_prefix):
+ p = [f for f in glob.glob(path_prefix + '.*') if not f.endswith('~')] # skip backup files
+ p.sort() # sort the matches for deterministic order
+ return p and p[0]
+
+ def _template(self, name):
+ if self._cache is not None:
+ if name not in self._cache:
+ self._cache[name] = self._load_template(name)
+ return self._cache[name]
+ else:
+ return self._load_template(name)
+
+ def __getattr__(self, name):
+ t = self._template(name)
+ if self._base and isinstance(t, Template):
+ def template(*a, **kw):
+ return self._base(t(*a, **kw))
+ return template
+ else:
+ return self._template(name)
+
+class GAE_Render(Render):
+ # Render gets over-written. make a copy here.
+ super = Render
+ def __init__(self, loc, *a, **kw):
+ GAE_Render.super.__init__(self, loc, *a, **kw)
+
+ import types
+ if isinstance(loc, types.ModuleType):
+ self.mod = loc
+ else:
+ name = loc.rstrip('/').replace('/', '.')
+ self.mod = __import__(name, None, None, ['x'])
+
+ self.mod.__dict__.update(kw.get('builtins', TEMPLATE_BUILTINS))
+ self.mod.__dict__.update(Template.globals)
+ self.mod.__dict__.update(kw.get('globals', {}))
+
+ def _load_template(self, name):
+ t = getattr(self.mod, name)
+ import types
+ if isinstance(t, types.ModuleType):
+ return GAE_Render(t, cache=self._cache is not None, base=self._base, **self._keywords)
+ else:
+ return t
+
+render = Render
+# setup render for Google App Engine.
+try:
+ from google import appengine
+ render = Render = GAE_Render
+except ImportError:
+ pass
+
+def frender(path, **keywords):
+ """Creates a template from the given file path.
+ """
+ return Template(open(path).read(), filename=path, **keywords)
+
+def compile_templates(root):
+ """Compiles templates to python code."""
+ re_start = re_compile('^', re.M)
+
+ for dirpath, dirnames, filenames in os.walk(root):
+ filenames = [f for f in filenames if not f.startswith('.') and not f.endswith('~') and not f.startswith('__init__.py')]
+
+ for d in dirnames[:]:
+ if d.startswith('.'):
+ dirnames.remove(d) # don't visit this dir
+
+ out = open(os.path.join(dirpath, '__init__.py'), 'w')
+ out.write('from web.template import CompiledTemplate, ForLoop, TemplateResult\n\n')
+ if dirnames:
+ out.write("import " + ", ".join(dirnames))
+ out.write("\n")
+
+ for f in filenames:
+ path = os.path.join(dirpath, f)
+
+ if '.' in f:
+ name, _ = f.split('.', 1)
+ else:
+ name = f
+
+ text = open(path).read()
+ text = Template.normalize_text(text)
+ code = Template.generate_code(text, path)
+
+ code = code.replace("__template__", name, 1)
+
+ out.write(code)
+
+ out.write('\n\n')
+ out.write('%s = CompiledTemplate(%s, %s)\n' % (name, name, repr(path)))
+ out.write("join_ = %s._join; escape_ = %s._escape\n\n" % (name, name))
+
+ # create template to make sure it compiles
+ t = Template(open(path).read(), path)
+ out.close()
+
+class ParseError(Exception):
+ pass
+
+class SecurityError(Exception):
+ """The template seems to be trying to do something naughty."""
+ pass
+
+# Enumerate all the allowed AST nodes
+ALLOWED_AST_NODES = [
+ "Add", "And",
+# "AssAttr",
+ "AssList", "AssName", "AssTuple",
+# "Assert",
+ "Assign", "AugAssign",
+# "Backquote",
+ "Bitand", "Bitor", "Bitxor", "Break",
+ "CallFunc","Class", "Compare", "Const", "Continue",
+ "Decorators", "Dict", "Discard", "Div",
+ "Ellipsis", "EmptyNode",
+# "Exec",
+ "Expression", "FloorDiv", "For",
+# "From",
+ "Function",
+ "GenExpr", "GenExprFor", "GenExprIf", "GenExprInner",
+ "Getattr",
+# "Global",
+ "If", "IfExp",
+# "Import",
+ "Invert", "Keyword", "Lambda", "LeftShift",
+ "List", "ListComp", "ListCompFor", "ListCompIf", "Mod",
+ "Module",
+ "Mul", "Name", "Not", "Or", "Pass", "Power",
+# "Print", "Printnl", "Raise",
+ "Return", "RightShift", "Slice", "Sliceobj",
+ "Stmt", "Sub", "Subscript",
+# "TryExcept", "TryFinally",
+ "Tuple", "UnaryAdd", "UnarySub",
+ "While", "With", "Yield",
+]
+
+class SafeVisitor(object):
+ """
+ Make sure code is safe by walking through the AST.
+
+ Code considered unsafe if:
+ * it has restricted AST nodes
+ * it is trying to access resricted attributes
+
+ Adopted from http://www.zafar.se/bkz/uploads/safe.txt (public domain, Babar K. Zafar)
+ """
+ def __init__(self):
+ "Initialize visitor by generating callbacks for all AST node types."
+ self.errors = []
+
+ def walk(self, ast, filename):
+ "Validate each node in AST and raise SecurityError if the code is not safe."
+ self.filename = filename
+ self.visit(ast)
+
+ if self.errors:
+ raise SecurityError, '\n'.join([str(err) for err in self.errors])
+
+ def visit(self, node, *args):
+ "Recursively validate node and all of its children."
+ def classname(obj):
+ return obj.__class__.__name__
+ nodename = classname(node)
+ fn = getattr(self, 'visit' + nodename, None)
+
+ if fn:
+ fn(node, *args)
+ else:
+ if nodename not in ALLOWED_AST_NODES:
+ self.fail(node, *args)
+
+ for child in node.getChildNodes():
+ self.visit(child, *args)
+
+ def visitName(self, node, *args):
+ "Disallow any attempts to access a restricted attr."
+ #self.assert_attr(node.getChildren()[0], node)
+ pass
+
+ def visitGetattr(self, node, *args):
+ "Disallow any attempts to access a restricted attribute."
+ self.assert_attr(node.attrname, node)
+
+ def assert_attr(self, attrname, node):
+ if self.is_unallowed_attr(attrname):
+ lineno = self.get_node_lineno(node)
+ e = SecurityError("%s:%d - access to attribute '%s' is denied" % (self.filename, lineno, attrname))
+ self.errors.append(e)
+
+ def is_unallowed_attr(self, name):
+ return name.startswith('_') \
+ or name.startswith('func_') \
+ or name.startswith('im_')
+
+ def get_node_lineno(self, node):
+ return (node.lineno) and node.lineno or 0
+
+ def fail(self, node, *args):
+ "Default callback for unallowed AST nodes."
+ lineno = self.get_node_lineno(node)
+ nodename = node.__class__.__name__
+ e = SecurityError("%s:%d - execution of '%s' statements is denied" % (self.filename, lineno, nodename))
+ self.errors.append(e)
+
+class TemplateResult(object, DictMixin):
+ """Dictionary like object for storing template output.
+
+ The result of a template execution is usally a string, but sometimes it
+ contains attributes set using $var. This class provides a simple
+ dictionary like interface for storing the output of the template and the
+ attributes. The output is stored with a special key __body__. Convering
+ the the TemplateResult to string or unicode returns the value of __body__.
+
+ When the template is in execution, the output is generated part by part
+ and those parts are combined at the end. Parts are added to the
+ TemplateResult by calling the `extend` method and the parts are combined
+ seemlessly when __body__ is accessed.
+
+ >>> d = TemplateResult(__body__='hello, world', x='foo')
+ >>> d
+
+ >>> print d
+ hello, world
+ >>> d.x
+ 'foo'
+ >>> d = TemplateResult()
+ >>> d.extend([u'hello', u'world'])
+ >>> d
+
+ """
+ def __init__(self, *a, **kw):
+ self.__dict__["_d"] = dict(*a, **kw)
+ self._d.setdefault("__body__", u'')
+
+ self.__dict__['_parts'] = []
+ self.__dict__["extend"] = self._parts.extend
+
+ self._d.setdefault("__body__", None)
+
+ def keys(self):
+ return self._d.keys()
+
+ def _prepare_body(self):
+ """Prepare value of __body__ by joining parts.
+ """
+ if self._parts:
+ value = u"".join(self._parts)
+ self._parts[:] = []
+ body = self._d.get('__body__')
+ if body:
+ self._d['__body__'] = body + value
+ else:
+ self._d['__body__'] = value
+
+ def __getitem__(self, name):
+ if name == "__body__":
+ self._prepare_body()
+ return self._d[name]
+
+ def __setitem__(self, name, value):
+ if name == "__body__":
+ self._prepare_body()
+ return self._d.__setitem__(name, value)
+
+ def __delitem__(self, name):
+ if name == "__body__":
+ self._prepare_body()
+ return self._d.__delitem__(name)
+
+ def __getattr__(self, key):
+ try:
+ return self[key]
+ except KeyError, k:
+ raise AttributeError, k
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def __delattr__(self, key):
+ try:
+ del self[key]
+ except KeyError, k:
+ raise AttributeError, k
+
+ def __unicode__(self):
+ self._prepare_body()
+ return self["__body__"]
+
+ def __str__(self):
+ self._prepare_body()
+ return self["__body__"].encode('utf-8')
+
+ def __repr__(self):
+ self._prepare_body()
+ return "" % self._d
+
+def test():
+ r"""Doctest for testing template module.
+
+ Define a utility function to run template test.
+
+ >>> class TestResult:
+ ... def __init__(self, t): self.t = t
+ ... def __getattr__(self, name): return getattr(self.t, name)
+ ... def __repr__(self): return repr(unicode(self))
+ ...
+ >>> def t(code, **keywords):
+ ... tmpl = Template(code, **keywords)
+ ... return lambda *a, **kw: TestResult(tmpl(*a, **kw))
+ ...
+
+ Simple tests.
+
+ >>> t('1')()
+ u'1\n'
+ >>> t('$def with ()\n1')()
+ u'1\n'
+ >>> t('$def with (a)\n$a')(1)
+ u'1\n'
+ >>> t('$def with (a=0)\n$a')(1)
+ u'1\n'
+ >>> t('$def with (a=0)\n$a')(a=1)
+ u'1\n'
+
+ Test complicated expressions.
+
+ >>> t('$def with (x)\n$x.upper()')('hello')
+ u'HELLO\n'
+ >>> t('$(2 * 3 + 4 * 5)')()
+ u'26\n'
+ >>> t('${2 * 3 + 4 * 5}')()
+ u'26\n'
+ >>> t('$def with (limit)\nkeep $(limit)ing.')('go')
+ u'keep going.\n'
+ >>> t('$def with (a)\n$a.b[0]')(storage(b=[1]))
+ u'1\n'
+
+ Test html escaping.
+
+ >>> t('$def with (x)\n$x', filename='a.html')('')
+ u'<html>\n'
+ >>> t('$def with (x)\n$x', filename='a.txt')('')
+ u'\n'
+
+ Test if, for and while.
+
+ >>> t('$if 1: 1')()
+ u'1\n'
+ >>> t('$if 1:\n 1')()
+ u'1\n'
+ >>> t('$if 1:\n 1\\')()
+ u'1'
+ >>> t('$if 0: 0\n$elif 1: 1')()
+ u'1\n'
+ >>> t('$if 0: 0\n$elif None: 0\n$else: 1')()
+ u'1\n'
+ >>> t('$if 0 < 1 and 1 < 2: 1')()
+ u'1\n'
+ >>> t('$for x in [1, 2, 3]: $x')()
+ u'1\n2\n3\n'
+ >>> t('$def with (d)\n$for k, v in d.iteritems(): $k')({1: 1})
+ u'1\n'
+ >>> t('$for x in [1, 2, 3]:\n\t$x')()
+ u' 1\n 2\n 3\n'
+ >>> t('$def with (a)\n$while a and a.pop():1')([1, 2, 3])
+ u'1\n1\n1\n'
+
+ The space after : must be ignored.
+
+ >>> t('$if True: foo')()
+ u'foo\n'
+
+ Test loop.xxx.
+
+ >>> t("$for i in range(5):$loop.index, $loop.parity")()
+ u'1, odd\n2, even\n3, odd\n4, even\n5, odd\n'
+ >>> t("$for i in range(2):\n $for j in range(2):$loop.parent.parity $loop.parity")()
+ u'odd odd\nodd even\neven odd\neven even\n'
+
+ Test assignment.
+
+ >>> t('$ a = 1\n$a')()
+ u'1\n'
+ >>> t('$ a = [1]\n$a[0]')()
+ u'1\n'
+ >>> t('$ a = {1: 1}\n$a.keys()[0]')()
+ u'1\n'
+ >>> t('$ a = []\n$if not a: 1')()
+ u'1\n'
+ >>> t('$ a = {}\n$if not a: 1')()
+ u'1\n'
+ >>> t('$ a = -1\n$a')()
+ u'-1\n'
+ >>> t('$ a = "1"\n$a')()
+ u'1\n'
+
+ Test comments.
+
+ >>> t('$# 0')()
+ u'\n'
+ >>> t('hello$#comment1\nhello$#comment2')()
+ u'hello\nhello\n'
+ >>> t('$#comment0\nhello$#comment1\nhello$#comment2')()
+ u'\nhello\nhello\n'
+
+ Test unicode.
+
+ >>> t('$def with (a)\n$a')(u'\u203d')
+ u'\u203d\n'
+ >>> t('$def with (a)\n$a')(u'\u203d'.encode('utf-8'))
+ u'\u203d\n'
+ >>> t(u'$def with (a)\n$a $:a')(u'\u203d')
+ u'\u203d \u203d\n'
+ >>> t(u'$def with ()\nfoo')()
+ u'foo\n'
+ >>> def f(x): return x
+ ...
+ >>> t(u'$def with (f)\n$:f("x")')(f)
+ u'x\n'
+ >>> t('$def with (f)\n$:f("x")')(f)
+ u'x\n'
+
+ Test dollar escaping.
+
+ >>> t("Stop, $$money isn't evaluated.")()
+ u"Stop, $money isn't evaluated.\n"
+ >>> t("Stop, \$money isn't evaluated.")()
+ u"Stop, $money isn't evaluated.\n"
+
+ Test space sensitivity.
+
+ >>> t('$def with (x)\n$x')(1)
+ u'1\n'
+ >>> t('$def with(x ,y)\n$x')(1, 1)
+ u'1\n'
+ >>> t('$(1 + 2*3 + 4)')()
+ u'11\n'
+
+ Make sure globals are working.
+
+ >>> t('$x')()
+ Traceback (most recent call last):
+ ...
+ NameError: global name 'x' is not defined
+ >>> t('$x', globals={'x': 1})()
+ u'1\n'
+
+ Can't change globals.
+
+ >>> t('$ x = 2\n$x', globals={'x': 1})()
+ u'2\n'
+ >>> t('$ x = x + 1\n$x', globals={'x': 1})()
+ Traceback (most recent call last):
+ ...
+ UnboundLocalError: local variable 'x' referenced before assignment
+
+ Make sure builtins are customizable.
+
+ >>> t('$min(1, 2)')()
+ u'1\n'
+ >>> t('$min(1, 2)', builtins={})()
+ Traceback (most recent call last):
+ ...
+ NameError: global name 'min' is not defined
+
+ Test vars.
+
+ >>> x = t('$var x: 1')()
+ >>> x.x
+ u'1'
+ >>> x = t('$var x = 1')()
+ >>> x.x
+ 1
+ >>> x = t('$var x: \n foo\n bar')()
+ >>> x.x
+ u'foo\nbar\n'
+
+ Test BOM chars.
+
+ >>> t('\xef\xbb\xbf$def with(x)\n$x')('foo')
+ u'foo\n'
+
+ Test for with weird cases.
+
+ >>> t('$for i in range(10)[1:5]:\n $i')()
+ u'1\n2\n3\n4\n'
+ >>> t("$for k, v in {'a': 1, 'b': 2}.items():\n $k $v")()
+ u'a 1\nb 2\n'
+ >>> t("$for k, v in ({'a': 1, 'b': 2}.items():\n $k $v")()
+ Traceback (most recent call last):
+ ...
+ SyntaxError: invalid syntax
+
+ Test datetime.
+
+ >>> import datetime
+ >>> t("$def with (date)\n$date.strftime('%m %Y')")(datetime.datetime(2009, 1, 1))
+ u'01 2009\n'
+ """
+ pass
+
+if __name__ == "__main__":
+ import sys
+ if '--compile' in sys.argv:
+ compile_templates(sys.argv[2])
+ else:
+ import doctest
+ doctest.testmod()
diff --git a/web/test.py b/web/test.py
index a942a91..7ee41d0 100644
--- a/web/test.py
+++ b/web/test.py
@@ -1,51 +1,51 @@
-"""test utilities
-(part of web.py)
-"""
-import unittest
-import sys, os
-import web
-
-TestCase = unittest.TestCase
-TestSuite = unittest.TestSuite
-
-def load_modules(names):
- return [__import__(name, None, None, "x") for name in names]
-
-def module_suite(module, classnames=None):
- """Makes a suite from a module."""
- if classnames:
- return unittest.TestLoader().loadTestsFromNames(classnames, module)
- elif hasattr(module, 'suite'):
- return module.suite()
- else:
- return unittest.TestLoader().loadTestsFromModule(module)
-
-def doctest_suite(module_names):
- """Makes a test suite from doctests."""
- import doctest
- suite = TestSuite()
- for mod in load_modules(module_names):
- suite.addTest(doctest.DocTestSuite(mod))
- return suite
-
-def suite(module_names):
- """Creates a suite from multiple modules."""
- suite = TestSuite()
- for mod in load_modules(module_names):
- suite.addTest(module_suite(mod))
- return suite
-
-def runTests(suite):
- runner = unittest.TextTestRunner()
- return runner.run(suite)
-
-def main(suite=None):
- if not suite:
- main_module = __import__('__main__')
- # allow command line switches
- args = [a for a in sys.argv[1:] if not a.startswith('-')]
- suite = module_suite(main_module, args or None)
-
- result = runTests(suite)
- sys.exit(not result.wasSuccessful())
-
+"""test utilities
+(part of web.py)
+"""
+import unittest
+import sys, os
+import web
+
+TestCase = unittest.TestCase
+TestSuite = unittest.TestSuite
+
+def load_modules(names):
+ return [__import__(name, None, None, "x") for name in names]
+
+def module_suite(module, classnames=None):
+ """Makes a suite from a module."""
+ if classnames:
+ return unittest.TestLoader().loadTestsFromNames(classnames, module)
+ elif hasattr(module, 'suite'):
+ return module.suite()
+ else:
+ return unittest.TestLoader().loadTestsFromModule(module)
+
+def doctest_suite(module_names):
+ """Makes a test suite from doctests."""
+ import doctest
+ suite = TestSuite()
+ for mod in load_modules(module_names):
+ suite.addTest(doctest.DocTestSuite(mod))
+ return suite
+
+def suite(module_names):
+ """Creates a suite from multiple modules."""
+ suite = TestSuite()
+ for mod in load_modules(module_names):
+ suite.addTest(module_suite(mod))
+ return suite
+
+def runTests(suite):
+ runner = unittest.TextTestRunner()
+ return runner.run(suite)
+
+def main(suite=None):
+ if not suite:
+ main_module = __import__('__main__')
+ # allow command line switches
+ args = [a for a in sys.argv[1:] if not a.startswith('-')]
+ suite = module_suite(main_module, args or None)
+
+ result = runTests(suite)
+ sys.exit(not result.wasSuccessful())
+
diff --git a/web/utils.py b/web/utils.py
index d5f4154..51db96c 100644
--- a/web/utils.py
+++ b/web/utils.py
@@ -1,1526 +1,1526 @@
-#!/usr/bin/env python
-"""
-General Utilities
-(part of web.py)
-"""
-
-__all__ = [
- "Storage", "storage", "storify",
- "Counter", "counter",
- "iters",
- "rstrips", "lstrips", "strips",
- "safeunicode", "safestr", "utf8",
- "TimeoutError", "timelimit",
- "Memoize", "memoize",
- "re_compile", "re_subm",
- "group", "uniq", "iterview",
- "IterBetter", "iterbetter",
- "safeiter", "safewrite",
- "dictreverse", "dictfind", "dictfindall", "dictincr", "dictadd",
- "requeue", "restack",
- "listget", "intget", "datestr",
- "numify", "denumify", "commify", "dateify",
- "nthstr", "cond",
- "CaptureStdout", "capturestdout", "Profile", "profile",
- "tryall",
- "ThreadedDict", "threadeddict",
- "autoassign",
- "to36",
- "safemarkdown",
- "sendmail"
-]
-
-import re, sys, time, threading, itertools, traceback, os
-
-try:
- import subprocess
-except ImportError:
- subprocess = None
-
-try: import datetime
-except ImportError: pass
-
-try: set
-except NameError:
- from sets import Set as set
-
-try:
- from threading import local as threadlocal
-except ImportError:
- from python23 import threadlocal
-
-class Storage(dict):
- """
- A Storage object is like a dictionary except `obj.foo` can be used
- in addition to `obj['foo']`.
-
- >>> o = storage(a=1)
- >>> o.a
- 1
- >>> o['a']
- 1
- >>> o.a = 2
- >>> o['a']
- 2
- >>> del o.a
- >>> o.a
- Traceback (most recent call last):
- ...
- AttributeError: 'a'
-
- """
- def __getattr__(self, key):
- try:
- return self[key]
- except KeyError, k:
- raise AttributeError, k
-
- def __setattr__(self, key, value):
- self[key] = value
-
- def __delattr__(self, key):
- try:
- del self[key]
- except KeyError, k:
- raise AttributeError, k
-
- def __repr__(self):
- return ''
-
-storage = Storage
-
-def storify(mapping, *requireds, **defaults):
- """
- Creates a `storage` object from dictionary `mapping`, raising `KeyError` if
- d doesn't have all of the keys in `requireds` and using the default
- values for keys found in `defaults`.
-
- For example, `storify({'a':1, 'c':3}, b=2, c=0)` will return the equivalent of
- `storage({'a':1, 'b':2, 'c':3})`.
-
- If a `storify` value is a list (e.g. multiple values in a form submission),
- `storify` returns the last element of the list, unless the key appears in
- `defaults` as a list. Thus:
-
- >>> storify({'a':[1, 2]}).a
- 2
- >>> storify({'a':[1, 2]}, a=[]).a
- [1, 2]
- >>> storify({'a':1}, a=[]).a
- [1]
- >>> storify({}, a=[]).a
- []
-
- Similarly, if the value has a `value` attribute, `storify will return _its_
- value, unless the key appears in `defaults` as a dictionary.
-
- >>> storify({'a':storage(value=1)}).a
- 1
- >>> storify({'a':storage(value=1)}, a={}).a
-
- >>> storify({}, a={}).a
- {}
-
- Optionally, keyword parameter `_unicode` can be passed to convert all values to unicode.
-
- >>> storify({'x': 'a'}, _unicode=True)
-
- >>> storify({'x': storage(value='a')}, x={}, _unicode=True)
- }>
- >>> storify({'x': storage(value='a')}, _unicode=True)
-
- """
- _unicode = defaults.pop('_unicode', False)
-
- # if _unicode is callable object, use it convert a string to unicode.
- to_unicode = safeunicode
- if _unicode is not False and hasattr(_unicode, "__call__"):
- to_unicode = _unicode
-
- def unicodify(s):
- if _unicode and isinstance(s, str): return to_unicode(s)
- else: return s
-
- def getvalue(x):
- if hasattr(x, 'file') and hasattr(x, 'value'):
- return x.value
- elif hasattr(x, 'value'):
- return unicodify(x.value)
- else:
- return unicodify(x)
-
- stor = Storage()
- for key in requireds + tuple(mapping.keys()):
- value = mapping[key]
- if isinstance(value, list):
- if isinstance(defaults.get(key), list):
- value = [getvalue(x) for x in value]
- else:
- value = value[-1]
- if not isinstance(defaults.get(key), dict):
- value = getvalue(value)
- if isinstance(defaults.get(key), list) and not isinstance(value, list):
- value = [value]
- setattr(stor, key, value)
-
- for (key, value) in defaults.iteritems():
- result = value
- if hasattr(stor, key):
- result = stor[key]
- if value == () and not isinstance(result, tuple):
- result = (result,)
- setattr(stor, key, result)
-
- return stor
-
-class Counter(storage):
- """Keeps count of how many times something is added.
-
- >>> c = counter()
- >>> c.add('x')
- >>> c.add('x')
- >>> c.add('x')
- >>> c.add('x')
- >>> c.add('x')
- >>> c.add('y')
- >>> c
-
- >>> c.most()
- ['x']
- """
- def add(self, n):
- self.setdefault(n, 0)
- self[n] += 1
-
- def most(self):
- """Returns the keys with maximum count."""
- m = max(self.itervalues())
- return [k for k, v in self.iteritems() if v == m]
-
- def least(self):
- """Returns the keys with mininum count."""
- m = min(self.itervalues())
- return [k for k, v in self.iteritems() if v == m]
-
- def percent(self, key):
- """Returns what percentage a certain key is of all entries.
-
- >>> c = counter()
- >>> c.add('x')
- >>> c.add('x')
- >>> c.add('x')
- >>> c.add('y')
- >>> c.percent('x')
- 0.75
- >>> c.percent('y')
- 0.25
- """
- return float(self[key])/sum(self.values())
-
- def sorted_keys(self):
- """Returns keys sorted by value.
-
- >>> c = counter()
- >>> c.add('x')
- >>> c.add('x')
- >>> c.add('y')
- >>> c.sorted_keys()
- ['x', 'y']
- """
- return sorted(self.keys(), key=lambda k: self[k], reverse=True)
-
- def sorted_values(self):
- """Returns values sorted by value.
-
- >>> c = counter()
- >>> c.add('x')
- >>> c.add('x')
- >>> c.add('y')
- >>> c.sorted_values()
- [2, 1]
- """
- return [self[k] for k in self.sorted_keys()]
-
- def sorted_items(self):
- """Returns items sorted by value.
-
- >>> c = counter()
- >>> c.add('x')
- >>> c.add('x')
- >>> c.add('y')
- >>> c.sorted_items()
- [('x', 2), ('y', 1)]
- """
- return [(k, self[k]) for k in self.sorted_keys()]
-
- def __repr__(self):
- return ''
-
-counter = Counter
-
-iters = [list, tuple]
-import __builtin__
-if hasattr(__builtin__, 'set'):
- iters.append(set)
-if hasattr(__builtin__, 'frozenset'):
- iters.append(set)
-if sys.version_info < (2,6): # sets module deprecated in 2.6
- try:
- from sets import Set
- iters.append(Set)
- except ImportError:
- pass
-
-class _hack(tuple): pass
-iters = _hack(iters)
-iters.__doc__ = """
-A list of iterable items (like lists, but not strings). Includes whichever
-of lists, tuples, sets, and Sets are available in this version of Python.
-"""
-
-def _strips(direction, text, remove):
- if isinstance(remove, iters):
- for subr in remove:
- text = _strips(direction, text, subr)
- return text
-
- if direction == 'l':
- if text.startswith(remove):
- return text[len(remove):]
- elif direction == 'r':
- if text.endswith(remove):
- return text[:-len(remove)]
- else:
- raise ValueError, "Direction needs to be r or l."
- return text
-
-def rstrips(text, remove):
- """
- removes the string `remove` from the right of `text`
-
- >>> rstrips("foobar", "bar")
- 'foo'
-
- """
- return _strips('r', text, remove)
-
-def lstrips(text, remove):
- """
- removes the string `remove` from the left of `text`
-
- >>> lstrips("foobar", "foo")
- 'bar'
- >>> lstrips('http://foo.org/', ['http://', 'https://'])
- 'foo.org/'
- >>> lstrips('FOOBARBAZ', ['FOO', 'BAR'])
- 'BAZ'
- >>> lstrips('FOOBARBAZ', ['BAR', 'FOO'])
- 'BARBAZ'
-
- """
- return _strips('l', text, remove)
-
-def strips(text, remove):
- """
- removes the string `remove` from the both sides of `text`
-
- >>> strips("foobarfoo", "foo")
- 'bar'
-
- """
- return rstrips(lstrips(text, remove), remove)
-
-def safeunicode(obj, encoding='utf-8'):
- r"""
- Converts any given object to unicode string.
-
- >>> safeunicode('hello')
- u'hello'
- >>> safeunicode(2)
- u'2'
- >>> safeunicode('\xe1\x88\xb4')
- u'\u1234'
- """
- t = type(obj)
- if t is unicode:
- return obj
- elif t is str:
- return obj.decode(encoding)
- elif t in [int, float, bool]:
- return unicode(obj)
- elif hasattr(obj, '__unicode__') or isinstance(obj, unicode):
- return unicode(obj)
- else:
- return str(obj).decode(encoding)
-
-def safestr(obj, encoding='utf-8'):
- r"""
- Converts any given object to utf-8 encoded string.
-
- >>> safestr('hello')
- 'hello'
- >>> safestr(u'\u1234')
- '\xe1\x88\xb4'
- >>> safestr(2)
- '2'
- """
- if isinstance(obj, unicode):
- return obj.encode(encoding)
- elif isinstance(obj, str):
- return obj
- elif hasattr(obj, 'next'): # iterator
- return itertools.imap(safestr, obj)
- else:
- return str(obj)
-
-# for backward-compatibility
-utf8 = safestr
-
-class TimeoutError(Exception): pass
-def timelimit(timeout):
- """
- A decorator to limit a function to `timeout` seconds, raising `TimeoutError`
- if it takes longer.
-
- >>> import time
- >>> def meaningoflife():
- ... time.sleep(.2)
- ... return 42
- >>>
- >>> timelimit(.1)(meaningoflife)()
- Traceback (most recent call last):
- ...
- TimeoutError: took too long
- >>> timelimit(1)(meaningoflife)()
- 42
-
- _Caveat:_ The function isn't stopped after `timeout` seconds but continues
- executing in a separate thread. (There seems to be no way to kill a thread.)
-
- inspired by
- """
- def _1(function):
- def _2(*args, **kw):
- class Dispatch(threading.Thread):
- def __init__(self):
- threading.Thread.__init__(self)
- self.result = None
- self.error = None
-
- self.setDaemon(True)
- self.start()
-
- def run(self):
- try:
- self.result = function(*args, **kw)
- except:
- self.error = sys.exc_info()
-
- c = Dispatch()
- c.join(timeout)
- if c.isAlive():
- raise TimeoutError, 'took too long'
- if c.error:
- raise c.error[0], c.error[1]
- return c.result
- return _2
- return _1
-
-class Memoize:
- """
- 'Memoizes' a function, caching its return values for each input.
- If `expires` is specified, values are recalculated after `expires` seconds.
- If `background` is specified, values are recalculated in a separate thread.
-
- >>> calls = 0
- >>> def howmanytimeshaveibeencalled():
- ... global calls
- ... calls += 1
- ... return calls
- >>> fastcalls = memoize(howmanytimeshaveibeencalled)
- >>> howmanytimeshaveibeencalled()
- 1
- >>> howmanytimeshaveibeencalled()
- 2
- >>> fastcalls()
- 3
- >>> fastcalls()
- 3
- >>> import time
- >>> fastcalls = memoize(howmanytimeshaveibeencalled, .1, background=False)
- >>> fastcalls()
- 4
- >>> fastcalls()
- 4
- >>> time.sleep(.2)
- >>> fastcalls()
- 5
- >>> def slowfunc():
- ... time.sleep(.1)
- ... return howmanytimeshaveibeencalled()
- >>> fastcalls = memoize(slowfunc, .2, background=True)
- >>> fastcalls()
- 6
- >>> timelimit(.05)(fastcalls)()
- 6
- >>> time.sleep(.2)
- >>> timelimit(.05)(fastcalls)()
- 6
- >>> timelimit(.05)(fastcalls)()
- 6
- >>> time.sleep(.2)
- >>> timelimit(.05)(fastcalls)()
- 7
- >>> fastcalls = memoize(slowfunc, None, background=True)
- >>> threading.Thread(target=fastcalls).start()
- >>> time.sleep(.01)
- >>> fastcalls()
- 9
- """
- def __init__(self, func, expires=None, background=True):
- self.func = func
- self.cache = {}
- self.expires = expires
- self.background = background
- self.running = {}
-
- def __call__(self, *args, **keywords):
- key = (args, tuple(keywords.items()))
- if not self.running.get(key):
- self.running[key] = threading.Lock()
- def update(block=False):
- if self.running[key].acquire(block):
- try:
- self.cache[key] = (self.func(*args, **keywords), time.time())
- finally:
- self.running[key].release()
-
- if key not in self.cache:
- update(block=True)
- elif self.expires and (time.time() - self.cache[key][1]) > self.expires:
- if self.background:
- threading.Thread(target=update).start()
- else:
- update()
- return self.cache[key][0]
-
-memoize = Memoize
-
-re_compile = memoize(re.compile) #@@ threadsafe?
-re_compile.__doc__ = """
-A memoized version of re.compile.
-"""
-
-class _re_subm_proxy:
- def __init__(self):
- self.match = None
- def __call__(self, match):
- self.match = match
- return ''
-
-def re_subm(pat, repl, string):
- """
- Like re.sub, but returns the replacement _and_ the match object.
-
- >>> t, m = re_subm('g(oo+)fball', r'f\\1lish', 'goooooofball')
- >>> t
- 'foooooolish'
- >>> m.groups()
- ('oooooo',)
- """
- compiled_pat = re_compile(pat)
- proxy = _re_subm_proxy()
- compiled_pat.sub(proxy.__call__, string)
- return compiled_pat.sub(repl, string), proxy.match
-
-def group(seq, size):
- """
- Returns an iterator over a series of lists of length size from iterable.
-
- >>> list(group([1,2,3,4], 2))
- [[1, 2], [3, 4]]
- >>> list(group([1,2,3,4,5], 2))
- [[1, 2], [3, 4], [5]]
- """
- def take(seq, n):
- for i in xrange(n):
- yield seq.next()
-
- if not hasattr(seq, 'next'):
- seq = iter(seq)
- while True:
- x = list(take(seq, size))
- if x:
- yield x
- else:
- break
-
-def uniq(seq, key=None):
- """
- Removes duplicate elements from a list while preserving the order of the rest.
-
- >>> uniq([9,0,2,1,0])
- [9, 0, 2, 1]
-
- The value of the optional `key` parameter should be a function that
- takes a single argument and returns a key to test the uniqueness.
-
- >>> uniq(["Foo", "foo", "bar"], key=lambda s: s.lower())
- ['Foo', 'bar']
- """
- key = key or (lambda x: x)
- seen = set()
- result = []
- for v in seq:
- k = key(v)
- if k in seen:
- continue
- seen.add(k)
- result.append(v)
- return result
-
-def iterview(x):
- """
- Takes an iterable `x` and returns an iterator over it
- which prints its progress to stderr as it iterates through.
- """
- WIDTH = 70
-
- def plainformat(n, lenx):
- return '%5.1f%% (%*d/%d)' % ((float(n)/lenx)*100, len(str(lenx)), n, lenx)
-
- def bars(size, n, lenx):
- val = int((float(n)*size)/lenx + 0.5)
- if size - val:
- spacing = ">" + (" "*(size-val))[1:]
- else:
- spacing = ""
- return "[%s%s]" % ("="*val, spacing)
-
- def eta(elapsed, n, lenx):
- if n == 0:
- return '--:--:--'
- if n == lenx:
- secs = int(elapsed)
- else:
- secs = int((elapsed/n) * (lenx-n))
- mins, secs = divmod(secs, 60)
- hrs, mins = divmod(mins, 60)
-
- return '%02d:%02d:%02d' % (hrs, mins, secs)
-
- def format(starttime, n, lenx):
- out = plainformat(n, lenx) + ' '
- if n == lenx:
- end = ' '
- else:
- end = ' ETA '
- end += eta(time.time() - starttime, n, lenx)
- out += bars(WIDTH - len(out) - len(end), n, lenx)
- out += end
- return out
-
- starttime = time.time()
- lenx = len(x)
- for n, y in enumerate(x):
- sys.stderr.write('\r' + format(starttime, n, lenx))
- yield y
- sys.stderr.write('\r' + format(starttime, n+1, lenx) + '\n')
-
-class IterBetter:
- """
- Returns an object that can be used as an iterator
- but can also be used via __getitem__ (although it
- cannot go backwards -- that is, you cannot request
- `iterbetter[0]` after requesting `iterbetter[1]`).
-
- >>> import itertools
- >>> c = iterbetter(itertools.count())
- >>> c[1]
- 1
- >>> c[5]
- 5
- >>> c[3]
- Traceback (most recent call last):
- ...
- IndexError: already passed 3
-
- For boolean test, IterBetter peeps at first value in the itertor without effecting the iteration.
-
- >>> c = iterbetter(iter(range(5)))
- >>> bool(c)
- True
- >>> list(c)
- [0, 1, 2, 3, 4]
- >>> c = iterbetter(iter([]))
- >>> bool(c)
- False
- >>> list(c)
- []
- """
- def __init__(self, iterator):
- self.i, self.c = iterator, 0
-
- def __iter__(self):
- if hasattr(self, "_head"):
- yield self._head
-
- while 1:
- yield self.i.next()
- self.c += 1
-
- def __getitem__(self, i):
- #todo: slices
- if i < self.c:
- raise IndexError, "already passed "+str(i)
- try:
- while i > self.c:
- self.i.next()
- self.c += 1
- # now self.c == i
- self.c += 1
- return self.i.next()
- except StopIteration:
- raise IndexError, str(i)
-
- def __nonzero__(self):
- if hasattr(self, "__len__"):
- return len(self) != 0
- elif hasattr(self, "_head"):
- return True
- else:
- try:
- self._head = self.i.next()
- except StopIteration:
- return False
- else:
- return True
-
-iterbetter = IterBetter
-
-def safeiter(it, cleanup=None, ignore_errors=True):
- """Makes an iterator safe by ignoring the exceptions occured during the iteration.
- """
- def next():
- while True:
- try:
- return it.next()
- except StopIteration:
- raise
- except:
- traceback.print_exc()
-
- it = iter(it)
- while True:
- yield next()
-
-def safewrite(filename, content):
- """Writes the content to a temp file and then moves the temp file to
- given filename to avoid overwriting the existing file in case of errors.
- """
- f = file(filename + '.tmp', 'w')
- f.write(content)
- f.close()
- os.rename(f.name, filename)
-
-def dictreverse(mapping):
- """
- Returns a new dictionary with keys and values swapped.
-
- >>> dictreverse({1: 2, 3: 4})
- {2: 1, 4: 3}
- """
- return dict([(value, key) for (key, value) in mapping.iteritems()])
-
-def dictfind(dictionary, element):
- """
- Returns a key whose value in `dictionary` is `element`
- or, if none exists, None.
-
- >>> d = {1:2, 3:4}
- >>> dictfind(d, 4)
- 3
- >>> dictfind(d, 5)
- """
- for (key, value) in dictionary.iteritems():
- if element is value:
- return key
-
-def dictfindall(dictionary, element):
- """
- Returns the keys whose values in `dictionary` are `element`
- or, if none exists, [].
-
- >>> d = {1:4, 3:4}
- >>> dictfindall(d, 4)
- [1, 3]
- >>> dictfindall(d, 5)
- []
- """
- res = []
- for (key, value) in dictionary.iteritems():
- if element is value:
- res.append(key)
- return res
-
-def dictincr(dictionary, element):
- """
- Increments `element` in `dictionary`,
- setting it to one if it doesn't exist.
-
- >>> d = {1:2, 3:4}
- >>> dictincr(d, 1)
- 3
- >>> d[1]
- 3
- >>> dictincr(d, 5)
- 1
- >>> d[5]
- 1
- """
- dictionary.setdefault(element, 0)
- dictionary[element] += 1
- return dictionary[element]
-
-def dictadd(*dicts):
- """
- Returns a dictionary consisting of the keys in the argument dictionaries.
- If they share a key, the value from the last argument is used.
-
- >>> dictadd({1: 0, 2: 0}, {2: 1, 3: 1})
- {1: 0, 2: 1, 3: 1}
- """
- result = {}
- for dct in dicts:
- result.update(dct)
- return result
-
-def requeue(queue, index=-1):
- """Returns the element at index after moving it to the beginning of the queue.
-
- >>> x = [1, 2, 3, 4]
- >>> requeue(x)
- 4
- >>> x
- [4, 1, 2, 3]
- """
- x = queue.pop(index)
- queue.insert(0, x)
- return x
-
-def restack(stack, index=0):
- """Returns the element at index after moving it to the top of stack.
-
- >>> x = [1, 2, 3, 4]
- >>> restack(x)
- 1
- >>> x
- [2, 3, 4, 1]
- """
- x = stack.pop(index)
- stack.append(x)
- return x
-
-def listget(lst, ind, default=None):
- """
- Returns `lst[ind]` if it exists, `default` otherwise.
-
- >>> listget(['a'], 0)
- 'a'
- >>> listget(['a'], 1)
- >>> listget(['a'], 1, 'b')
- 'b'
- """
- if len(lst)-1 < ind:
- return default
- return lst[ind]
-
-def intget(integer, default=None):
- """
- Returns `integer` as an int or `default` if it can't.
-
- >>> intget('3')
- 3
- >>> intget('3a')
- >>> intget('3a', 0)
- 0
- """
- try:
- return int(integer)
- except (TypeError, ValueError):
- return default
-
-def datestr(then, now=None):
- """
- Converts a (UTC) datetime object to a nice string representation.
-
- >>> from datetime import datetime, timedelta
- >>> d = datetime(1970, 5, 1)
- >>> datestr(d, now=d)
- '0 microseconds ago'
- >>> for t, v in {
- ... timedelta(microseconds=1): '1 microsecond ago',
- ... timedelta(microseconds=2): '2 microseconds ago',
- ... -timedelta(microseconds=1): '1 microsecond from now',
- ... -timedelta(microseconds=2): '2 microseconds from now',
- ... timedelta(microseconds=2000): '2 milliseconds ago',
- ... timedelta(seconds=2): '2 seconds ago',
- ... timedelta(seconds=2*60): '2 minutes ago',
- ... timedelta(seconds=2*60*60): '2 hours ago',
- ... timedelta(days=2): '2 days ago',
- ... }.iteritems():
- ... assert datestr(d, now=d+t) == v
- >>> datestr(datetime(1970, 1, 1), now=d)
- 'January 1'
- >>> datestr(datetime(1969, 1, 1), now=d)
- 'January 1, 1969'
- >>> datestr(datetime(1970, 6, 1), now=d)
- 'June 1, 1970'
- >>> datestr(None)
- ''
- """
- def agohence(n, what, divisor=None):
- if divisor: n = n // divisor
-
- out = str(abs(n)) + ' ' + what # '2 day'
- if abs(n) != 1: out += 's' # '2 days'
- out += ' ' # '2 days '
- if n < 0:
- out += 'from now'
- else:
- out += 'ago'
- return out # '2 days ago'
-
- oneday = 24 * 60 * 60
-
- if not then: return ""
- if not now: now = datetime.datetime.utcnow()
- if type(now).__name__ == "DateTime":
- now = datetime.datetime.fromtimestamp(now)
- if type(then).__name__ == "DateTime":
- then = datetime.datetime.fromtimestamp(then)
- elif type(then).__name__ == "date":
- then = datetime.datetime(then.year, then.month, then.day)
-
- delta = now - then
- deltaseconds = int(delta.days * oneday + delta.seconds + delta.microseconds * 1e-06)
- deltadays = abs(deltaseconds) // oneday
- if deltaseconds < 0: deltadays *= -1 # fix for oddity of floor
-
- if deltadays:
- if abs(deltadays) < 4:
- return agohence(deltadays, 'day')
-
- try:
- out = then.strftime('%B %e') # e.g. 'June 3'
- except ValueError:
- # %e doesn't work on Windows.
- out = then.strftime('%B %d') # e.g. 'June 03'
-
- if then.year != now.year or deltadays < 0:
- out += ', %s' % then.year
- return out
-
- if int(deltaseconds):
- if abs(deltaseconds) > (60 * 60):
- return agohence(deltaseconds, 'hour', 60 * 60)
- elif abs(deltaseconds) > 60:
- return agohence(deltaseconds, 'minute', 60)
- else:
- return agohence(deltaseconds, 'second')
-
- deltamicroseconds = delta.microseconds
- if delta.days: deltamicroseconds = int(delta.microseconds - 1e6) # datetime oddity
- if abs(deltamicroseconds) > 1000:
- return agohence(deltamicroseconds, 'millisecond', 1000)
-
- return agohence(deltamicroseconds, 'microsecond')
-
-def numify(string):
- """
- Removes all non-digit characters from `string`.
-
- >>> numify('800-555-1212')
- '8005551212'
- >>> numify('800.555.1212')
- '8005551212'
-
- """
- return ''.join([c for c in str(string) if c.isdigit()])
-
-def denumify(string, pattern):
- """
- Formats `string` according to `pattern`, where the letter X gets replaced
- by characters from `string`.
-
- >>> denumify("8005551212", "(XXX) XXX-XXXX")
- '(800) 555-1212'
-
- """
- out = []
- for c in pattern:
- if c == "X":
- out.append(string[0])
- string = string[1:]
- else:
- out.append(c)
- return ''.join(out)
-
-def commify(n):
- """
- Add commas to an integer `n`.
-
- >>> commify(1)
- '1'
- >>> commify(123)
- '123'
- >>> commify(1234)
- '1,234'
- >>> commify(1234567890)
- '1,234,567,890'
- >>> commify(123.0)
- '123.0'
- >>> commify(1234.5)
- '1,234.5'
- >>> commify(1234.56789)
- '1,234.56789'
- >>> commify('%.2f' % 1234.5)
- '1,234.50'
- >>> commify(None)
- >>>
-
- """
- if n is None: return None
- n = str(n)
- if '.' in n:
- dollars, cents = n.split('.')
- else:
- dollars, cents = n, None
-
- r = []
- for i, c in enumerate(str(dollars)[::-1]):
- if i and (not (i % 3)):
- r.insert(0, ',')
- r.insert(0, c)
- out = ''.join(r)
- if cents:
- out += '.' + cents
- return out
-
-def dateify(datestring):
- """
- Formats a numified `datestring` properly.
- """
- return denumify(datestring, "XXXX-XX-XX XX:XX:XX")
-
-
-def nthstr(n):
- """
- Formats an ordinal.
- Doesn't handle negative numbers.
-
- >>> nthstr(1)
- '1st'
- >>> nthstr(0)
- '0th'
- >>> [nthstr(x) for x in [2, 3, 4, 5, 10, 11, 12, 13, 14, 15]]
- ['2nd', '3rd', '4th', '5th', '10th', '11th', '12th', '13th', '14th', '15th']
- >>> [nthstr(x) for x in [91, 92, 93, 94, 99, 100, 101, 102]]
- ['91st', '92nd', '93rd', '94th', '99th', '100th', '101st', '102nd']
- >>> [nthstr(x) for x in [111, 112, 113, 114, 115]]
- ['111th', '112th', '113th', '114th', '115th']
-
- """
-
- assert n >= 0
- if n % 100 in [11, 12, 13]: return '%sth' % n
- return {1: '%sst', 2: '%snd', 3: '%srd'}.get(n % 10, '%sth') % n
-
-def cond(predicate, consequence, alternative=None):
- """
- Function replacement for if-else to use in expressions.
-
- >>> x = 2
- >>> cond(x % 2 == 0, "even", "odd")
- 'even'
- >>> cond(x % 2 == 0, "even", "odd") + '_row'
- 'even_row'
- """
- if predicate:
- return consequence
- else:
- return alternative
-
-class CaptureStdout:
- """
- Captures everything `func` prints to stdout and returns it instead.
-
- >>> def idiot():
- ... print "foo"
- >>> capturestdout(idiot)()
- 'foo\\n'
-
- **WARNING:** Not threadsafe!
- """
- def __init__(self, func):
- self.func = func
- def __call__(self, *args, **keywords):
- from cStringIO import StringIO
- # Not threadsafe!
- out = StringIO()
- oldstdout = sys.stdout
- sys.stdout = out
- try:
- self.func(*args, **keywords)
- finally:
- sys.stdout = oldstdout
- return out.getvalue()
-
-capturestdout = CaptureStdout
-
-class Profile:
- """
- Profiles `func` and returns a tuple containing its output
- and a string with human-readable profiling information.
-
- >>> import time
- >>> out, inf = profile(time.sleep)(.001)
- >>> out
- >>> inf[:10].strip()
- 'took 0.0'
- """
- def __init__(self, func):
- self.func = func
- def __call__(self, *args): ##, **kw): kw unused
- import hotshot, hotshot.stats, os, tempfile ##, time already imported
- f, filename = tempfile.mkstemp()
- os.close(f)
-
- prof = hotshot.Profile(filename)
-
- stime = time.time()
- result = prof.runcall(self.func, *args)
- stime = time.time() - stime
- prof.close()
-
- import cStringIO
- out = cStringIO.StringIO()
- stats = hotshot.stats.load(filename)
- stats.stream = out
- stats.strip_dirs()
- stats.sort_stats('time', 'calls')
- stats.print_stats(40)
- stats.print_callers()
-
- x = '\n\ntook '+ str(stime) + ' seconds\n'
- x += out.getvalue()
-
- # remove the tempfile
- try:
- os.remove(filename)
- except IOError:
- pass
-
- return result, x
-
-profile = Profile
-
-
-import traceback
-# hack for compatibility with Python 2.3:
-if not hasattr(traceback, 'format_exc'):
- from cStringIO import StringIO
- def format_exc(limit=None):
- strbuf = StringIO()
- traceback.print_exc(limit, strbuf)
- return strbuf.getvalue()
- traceback.format_exc = format_exc
-
-def tryall(context, prefix=None):
- """
- Tries a series of functions and prints their results.
- `context` is a dictionary mapping names to values;
- the value will only be tried if it's callable.
-
- >>> tryall(dict(j=lambda: True))
- j: True
- ----------------------------------------
- results:
- True: 1
-
- For example, you might have a file `test/stuff.py`
- with a series of functions testing various things in it.
- At the bottom, have a line:
-
- if __name__ == "__main__": tryall(globals())
-
- Then you can run `python test/stuff.py` and get the results of
- all the tests.
- """
- context = context.copy() # vars() would update
- results = {}
- for (key, value) in context.iteritems():
- if not hasattr(value, '__call__'):
- continue
- if prefix and not key.startswith(prefix):
- continue
- print key + ':',
- try:
- r = value()
- dictincr(results, r)
- print r
- except:
- print 'ERROR'
- dictincr(results, 'ERROR')
- print ' ' + '\n '.join(traceback.format_exc().split('\n'))
-
- print '-'*40
- print 'results:'
- for (key, value) in results.iteritems():
- print ' '*2, str(key)+':', value
-
-class ThreadedDict(threadlocal):
- """
- Thread local storage.
-
- >>> d = ThreadedDict()
- >>> d.x = 1
- >>> d.x
- 1
- >>> import threading
- >>> def f(): d.x = 2
- ...
- >>> t = threading.Thread(target=f)
- >>> t.start()
- >>> t.join()
- >>> d.x
- 1
- """
- _instances = set()
-
- def __init__(self):
- ThreadedDict._instances.add(self)
-
- def __del__(self):
- ThreadedDict._instances.remove(self)
-
- def __hash__(self):
- return id(self)
-
- def clear_all():
- """Clears all ThreadedDict instances.
- """
- for t in list(ThreadedDict._instances):
- t.clear()
- clear_all = staticmethod(clear_all)
-
- # Define all these methods to more or less fully emulate dict -- attribute access
- # is built into threading.local.
-
- def __getitem__(self, key):
- return self.__dict__[key]
-
- def __setitem__(self, key, value):
- self.__dict__[key] = value
-
- def __delitem__(self, key):
- del self.__dict__[key]
-
- def __contains__(self, key):
- return key in self.__dict__
-
- has_key = __contains__
-
- def clear(self):
- self.__dict__.clear()
-
- def copy(self):
- return self.__dict__.copy()
-
- def get(self, key, default=None):
- return self.__dict__.get(key, default)
-
- def items(self):
- return self.__dict__.items()
-
- def iteritems(self):
- return self.__dict__.iteritems()
-
- def keys(self):
- return self.__dict__.keys()
-
- def iterkeys(self):
- return self.__dict__.iterkeys()
-
- iter = iterkeys
-
- def values(self):
- return self.__dict__.values()
-
- def itervalues(self):
- return self.__dict__.itervalues()
-
- def pop(self, key, *args):
- return self.__dict__.pop(key, *args)
-
- def popitem(self):
- return self.__dict__.popitem()
-
- def setdefault(self, key, default=None):
- return self.__dict__.setdefault(key, default)
-
- def update(self, *args, **kwargs):
- self.__dict__.update(*args, **kwargs)
-
- def __repr__(self):
- return '' % self.__dict__
-
- __str__ = __repr__
-
-threadeddict = ThreadedDict
-
-def autoassign(self, locals):
- """
- Automatically assigns local variables to `self`.
-
- >>> self = storage()
- >>> autoassign(self, dict(a=1, b=2))
- >>> self
-
-
- Generally used in `__init__` methods, as in:
-
- def __init__(self, foo, bar, baz=1): autoassign(self, locals())
- """
- for (key, value) in locals.iteritems():
- if key == 'self':
- continue
- setattr(self, key, value)
-
-def to36(q):
- """
- Converts an integer to base 36 (a useful scheme for human-sayable IDs).
-
- >>> to36(35)
- 'z'
- >>> to36(119292)
- '2k1o'
- >>> int(to36(939387374), 36)
- 939387374
- >>> to36(0)
- '0'
- >>> to36(-393)
- Traceback (most recent call last):
- ...
- ValueError: must supply a positive integer
-
- """
- if q < 0: raise ValueError, "must supply a positive integer"
- letters = "0123456789abcdefghijklmnopqrstuvwxyz"
- converted = []
- while q != 0:
- q, r = divmod(q, 36)
- converted.insert(0, letters[r])
- return "".join(converted) or '0'
-
-
-r_url = re_compile('(?', text)
- text = markdown(text)
- return text
-
-def sendmail(from_address, to_address, subject, message, headers=None, **kw):
- """
- Sends the email message `message` with mail and envelope headers
- for from `from_address_` to `to_address` with `subject`.
- Additional email headers can be specified with the dictionary
- `headers.
-
- Optionally cc, bcc and attachments can be specified as keyword arguments.
- Attachments must be an iterable and each attachment can be either a
- filename or a file object or a dictionary with filename, content and
- optionally content_type keys.
-
- If `web.config.smtp_server` is set, it will send the message
- to that SMTP server. Otherwise it will look for
- `/usr/sbin/sendmail`, the typical location for the sendmail-style
- binary. To use sendmail from a different path, set `web.config.sendmail_path`.
- """
- attachments = kw.pop("attachments", [])
- mail = _EmailMessage(from_address, to_address, subject, message, headers, **kw)
-
- for a in attachments:
- if isinstance(a, dict):
- mail.attach(a['filename'], a['content'], a.get('content_type'))
- elif hasattr(a, 'read'): # file
- filename = os.path.basename(getattr(a, "name", ""))
- content_type = getattr(a, 'content_type', None)
- mail.attach(filename, a.read(), content_type)
- elif isinstance(a, basestring):
- f = open(a, 'rb')
- content = f.read()
- f.close()
- filename = os.path.basename(a)
- mail.attach(filename, content, None)
- else:
- raise ValueError, "Invalid attachment: %s" % repr(a)
-
- mail.send()
-
-class _EmailMessage:
- def __init__(self, from_address, to_address, subject, message, headers=None, **kw):
- def listify(x):
- if not isinstance(x, list):
- return [safestr(x)]
- else:
- return [safestr(a) for a in x]
-
- subject = safestr(subject)
- message = safestr(message)
-
- from_address = safestr(from_address)
- to_address = listify(to_address)
- cc = listify(kw.get('cc', []))
- bcc = listify(kw.get('bcc', []))
- recipients = to_address + cc + bcc
-
- import email.Utils
- self.from_address = email.Utils.parseaddr(from_address)[1]
- self.recipients = [email.Utils.parseaddr(r)[1] for r in recipients]
-
- self.headers = dictadd({
- 'From': from_address,
- 'To': ", ".join(to_address),
- 'Subject': subject
- }, headers or {})
-
- if cc:
- self.headers['Cc'] = ", ".join(cc)
-
- self.message = self.new_message()
- self.message.add_header("Content-Transfer-Encoding", "7bit")
- self.message.add_header("Content-Disposition", "inline")
- self.message.add_header("MIME-Version", "1.0")
- self.message.set_payload(message, 'utf-8')
- self.multipart = False
-
- def new_message(self):
- from email.Message import Message
- return Message()
-
- def attach(self, filename, content, content_type=None):
- if not self.multipart:
- msg = self.new_message()
- msg.add_header("Content-Type", "multipart/mixed")
- msg.attach(self.message)
- self.message = msg
- self.multipart = True
-
- import mimetypes
- try:
- from email import encoders
- except:
- from email import Encoders as encoders
-
- content_type = content_type or mimetypes.guess_type(filename)[0] or "applcation/octet-stream"
-
- msg = self.new_message()
- msg.set_payload(content)
- msg.add_header('Content-Type', content_type)
- msg.add_header('Content-Disposition', 'attachment', filename=filename)
-
- if not content_type.startswith("text/"):
- encoders.encode_base64(msg)
-
- self.message.attach(msg)
-
- def prepare_message(self):
- for k, v in self.headers.iteritems():
- if k.lower() == "content-type":
- self.message.set_type(v)
- else:
- self.message.add_header(k, v)
-
- self.headers = {}
-
- def send(self):
- try:
- import webapi
- except ImportError:
- webapi = Storage(config=Storage())
-
- self.prepare_message()
- message_text = self.message.as_string()
-
- if webapi.config.get('smtp_server'):
- server = webapi.config.get('smtp_server')
- port = webapi.config.get('smtp_port', 0)
- username = webapi.config.get('smtp_username')
- password = webapi.config.get('smtp_password')
- debug_level = webapi.config.get('smtp_debuglevel', None)
- starttls = webapi.config.get('smtp_starttls', False)
-
- import smtplib
- smtpserver = smtplib.SMTP(server, port)
-
- if debug_level:
- smtpserver.set_debuglevel(debug_level)
-
- if starttls:
- smtpserver.ehlo()
- smtpserver.starttls()
- smtpserver.ehlo()
-
- if username and password:
- smtpserver.login(username, password)
-
- smtpserver.sendmail(self.from_address, self.recipients, message_text)
- smtpserver.quit()
- elif webapi.config.get('email_engine') == 'aws':
- import boto.ses
- c = boto.ses.SESConnection(
- aws_access_key_id=webapi.config.get('aws_access_key_id'),
- aws_secret_access_key=web.api.config.get('aws_secret_access_key'))
- c.send_raw_email(self.from_address, message_text, self.from_recipients)
- else:
- sendmail = webapi.config.get('sendmail_path', '/usr/sbin/sendmail')
-
- assert not self.from_address.startswith('-'), 'security'
- for r in self.recipients:
- assert not r.startswith('-'), 'security'
-
- cmd = [sendmail, '-f', self.from_address] + self.recipients
-
- if subprocess:
- p = subprocess.Popen(cmd, stdin=subprocess.PIPE)
- p.stdin.write(message_text)
- p.stdin.close()
- p.wait()
- else:
- i, o = os.popen2(cmd)
- i.write(message)
- i.close()
- o.close()
- del i, o
-
- def __repr__(self):
- return ""
-
- def __str__(self):
- return self.message.as_string()
-
-if __name__ == "__main__":
- import doctest
- doctest.testmod()
+#!/usr/bin/env python
+"""
+General Utilities
+(part of web.py)
+"""
+
+__all__ = [
+ "Storage", "storage", "storify",
+ "Counter", "counter",
+ "iters",
+ "rstrips", "lstrips", "strips",
+ "safeunicode", "safestr", "utf8",
+ "TimeoutError", "timelimit",
+ "Memoize", "memoize",
+ "re_compile", "re_subm",
+ "group", "uniq", "iterview",
+ "IterBetter", "iterbetter",
+ "safeiter", "safewrite",
+ "dictreverse", "dictfind", "dictfindall", "dictincr", "dictadd",
+ "requeue", "restack",
+ "listget", "intget", "datestr",
+ "numify", "denumify", "commify", "dateify",
+ "nthstr", "cond",
+ "CaptureStdout", "capturestdout", "Profile", "profile",
+ "tryall",
+ "ThreadedDict", "threadeddict",
+ "autoassign",
+ "to36",
+ "safemarkdown",
+ "sendmail"
+]
+
+import re, sys, time, threading, itertools, traceback, os
+
+try:
+ import subprocess
+except ImportError:
+ subprocess = None
+
+try: import datetime
+except ImportError: pass
+
+try: set
+except NameError:
+ from sets import Set as set
+
+try:
+ from threading import local as threadlocal
+except ImportError:
+ from python23 import threadlocal
+
+class Storage(dict):
+ """
+ A Storage object is like a dictionary except `obj.foo` can be used
+ in addition to `obj['foo']`.
+
+ >>> o = storage(a=1)
+ >>> o.a
+ 1
+ >>> o['a']
+ 1
+ >>> o.a = 2
+ >>> o['a']
+ 2
+ >>> del o.a
+ >>> o.a
+ Traceback (most recent call last):
+ ...
+ AttributeError: 'a'
+
+ """
+ def __getattr__(self, key):
+ try:
+ return self[key]
+ except KeyError, k:
+ raise AttributeError, k
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def __delattr__(self, key):
+ try:
+ del self[key]
+ except KeyError, k:
+ raise AttributeError, k
+
+ def __repr__(self):
+ return ''
+
+storage = Storage
+
+def storify(mapping, *requireds, **defaults):
+ """
+ Creates a `storage` object from dictionary `mapping`, raising `KeyError` if
+ d doesn't have all of the keys in `requireds` and using the default
+ values for keys found in `defaults`.
+
+ For example, `storify({'a':1, 'c':3}, b=2, c=0)` will return the equivalent of
+ `storage({'a':1, 'b':2, 'c':3})`.
+
+ If a `storify` value is a list (e.g. multiple values in a form submission),
+ `storify` returns the last element of the list, unless the key appears in
+ `defaults` as a list. Thus:
+
+ >>> storify({'a':[1, 2]}).a
+ 2
+ >>> storify({'a':[1, 2]}, a=[]).a
+ [1, 2]
+ >>> storify({'a':1}, a=[]).a
+ [1]
+ >>> storify({}, a=[]).a
+ []
+
+ Similarly, if the value has a `value` attribute, `storify will return _its_
+ value, unless the key appears in `defaults` as a dictionary.
+
+ >>> storify({'a':storage(value=1)}).a
+ 1
+ >>> storify({'a':storage(value=1)}, a={}).a
+
+ >>> storify({}, a={}).a
+ {}
+
+ Optionally, keyword parameter `_unicode` can be passed to convert all values to unicode.
+
+ >>> storify({'x': 'a'}, _unicode=True)
+
+ >>> storify({'x': storage(value='a')}, x={}, _unicode=True)
+ }>
+ >>> storify({'x': storage(value='a')}, _unicode=True)
+
+ """
+ _unicode = defaults.pop('_unicode', False)
+
+ # if _unicode is callable object, use it convert a string to unicode.
+ to_unicode = safeunicode
+ if _unicode is not False and hasattr(_unicode, "__call__"):
+ to_unicode = _unicode
+
+ def unicodify(s):
+ if _unicode and isinstance(s, str): return to_unicode(s)
+ else: return s
+
+ def getvalue(x):
+ if hasattr(x, 'file') and hasattr(x, 'value'):
+ return x.value
+ elif hasattr(x, 'value'):
+ return unicodify(x.value)
+ else:
+ return unicodify(x)
+
+ stor = Storage()
+ for key in requireds + tuple(mapping.keys()):
+ value = mapping[key]
+ if isinstance(value, list):
+ if isinstance(defaults.get(key), list):
+ value = [getvalue(x) for x in value]
+ else:
+ value = value[-1]
+ if not isinstance(defaults.get(key), dict):
+ value = getvalue(value)
+ if isinstance(defaults.get(key), list) and not isinstance(value, list):
+ value = [value]
+ setattr(stor, key, value)
+
+ for (key, value) in defaults.iteritems():
+ result = value
+ if hasattr(stor, key):
+ result = stor[key]
+ if value == () and not isinstance(result, tuple):
+ result = (result,)
+ setattr(stor, key, result)
+
+ return stor
+
+class Counter(storage):
+ """Keeps count of how many times something is added.
+
+ >>> c = counter()
+ >>> c.add('x')
+ >>> c.add('x')
+ >>> c.add('x')
+ >>> c.add('x')
+ >>> c.add('x')
+ >>> c.add('y')
+ >>> c
+
+ >>> c.most()
+ ['x']
+ """
+ def add(self, n):
+ self.setdefault(n, 0)
+ self[n] += 1
+
+ def most(self):
+ """Returns the keys with maximum count."""
+ m = max(self.itervalues())
+ return [k for k, v in self.iteritems() if v == m]
+
+ def least(self):
+ """Returns the keys with mininum count."""
+ m = min(self.itervalues())
+ return [k for k, v in self.iteritems() if v == m]
+
+ def percent(self, key):
+ """Returns what percentage a certain key is of all entries.
+
+ >>> c = counter()
+ >>> c.add('x')
+ >>> c.add('x')
+ >>> c.add('x')
+ >>> c.add('y')
+ >>> c.percent('x')
+ 0.75
+ >>> c.percent('y')
+ 0.25
+ """
+ return float(self[key])/sum(self.values())
+
+ def sorted_keys(self):
+ """Returns keys sorted by value.
+
+ >>> c = counter()
+ >>> c.add('x')
+ >>> c.add('x')
+ >>> c.add('y')
+ >>> c.sorted_keys()
+ ['x', 'y']
+ """
+ return sorted(self.keys(), key=lambda k: self[k], reverse=True)
+
+ def sorted_values(self):
+ """Returns values sorted by value.
+
+ >>> c = counter()
+ >>> c.add('x')
+ >>> c.add('x')
+ >>> c.add('y')
+ >>> c.sorted_values()
+ [2, 1]
+ """
+ return [self[k] for k in self.sorted_keys()]
+
+ def sorted_items(self):
+ """Returns items sorted by value.
+
+ >>> c = counter()
+ >>> c.add('x')
+ >>> c.add('x')
+ >>> c.add('y')
+ >>> c.sorted_items()
+ [('x', 2), ('y', 1)]
+ """
+ return [(k, self[k]) for k in self.sorted_keys()]
+
+ def __repr__(self):
+ return ''
+
+counter = Counter
+
+iters = [list, tuple]
+import __builtin__
+if hasattr(__builtin__, 'set'):
+ iters.append(set)
+if hasattr(__builtin__, 'frozenset'):
+ iters.append(set)
+if sys.version_info < (2,6): # sets module deprecated in 2.6
+ try:
+ from sets import Set
+ iters.append(Set)
+ except ImportError:
+ pass
+
+class _hack(tuple): pass
+iters = _hack(iters)
+iters.__doc__ = """
+A list of iterable items (like lists, but not strings). Includes whichever
+of lists, tuples, sets, and Sets are available in this version of Python.
+"""
+
+def _strips(direction, text, remove):
+ if isinstance(remove, iters):
+ for subr in remove:
+ text = _strips(direction, text, subr)
+ return text
+
+ if direction == 'l':
+ if text.startswith(remove):
+ return text[len(remove):]
+ elif direction == 'r':
+ if text.endswith(remove):
+ return text[:-len(remove)]
+ else:
+ raise ValueError, "Direction needs to be r or l."
+ return text
+
+def rstrips(text, remove):
+ """
+ removes the string `remove` from the right of `text`
+
+ >>> rstrips("foobar", "bar")
+ 'foo'
+
+ """
+ return _strips('r', text, remove)
+
+def lstrips(text, remove):
+ """
+ removes the string `remove` from the left of `text`
+
+ >>> lstrips("foobar", "foo")
+ 'bar'
+ >>> lstrips('http://foo.org/', ['http://', 'https://'])
+ 'foo.org/'
+ >>> lstrips('FOOBARBAZ', ['FOO', 'BAR'])
+ 'BAZ'
+ >>> lstrips('FOOBARBAZ', ['BAR', 'FOO'])
+ 'BARBAZ'
+
+ """
+ return _strips('l', text, remove)
+
+def strips(text, remove):
+ """
+ removes the string `remove` from the both sides of `text`
+
+ >>> strips("foobarfoo", "foo")
+ 'bar'
+
+ """
+ return rstrips(lstrips(text, remove), remove)
+
+def safeunicode(obj, encoding='utf-8'):
+ r"""
+ Converts any given object to unicode string.
+
+ >>> safeunicode('hello')
+ u'hello'
+ >>> safeunicode(2)
+ u'2'
+ >>> safeunicode('\xe1\x88\xb4')
+ u'\u1234'
+ """
+ t = type(obj)
+ if t is unicode:
+ return obj
+ elif t is str:
+ return obj.decode(encoding)
+ elif t in [int, float, bool]:
+ return unicode(obj)
+ elif hasattr(obj, '__unicode__') or isinstance(obj, unicode):
+ return unicode(obj)
+ else:
+ return str(obj).decode(encoding)
+
+def safestr(obj, encoding='utf-8'):
+ r"""
+ Converts any given object to utf-8 encoded string.
+
+ >>> safestr('hello')
+ 'hello'
+ >>> safestr(u'\u1234')
+ '\xe1\x88\xb4'
+ >>> safestr(2)
+ '2'
+ """
+ if isinstance(obj, unicode):
+ return obj.encode(encoding)
+ elif isinstance(obj, str):
+ return obj
+ elif hasattr(obj, 'next'): # iterator
+ return itertools.imap(safestr, obj)
+ else:
+ return str(obj)
+
+# for backward-compatibility
+utf8 = safestr
+
+class TimeoutError(Exception): pass
+def timelimit(timeout):
+ """
+ A decorator to limit a function to `timeout` seconds, raising `TimeoutError`
+ if it takes longer.
+
+ >>> import time
+ >>> def meaningoflife():
+ ... time.sleep(.2)
+ ... return 42
+ >>>
+ >>> timelimit(.1)(meaningoflife)()
+ Traceback (most recent call last):
+ ...
+ TimeoutError: took too long
+ >>> timelimit(1)(meaningoflife)()
+ 42
+
+ _Caveat:_ The function isn't stopped after `timeout` seconds but continues
+ executing in a separate thread. (There seems to be no way to kill a thread.)
+
+ inspired by
+ """
+ def _1(function):
+ def _2(*args, **kw):
+ class Dispatch(threading.Thread):
+ def __init__(self):
+ threading.Thread.__init__(self)
+ self.result = None
+ self.error = None
+
+ self.setDaemon(True)
+ self.start()
+
+ def run(self):
+ try:
+ self.result = function(*args, **kw)
+ except:
+ self.error = sys.exc_info()
+
+ c = Dispatch()
+ c.join(timeout)
+ if c.isAlive():
+ raise TimeoutError, 'took too long'
+ if c.error:
+ raise c.error[0], c.error[1]
+ return c.result
+ return _2
+ return _1
+
+class Memoize:
+ """
+ 'Memoizes' a function, caching its return values for each input.
+ If `expires` is specified, values are recalculated after `expires` seconds.
+ If `background` is specified, values are recalculated in a separate thread.
+
+ >>> calls = 0
+ >>> def howmanytimeshaveibeencalled():
+ ... global calls
+ ... calls += 1
+ ... return calls
+ >>> fastcalls = memoize(howmanytimeshaveibeencalled)
+ >>> howmanytimeshaveibeencalled()
+ 1
+ >>> howmanytimeshaveibeencalled()
+ 2
+ >>> fastcalls()
+ 3
+ >>> fastcalls()
+ 3
+ >>> import time
+ >>> fastcalls = memoize(howmanytimeshaveibeencalled, .1, background=False)
+ >>> fastcalls()
+ 4
+ >>> fastcalls()
+ 4
+ >>> time.sleep(.2)
+ >>> fastcalls()
+ 5
+ >>> def slowfunc():
+ ... time.sleep(.1)
+ ... return howmanytimeshaveibeencalled()
+ >>> fastcalls = memoize(slowfunc, .2, background=True)
+ >>> fastcalls()
+ 6
+ >>> timelimit(.05)(fastcalls)()
+ 6
+ >>> time.sleep(.2)
+ >>> timelimit(.05)(fastcalls)()
+ 6
+ >>> timelimit(.05)(fastcalls)()
+ 6
+ >>> time.sleep(.2)
+ >>> timelimit(.05)(fastcalls)()
+ 7
+ >>> fastcalls = memoize(slowfunc, None, background=True)
+ >>> threading.Thread(target=fastcalls).start()
+ >>> time.sleep(.01)
+ >>> fastcalls()
+ 9
+ """
+ def __init__(self, func, expires=None, background=True):
+ self.func = func
+ self.cache = {}
+ self.expires = expires
+ self.background = background
+ self.running = {}
+
+ def __call__(self, *args, **keywords):
+ key = (args, tuple(keywords.items()))
+ if not self.running.get(key):
+ self.running[key] = threading.Lock()
+ def update(block=False):
+ if self.running[key].acquire(block):
+ try:
+ self.cache[key] = (self.func(*args, **keywords), time.time())
+ finally:
+ self.running[key].release()
+
+ if key not in self.cache:
+ update(block=True)
+ elif self.expires and (time.time() - self.cache[key][1]) > self.expires:
+ if self.background:
+ threading.Thread(target=update).start()
+ else:
+ update()
+ return self.cache[key][0]
+
+memoize = Memoize
+
+re_compile = memoize(re.compile) #@@ threadsafe?
+re_compile.__doc__ = """
+A memoized version of re.compile.
+"""
+
+class _re_subm_proxy:
+ def __init__(self):
+ self.match = None
+ def __call__(self, match):
+ self.match = match
+ return ''
+
+def re_subm(pat, repl, string):
+ """
+ Like re.sub, but returns the replacement _and_ the match object.
+
+ >>> t, m = re_subm('g(oo+)fball', r'f\\1lish', 'goooooofball')
+ >>> t
+ 'foooooolish'
+ >>> m.groups()
+ ('oooooo',)
+ """
+ compiled_pat = re_compile(pat)
+ proxy = _re_subm_proxy()
+ compiled_pat.sub(proxy.__call__, string)
+ return compiled_pat.sub(repl, string), proxy.match
+
+def group(seq, size):
+ """
+ Returns an iterator over a series of lists of length size from iterable.
+
+ >>> list(group([1,2,3,4], 2))
+ [[1, 2], [3, 4]]
+ >>> list(group([1,2,3,4,5], 2))
+ [[1, 2], [3, 4], [5]]
+ """
+ def take(seq, n):
+ for i in xrange(n):
+ yield seq.next()
+
+ if not hasattr(seq, 'next'):
+ seq = iter(seq)
+ while True:
+ x = list(take(seq, size))
+ if x:
+ yield x
+ else:
+ break
+
+def uniq(seq, key=None):
+ """
+ Removes duplicate elements from a list while preserving the order of the rest.
+
+ >>> uniq([9,0,2,1,0])
+ [9, 0, 2, 1]
+
+ The value of the optional `key` parameter should be a function that
+ takes a single argument and returns a key to test the uniqueness.
+
+ >>> uniq(["Foo", "foo", "bar"], key=lambda s: s.lower())
+ ['Foo', 'bar']
+ """
+ key = key or (lambda x: x)
+ seen = set()
+ result = []
+ for v in seq:
+ k = key(v)
+ if k in seen:
+ continue
+ seen.add(k)
+ result.append(v)
+ return result
+
+def iterview(x):
+ """
+ Takes an iterable `x` and returns an iterator over it
+ which prints its progress to stderr as it iterates through.
+ """
+ WIDTH = 70
+
+ def plainformat(n, lenx):
+ return '%5.1f%% (%*d/%d)' % ((float(n)/lenx)*100, len(str(lenx)), n, lenx)
+
+ def bars(size, n, lenx):
+ val = int((float(n)*size)/lenx + 0.5)
+ if size - val:
+ spacing = ">" + (" "*(size-val))[1:]
+ else:
+ spacing = ""
+ return "[%s%s]" % ("="*val, spacing)
+
+ def eta(elapsed, n, lenx):
+ if n == 0:
+ return '--:--:--'
+ if n == lenx:
+ secs = int(elapsed)
+ else:
+ secs = int((elapsed/n) * (lenx-n))
+ mins, secs = divmod(secs, 60)
+ hrs, mins = divmod(mins, 60)
+
+ return '%02d:%02d:%02d' % (hrs, mins, secs)
+
+ def format(starttime, n, lenx):
+ out = plainformat(n, lenx) + ' '
+ if n == lenx:
+ end = ' '
+ else:
+ end = ' ETA '
+ end += eta(time.time() - starttime, n, lenx)
+ out += bars(WIDTH - len(out) - len(end), n, lenx)
+ out += end
+ return out
+
+ starttime = time.time()
+ lenx = len(x)
+ for n, y in enumerate(x):
+ sys.stderr.write('\r' + format(starttime, n, lenx))
+ yield y
+ sys.stderr.write('\r' + format(starttime, n+1, lenx) + '\n')
+
+class IterBetter:
+ """
+ Returns an object that can be used as an iterator
+ but can also be used via __getitem__ (although it
+ cannot go backwards -- that is, you cannot request
+ `iterbetter[0]` after requesting `iterbetter[1]`).
+
+ >>> import itertools
+ >>> c = iterbetter(itertools.count())
+ >>> c[1]
+ 1
+ >>> c[5]
+ 5
+ >>> c[3]
+ Traceback (most recent call last):
+ ...
+ IndexError: already passed 3
+
+ For boolean test, IterBetter peeps at first value in the itertor without effecting the iteration.
+
+ >>> c = iterbetter(iter(range(5)))
+ >>> bool(c)
+ True
+ >>> list(c)
+ [0, 1, 2, 3, 4]
+ >>> c = iterbetter(iter([]))
+ >>> bool(c)
+ False
+ >>> list(c)
+ []
+ """
+ def __init__(self, iterator):
+ self.i, self.c = iterator, 0
+
+ def __iter__(self):
+ if hasattr(self, "_head"):
+ yield self._head
+
+ while 1:
+ yield self.i.next()
+ self.c += 1
+
+ def __getitem__(self, i):
+ #todo: slices
+ if i < self.c:
+ raise IndexError, "already passed "+str(i)
+ try:
+ while i > self.c:
+ self.i.next()
+ self.c += 1
+ # now self.c == i
+ self.c += 1
+ return self.i.next()
+ except StopIteration:
+ raise IndexError, str(i)
+
+ def __nonzero__(self):
+ if hasattr(self, "__len__"):
+ return len(self) != 0
+ elif hasattr(self, "_head"):
+ return True
+ else:
+ try:
+ self._head = self.i.next()
+ except StopIteration:
+ return False
+ else:
+ return True
+
+iterbetter = IterBetter
+
+def safeiter(it, cleanup=None, ignore_errors=True):
+ """Makes an iterator safe by ignoring the exceptions occured during the iteration.
+ """
+ def next():
+ while True:
+ try:
+ return it.next()
+ except StopIteration:
+ raise
+ except:
+ traceback.print_exc()
+
+ it = iter(it)
+ while True:
+ yield next()
+
+def safewrite(filename, content):
+ """Writes the content to a temp file and then moves the temp file to
+ given filename to avoid overwriting the existing file in case of errors.
+ """
+ f = file(filename + '.tmp', 'w')
+ f.write(content)
+ f.close()
+ os.rename(f.name, filename)
+
+def dictreverse(mapping):
+ """
+ Returns a new dictionary with keys and values swapped.
+
+ >>> dictreverse({1: 2, 3: 4})
+ {2: 1, 4: 3}
+ """
+ return dict([(value, key) for (key, value) in mapping.iteritems()])
+
+def dictfind(dictionary, element):
+ """
+ Returns a key whose value in `dictionary` is `element`
+ or, if none exists, None.
+
+ >>> d = {1:2, 3:4}
+ >>> dictfind(d, 4)
+ 3
+ >>> dictfind(d, 5)
+ """
+ for (key, value) in dictionary.iteritems():
+ if element is value:
+ return key
+
+def dictfindall(dictionary, element):
+ """
+ Returns the keys whose values in `dictionary` are `element`
+ or, if none exists, [].
+
+ >>> d = {1:4, 3:4}
+ >>> dictfindall(d, 4)
+ [1, 3]
+ >>> dictfindall(d, 5)
+ []
+ """
+ res = []
+ for (key, value) in dictionary.iteritems():
+ if element is value:
+ res.append(key)
+ return res
+
+def dictincr(dictionary, element):
+ """
+ Increments `element` in `dictionary`,
+ setting it to one if it doesn't exist.
+
+ >>> d = {1:2, 3:4}
+ >>> dictincr(d, 1)
+ 3
+ >>> d[1]
+ 3
+ >>> dictincr(d, 5)
+ 1
+ >>> d[5]
+ 1
+ """
+ dictionary.setdefault(element, 0)
+ dictionary[element] += 1
+ return dictionary[element]
+
+def dictadd(*dicts):
+ """
+ Returns a dictionary consisting of the keys in the argument dictionaries.
+ If they share a key, the value from the last argument is used.
+
+ >>> dictadd({1: 0, 2: 0}, {2: 1, 3: 1})
+ {1: 0, 2: 1, 3: 1}
+ """
+ result = {}
+ for dct in dicts:
+ result.update(dct)
+ return result
+
+def requeue(queue, index=-1):
+ """Returns the element at index after moving it to the beginning of the queue.
+
+ >>> x = [1, 2, 3, 4]
+ >>> requeue(x)
+ 4
+ >>> x
+ [4, 1, 2, 3]
+ """
+ x = queue.pop(index)
+ queue.insert(0, x)
+ return x
+
+def restack(stack, index=0):
+ """Returns the element at index after moving it to the top of stack.
+
+ >>> x = [1, 2, 3, 4]
+ >>> restack(x)
+ 1
+ >>> x
+ [2, 3, 4, 1]
+ """
+ x = stack.pop(index)
+ stack.append(x)
+ return x
+
+def listget(lst, ind, default=None):
+ """
+ Returns `lst[ind]` if it exists, `default` otherwise.
+
+ >>> listget(['a'], 0)
+ 'a'
+ >>> listget(['a'], 1)
+ >>> listget(['a'], 1, 'b')
+ 'b'
+ """
+ if len(lst)-1 < ind:
+ return default
+ return lst[ind]
+
+def intget(integer, default=None):
+ """
+ Returns `integer` as an int or `default` if it can't.
+
+ >>> intget('3')
+ 3
+ >>> intget('3a')
+ >>> intget('3a', 0)
+ 0
+ """
+ try:
+ return int(integer)
+ except (TypeError, ValueError):
+ return default
+
+def datestr(then, now=None):
+ """
+ Converts a (UTC) datetime object to a nice string representation.
+
+ >>> from datetime import datetime, timedelta
+ >>> d = datetime(1970, 5, 1)
+ >>> datestr(d, now=d)
+ '0 microseconds ago'
+ >>> for t, v in {
+ ... timedelta(microseconds=1): '1 microsecond ago',
+ ... timedelta(microseconds=2): '2 microseconds ago',
+ ... -timedelta(microseconds=1): '1 microsecond from now',
+ ... -timedelta(microseconds=2): '2 microseconds from now',
+ ... timedelta(microseconds=2000): '2 milliseconds ago',
+ ... timedelta(seconds=2): '2 seconds ago',
+ ... timedelta(seconds=2*60): '2 minutes ago',
+ ... timedelta(seconds=2*60*60): '2 hours ago',
+ ... timedelta(days=2): '2 days ago',
+ ... }.iteritems():
+ ... assert datestr(d, now=d+t) == v
+ >>> datestr(datetime(1970, 1, 1), now=d)
+ 'January 1'
+ >>> datestr(datetime(1969, 1, 1), now=d)
+ 'January 1, 1969'
+ >>> datestr(datetime(1970, 6, 1), now=d)
+ 'June 1, 1970'
+ >>> datestr(None)
+ ''
+ """
+ def agohence(n, what, divisor=None):
+ if divisor: n = n // divisor
+
+ out = str(abs(n)) + ' ' + what # '2 day'
+ if abs(n) != 1: out += 's' # '2 days'
+ out += ' ' # '2 days '
+ if n < 0:
+ out += 'from now'
+ else:
+ out += 'ago'
+ return out # '2 days ago'
+
+ oneday = 24 * 60 * 60
+
+ if not then: return ""
+ if not now: now = datetime.datetime.utcnow()
+ if type(now).__name__ == "DateTime":
+ now = datetime.datetime.fromtimestamp(now)
+ if type(then).__name__ == "DateTime":
+ then = datetime.datetime.fromtimestamp(then)
+ elif type(then).__name__ == "date":
+ then = datetime.datetime(then.year, then.month, then.day)
+
+ delta = now - then
+ deltaseconds = int(delta.days * oneday + delta.seconds + delta.microseconds * 1e-06)
+ deltadays = abs(deltaseconds) // oneday
+ if deltaseconds < 0: deltadays *= -1 # fix for oddity of floor
+
+ if deltadays:
+ if abs(deltadays) < 4:
+ return agohence(deltadays, 'day')
+
+ try:
+ out = then.strftime('%B %e') # e.g. 'June 3'
+ except ValueError:
+ # %e doesn't work on Windows.
+ out = then.strftime('%B %d') # e.g. 'June 03'
+
+ if then.year != now.year or deltadays < 0:
+ out += ', %s' % then.year
+ return out
+
+ if int(deltaseconds):
+ if abs(deltaseconds) > (60 * 60):
+ return agohence(deltaseconds, 'hour', 60 * 60)
+ elif abs(deltaseconds) > 60:
+ return agohence(deltaseconds, 'minute', 60)
+ else:
+ return agohence(deltaseconds, 'second')
+
+ deltamicroseconds = delta.microseconds
+ if delta.days: deltamicroseconds = int(delta.microseconds - 1e6) # datetime oddity
+ if abs(deltamicroseconds) > 1000:
+ return agohence(deltamicroseconds, 'millisecond', 1000)
+
+ return agohence(deltamicroseconds, 'microsecond')
+
+def numify(string):
+ """
+ Removes all non-digit characters from `string`.
+
+ >>> numify('800-555-1212')
+ '8005551212'
+ >>> numify('800.555.1212')
+ '8005551212'
+
+ """
+ return ''.join([c for c in str(string) if c.isdigit()])
+
+def denumify(string, pattern):
+ """
+ Formats `string` according to `pattern`, where the letter X gets replaced
+ by characters from `string`.
+
+ >>> denumify("8005551212", "(XXX) XXX-XXXX")
+ '(800) 555-1212'
+
+ """
+ out = []
+ for c in pattern:
+ if c == "X":
+ out.append(string[0])
+ string = string[1:]
+ else:
+ out.append(c)
+ return ''.join(out)
+
+def commify(n):
+ """
+ Add commas to an integer `n`.
+
+ >>> commify(1)
+ '1'
+ >>> commify(123)
+ '123'
+ >>> commify(1234)
+ '1,234'
+ >>> commify(1234567890)
+ '1,234,567,890'
+ >>> commify(123.0)
+ '123.0'
+ >>> commify(1234.5)
+ '1,234.5'
+ >>> commify(1234.56789)
+ '1,234.56789'
+ >>> commify('%.2f' % 1234.5)
+ '1,234.50'
+ >>> commify(None)
+ >>>
+
+ """
+ if n is None: return None
+ n = str(n)
+ if '.' in n:
+ dollars, cents = n.split('.')
+ else:
+ dollars, cents = n, None
+
+ r = []
+ for i, c in enumerate(str(dollars)[::-1]):
+ if i and (not (i % 3)):
+ r.insert(0, ',')
+ r.insert(0, c)
+ out = ''.join(r)
+ if cents:
+ out += '.' + cents
+ return out
+
+def dateify(datestring):
+ """
+ Formats a numified `datestring` properly.
+ """
+ return denumify(datestring, "XXXX-XX-XX XX:XX:XX")
+
+
+def nthstr(n):
+ """
+ Formats an ordinal.
+ Doesn't handle negative numbers.
+
+ >>> nthstr(1)
+ '1st'
+ >>> nthstr(0)
+ '0th'
+ >>> [nthstr(x) for x in [2, 3, 4, 5, 10, 11, 12, 13, 14, 15]]
+ ['2nd', '3rd', '4th', '5th', '10th', '11th', '12th', '13th', '14th', '15th']
+ >>> [nthstr(x) for x in [91, 92, 93, 94, 99, 100, 101, 102]]
+ ['91st', '92nd', '93rd', '94th', '99th', '100th', '101st', '102nd']
+ >>> [nthstr(x) for x in [111, 112, 113, 114, 115]]
+ ['111th', '112th', '113th', '114th', '115th']
+
+ """
+
+ assert n >= 0
+ if n % 100 in [11, 12, 13]: return '%sth' % n
+ return {1: '%sst', 2: '%snd', 3: '%srd'}.get(n % 10, '%sth') % n
+
+def cond(predicate, consequence, alternative=None):
+ """
+ Function replacement for if-else to use in expressions.
+
+ >>> x = 2
+ >>> cond(x % 2 == 0, "even", "odd")
+ 'even'
+ >>> cond(x % 2 == 0, "even", "odd") + '_row'
+ 'even_row'
+ """
+ if predicate:
+ return consequence
+ else:
+ return alternative
+
+class CaptureStdout:
+ """
+ Captures everything `func` prints to stdout and returns it instead.
+
+ >>> def idiot():
+ ... print "foo"
+ >>> capturestdout(idiot)()
+ 'foo\\n'
+
+ **WARNING:** Not threadsafe!
+ """
+ def __init__(self, func):
+ self.func = func
+ def __call__(self, *args, **keywords):
+ from cStringIO import StringIO
+ # Not threadsafe!
+ out = StringIO()
+ oldstdout = sys.stdout
+ sys.stdout = out
+ try:
+ self.func(*args, **keywords)
+ finally:
+ sys.stdout = oldstdout
+ return out.getvalue()
+
+capturestdout = CaptureStdout
+
+class Profile:
+ """
+ Profiles `func` and returns a tuple containing its output
+ and a string with human-readable profiling information.
+
+ >>> import time
+ >>> out, inf = profile(time.sleep)(.001)
+ >>> out
+ >>> inf[:10].strip()
+ 'took 0.0'
+ """
+ def __init__(self, func):
+ self.func = func
+ def __call__(self, *args): ##, **kw): kw unused
+ import hotshot, hotshot.stats, os, tempfile ##, time already imported
+ f, filename = tempfile.mkstemp()
+ os.close(f)
+
+ prof = hotshot.Profile(filename)
+
+ stime = time.time()
+ result = prof.runcall(self.func, *args)
+ stime = time.time() - stime
+ prof.close()
+
+ import cStringIO
+ out = cStringIO.StringIO()
+ stats = hotshot.stats.load(filename)
+ stats.stream = out
+ stats.strip_dirs()
+ stats.sort_stats('time', 'calls')
+ stats.print_stats(40)
+ stats.print_callers()
+
+ x = '\n\ntook '+ str(stime) + ' seconds\n'
+ x += out.getvalue()
+
+ # remove the tempfile
+ try:
+ os.remove(filename)
+ except IOError:
+ pass
+
+ return result, x
+
+profile = Profile
+
+
+import traceback
+# hack for compatibility with Python 2.3:
+if not hasattr(traceback, 'format_exc'):
+ from cStringIO import StringIO
+ def format_exc(limit=None):
+ strbuf = StringIO()
+ traceback.print_exc(limit, strbuf)
+ return strbuf.getvalue()
+ traceback.format_exc = format_exc
+
+def tryall(context, prefix=None):
+ """
+ Tries a series of functions and prints their results.
+ `context` is a dictionary mapping names to values;
+ the value will only be tried if it's callable.
+
+ >>> tryall(dict(j=lambda: True))
+ j: True
+ ----------------------------------------
+ results:
+ True: 1
+
+ For example, you might have a file `test/stuff.py`
+ with a series of functions testing various things in it.
+ At the bottom, have a line:
+
+ if __name__ == "__main__": tryall(globals())
+
+ Then you can run `python test/stuff.py` and get the results of
+ all the tests.
+ """
+ context = context.copy() # vars() would update
+ results = {}
+ for (key, value) in context.iteritems():
+ if not hasattr(value, '__call__'):
+ continue
+ if prefix and not key.startswith(prefix):
+ continue
+ print key + ':',
+ try:
+ r = value()
+ dictincr(results, r)
+ print r
+ except:
+ print 'ERROR'
+ dictincr(results, 'ERROR')
+ print ' ' + '\n '.join(traceback.format_exc().split('\n'))
+
+ print '-'*40
+ print 'results:'
+ for (key, value) in results.iteritems():
+ print ' '*2, str(key)+':', value
+
+class ThreadedDict(threadlocal):
+ """
+ Thread local storage.
+
+ >>> d = ThreadedDict()
+ >>> d.x = 1
+ >>> d.x
+ 1
+ >>> import threading
+ >>> def f(): d.x = 2
+ ...
+ >>> t = threading.Thread(target=f)
+ >>> t.start()
+ >>> t.join()
+ >>> d.x
+ 1
+ """
+ _instances = set()
+
+ def __init__(self):
+ ThreadedDict._instances.add(self)
+
+ def __del__(self):
+ ThreadedDict._instances.remove(self)
+
+ def __hash__(self):
+ return id(self)
+
+ def clear_all():
+ """Clears all ThreadedDict instances.
+ """
+ for t in list(ThreadedDict._instances):
+ t.clear()
+ clear_all = staticmethod(clear_all)
+
+ # Define all these methods to more or less fully emulate dict -- attribute access
+ # is built into threading.local.
+
+ def __getitem__(self, key):
+ return self.__dict__[key]
+
+ def __setitem__(self, key, value):
+ self.__dict__[key] = value
+
+ def __delitem__(self, key):
+ del self.__dict__[key]
+
+ def __contains__(self, key):
+ return key in self.__dict__
+
+ has_key = __contains__
+
+ def clear(self):
+ self.__dict__.clear()
+
+ def copy(self):
+ return self.__dict__.copy()
+
+ def get(self, key, default=None):
+ return self.__dict__.get(key, default)
+
+ def items(self):
+ return self.__dict__.items()
+
+ def iteritems(self):
+ return self.__dict__.iteritems()
+
+ def keys(self):
+ return self.__dict__.keys()
+
+ def iterkeys(self):
+ return self.__dict__.iterkeys()
+
+ iter = iterkeys
+
+ def values(self):
+ return self.__dict__.values()
+
+ def itervalues(self):
+ return self.__dict__.itervalues()
+
+ def pop(self, key, *args):
+ return self.__dict__.pop(key, *args)
+
+ def popitem(self):
+ return self.__dict__.popitem()
+
+ def setdefault(self, key, default=None):
+ return self.__dict__.setdefault(key, default)
+
+ def update(self, *args, **kwargs):
+ self.__dict__.update(*args, **kwargs)
+
+ def __repr__(self):
+ return '' % self.__dict__
+
+ __str__ = __repr__
+
+threadeddict = ThreadedDict
+
+def autoassign(self, locals):
+ """
+ Automatically assigns local variables to `self`.
+
+ >>> self = storage()
+ >>> autoassign(self, dict(a=1, b=2))
+ >>> self
+
+
+ Generally used in `__init__` methods, as in:
+
+ def __init__(self, foo, bar, baz=1): autoassign(self, locals())
+ """
+ for (key, value) in locals.iteritems():
+ if key == 'self':
+ continue
+ setattr(self, key, value)
+
+def to36(q):
+ """
+ Converts an integer to base 36 (a useful scheme for human-sayable IDs).
+
+ >>> to36(35)
+ 'z'
+ >>> to36(119292)
+ '2k1o'
+ >>> int(to36(939387374), 36)
+ 939387374
+ >>> to36(0)
+ '0'
+ >>> to36(-393)
+ Traceback (most recent call last):
+ ...
+ ValueError: must supply a positive integer
+
+ """
+ if q < 0: raise ValueError, "must supply a positive integer"
+ letters = "0123456789abcdefghijklmnopqrstuvwxyz"
+ converted = []
+ while q != 0:
+ q, r = divmod(q, 36)
+ converted.insert(0, letters[r])
+ return "".join(converted) or '0'
+
+
+r_url = re_compile('(?', text)
+ text = markdown(text)
+ return text
+
+def sendmail(from_address, to_address, subject, message, headers=None, **kw):
+ """
+ Sends the email message `message` with mail and envelope headers
+ for from `from_address_` to `to_address` with `subject`.
+ Additional email headers can be specified with the dictionary
+ `headers.
+
+ Optionally cc, bcc and attachments can be specified as keyword arguments.
+ Attachments must be an iterable and each attachment can be either a
+ filename or a file object or a dictionary with filename, content and
+ optionally content_type keys.
+
+ If `web.config.smtp_server` is set, it will send the message
+ to that SMTP server. Otherwise it will look for
+ `/usr/sbin/sendmail`, the typical location for the sendmail-style
+ binary. To use sendmail from a different path, set `web.config.sendmail_path`.
+ """
+ attachments = kw.pop("attachments", [])
+ mail = _EmailMessage(from_address, to_address, subject, message, headers, **kw)
+
+ for a in attachments:
+ if isinstance(a, dict):
+ mail.attach(a['filename'], a['content'], a.get('content_type'))
+ elif hasattr(a, 'read'): # file
+ filename = os.path.basename(getattr(a, "name", ""))
+ content_type = getattr(a, 'content_type', None)
+ mail.attach(filename, a.read(), content_type)
+ elif isinstance(a, basestring):
+ f = open(a, 'rb')
+ content = f.read()
+ f.close()
+ filename = os.path.basename(a)
+ mail.attach(filename, content, None)
+ else:
+ raise ValueError, "Invalid attachment: %s" % repr(a)
+
+ mail.send()
+
+class _EmailMessage:
+ def __init__(self, from_address, to_address, subject, message, headers=None, **kw):
+ def listify(x):
+ if not isinstance(x, list):
+ return [safestr(x)]
+ else:
+ return [safestr(a) for a in x]
+
+ subject = safestr(subject)
+ message = safestr(message)
+
+ from_address = safestr(from_address)
+ to_address = listify(to_address)
+ cc = listify(kw.get('cc', []))
+ bcc = listify(kw.get('bcc', []))
+ recipients = to_address + cc + bcc
+
+ import email.Utils
+ self.from_address = email.Utils.parseaddr(from_address)[1]
+ self.recipients = [email.Utils.parseaddr(r)[1] for r in recipients]
+
+ self.headers = dictadd({
+ 'From': from_address,
+ 'To': ", ".join(to_address),
+ 'Subject': subject
+ }, headers or {})
+
+ if cc:
+ self.headers['Cc'] = ", ".join(cc)
+
+ self.message = self.new_message()
+ self.message.add_header("Content-Transfer-Encoding", "7bit")
+ self.message.add_header("Content-Disposition", "inline")
+ self.message.add_header("MIME-Version", "1.0")
+ self.message.set_payload(message, 'utf-8')
+ self.multipart = False
+
+ def new_message(self):
+ from email.Message import Message
+ return Message()
+
+ def attach(self, filename, content, content_type=None):
+ if not self.multipart:
+ msg = self.new_message()
+ msg.add_header("Content-Type", "multipart/mixed")
+ msg.attach(self.message)
+ self.message = msg
+ self.multipart = True
+
+ import mimetypes
+ try:
+ from email import encoders
+ except:
+ from email import Encoders as encoders
+
+ content_type = content_type or mimetypes.guess_type(filename)[0] or "applcation/octet-stream"
+
+ msg = self.new_message()
+ msg.set_payload(content)
+ msg.add_header('Content-Type', content_type)
+ msg.add_header('Content-Disposition', 'attachment', filename=filename)
+
+ if not content_type.startswith("text/"):
+ encoders.encode_base64(msg)
+
+ self.message.attach(msg)
+
+ def prepare_message(self):
+ for k, v in self.headers.iteritems():
+ if k.lower() == "content-type":
+ self.message.set_type(v)
+ else:
+ self.message.add_header(k, v)
+
+ self.headers = {}
+
+ def send(self):
+ try:
+ import webapi
+ except ImportError:
+ webapi = Storage(config=Storage())
+
+ self.prepare_message()
+ message_text = self.message.as_string()
+
+ if webapi.config.get('smtp_server'):
+ server = webapi.config.get('smtp_server')
+ port = webapi.config.get('smtp_port', 0)
+ username = webapi.config.get('smtp_username')
+ password = webapi.config.get('smtp_password')
+ debug_level = webapi.config.get('smtp_debuglevel', None)
+ starttls = webapi.config.get('smtp_starttls', False)
+
+ import smtplib
+ smtpserver = smtplib.SMTP(server, port)
+
+ if debug_level:
+ smtpserver.set_debuglevel(debug_level)
+
+ if starttls:
+ smtpserver.ehlo()
+ smtpserver.starttls()
+ smtpserver.ehlo()
+
+ if username and password:
+ smtpserver.login(username, password)
+
+ smtpserver.sendmail(self.from_address, self.recipients, message_text)
+ smtpserver.quit()
+ elif webapi.config.get('email_engine') == 'aws':
+ import boto.ses
+ c = boto.ses.SESConnection(
+ aws_access_key_id=webapi.config.get('aws_access_key_id'),
+ aws_secret_access_key=web.api.config.get('aws_secret_access_key'))
+ c.send_raw_email(self.from_address, message_text, self.from_recipients)
+ else:
+ sendmail = webapi.config.get('sendmail_path', '/usr/sbin/sendmail')
+
+ assert not self.from_address.startswith('-'), 'security'
+ for r in self.recipients:
+ assert not r.startswith('-'), 'security'
+
+ cmd = [sendmail, '-f', self.from_address] + self.recipients
+
+ if subprocess:
+ p = subprocess.Popen(cmd, stdin=subprocess.PIPE)
+ p.stdin.write(message_text)
+ p.stdin.close()
+ p.wait()
+ else:
+ i, o = os.popen2(cmd)
+ i.write(message)
+ i.close()
+ o.close()
+ del i, o
+
+ def __repr__(self):
+ return ""
+
+ def __str__(self):
+ return self.message.as_string()
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/web/webapi.py b/web/webapi.py
index 7a233cc..2e5b10f 100644
--- a/web/webapi.py
+++ b/web/webapi.py
@@ -1,525 +1,525 @@
-"""
-Web API (wrapper around WSGI)
-(from web.py)
-"""
-
-__all__ = [
- "config",
- "header", "debug",
- "input", "data",
- "setcookie", "cookies",
- "ctx",
- "HTTPError",
-
- # 200, 201, 202
- "OK", "Created", "Accepted",
- "ok", "created", "accepted",
-
- # 301, 302, 303, 304, 307
- "Redirect", "Found", "SeeOther", "NotModified", "TempRedirect",
- "redirect", "found", "seeother", "notmodified", "tempredirect",
-
- # 400, 401, 403, 404, 405, 406, 409, 410, 412, 415
- "BadRequest", "Unauthorized", "Forbidden", "NotFound", "NoMethod", "NotAcceptable", "Conflict", "Gone", "PreconditionFailed", "UnsupportedMediaType",
- "badrequest", "unauthorized", "forbidden", "notfound", "nomethod", "notacceptable", "conflict", "gone", "preconditionfailed", "unsupportedmediatype",
-
- # 500
- "InternalError",
- "internalerror",
-]
-
-import sys, cgi, Cookie, pprint, urlparse, urllib
-from utils import storage, storify, threadeddict, dictadd, intget, safestr
-
-config = storage()
-config.__doc__ = """
-A configuration object for various aspects of web.py.
-
-`debug`
- : when True, enables reloading, disabled template caching and sets internalerror to debugerror.
-"""
-
-class HTTPError(Exception):
- def __init__(self, status, headers={}, data=""):
- ctx.status = status
- for k, v in headers.items():
- header(k, v)
- self.data = data
- Exception.__init__(self, status)
-
-def _status_code(status, data=None, classname=None, docstring=None):
- if data is None:
- data = status.split(" ", 1)[1]
- classname = status.split(" ", 1)[1].replace(' ', '') # 304 Not Modified -> NotModified
- docstring = docstring or '`%s` status' % status
-
- def __init__(self, data=data, headers={}):
- HTTPError.__init__(self, status, headers, data)
-
- # trick to create class dynamically with dynamic docstring.
- return type(classname, (HTTPError, object), {
- '__doc__': docstring,
- '__init__': __init__
- })
-
-ok = OK = _status_code("200 OK", data="")
-created = Created = _status_code("201 Created")
-accepted = Accepted = _status_code("202 Accepted")
-
-class Redirect(HTTPError):
- """A `301 Moved Permanently` redirect."""
- def __init__(self, url, status='301 Moved Permanently', absolute=False):
- """
- Returns a `status` redirect to the new URL.
- `url` is joined with the base URL so that things like
- `redirect("about") will work properly.
- """
- newloc = urlparse.urljoin(ctx.path, url)
-
- if newloc.startswith('/'):
- if absolute:
- home = ctx.realhome
- else:
- home = ctx.home
- newloc = home + newloc
-
- headers = {
- 'Content-Type': 'text/html',
- 'Location': newloc
- }
- HTTPError.__init__(self, status, headers, "")
-
-redirect = Redirect
-
-class Found(Redirect):
- """A `302 Found` redirect."""
- def __init__(self, url, absolute=False):
- Redirect.__init__(self, url, '302 Found', absolute=absolute)
-
-found = Found
-
-class SeeOther(Redirect):
- """A `303 See Other` redirect."""
- def __init__(self, url, absolute=False):
- Redirect.__init__(self, url, '303 See Other', absolute=absolute)
-
-seeother = SeeOther
-
-class NotModified(HTTPError):
- """A `304 Not Modified` status."""
- def __init__(self):
- HTTPError.__init__(self, "304 Not Modified")
-
-notmodified = NotModified
-
-class TempRedirect(Redirect):
- """A `307 Temporary Redirect` redirect."""
- def __init__(self, url, absolute=False):
- Redirect.__init__(self, url, '307 Temporary Redirect', absolute=absolute)
-
-tempredirect = TempRedirect
-
-class BadRequest(HTTPError):
- """`400 Bad Request` error."""
- message = "bad request"
- def __init__(self, message=None):
- status = "400 Bad Request"
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, message or self.message)
-
-badrequest = BadRequest
-
-class Unauthorized(HTTPError):
- """`401 Unauthorized` error."""
- message = "unauthorized"
- def __init__(self):
- status = "401 Unauthorized"
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, self.message)
-
-unauthorized = Unauthorized
-
-class Forbidden(HTTPError):
- """`403 Forbidden` error."""
- message = "forbidden"
- def __init__(self):
- status = "403 Forbidden"
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, self.message)
-
-forbidden = Forbidden
-
-class _NotFound(HTTPError):
- """`404 Not Found` error."""
- message = "not found"
- def __init__(self, message=None):
- status = '404 Not Found'
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, message or self.message)
-
-def NotFound(message=None):
- """Returns HTTPError with '404 Not Found' error from the active application.
- """
- if message:
- return _NotFound(message)
- elif ctx.get('app_stack'):
- return ctx.app_stack[-1].notfound()
- else:
- return _NotFound()
-
-notfound = NotFound
-
-class NoMethod(HTTPError):
- """A `405 Method Not Allowed` error."""
- def __init__(self, cls=None):
- status = '405 Method Not Allowed'
- headers = {}
- headers['Content-Type'] = 'text/html'
-
- methods = ['GET', 'HEAD', 'POST', 'PUT', 'DELETE']
- if cls:
- methods = [method for method in methods if hasattr(cls, method)]
-
- headers['Allow'] = ', '.join(methods)
- data = None
- HTTPError.__init__(self, status, headers, data)
-
-nomethod = NoMethod
-
-class NotAcceptable(HTTPError):
- """`406 Not Acceptable` error."""
- message = "not acceptable"
- def __init__(self):
- status = "406 Not Acceptable"
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, self.message)
-
-notacceptable = NotAcceptable
-
-class Conflict(HTTPError):
- """`409 Conflict` error."""
- message = "conflict"
- def __init__(self):
- status = "409 Conflict"
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, self.message)
-
-conflict = Conflict
-
-class Gone(HTTPError):
- """`410 Gone` error."""
- message = "gone"
- def __init__(self):
- status = '410 Gone'
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, self.message)
-
-gone = Gone
-
-class PreconditionFailed(HTTPError):
- """`412 Precondition Failed` error."""
- message = "precondition failed"
- def __init__(self):
- status = "412 Precondition Failed"
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, self.message)
-
-preconditionfailed = PreconditionFailed
-
-class UnsupportedMediaType(HTTPError):
- """`415 Unsupported Media Type` error."""
- message = "unsupported media type"
- def __init__(self):
- status = "415 Unsupported Media Type"
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, self.message)
-
-unsupportedmediatype = UnsupportedMediaType
-
-class _InternalError(HTTPError):
- """500 Internal Server Error`."""
- message = "internal server error"
-
- def __init__(self, message=None):
- status = '500 Internal Server Error'
- headers = {'Content-Type': 'text/html'}
- HTTPError.__init__(self, status, headers, message or self.message)
-
-def InternalError(message=None):
- """Returns HTTPError with '500 internal error' error from the active application.
- """
- if message:
- return _InternalError(message)
- elif ctx.get('app_stack'):
- return ctx.app_stack[-1].internalerror()
- else:
- return _InternalError()
-
-internalerror = InternalError
-
-def header(hdr, value, unique=False):
- """
- Adds the header `hdr: value` with the response.
-
- If `unique` is True and a header with that name already exists,
- it doesn't add a new one.
- """
- hdr, value = safestr(hdr), safestr(value)
- # protection against HTTP response splitting attack
- if '\n' in hdr or '\r' in hdr or '\n' in value or '\r' in value:
- raise ValueError, 'invalid characters in header'
-
- if unique is True:
- for h, v in ctx.headers:
- if h.lower() == hdr.lower(): return
-
- ctx.headers.append((hdr, value))
-
-def rawinput(method=None):
- """Returns storage object with GET or POST arguments.
- """
- method = method or "both"
- from cStringIO import StringIO
-
- def dictify(fs):
- # hack to make web.input work with enctype='text/plain.
- if fs.list is None:
- fs.list = []
-
- return dict([(k, fs[k]) for k in fs.keys()])
-
- e = ctx.env.copy()
- a = b = {}
-
- if method.lower() in ['both', 'post', 'put']:
- if e['REQUEST_METHOD'] in ['POST', 'PUT']:
- if e.get('CONTENT_TYPE', '').lower().startswith('multipart/'):
- # since wsgi.input is directly passed to cgi.FieldStorage,
- # it can not be called multiple times. Saving the FieldStorage
- # object in ctx to allow calling web.input multiple times.
- a = ctx.get('_fieldstorage')
- if not a:
- fp = e['wsgi.input']
- a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1)
- ctx._fieldstorage = a
- else:
- fp = StringIO(data())
- a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1)
- a = dictify(a)
-
- if method.lower() in ['both', 'get']:
- e['REQUEST_METHOD'] = 'GET'
- b = dictify(cgi.FieldStorage(environ=e, keep_blank_values=1))
-
- def process_fieldstorage(fs):
- if isinstance(fs, list):
- return [process_fieldstorage(x) for x in fs]
- elif fs.filename is None:
- return fs.value
- else:
- return fs
-
- return storage([(k, process_fieldstorage(v)) for k, v in dictadd(b, a).items()])
-
-def input(*requireds, **defaults):
- """
- Returns a `storage` object with the GET and POST arguments.
- See `storify` for how `requireds` and `defaults` work.
- """
- _method = defaults.pop('_method', 'both')
- out = rawinput(_method)
- try:
- defaults.setdefault('_unicode', True) # force unicode conversion by default.
- return storify(out, *requireds, **defaults)
- except KeyError:
- raise badrequest()
-
-def data():
- """Returns the data sent with the request."""
- if 'data' not in ctx:
- cl = intget(ctx.env.get('CONTENT_LENGTH'), 0)
- ctx.data = ctx.env['wsgi.input'].read(cl)
- return ctx.data
-
-def setcookie(name, value, expires='', domain=None,
- secure=False, httponly=False, path=None):
- """Sets a cookie."""
- morsel = Cookie.Morsel()
- name, value = safestr(name), safestr(value)
- morsel.set(name, value, urllib.quote(value))
- if expires < 0:
- expires = -1000000000
- morsel['expires'] = expires
- morsel['path'] = path or ctx.homepath+'/'
- if domain:
- morsel['domain'] = domain
- if secure:
- morsel['secure'] = secure
- value = morsel.OutputString()
- if httponly:
- value += '; httponly'
- header('Set-Cookie', value)
-
-def decode_cookie(value):
- r"""Safely decodes a cookie value to unicode.
-
- Tries us-ascii, utf-8 and io8859 encodings, in that order.
-
- >>> decode_cookie('')
- u''
- >>> decode_cookie('asdf')
- u'asdf'
- >>> decode_cookie('foo \xC3\xA9 bar')
- u'foo \xe9 bar'
- >>> decode_cookie('foo \xE9 bar')
- u'foo \xe9 bar'
- """
- try:
- # First try plain ASCII encoding
- return unicode(value, 'us-ascii')
- except UnicodeError:
- # Then try UTF-8, and if that fails, ISO8859
- try:
- return unicode(value, 'utf-8')
- except UnicodeError:
- return unicode(value, 'iso8859', 'ignore')
-
-def parse_cookies(http_cookie):
- r"""Parse a HTTP_COOKIE header and return dict of cookie names and decoded values.
-
- >>> sorted(parse_cookies('').items())
- []
- >>> sorted(parse_cookies('a=1').items())
- [('a', '1')]
- >>> sorted(parse_cookies('a=1%202').items())
- [('a', '1 2')]
- >>> sorted(parse_cookies('a=Z%C3%A9Z').items())
- [('a', 'Z\xc3\xa9Z')]
- >>> sorted(parse_cookies('a=1; b=2; c=3').items())
- [('a', '1'), ('b', '2'), ('c', '3')]
- >>> sorted(parse_cookies('a=1; b=w("x")|y=z; c=3').items())
- [('a', '1'), ('b', 'w('), ('c', '3')]
- >>> sorted(parse_cookies('a=1; b=w(%22x%22)|y=z; c=3').items())
- [('a', '1'), ('b', 'w("x")|y=z'), ('c', '3')]
-
- >>> sorted(parse_cookies('keebler=E=mc2').items())
- [('keebler', 'E=mc2')]
- >>> sorted(parse_cookies(r'keebler="E=mc2; L=\"Loves\"; fudge=\012;"').items())
- [('keebler', 'E=mc2; L="Loves"; fudge=\n;')]
- """
- #print "parse_cookies"
- if '"' in http_cookie:
- # HTTP_COOKIE has quotes in it, use slow but correct cookie parsing
- cookie = Cookie.SimpleCookie()
- try:
- cookie.load(http_cookie)
- except Cookie.CookieError:
- # If HTTP_COOKIE header is malformed, try at least to load the cookies we can by
- # first splitting on ';' and loading each attr=value pair separately
- cookie = Cookie.SimpleCookie()
- for attr_value in http_cookie.split(';'):
- try:
- cookie.load(attr_value)
- except Cookie.CookieError:
- pass
- cookies = dict((k, urllib.unquote(v.value)) for k, v in cookie.iteritems())
- else:
- # HTTP_COOKIE doesn't have quotes, use fast cookie parsing
- cookies = {}
- for key_value in http_cookie.split(';'):
- key_value = key_value.split('=', 1)
- if len(key_value) == 2:
- key, value = key_value
- cookies[key.strip()] = urllib.unquote(value.strip())
- return cookies
-
-def cookies(*requireds, **defaults):
- r"""Returns a `storage` object with all the request cookies in it.
-
- See `storify` for how `requireds` and `defaults` work.
-
- This is forgiving on bad HTTP_COOKIE input, it tries to parse at least
- the cookies it can.
-
- The values are converted to unicode if _unicode=True is passed.
- """
- # If _unicode=True is specified, use decode_cookie to convert cookie value to unicode
- if defaults.get("_unicode") is True:
- defaults['_unicode'] = decode_cookie
-
- # parse cookie string and cache the result for next time.
- if '_parsed_cookies' not in ctx:
- http_cookie = ctx.env.get("HTTP_COOKIE", "")
- ctx._parsed_cookies = parse_cookies(http_cookie)
-
- try:
- return storify(ctx._parsed_cookies, *requireds, **defaults)
- except KeyError:
- badrequest()
- raise StopIteration
-
-def debug(*args):
- """
- Prints a prettyprinted version of `args` to stderr.
- """
- try:
- out = ctx.environ['wsgi.errors']
- except:
- out = sys.stderr
- for arg in args:
- print >> out, pprint.pformat(arg)
- return ''
-
-def _debugwrite(x):
- try:
- out = ctx.environ['wsgi.errors']
- except:
- out = sys.stderr
- out.write(x)
-debug.write = _debugwrite
-
-ctx = context = threadeddict()
-
-ctx.__doc__ = """
-A `storage` object containing various information about the request:
-
-`environ` (aka `env`)
- : A dictionary containing the standard WSGI environment variables.
-
-`host`
- : The domain (`Host` header) requested by the user.
-
-`home`
- : The base path for the application.
-
-`ip`
- : The IP address of the requester.
-
-`method`
- : The HTTP method used.
-
-`path`
- : The path request.
-
-`query`
- : If there are no query arguments, the empty string. Otherwise, a `?` followed
- by the query string.
-
-`fullpath`
- : The full path requested, including query arguments (`== path + query`).
-
-### Response Data
-
-`status` (default: "200 OK")
- : The status code to be used in the response.
-
-`headers`
- : A list of 2-tuples to be used in the response.
-
-`output`
- : A string to be used as the response.
-"""
-
-if __name__ == "__main__":
- import doctest
+"""
+Web API (wrapper around WSGI)
+(from web.py)
+"""
+
+__all__ = [
+ "config",
+ "header", "debug",
+ "input", "data",
+ "setcookie", "cookies",
+ "ctx",
+ "HTTPError",
+
+ # 200, 201, 202
+ "OK", "Created", "Accepted",
+ "ok", "created", "accepted",
+
+ # 301, 302, 303, 304, 307
+ "Redirect", "Found", "SeeOther", "NotModified", "TempRedirect",
+ "redirect", "found", "seeother", "notmodified", "tempredirect",
+
+ # 400, 401, 403, 404, 405, 406, 409, 410, 412, 415
+ "BadRequest", "Unauthorized", "Forbidden", "NotFound", "NoMethod", "NotAcceptable", "Conflict", "Gone", "PreconditionFailed", "UnsupportedMediaType",
+ "badrequest", "unauthorized", "forbidden", "notfound", "nomethod", "notacceptable", "conflict", "gone", "preconditionfailed", "unsupportedmediatype",
+
+ # 500
+ "InternalError",
+ "internalerror",
+]
+
+import sys, cgi, Cookie, pprint, urlparse, urllib
+from utils import storage, storify, threadeddict, dictadd, intget, safestr
+
+config = storage()
+config.__doc__ = """
+A configuration object for various aspects of web.py.
+
+`debug`
+ : when True, enables reloading, disabled template caching and sets internalerror to debugerror.
+"""
+
+class HTTPError(Exception):
+ def __init__(self, status, headers={}, data=""):
+ ctx.status = status
+ for k, v in headers.items():
+ header(k, v)
+ self.data = data
+ Exception.__init__(self, status)
+
+def _status_code(status, data=None, classname=None, docstring=None):
+ if data is None:
+ data = status.split(" ", 1)[1]
+ classname = status.split(" ", 1)[1].replace(' ', '') # 304 Not Modified -> NotModified
+ docstring = docstring or '`%s` status' % status
+
+ def __init__(self, data=data, headers={}):
+ HTTPError.__init__(self, status, headers, data)
+
+ # trick to create class dynamically with dynamic docstring.
+ return type(classname, (HTTPError, object), {
+ '__doc__': docstring,
+ '__init__': __init__
+ })
+
+ok = OK = _status_code("200 OK", data="")
+created = Created = _status_code("201 Created")
+accepted = Accepted = _status_code("202 Accepted")
+
+class Redirect(HTTPError):
+ """A `301 Moved Permanently` redirect."""
+ def __init__(self, url, status='301 Moved Permanently', absolute=False):
+ """
+ Returns a `status` redirect to the new URL.
+ `url` is joined with the base URL so that things like
+ `redirect("about") will work properly.
+ """
+ newloc = urlparse.urljoin(ctx.path, url)
+
+ if newloc.startswith('/'):
+ if absolute:
+ home = ctx.realhome
+ else:
+ home = ctx.home
+ newloc = home + newloc
+
+ headers = {
+ 'Content-Type': 'text/html',
+ 'Location': newloc
+ }
+ HTTPError.__init__(self, status, headers, "")
+
+redirect = Redirect
+
+class Found(Redirect):
+ """A `302 Found` redirect."""
+ def __init__(self, url, absolute=False):
+ Redirect.__init__(self, url, '302 Found', absolute=absolute)
+
+found = Found
+
+class SeeOther(Redirect):
+ """A `303 See Other` redirect."""
+ def __init__(self, url, absolute=False):
+ Redirect.__init__(self, url, '303 See Other', absolute=absolute)
+
+seeother = SeeOther
+
+class NotModified(HTTPError):
+ """A `304 Not Modified` status."""
+ def __init__(self):
+ HTTPError.__init__(self, "304 Not Modified")
+
+notmodified = NotModified
+
+class TempRedirect(Redirect):
+ """A `307 Temporary Redirect` redirect."""
+ def __init__(self, url, absolute=False):
+ Redirect.__init__(self, url, '307 Temporary Redirect', absolute=absolute)
+
+tempredirect = TempRedirect
+
+class BadRequest(HTTPError):
+ """`400 Bad Request` error."""
+ message = "bad request"
+ def __init__(self, message=None):
+ status = "400 Bad Request"
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, message or self.message)
+
+badrequest = BadRequest
+
+class Unauthorized(HTTPError):
+ """`401 Unauthorized` error."""
+ message = "unauthorized"
+ def __init__(self):
+ status = "401 Unauthorized"
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, self.message)
+
+unauthorized = Unauthorized
+
+class Forbidden(HTTPError):
+ """`403 Forbidden` error."""
+ message = "forbidden"
+ def __init__(self):
+ status = "403 Forbidden"
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, self.message)
+
+forbidden = Forbidden
+
+class _NotFound(HTTPError):
+ """`404 Not Found` error."""
+ message = "not found"
+ def __init__(self, message=None):
+ status = '404 Not Found'
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, message or self.message)
+
+def NotFound(message=None):
+ """Returns HTTPError with '404 Not Found' error from the active application.
+ """
+ if message:
+ return _NotFound(message)
+ elif ctx.get('app_stack'):
+ return ctx.app_stack[-1].notfound()
+ else:
+ return _NotFound()
+
+notfound = NotFound
+
+class NoMethod(HTTPError):
+ """A `405 Method Not Allowed` error."""
+ def __init__(self, cls=None):
+ status = '405 Method Not Allowed'
+ headers = {}
+ headers['Content-Type'] = 'text/html'
+
+ methods = ['GET', 'HEAD', 'POST', 'PUT', 'DELETE']
+ if cls:
+ methods = [method for method in methods if hasattr(cls, method)]
+
+ headers['Allow'] = ', '.join(methods)
+ data = None
+ HTTPError.__init__(self, status, headers, data)
+
+nomethod = NoMethod
+
+class NotAcceptable(HTTPError):
+ """`406 Not Acceptable` error."""
+ message = "not acceptable"
+ def __init__(self):
+ status = "406 Not Acceptable"
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, self.message)
+
+notacceptable = NotAcceptable
+
+class Conflict(HTTPError):
+ """`409 Conflict` error."""
+ message = "conflict"
+ def __init__(self):
+ status = "409 Conflict"
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, self.message)
+
+conflict = Conflict
+
+class Gone(HTTPError):
+ """`410 Gone` error."""
+ message = "gone"
+ def __init__(self):
+ status = '410 Gone'
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, self.message)
+
+gone = Gone
+
+class PreconditionFailed(HTTPError):
+ """`412 Precondition Failed` error."""
+ message = "precondition failed"
+ def __init__(self):
+ status = "412 Precondition Failed"
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, self.message)
+
+preconditionfailed = PreconditionFailed
+
+class UnsupportedMediaType(HTTPError):
+ """`415 Unsupported Media Type` error."""
+ message = "unsupported media type"
+ def __init__(self):
+ status = "415 Unsupported Media Type"
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, self.message)
+
+unsupportedmediatype = UnsupportedMediaType
+
+class _InternalError(HTTPError):
+ """500 Internal Server Error`."""
+ message = "internal server error"
+
+ def __init__(self, message=None):
+ status = '500 Internal Server Error'
+ headers = {'Content-Type': 'text/html'}
+ HTTPError.__init__(self, status, headers, message or self.message)
+
+def InternalError(message=None):
+ """Returns HTTPError with '500 internal error' error from the active application.
+ """
+ if message:
+ return _InternalError(message)
+ elif ctx.get('app_stack'):
+ return ctx.app_stack[-1].internalerror()
+ else:
+ return _InternalError()
+
+internalerror = InternalError
+
+def header(hdr, value, unique=False):
+ """
+ Adds the header `hdr: value` with the response.
+
+ If `unique` is True and a header with that name already exists,
+ it doesn't add a new one.
+ """
+ hdr, value = safestr(hdr), safestr(value)
+ # protection against HTTP response splitting attack
+ if '\n' in hdr or '\r' in hdr or '\n' in value or '\r' in value:
+ raise ValueError, 'invalid characters in header'
+
+ if unique is True:
+ for h, v in ctx.headers:
+ if h.lower() == hdr.lower(): return
+
+ ctx.headers.append((hdr, value))
+
+def rawinput(method=None):
+ """Returns storage object with GET or POST arguments.
+ """
+ method = method or "both"
+ from cStringIO import StringIO
+
+ def dictify(fs):
+ # hack to make web.input work with enctype='text/plain.
+ if fs.list is None:
+ fs.list = []
+
+ return dict([(k, fs[k]) for k in fs.keys()])
+
+ e = ctx.env.copy()
+ a = b = {}
+
+ if method.lower() in ['both', 'post', 'put']:
+ if e['REQUEST_METHOD'] in ['POST', 'PUT']:
+ if e.get('CONTENT_TYPE', '').lower().startswith('multipart/'):
+ # since wsgi.input is directly passed to cgi.FieldStorage,
+ # it can not be called multiple times. Saving the FieldStorage
+ # object in ctx to allow calling web.input multiple times.
+ a = ctx.get('_fieldstorage')
+ if not a:
+ fp = e['wsgi.input']
+ a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1)
+ ctx._fieldstorage = a
+ else:
+ fp = StringIO(data())
+ a = cgi.FieldStorage(fp=fp, environ=e, keep_blank_values=1)
+ a = dictify(a)
+
+ if method.lower() in ['both', 'get']:
+ e['REQUEST_METHOD'] = 'GET'
+ b = dictify(cgi.FieldStorage(environ=e, keep_blank_values=1))
+
+ def process_fieldstorage(fs):
+ if isinstance(fs, list):
+ return [process_fieldstorage(x) for x in fs]
+ elif fs.filename is None:
+ return fs.value
+ else:
+ return fs
+
+ return storage([(k, process_fieldstorage(v)) for k, v in dictadd(b, a).items()])
+
+def input(*requireds, **defaults):
+ """
+ Returns a `storage` object with the GET and POST arguments.
+ See `storify` for how `requireds` and `defaults` work.
+ """
+ _method = defaults.pop('_method', 'both')
+ out = rawinput(_method)
+ try:
+ defaults.setdefault('_unicode', True) # force unicode conversion by default.
+ return storify(out, *requireds, **defaults)
+ except KeyError:
+ raise badrequest()
+
+def data():
+ """Returns the data sent with the request."""
+ if 'data' not in ctx:
+ cl = intget(ctx.env.get('CONTENT_LENGTH'), 0)
+ ctx.data = ctx.env['wsgi.input'].read(cl)
+ return ctx.data
+
+def setcookie(name, value, expires='', domain=None,
+ secure=False, httponly=False, path=None):
+ """Sets a cookie."""
+ morsel = Cookie.Morsel()
+ name, value = safestr(name), safestr(value)
+ morsel.set(name, value, urllib.quote(value))
+ if expires < 0:
+ expires = -1000000000
+ morsel['expires'] = expires
+ morsel['path'] = path or ctx.homepath+'/'
+ if domain:
+ morsel['domain'] = domain
+ if secure:
+ morsel['secure'] = secure
+ value = morsel.OutputString()
+ if httponly:
+ value += '; httponly'
+ header('Set-Cookie', value)
+
+def decode_cookie(value):
+ r"""Safely decodes a cookie value to unicode.
+
+ Tries us-ascii, utf-8 and io8859 encodings, in that order.
+
+ >>> decode_cookie('')
+ u''
+ >>> decode_cookie('asdf')
+ u'asdf'
+ >>> decode_cookie('foo \xC3\xA9 bar')
+ u'foo \xe9 bar'
+ >>> decode_cookie('foo \xE9 bar')
+ u'foo \xe9 bar'
+ """
+ try:
+ # First try plain ASCII encoding
+ return unicode(value, 'us-ascii')
+ except UnicodeError:
+ # Then try UTF-8, and if that fails, ISO8859
+ try:
+ return unicode(value, 'utf-8')
+ except UnicodeError:
+ return unicode(value, 'iso8859', 'ignore')
+
+def parse_cookies(http_cookie):
+ r"""Parse a HTTP_COOKIE header and return dict of cookie names and decoded values.
+
+ >>> sorted(parse_cookies('').items())
+ []
+ >>> sorted(parse_cookies('a=1').items())
+ [('a', '1')]
+ >>> sorted(parse_cookies('a=1%202').items())
+ [('a', '1 2')]
+ >>> sorted(parse_cookies('a=Z%C3%A9Z').items())
+ [('a', 'Z\xc3\xa9Z')]
+ >>> sorted(parse_cookies('a=1; b=2; c=3').items())
+ [('a', '1'), ('b', '2'), ('c', '3')]
+ >>> sorted(parse_cookies('a=1; b=w("x")|y=z; c=3').items())
+ [('a', '1'), ('b', 'w('), ('c', '3')]
+ >>> sorted(parse_cookies('a=1; b=w(%22x%22)|y=z; c=3').items())
+ [('a', '1'), ('b', 'w("x")|y=z'), ('c', '3')]
+
+ >>> sorted(parse_cookies('keebler=E=mc2').items())
+ [('keebler', 'E=mc2')]
+ >>> sorted(parse_cookies(r'keebler="E=mc2; L=\"Loves\"; fudge=\012;"').items())
+ [('keebler', 'E=mc2; L="Loves"; fudge=\n;')]
+ """
+ #print "parse_cookies"
+ if '"' in http_cookie:
+ # HTTP_COOKIE has quotes in it, use slow but correct cookie parsing
+ cookie = Cookie.SimpleCookie()
+ try:
+ cookie.load(http_cookie)
+ except Cookie.CookieError:
+ # If HTTP_COOKIE header is malformed, try at least to load the cookies we can by
+ # first splitting on ';' and loading each attr=value pair separately
+ cookie = Cookie.SimpleCookie()
+ for attr_value in http_cookie.split(';'):
+ try:
+ cookie.load(attr_value)
+ except Cookie.CookieError:
+ pass
+ cookies = dict((k, urllib.unquote(v.value)) for k, v in cookie.iteritems())
+ else:
+ # HTTP_COOKIE doesn't have quotes, use fast cookie parsing
+ cookies = {}
+ for key_value in http_cookie.split(';'):
+ key_value = key_value.split('=', 1)
+ if len(key_value) == 2:
+ key, value = key_value
+ cookies[key.strip()] = urllib.unquote(value.strip())
+ return cookies
+
+def cookies(*requireds, **defaults):
+ r"""Returns a `storage` object with all the request cookies in it.
+
+ See `storify` for how `requireds` and `defaults` work.
+
+ This is forgiving on bad HTTP_COOKIE input, it tries to parse at least
+ the cookies it can.
+
+ The values are converted to unicode if _unicode=True is passed.
+ """
+ # If _unicode=True is specified, use decode_cookie to convert cookie value to unicode
+ if defaults.get("_unicode") is True:
+ defaults['_unicode'] = decode_cookie
+
+ # parse cookie string and cache the result for next time.
+ if '_parsed_cookies' not in ctx:
+ http_cookie = ctx.env.get("HTTP_COOKIE", "")
+ ctx._parsed_cookies = parse_cookies(http_cookie)
+
+ try:
+ return storify(ctx._parsed_cookies, *requireds, **defaults)
+ except KeyError:
+ badrequest()
+ raise StopIteration
+
+def debug(*args):
+ """
+ Prints a prettyprinted version of `args` to stderr.
+ """
+ try:
+ out = ctx.environ['wsgi.errors']
+ except:
+ out = sys.stderr
+ for arg in args:
+ print >> out, pprint.pformat(arg)
+ return ''
+
+def _debugwrite(x):
+ try:
+ out = ctx.environ['wsgi.errors']
+ except:
+ out = sys.stderr
+ out.write(x)
+debug.write = _debugwrite
+
+ctx = context = threadeddict()
+
+ctx.__doc__ = """
+A `storage` object containing various information about the request:
+
+`environ` (aka `env`)
+ : A dictionary containing the standard WSGI environment variables.
+
+`host`
+ : The domain (`Host` header) requested by the user.
+
+`home`
+ : The base path for the application.
+
+`ip`
+ : The IP address of the requester.
+
+`method`
+ : The HTTP method used.
+
+`path`
+ : The path request.
+
+`query`
+ : If there are no query arguments, the empty string. Otherwise, a `?` followed
+ by the query string.
+
+`fullpath`
+ : The full path requested, including query arguments (`== path + query`).
+
+### Response Data
+
+`status` (default: "200 OK")
+ : The status code to be used in the response.
+
+`headers`
+ : A list of 2-tuples to be used in the response.
+
+`output`
+ : A string to be used as the response.
+"""
+
+if __name__ == "__main__":
+ import doctest
doctest.testmod()
\ No newline at end of file
diff --git a/web/webopenid.py b/web/webopenid.py
index b482216..28d89f4 100644
--- a/web/webopenid.py
+++ b/web/webopenid.py
@@ -1,115 +1,115 @@
-"""openid.py: an openid library for web.py
-
-Notes:
-
- - This will create a file called .openid_secret_key in the
- current directory with your secret key in it. If someone
- has access to this file they can log in as any user. And
- if the app can't find this file for any reason (e.g. you
- moved the app somewhere else) then each currently logged
- in user will get logged out.
-
- - State must be maintained through the entire auth process
- -- this means that if you have multiple web.py processes
- serving one set of URLs or if you restart your app often
- then log ins will fail. You have to replace sessions and
- store for things to work.
-
- - We set cookies starting with "openid_".
-
-"""
-
-import os
-import random
-import hmac
-import __init__ as web
-import openid.consumer.consumer
-import openid.store.memstore
-
-sessions = {}
-store = openid.store.memstore.MemoryStore()
-
-def _secret():
- try:
- secret = file('.openid_secret_key').read()
- except IOError:
- # file doesn't exist
- secret = os.urandom(20)
- file('.openid_secret_key', 'w').write(secret)
- return secret
-
-def _hmac(identity_url):
- return hmac.new(_secret(), identity_url).hexdigest()
-
-def _random_session():
- n = random.random()
- while n in sessions:
- n = random.random()
- n = str(n)
- return n
-
-def status():
- oid_hash = web.cookies().get('openid_identity_hash', '').split(',', 1)
- if len(oid_hash) > 1:
- oid_hash, identity_url = oid_hash
- if oid_hash == _hmac(identity_url):
- return identity_url
- return None
-
-def form(openid_loc):
- oid = status()
- if oid:
- return '''
- ''' % (openid_loc, oid, web.ctx.fullpath)
- else:
- return '''
- ''' % (openid_loc, web.ctx.fullpath)
-
-def logout():
- web.setcookie('openid_identity_hash', '', expires=-1)
-
-class host:
- def POST(self):
- # unlike the usual scheme of things, the POST is actually called
- # first here
- i = web.input(return_to='/')
- if i.get('action') == 'logout':
- logout()
- return web.redirect(i.return_to)
-
- i = web.input('openid', return_to='/')
-
- n = _random_session()
- sessions[n] = {'webpy_return_to': i.return_to}
-
- c = openid.consumer.consumer.Consumer(sessions[n], store)
- a = c.begin(i.openid)
- f = a.redirectURL(web.ctx.home, web.ctx.home + web.ctx.fullpath)
-
- web.setcookie('openid_session_id', n)
- return web.redirect(f)
-
- def GET(self):
- n = web.cookies('openid_session_id').openid_session_id
- web.setcookie('openid_session_id', '', expires=-1)
- return_to = sessions[n]['webpy_return_to']
-
- c = openid.consumer.consumer.Consumer(sessions[n], store)
- a = c.complete(web.input(), web.ctx.home + web.ctx.fullpath)
-
- if a.status.lower() == 'success':
- web.setcookie('openid_identity_hash', _hmac(a.identity_url) + ',' + a.identity_url)
-
- del sessions[n]
- return web.redirect(return_to)
+"""openid.py: an openid library for web.py
+
+Notes:
+
+ - This will create a file called .openid_secret_key in the
+ current directory with your secret key in it. If someone
+ has access to this file they can log in as any user. And
+ if the app can't find this file for any reason (e.g. you
+ moved the app somewhere else) then each currently logged
+ in user will get logged out.
+
+ - State must be maintained through the entire auth process
+ -- this means that if you have multiple web.py processes
+ serving one set of URLs or if you restart your app often
+ then log ins will fail. You have to replace sessions and
+ store for things to work.
+
+ - We set cookies starting with "openid_".
+
+"""
+
+import os
+import random
+import hmac
+import __init__ as web
+import openid.consumer.consumer
+import openid.store.memstore
+
+sessions = {}
+store = openid.store.memstore.MemoryStore()
+
+def _secret():
+ try:
+ secret = file('.openid_secret_key').read()
+ except IOError:
+ # file doesn't exist
+ secret = os.urandom(20)
+ file('.openid_secret_key', 'w').write(secret)
+ return secret
+
+def _hmac(identity_url):
+ return hmac.new(_secret(), identity_url).hexdigest()
+
+def _random_session():
+ n = random.random()
+ while n in sessions:
+ n = random.random()
+ n = str(n)
+ return n
+
+def status():
+ oid_hash = web.cookies().get('openid_identity_hash', '').split(',', 1)
+ if len(oid_hash) > 1:
+ oid_hash, identity_url = oid_hash
+ if oid_hash == _hmac(identity_url):
+ return identity_url
+ return None
+
+def form(openid_loc):
+ oid = status()
+ if oid:
+ return '''
+ ''' % (openid_loc, oid, web.ctx.fullpath)
+ else:
+ return '''
+ ''' % (openid_loc, web.ctx.fullpath)
+
+def logout():
+ web.setcookie('openid_identity_hash', '', expires=-1)
+
+class host:
+ def POST(self):
+ # unlike the usual scheme of things, the POST is actually called
+ # first here
+ i = web.input(return_to='/')
+ if i.get('action') == 'logout':
+ logout()
+ return web.redirect(i.return_to)
+
+ i = web.input('openid', return_to='/')
+
+ n = _random_session()
+ sessions[n] = {'webpy_return_to': i.return_to}
+
+ c = openid.consumer.consumer.Consumer(sessions[n], store)
+ a = c.begin(i.openid)
+ f = a.redirectURL(web.ctx.home, web.ctx.home + web.ctx.fullpath)
+
+ web.setcookie('openid_session_id', n)
+ return web.redirect(f)
+
+ def GET(self):
+ n = web.cookies('openid_session_id').openid_session_id
+ web.setcookie('openid_session_id', '', expires=-1)
+ return_to = sessions[n]['webpy_return_to']
+
+ c = openid.consumer.consumer.Consumer(sessions[n], store)
+ a = c.complete(web.input(), web.ctx.home + web.ctx.fullpath)
+
+ if a.status.lower() == 'success':
+ web.setcookie('openid_identity_hash', _hmac(a.identity_url) + ',' + a.identity_url)
+
+ del sessions[n]
+ return web.redirect(return_to)
diff --git a/web/wsgi.py b/web/wsgi.py
index fe53ab2..3e509bb 100644
--- a/web/wsgi.py
+++ b/web/wsgi.py
@@ -1,70 +1,70 @@
-"""
-WSGI Utilities
-(from web.py)
-"""
-
-import os, sys
-
-import http
-import webapi as web
-from utils import listget
-from net import validaddr, validip
-import httpserver
-
-def runfcgi(func, addr=('localhost', 8000)):
- """Runs a WSGI function as a FastCGI server."""
- import flup.server.fcgi as flups
- return flups.WSGIServer(func, multiplexed=True, bindAddress=addr, debug=False).run()
-
-def runscgi(func, addr=('localhost', 4000)):
- """Runs a WSGI function as an SCGI server."""
- import flup.server.scgi as flups
- return flups.WSGIServer(func, bindAddress=addr, debug=False).run()
-
-def runwsgi(func):
- """
- Runs a WSGI-compatible `func` using FCGI, SCGI, or a simple web server,
- as appropriate based on context and `sys.argv`.
- """
-
- if os.environ.has_key('SERVER_SOFTWARE'): # cgi
- os.environ['FCGI_FORCE_CGI'] = 'Y'
-
- if (os.environ.has_key('PHP_FCGI_CHILDREN') #lighttpd fastcgi
- or os.environ.has_key('SERVER_SOFTWARE')):
- return runfcgi(func, None)
-
- if 'fcgi' in sys.argv or 'fastcgi' in sys.argv:
- args = sys.argv[1:]
- if 'fastcgi' in args: args.remove('fastcgi')
- elif 'fcgi' in args: args.remove('fcgi')
- if args:
- return runfcgi(func, validaddr(args[0]))
- else:
- return runfcgi(func, None)
-
- if 'scgi' in sys.argv:
- args = sys.argv[1:]
- args.remove('scgi')
- if args:
- return runscgi(func, validaddr(args[0]))
- else:
- return runscgi(func)
-
- return httpserver.runsimple(func, validip(listget(sys.argv, 1, '')))
-
-def _is_dev_mode():
- # Some embedded python interpreters won't have sys.arv
- # For details, see https://github.com/webpy/webpy/issues/87
- argv = getattr(sys, "argv", [])
-
- # quick hack to check if the program is running in dev mode.
- if os.environ.has_key('SERVER_SOFTWARE') \
- or os.environ.has_key('PHP_FCGI_CHILDREN') \
- or 'fcgi' in argv or 'fastcgi' in argv \
- or 'mod_wsgi' in argv:
- return False
- return True
-
-# When running the builtin-server, enable debug mode if not already set.
-web.config.setdefault('debug', _is_dev_mode())
+"""
+WSGI Utilities
+(from web.py)
+"""
+
+import os, sys
+
+import http
+import webapi as web
+from utils import listget
+from net import validaddr, validip
+import httpserver
+
+def runfcgi(func, addr=('localhost', 8000)):
+ """Runs a WSGI function as a FastCGI server."""
+ import flup.server.fcgi as flups
+ return flups.WSGIServer(func, multiplexed=True, bindAddress=addr, debug=False).run()
+
+def runscgi(func, addr=('localhost', 4000)):
+ """Runs a WSGI function as an SCGI server."""
+ import flup.server.scgi as flups
+ return flups.WSGIServer(func, bindAddress=addr, debug=False).run()
+
+def runwsgi(func):
+ """
+ Runs a WSGI-compatible `func` using FCGI, SCGI, or a simple web server,
+ as appropriate based on context and `sys.argv`.
+ """
+
+ if os.environ.has_key('SERVER_SOFTWARE'): # cgi
+ os.environ['FCGI_FORCE_CGI'] = 'Y'
+
+ if (os.environ.has_key('PHP_FCGI_CHILDREN') #lighttpd fastcgi
+ or os.environ.has_key('SERVER_SOFTWARE')):
+ return runfcgi(func, None)
+
+ if 'fcgi' in sys.argv or 'fastcgi' in sys.argv:
+ args = sys.argv[1:]
+ if 'fastcgi' in args: args.remove('fastcgi')
+ elif 'fcgi' in args: args.remove('fcgi')
+ if args:
+ return runfcgi(func, validaddr(args[0]))
+ else:
+ return runfcgi(func, None)
+
+ if 'scgi' in sys.argv:
+ args = sys.argv[1:]
+ args.remove('scgi')
+ if args:
+ return runscgi(func, validaddr(args[0]))
+ else:
+ return runscgi(func)
+
+ return httpserver.runsimple(func, validip(listget(sys.argv, 1, '')))
+
+def _is_dev_mode():
+ # Some embedded python interpreters won't have sys.arv
+ # For details, see https://github.com/webpy/webpy/issues/87
+ argv = getattr(sys, "argv", [])
+
+ # quick hack to check if the program is running in dev mode.
+ if os.environ.has_key('SERVER_SOFTWARE') \
+ or os.environ.has_key('PHP_FCGI_CHILDREN') \
+ or 'fcgi' in argv or 'fastcgi' in argv \
+ or 'mod_wsgi' in argv:
+ return False
+ return True
+
+# When running the builtin-server, enable debug mode if not already set.
+web.config.setdefault('debug', _is_dev_mode())
diff --git a/web/wsgiserver/__init__.py b/web/wsgiserver/__init__.py
index 55d1dd9..0cbfd64 100644
--- a/web/wsgiserver/__init__.py
+++ b/web/wsgiserver/__init__.py
@@ -1,2219 +1,2219 @@
-"""A high-speed, production ready, thread pooled, generic HTTP server.
-
-Simplest example on how to use this module directly
-(without using CherryPy's application machinery)::
-
- from cherrypy import wsgiserver
-
- def my_crazy_app(environ, start_response):
- status = '200 OK'
- response_headers = [('Content-type','text/plain')]
- start_response(status, response_headers)
- return ['Hello world!']
-
- server = wsgiserver.CherryPyWSGIServer(
- ('0.0.0.0', 8070), my_crazy_app,
- server_name='www.cherrypy.example')
- server.start()
-
-The CherryPy WSGI server can serve as many WSGI applications
-as you want in one instance by using a WSGIPathInfoDispatcher::
-
- d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app})
- server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d)
-
-Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance.
-
-This won't call the CherryPy engine (application side) at all, only the
-HTTP server, which is independent from the rest of CherryPy. Don't
-let the name "CherryPyWSGIServer" throw you; the name merely reflects
-its origin, not its coupling.
-
-For those of you wanting to understand internals of this module, here's the
-basic call flow. The server's listening thread runs a very tight loop,
-sticking incoming connections onto a Queue::
-
- server = CherryPyWSGIServer(...)
- server.start()
- while True:
- tick()
- # This blocks until a request comes in:
- child = socket.accept()
- conn = HTTPConnection(child, ...)
- server.requests.put(conn)
-
-Worker threads are kept in a pool and poll the Queue, popping off and then
-handling each connection in turn. Each connection can consist of an arbitrary
-number of requests and their responses, so we run a nested loop::
-
- while True:
- conn = server.requests.get()
- conn.communicate()
- -> while True:
- req = HTTPRequest(...)
- req.parse_request()
- -> # Read the Request-Line, e.g. "GET /page HTTP/1.1"
- req.rfile.readline()
- read_headers(req.rfile, req.inheaders)
- req.respond()
- -> response = app(...)
- try:
- for chunk in response:
- if chunk:
- req.write(chunk)
- finally:
- if hasattr(response, "close"):
- response.close()
- if req.close_connection:
- return
-"""
-
-CRLF = '\r\n'
-import os
-import Queue
-import re
-quoted_slash = re.compile("(?i)%2F")
-import rfc822
-import socket
-import sys
-if 'win' in sys.platform and not hasattr(socket, 'IPPROTO_IPV6'):
- socket.IPPROTO_IPV6 = 41
-try:
- import cStringIO as StringIO
-except ImportError:
- import StringIO
-DEFAULT_BUFFER_SIZE = -1
-
-_fileobject_uses_str_type = isinstance(socket._fileobject(None)._rbuf, basestring)
-
-import threading
-import time
-import traceback
-def format_exc(limit=None):
- """Like print_exc() but return a string. Backport for Python 2.3."""
- try:
- etype, value, tb = sys.exc_info()
- return ''.join(traceback.format_exception(etype, value, tb, limit))
- finally:
- etype = value = tb = None
-
-
-from urllib import unquote
-from urlparse import urlparse
-import warnings
-
-import errno
-
-def plat_specific_errors(*errnames):
- """Return error numbers for all errors in errnames on this platform.
-
- The 'errno' module contains different global constants depending on
- the specific platform (OS). This function will return the list of
- numeric values for a given list of potential names.
- """
- errno_names = dir(errno)
- nums = [getattr(errno, k) for k in errnames if k in errno_names]
- # de-dupe the list
- return dict.fromkeys(nums).keys()
-
-socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR")
-
-socket_errors_to_ignore = plat_specific_errors(
- "EPIPE",
- "EBADF", "WSAEBADF",
- "ENOTSOCK", "WSAENOTSOCK",
- "ETIMEDOUT", "WSAETIMEDOUT",
- "ECONNREFUSED", "WSAECONNREFUSED",
- "ECONNRESET", "WSAECONNRESET",
- "ECONNABORTED", "WSAECONNABORTED",
- "ENETRESET", "WSAENETRESET",
- "EHOSTDOWN", "EHOSTUNREACH",
- )
-socket_errors_to_ignore.append("timed out")
-socket_errors_to_ignore.append("The read operation timed out")
-
-socket_errors_nonblocking = plat_specific_errors(
- 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK')
-
-comma_separated_headers = ['Accept', 'Accept-Charset', 'Accept-Encoding',
- 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
- 'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
- 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
- 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
- 'WWW-Authenticate']
-
-
-import logging
-if not hasattr(logging, 'statistics'): logging.statistics = {}
-
-
-def read_headers(rfile, hdict=None):
- """Read headers from the given stream into the given header dict.
-
- If hdict is None, a new header dict is created. Returns the populated
- header dict.
-
- Headers which are repeated are folded together using a comma if their
- specification so dictates.
-
- This function raises ValueError when the read bytes violate the HTTP spec.
- You should probably return "400 Bad Request" if this happens.
- """
- if hdict is None:
- hdict = {}
-
- while True:
- line = rfile.readline()
- if not line:
- # No more data--illegal end of headers
- raise ValueError("Illegal end of headers.")
-
- if line == CRLF:
- # Normal end of headers
- break
- if not line.endswith(CRLF):
- raise ValueError("HTTP requires CRLF terminators")
-
- if line[0] in ' \t':
- # It's a continuation line.
- v = line.strip()
- else:
- try:
- k, v = line.split(":", 1)
- except ValueError:
- raise ValueError("Illegal header line.")
- # TODO: what about TE and WWW-Authenticate?
- k = k.strip().title()
- v = v.strip()
- hname = k
-
- if k in comma_separated_headers:
- existing = hdict.get(hname)
- if existing:
- v = ", ".join((existing, v))
- hdict[hname] = v
-
- return hdict
-
-
-class MaxSizeExceeded(Exception):
- pass
-
-class SizeCheckWrapper(object):
- """Wraps a file-like object, raising MaxSizeExceeded if too large."""
-
- def __init__(self, rfile, maxlen):
- self.rfile = rfile
- self.maxlen = maxlen
- self.bytes_read = 0
-
- def _check_length(self):
- if self.maxlen and self.bytes_read > self.maxlen:
- raise MaxSizeExceeded()
-
- def read(self, size=None):
- data = self.rfile.read(size)
- self.bytes_read += len(data)
- self._check_length()
- return data
-
- def readline(self, size=None):
- if size is not None:
- data = self.rfile.readline(size)
- self.bytes_read += len(data)
- self._check_length()
- return data
-
- # User didn't specify a size ...
- # We read the line in chunks to make sure it's not a 100MB line !
- res = []
- while True:
- data = self.rfile.readline(256)
- self.bytes_read += len(data)
- self._check_length()
- res.append(data)
- # See http://www.cherrypy.org/ticket/421
- if len(data) < 256 or data[-1:] == "\n":
- return ''.join(res)
-
- def readlines(self, sizehint=0):
- # Shamelessly stolen from StringIO
- total = 0
- lines = []
- line = self.readline()
- while line:
- lines.append(line)
- total += len(line)
- if 0 < sizehint <= total:
- break
- line = self.readline()
- return lines
-
- def close(self):
- self.rfile.close()
-
- def __iter__(self):
- return self
-
- def next(self):
- data = self.rfile.next()
- self.bytes_read += len(data)
- self._check_length()
- return data
-
-
-class KnownLengthRFile(object):
- """Wraps a file-like object, returning an empty string when exhausted."""
-
- def __init__(self, rfile, content_length):
- self.rfile = rfile
- self.remaining = content_length
-
- def read(self, size=None):
- if self.remaining == 0:
- return ''
- if size is None:
- size = self.remaining
- else:
- size = min(size, self.remaining)
-
- data = self.rfile.read(size)
- self.remaining -= len(data)
- return data
-
- def readline(self, size=None):
- if self.remaining == 0:
- return ''
- if size is None:
- size = self.remaining
- else:
- size = min(size, self.remaining)
-
- data = self.rfile.readline(size)
- self.remaining -= len(data)
- return data
-
- def readlines(self, sizehint=0):
- # Shamelessly stolen from StringIO
- total = 0
- lines = []
- line = self.readline(sizehint)
- while line:
- lines.append(line)
- total += len(line)
- if 0 < sizehint <= total:
- break
- line = self.readline(sizehint)
- return lines
-
- def close(self):
- self.rfile.close()
-
- def __iter__(self):
- return self
-
- def __next__(self):
- data = next(self.rfile)
- self.remaining -= len(data)
- return data
-
-
-class ChunkedRFile(object):
- """Wraps a file-like object, returning an empty string when exhausted.
-
- This class is intended to provide a conforming wsgi.input value for
- request entities that have been encoded with the 'chunked' transfer
- encoding.
- """
-
- def __init__(self, rfile, maxlen, bufsize=8192):
- self.rfile = rfile
- self.maxlen = maxlen
- self.bytes_read = 0
- self.buffer = ''
- self.bufsize = bufsize
- self.closed = False
-
- def _fetch(self):
- if self.closed:
- return
-
- line = self.rfile.readline()
- self.bytes_read += len(line)
-
- if self.maxlen and self.bytes_read > self.maxlen:
- raise MaxSizeExceeded("Request Entity Too Large", self.maxlen)
-
- line = line.strip().split(";", 1)
-
- try:
- chunk_size = line.pop(0)
- chunk_size = int(chunk_size, 16)
- except ValueError:
- raise ValueError("Bad chunked transfer size: " + repr(chunk_size))
-
- if chunk_size <= 0:
- self.closed = True
- return
-
-## if line: chunk_extension = line[0]
-
- if self.maxlen and self.bytes_read + chunk_size > self.maxlen:
- raise IOError("Request Entity Too Large")
-
- chunk = self.rfile.read(chunk_size)
- self.bytes_read += len(chunk)
- self.buffer += chunk
-
- crlf = self.rfile.read(2)
- if crlf != CRLF:
- raise ValueError(
- "Bad chunked transfer coding (expected '\\r\\n', "
- "got " + repr(crlf) + ")")
-
- def read(self, size=None):
- data = ''
- while True:
- if size and len(data) >= size:
- return data
-
- if not self.buffer:
- self._fetch()
- if not self.buffer:
- # EOF
- return data
-
- if size:
- remaining = size - len(data)
- data += self.buffer[:remaining]
- self.buffer = self.buffer[remaining:]
- else:
- data += self.buffer
-
- def readline(self, size=None):
- data = ''
- while True:
- if size and len(data) >= size:
- return data
-
- if not self.buffer:
- self._fetch()
- if not self.buffer:
- # EOF
- return data
-
- newline_pos = self.buffer.find('\n')
- if size:
- if newline_pos == -1:
- remaining = size - len(data)
- data += self.buffer[:remaining]
- self.buffer = self.buffer[remaining:]
- else:
- remaining = min(size - len(data), newline_pos)
- data += self.buffer[:remaining]
- self.buffer = self.buffer[remaining:]
- else:
- if newline_pos == -1:
- data += self.buffer
- else:
- data += self.buffer[:newline_pos]
- self.buffer = self.buffer[newline_pos:]
-
- def readlines(self, sizehint=0):
- # Shamelessly stolen from StringIO
- total = 0
- lines = []
- line = self.readline(sizehint)
- while line:
- lines.append(line)
- total += len(line)
- if 0 < sizehint <= total:
- break
- line = self.readline(sizehint)
- return lines
-
- def read_trailer_lines(self):
- if not self.closed:
- raise ValueError(
- "Cannot read trailers until the request body has been read.")
-
- while True:
- line = self.rfile.readline()
- if not line:
- # No more data--illegal end of headers
- raise ValueError("Illegal end of headers.")
-
- self.bytes_read += len(line)
- if self.maxlen and self.bytes_read > self.maxlen:
- raise IOError("Request Entity Too Large")
-
- if line == CRLF:
- # Normal end of headers
- break
- if not line.endswith(CRLF):
- raise ValueError("HTTP requires CRLF terminators")
-
- yield line
-
- def close(self):
- self.rfile.close()
-
- def __iter__(self):
- # Shamelessly stolen from StringIO
- total = 0
- line = self.readline(sizehint)
- while line:
- yield line
- total += len(line)
- if 0 < sizehint <= total:
- break
- line = self.readline(sizehint)
-
-
-class HTTPRequest(object):
- """An HTTP Request (and response).
-
- A single HTTP connection may consist of multiple request/response pairs.
- """
-
- server = None
- """The HTTPServer object which is receiving this request."""
-
- conn = None
- """The HTTPConnection object on which this request connected."""
-
- inheaders = {}
- """A dict of request headers."""
-
- outheaders = []
- """A list of header tuples to write in the response."""
-
- ready = False
- """When True, the request has been parsed and is ready to begin generating
- the response. When False, signals the calling Connection that the response
- should not be generated and the connection should close."""
-
- close_connection = False
- """Signals the calling Connection that the request should close. This does
- not imply an error! The client and/or server may each request that the
- connection be closed."""
-
- chunked_write = False
- """If True, output will be encoded with the "chunked" transfer-coding.
-
- This value is set automatically inside send_headers."""
-
- def __init__(self, server, conn):
- self.server= server
- self.conn = conn
-
- self.ready = False
- self.started_request = False
- self.scheme = "http"
- if self.server.ssl_adapter is not None:
- self.scheme = "https"
- # Use the lowest-common protocol in case read_request_line errors.
- self.response_protocol = 'HTTP/1.0'
- self.inheaders = {}
-
- self.status = ""
- self.outheaders = []
- self.sent_headers = False
- self.close_connection = self.__class__.close_connection
- self.chunked_read = False
- self.chunked_write = self.__class__.chunked_write
-
- def parse_request(self):
- """Parse the next HTTP request start-line and message-headers."""
- self.rfile = SizeCheckWrapper(self.conn.rfile,
- self.server.max_request_header_size)
- try:
- self.read_request_line()
- except MaxSizeExceeded:
- self.simple_response("414 Request-URI Too Long",
- "The Request-URI sent with the request exceeds the maximum "
- "allowed bytes.")
- return
-
- try:
- success = self.read_request_headers()
- except MaxSizeExceeded:
- self.simple_response("413 Request Entity Too Large",
- "The headers sent with the request exceed the maximum "
- "allowed bytes.")
- return
- else:
- if not success:
- return
-
- self.ready = True
-
- def read_request_line(self):
- # HTTP/1.1 connections are persistent by default. If a client
- # requests a page, then idles (leaves the connection open),
- # then rfile.readline() will raise socket.error("timed out").
- # Note that it does this based on the value given to settimeout(),
- # and doesn't need the client to request or acknowledge the close
- # (although your TCP stack might suffer for it: cf Apache's history
- # with FIN_WAIT_2).
- request_line = self.rfile.readline()
-
- # Set started_request to True so communicate() knows to send 408
- # from here on out.
- self.started_request = True
- if not request_line:
- # Force self.ready = False so the connection will close.
- self.ready = False
- return
-
- if request_line == CRLF:
- # RFC 2616 sec 4.1: "...if the server is reading the protocol
- # stream at the beginning of a message and receives a CRLF
- # first, it should ignore the CRLF."
- # But only ignore one leading line! else we enable a DoS.
- request_line = self.rfile.readline()
- if not request_line:
- self.ready = False
- return
-
- if not request_line.endswith(CRLF):
- self.simple_response("400 Bad Request", "HTTP requires CRLF terminators")
- return
-
- try:
- method, uri, req_protocol = request_line.strip().split(" ", 2)
- rp = int(req_protocol[5]), int(req_protocol[7])
- except (ValueError, IndexError):
- self.simple_response("400 Bad Request", "Malformed Request-Line")
- return
-
- self.uri = uri
- self.method = method
-
- # uri may be an abs_path (including "http://host.domain.tld");
- scheme, authority, path = self.parse_request_uri(uri)
- if '#' in path:
- self.simple_response("400 Bad Request",
- "Illegal #fragment in Request-URI.")
- return
-
- if scheme:
- self.scheme = scheme
-
- qs = ''
- if '?' in path:
- path, qs = path.split('?', 1)
-
- # Unquote the path+params (e.g. "/this%20path" -> "/this path").
- # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2
- #
- # But note that "...a URI must be separated into its components
- # before the escaped characters within those components can be
- # safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2
- # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path".
- try:
- atoms = [unquote(x) for x in quoted_slash.split(path)]
- except ValueError, ex:
- self.simple_response("400 Bad Request", ex.args[0])
- return
- path = "%2F".join(atoms)
- self.path = path
-
- # Note that, like wsgiref and most other HTTP servers,
- # we "% HEX HEX"-unquote the path but not the query string.
- self.qs = qs
-
- # Compare request and server HTTP protocol versions, in case our
- # server does not support the requested protocol. Limit our output
- # to min(req, server). We want the following output:
- # request server actual written supported response
- # protocol protocol response protocol feature set
- # a 1.0 1.0 1.0 1.0
- # b 1.0 1.1 1.1 1.0
- # c 1.1 1.0 1.0 1.0
- # d 1.1 1.1 1.1 1.1
- # Notice that, in (b), the response will be "HTTP/1.1" even though
- # the client only understands 1.0. RFC 2616 10.5.6 says we should
- # only return 505 if the _major_ version is different.
- sp = int(self.server.protocol[5]), int(self.server.protocol[7])
-
- if sp[0] != rp[0]:
- self.simple_response("505 HTTP Version Not Supported")
- return
- self.request_protocol = req_protocol
- self.response_protocol = "HTTP/%s.%s" % min(rp, sp)
-
- def read_request_headers(self):
- """Read self.rfile into self.inheaders. Return success."""
-
- # then all the http headers
- try:
- read_headers(self.rfile, self.inheaders)
- except ValueError, ex:
- self.simple_response("400 Bad Request", ex.args[0])
- return False
-
- mrbs = self.server.max_request_body_size
- if mrbs and int(self.inheaders.get("Content-Length", 0)) > mrbs:
- self.simple_response("413 Request Entity Too Large",
- "The entity sent with the request exceeds the maximum "
- "allowed bytes.")
- return False
-
- # Persistent connection support
- if self.response_protocol == "HTTP/1.1":
- # Both server and client are HTTP/1.1
- if self.inheaders.get("Connection", "") == "close":
- self.close_connection = True
- else:
- # Either the server or client (or both) are HTTP/1.0
- if self.inheaders.get("Connection", "") != "Keep-Alive":
- self.close_connection = True
-
- # Transfer-Encoding support
- te = None
- if self.response_protocol == "HTTP/1.1":
- te = self.inheaders.get("Transfer-Encoding")
- if te:
- te = [x.strip().lower() for x in te.split(",") if x.strip()]
-
- self.chunked_read = False
-
- if te:
- for enc in te:
- if enc == "chunked":
- self.chunked_read = True
- else:
- # Note that, even if we see "chunked", we must reject
- # if there is an extension we don't recognize.
- self.simple_response("501 Unimplemented")
- self.close_connection = True
- return False
-
- # From PEP 333:
- # "Servers and gateways that implement HTTP 1.1 must provide
- # transparent support for HTTP 1.1's "expect/continue" mechanism.
- # This may be done in any of several ways:
- # 1. Respond to requests containing an Expect: 100-continue request
- # with an immediate "100 Continue" response, and proceed normally.
- # 2. Proceed with the request normally, but provide the application
- # with a wsgi.input stream that will send the "100 Continue"
- # response if/when the application first attempts to read from
- # the input stream. The read request must then remain blocked
- # until the client responds.
- # 3. Wait until the client decides that the server does not support
- # expect/continue, and sends the request body on its own.
- # (This is suboptimal, and is not recommended.)
- #
- # We used to do 3, but are now doing 1. Maybe we'll do 2 someday,
- # but it seems like it would be a big slowdown for such a rare case.
- if self.inheaders.get("Expect", "") == "100-continue":
- # Don't use simple_response here, because it emits headers
- # we don't want. See http://www.cherrypy.org/ticket/951
- msg = self.server.protocol + " 100 Continue\r\n\r\n"
- try:
- self.conn.wfile.sendall(msg)
- except socket.error, x:
- if x.args[0] not in socket_errors_to_ignore:
- raise
- return True
-
- def parse_request_uri(self, uri):
- """Parse a Request-URI into (scheme, authority, path).
-
- Note that Request-URI's must be one of::
-
- Request-URI = "*" | absoluteURI | abs_path | authority
-
- Therefore, a Request-URI which starts with a double forward-slash
- cannot be a "net_path"::
-
- net_path = "//" authority [ abs_path ]
-
- Instead, it must be interpreted as an "abs_path" with an empty first
- path segment::
-
- abs_path = "/" path_segments
- path_segments = segment *( "/" segment )
- segment = *pchar *( ";" param )
- param = *pchar
- """
- if uri == "*":
- return None, None, uri
-
- i = uri.find('://')
- if i > 0 and '?' not in uri[:i]:
- # An absoluteURI.
- # If there's a scheme (and it must be http or https), then:
- # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query ]]
- scheme, remainder = uri[:i].lower(), uri[i + 3:]
- authority, path = remainder.split("/", 1)
- return scheme, authority, path
-
- if uri.startswith('/'):
- # An abs_path.
- return None, None, uri
- else:
- # An authority.
- return None, uri, None
-
- def respond(self):
- """Call the gateway and write its iterable output."""
- mrbs = self.server.max_request_body_size
- if self.chunked_read:
- self.rfile = ChunkedRFile(self.conn.rfile, mrbs)
- else:
- cl = int(self.inheaders.get("Content-Length", 0))
- if mrbs and mrbs < cl:
- if not self.sent_headers:
- self.simple_response("413 Request Entity Too Large",
- "The entity sent with the request exceeds the maximum "
- "allowed bytes.")
- return
- self.rfile = KnownLengthRFile(self.conn.rfile, cl)
-
- self.server.gateway(self).respond()
-
- if (self.ready and not self.sent_headers):
- self.sent_headers = True
- self.send_headers()
- if self.chunked_write:
- self.conn.wfile.sendall("0\r\n\r\n")
-
- def simple_response(self, status, msg=""):
- """Write a simple response back to the client."""
- status = str(status)
- buf = [self.server.protocol + " " +
- status + CRLF,
- "Content-Length: %s\r\n" % len(msg),
- "Content-Type: text/plain\r\n"]
-
- if status[:3] in ("413", "414"):
- # Request Entity Too Large / Request-URI Too Long
- self.close_connection = True
- if self.response_protocol == 'HTTP/1.1':
- # This will not be true for 414, since read_request_line
- # usually raises 414 before reading the whole line, and we
- # therefore cannot know the proper response_protocol.
- buf.append("Connection: close\r\n")
- else:
- # HTTP/1.0 had no 413/414 status nor Connection header.
- # Emit 400 instead and trust the message body is enough.
- status = "400 Bad Request"
-
- buf.append(CRLF)
- if msg:
- if isinstance(msg, unicode):
- msg = msg.encode("ISO-8859-1")
- buf.append(msg)
-
- try:
- self.conn.wfile.sendall("".join(buf))
- except socket.error, x:
- if x.args[0] not in socket_errors_to_ignore:
- raise
-
- def write(self, chunk):
- """Write unbuffered data to the client."""
- if self.chunked_write and chunk:
- buf = [hex(len(chunk))[2:], CRLF, chunk, CRLF]
- self.conn.wfile.sendall("".join(buf))
- else:
- self.conn.wfile.sendall(chunk)
-
- def send_headers(self):
- """Assert, process, and send the HTTP response message-headers.
-
- You must set self.status, and self.outheaders before calling this.
- """
- hkeys = [key.lower() for key, value in self.outheaders]
- status = int(self.status[:3])
-
- if status == 413:
- # Request Entity Too Large. Close conn to avoid garbage.
- self.close_connection = True
- elif "content-length" not in hkeys:
- # "All 1xx (informational), 204 (no content),
- # and 304 (not modified) responses MUST NOT
- # include a message-body." So no point chunking.
- if status < 200 or status in (204, 205, 304):
- pass
- else:
- if (self.response_protocol == 'HTTP/1.1'
- and self.method != 'HEAD'):
- # Use the chunked transfer-coding
- self.chunked_write = True
- self.outheaders.append(("Transfer-Encoding", "chunked"))
- else:
- # Closing the conn is the only way to determine len.
- self.close_connection = True
-
- if "connection" not in hkeys:
- if self.response_protocol == 'HTTP/1.1':
- # Both server and client are HTTP/1.1 or better
- if self.close_connection:
- self.outheaders.append(("Connection", "close"))
- else:
- # Server and/or client are HTTP/1.0
- if not self.close_connection:
- self.outheaders.append(("Connection", "Keep-Alive"))
-
- if (not self.close_connection) and (not self.chunked_read):
- # Read any remaining request body data on the socket.
- # "If an origin server receives a request that does not include an
- # Expect request-header field with the "100-continue" expectation,
- # the request includes a request body, and the server responds
- # with a final status code before reading the entire request body
- # from the transport connection, then the server SHOULD NOT close
- # the transport connection until it has read the entire request,
- # or until the client closes the connection. Otherwise, the client
- # might not reliably receive the response message. However, this
- # requirement is not be construed as preventing a server from
- # defending itself against denial-of-service attacks, or from
- # badly broken client implementations."
- remaining = getattr(self.rfile, 'remaining', 0)
- if remaining > 0:
- self.rfile.read(remaining)
-
- if "date" not in hkeys:
- self.outheaders.append(("Date", rfc822.formatdate()))
-
- if "server" not in hkeys:
- self.outheaders.append(("Server", self.server.server_name))
-
- buf = [self.server.protocol + " " + self.status + CRLF]
- for k, v in self.outheaders:
- buf.append(k + ": " + v + CRLF)
- buf.append(CRLF)
- self.conn.wfile.sendall("".join(buf))
-
-
-class NoSSLError(Exception):
- """Exception raised when a client speaks HTTP to an HTTPS socket."""
- pass
-
-
-class FatalSSLAlert(Exception):
- """Exception raised when the SSL implementation signals a fatal alert."""
- pass
-
-
-class CP_fileobject(socket._fileobject):
- """Faux file object attached to a socket object."""
-
- def __init__(self, *args, **kwargs):
- self.bytes_read = 0
- self.bytes_written = 0
- socket._fileobject.__init__(self, *args, **kwargs)
-
- def sendall(self, data):
- """Sendall for non-blocking sockets."""
- while data:
- try:
- bytes_sent = self.send(data)
- data = data[bytes_sent:]
- except socket.error, e:
- if e.args[0] not in socket_errors_nonblocking:
- raise
-
- def send(self, data):
- bytes_sent = self._sock.send(data)
- self.bytes_written += bytes_sent
- return bytes_sent
-
- def flush(self):
- if self._wbuf:
- buffer = "".join(self._wbuf)
- self._wbuf = []
- self.sendall(buffer)
-
- def recv(self, size):
- while True:
- try:
- data = self._sock.recv(size)
- self.bytes_read += len(data)
- return data
- except socket.error, e:
- if (e.args[0] not in socket_errors_nonblocking
- and e.args[0] not in socket_error_eintr):
- raise
-
- if not _fileobject_uses_str_type:
- def read(self, size=-1):
- # Use max, disallow tiny reads in a loop as they are very inefficient.
- # We never leave read() with any leftover data from a new recv() call
- # in our internal buffer.
- rbufsize = max(self._rbufsize, self.default_bufsize)
- # Our use of StringIO rather than lists of string objects returned by
- # recv() minimizes memory usage and fragmentation that occurs when
- # rbufsize is large compared to the typical return value of recv().
- buf = self._rbuf
- buf.seek(0, 2) # seek end
- if size < 0:
- # Read until EOF
- self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
- while True:
- data = self.recv(rbufsize)
- if not data:
- break
- buf.write(data)
- return buf.getvalue()
- else:
- # Read until size bytes or EOF seen, whichever comes first
- buf_len = buf.tell()
- if buf_len >= size:
- # Already have size bytes in our buffer? Extract and return.
- buf.seek(0)
- rv = buf.read(size)
- self._rbuf = StringIO.StringIO()
- self._rbuf.write(buf.read())
- return rv
-
- self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
- while True:
- left = size - buf_len
- # recv() will malloc the amount of memory given as its
- # parameter even though it often returns much less data
- # than that. The returned data string is short lived
- # as we copy it into a StringIO and free it. This avoids
- # fragmentation issues on many platforms.
- data = self.recv(left)
- if not data:
- break
- n = len(data)
- if n == size and not buf_len:
- # Shortcut. Avoid buffer data copies when:
- # - We have no data in our buffer.
- # AND
- # - Our call to recv returned exactly the
- # number of bytes we were asked to read.
- return data
- if n == left:
- buf.write(data)
- del data # explicit free
- break
- assert n <= left, "recv(%d) returned %d bytes" % (left, n)
- buf.write(data)
- buf_len += n
- del data # explicit free
- #assert buf_len == buf.tell()
- return buf.getvalue()
-
- def readline(self, size=-1):
- buf = self._rbuf
- buf.seek(0, 2) # seek end
- if buf.tell() > 0:
- # check if we already have it in our buffer
- buf.seek(0)
- bline = buf.readline(size)
- if bline.endswith('\n') or len(bline) == size:
- self._rbuf = StringIO.StringIO()
- self._rbuf.write(buf.read())
- return bline
- del bline
- if size < 0:
- # Read until \n or EOF, whichever comes first
- if self._rbufsize <= 1:
- # Speed up unbuffered case
- buf.seek(0)
- buffers = [buf.read()]
- self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
- data = None
- recv = self.recv
- while data != "\n":
- data = recv(1)
- if not data:
- break
- buffers.append(data)
- return "".join(buffers)
-
- buf.seek(0, 2) # seek end
- self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
- while True:
- data = self.recv(self._rbufsize)
- if not data:
- break
- nl = data.find('\n')
- if nl >= 0:
- nl += 1
- buf.write(data[:nl])
- self._rbuf.write(data[nl:])
- del data
- break
- buf.write(data)
- return buf.getvalue()
- else:
- # Read until size bytes or \n or EOF seen, whichever comes first
- buf.seek(0, 2) # seek end
- buf_len = buf.tell()
- if buf_len >= size:
- buf.seek(0)
- rv = buf.read(size)
- self._rbuf = StringIO.StringIO()
- self._rbuf.write(buf.read())
- return rv
- self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
- while True:
- data = self.recv(self._rbufsize)
- if not data:
- break
- left = size - buf_len
- # did we just receive a newline?
- nl = data.find('\n', 0, left)
- if nl >= 0:
- nl += 1
- # save the excess data to _rbuf
- self._rbuf.write(data[nl:])
- if buf_len:
- buf.write(data[:nl])
- break
- else:
- # Shortcut. Avoid data copy through buf when returning
- # a substring of our first recv().
- return data[:nl]
- n = len(data)
- if n == size and not buf_len:
- # Shortcut. Avoid data copy through buf when
- # returning exactly all of our first recv().
- return data
- if n >= left:
- buf.write(data[:left])
- self._rbuf.write(data[left:])
- break
- buf.write(data)
- buf_len += n
- #assert buf_len == buf.tell()
- return buf.getvalue()
- else:
- def read(self, size=-1):
- if size < 0:
- # Read until EOF
- buffers = [self._rbuf]
- self._rbuf = ""
- if self._rbufsize <= 1:
- recv_size = self.default_bufsize
- else:
- recv_size = self._rbufsize
-
- while True:
- data = self.recv(recv_size)
- if not data:
- break
- buffers.append(data)
- return "".join(buffers)
- else:
- # Read until size bytes or EOF seen, whichever comes first
- data = self._rbuf
- buf_len = len(data)
- if buf_len >= size:
- self._rbuf = data[size:]
- return data[:size]
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- while True:
- left = size - buf_len
- recv_size = max(self._rbufsize, left)
- data = self.recv(recv_size)
- if not data:
- break
- buffers.append(data)
- n = len(data)
- if n >= left:
- self._rbuf = data[left:]
- buffers[-1] = data[:left]
- break
- buf_len += n
- return "".join(buffers)
-
- def readline(self, size=-1):
- data = self._rbuf
- if size < 0:
- # Read until \n or EOF, whichever comes first
- if self._rbufsize <= 1:
- # Speed up unbuffered case
- assert data == ""
- buffers = []
- while data != "\n":
- data = self.recv(1)
- if not data:
- break
- buffers.append(data)
- return "".join(buffers)
- nl = data.find('\n')
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- return data[:nl]
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- while True:
- data = self.recv(self._rbufsize)
- if not data:
- break
- buffers.append(data)
- nl = data.find('\n')
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- buffers[-1] = data[:nl]
- break
- return "".join(buffers)
- else:
- # Read until size bytes or \n or EOF seen, whichever comes first
- nl = data.find('\n', 0, size)
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- return data[:nl]
- buf_len = len(data)
- if buf_len >= size:
- self._rbuf = data[size:]
- return data[:size]
- buffers = []
- if data:
- buffers.append(data)
- self._rbuf = ""
- while True:
- data = self.recv(self._rbufsize)
- if not data:
- break
- buffers.append(data)
- left = size - buf_len
- nl = data.find('\n', 0, left)
- if nl >= 0:
- nl += 1
- self._rbuf = data[nl:]
- buffers[-1] = data[:nl]
- break
- n = len(data)
- if n >= left:
- self._rbuf = data[left:]
- buffers[-1] = data[:left]
- break
- buf_len += n
- return "".join(buffers)
-
-
-class HTTPConnection(object):
- """An HTTP connection (active socket).
-
- server: the Server object which received this connection.
- socket: the raw socket object (usually TCP) for this connection.
- makefile: a fileobject class for reading from the socket.
- """
-
- remote_addr = None
- remote_port = None
- ssl_env = None
- rbufsize = DEFAULT_BUFFER_SIZE
- wbufsize = DEFAULT_BUFFER_SIZE
- RequestHandlerClass = HTTPRequest
-
- def __init__(self, server, sock, makefile=CP_fileobject):
- self.server = server
- self.socket = sock
- self.rfile = makefile(sock, "rb", self.rbufsize)
- self.wfile = makefile(sock, "wb", self.wbufsize)
- self.requests_seen = 0
-
- def communicate(self):
- """Read each request and respond appropriately."""
- request_seen = False
- try:
- while True:
- # (re)set req to None so that if something goes wrong in
- # the RequestHandlerClass constructor, the error doesn't
- # get written to the previous request.
- req = None
- req = self.RequestHandlerClass(self.server, self)
-
- # This order of operations should guarantee correct pipelining.
- req.parse_request()
- if self.server.stats['Enabled']:
- self.requests_seen += 1
- if not req.ready:
- # Something went wrong in the parsing (and the server has
- # probably already made a simple_response). Return and
- # let the conn close.
- return
-
- request_seen = True
- req.respond()
- if req.close_connection:
- return
- except socket.error, e:
- errnum = e.args[0]
- # sadly SSL sockets return a different (longer) time out string
- if errnum == 'timed out' or errnum == 'The read operation timed out':
- # Don't error if we're between requests; only error
- # if 1) no request has been started at all, or 2) we're
- # in the middle of a request.
- # See http://www.cherrypy.org/ticket/853
- if (not request_seen) or (req and req.started_request):
- # Don't bother writing the 408 if the response
- # has already started being written.
- if req and not req.sent_headers:
- try:
- req.simple_response("408 Request Timeout")
- except FatalSSLAlert:
- # Close the connection.
- return
- elif errnum not in socket_errors_to_ignore:
- if req and not req.sent_headers:
- try:
- req.simple_response("500 Internal Server Error",
- format_exc())
- except FatalSSLAlert:
- # Close the connection.
- return
- return
- except (KeyboardInterrupt, SystemExit):
- raise
- except FatalSSLAlert:
- # Close the connection.
- return
- except NoSSLError:
- if req and not req.sent_headers:
- # Unwrap our wfile
- self.wfile = CP_fileobject(self.socket._sock, "wb", self.wbufsize)
- req.simple_response("400 Bad Request",
- "The client sent a plain HTTP request, but "
- "this server only speaks HTTPS on this port.")
- self.linger = True
- except Exception:
- if req and not req.sent_headers:
- try:
- req.simple_response("500 Internal Server Error", format_exc())
- except FatalSSLAlert:
- # Close the connection.
- return
-
- linger = False
-
- def close(self):
- """Close the socket underlying this connection."""
- self.rfile.close()
-
- if not self.linger:
- # Python's socket module does NOT call close on the kernel socket
- # when you call socket.close(). We do so manually here because we
- # want this server to send a FIN TCP segment immediately. Note this
- # must be called *before* calling socket.close(), because the latter
- # drops its reference to the kernel socket.
- if hasattr(self.socket, '_sock'):
- self.socket._sock.close()
- self.socket.close()
- else:
- # On the other hand, sometimes we want to hang around for a bit
- # to make sure the client has a chance to read our entire
- # response. Skipping the close() calls here delays the FIN
- # packet until the socket object is garbage-collected later.
- # Someday, perhaps, we'll do the full lingering_close that
- # Apache does, but not today.
- pass
-
-
-_SHUTDOWNREQUEST = None
-
-class WorkerThread(threading.Thread):
- """Thread which continuously polls a Queue for Connection objects.
-
- Due to the timing issues of polling a Queue, a WorkerThread does not
- check its own 'ready' flag after it has started. To stop the thread,
- it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue
- (one for each running WorkerThread).
- """
-
- conn = None
- """The current connection pulled off the Queue, or None."""
-
- server = None
- """The HTTP Server which spawned this thread, and which owns the
- Queue and is placing active connections into it."""
-
- ready = False
- """A simple flag for the calling server to know when this thread
- has begun polling the Queue."""
-
-
- def __init__(self, server):
- self.ready = False
- self.server = server
-
- self.requests_seen = 0
- self.bytes_read = 0
- self.bytes_written = 0
- self.start_time = None
- self.work_time = 0
- self.stats = {
- 'Requests': lambda s: self.requests_seen + ((self.start_time is None) and 0 or self.conn.requests_seen),
- 'Bytes Read': lambda s: self.bytes_read + ((self.start_time is None) and 0 or self.conn.rfile.bytes_read),
- 'Bytes Written': lambda s: self.bytes_written + ((self.start_time is None) and 0 or self.conn.wfile.bytes_written),
- 'Work Time': lambda s: self.work_time + ((self.start_time is None) and 0 or time.time() - self.start_time),
- 'Read Throughput': lambda s: s['Bytes Read'](s) / (s['Work Time'](s) or 1e-6),
- 'Write Throughput': lambda s: s['Bytes Written'](s) / (s['Work Time'](s) or 1e-6),
- }
- threading.Thread.__init__(self)
-
- def run(self):
- self.server.stats['Worker Threads'][self.getName()] = self.stats
- try:
- self.ready = True
- while True:
- conn = self.server.requests.get()
- if conn is _SHUTDOWNREQUEST:
- return
-
- self.conn = conn
- if self.server.stats['Enabled']:
- self.start_time = time.time()
- try:
- conn.communicate()
- finally:
- conn.close()
- if self.server.stats['Enabled']:
- self.requests_seen += self.conn.requests_seen
- self.bytes_read += self.conn.rfile.bytes_read
- self.bytes_written += self.conn.wfile.bytes_written
- self.work_time += time.time() - self.start_time
- self.start_time = None
- self.conn = None
- except (KeyboardInterrupt, SystemExit), exc:
- self.server.interrupt = exc
-
-
-class ThreadPool(object):
- """A Request Queue for the CherryPyWSGIServer which pools threads.
-
- ThreadPool objects must provide min, get(), put(obj), start()
- and stop(timeout) attributes.
- """
-
- def __init__(self, server, min=10, max=-1):
- self.server = server
- self.min = min
- self.max = max
- self._threads = []
- self._queue = Queue.Queue()
- self.get = self._queue.get
-
- def start(self):
- """Start the pool of threads."""
- for i in range(self.min):
- self._threads.append(WorkerThread(self.server))
- for worker in self._threads:
- worker.setName("CP Server " + worker.getName())
- worker.start()
- for worker in self._threads:
- while not worker.ready:
- time.sleep(.1)
-
- def _get_idle(self):
- """Number of worker threads which are idle. Read-only."""
- return len([t for t in self._threads if t.conn is None])
- idle = property(_get_idle, doc=_get_idle.__doc__)
-
- def put(self, obj):
- self._queue.put(obj)
- if obj is _SHUTDOWNREQUEST:
- return
-
- def grow(self, amount):
- """Spawn new worker threads (not above self.max)."""
- for i in range(amount):
- if self.max > 0 and len(self._threads) >= self.max:
- break
- worker = WorkerThread(self.server)
- worker.setName("CP Server " + worker.getName())
- self._threads.append(worker)
- worker.start()
-
- def shrink(self, amount):
- """Kill off worker threads (not below self.min)."""
- # Grow/shrink the pool if necessary.
- # Remove any dead threads from our list
- for t in self._threads:
- if not t.isAlive():
- self._threads.remove(t)
- amount -= 1
-
- if amount > 0:
- for i in range(min(amount, len(self._threads) - self.min)):
- # Put a number of shutdown requests on the queue equal
- # to 'amount'. Once each of those is processed by a worker,
- # that worker will terminate and be culled from our list
- # in self.put.
- self._queue.put(_SHUTDOWNREQUEST)
-
- def stop(self, timeout=5):
- # Must shut down threads here so the code that calls
- # this method can know when all threads are stopped.
- for worker in self._threads:
- self._queue.put(_SHUTDOWNREQUEST)
-
- # Don't join currentThread (when stop is called inside a request).
- current = threading.currentThread()
- if timeout and timeout >= 0:
- endtime = time.time() + timeout
- while self._threads:
- worker = self._threads.pop()
- if worker is not current and worker.isAlive():
- try:
- if timeout is None or timeout < 0:
- worker.join()
- else:
- remaining_time = endtime - time.time()
- if remaining_time > 0:
- worker.join(remaining_time)
- if worker.isAlive():
- # We exhausted the timeout.
- # Forcibly shut down the socket.
- c = worker.conn
- if c and not c.rfile.closed:
- try:
- c.socket.shutdown(socket.SHUT_RD)
- except TypeError:
- # pyOpenSSL sockets don't take an arg
- c.socket.shutdown()
- worker.join()
- except (AssertionError,
- # Ignore repeated Ctrl-C.
- # See http://www.cherrypy.org/ticket/691.
- KeyboardInterrupt), exc1:
- pass
-
- def _get_qsize(self):
- return self._queue.qsize()
- qsize = property(_get_qsize)
-
-
-
-try:
- import fcntl
-except ImportError:
- try:
- from ctypes import windll, WinError
- except ImportError:
- def prevent_socket_inheritance(sock):
- """Dummy function, since neither fcntl nor ctypes are available."""
- pass
- else:
- def prevent_socket_inheritance(sock):
- """Mark the given socket fd as non-inheritable (Windows)."""
- if not windll.kernel32.SetHandleInformation(sock.fileno(), 1, 0):
- raise WinError()
-else:
- def prevent_socket_inheritance(sock):
- """Mark the given socket fd as non-inheritable (POSIX)."""
- fd = sock.fileno()
- old_flags = fcntl.fcntl(fd, fcntl.F_GETFD)
- fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC)
-
-
-class SSLAdapter(object):
- """Base class for SSL driver library adapters.
-
- Required methods:
-
- * ``wrap(sock) -> (wrapped socket, ssl environ dict)``
- * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> socket file object``
- """
-
- def __init__(self, certificate, private_key, certificate_chain=None):
- self.certificate = certificate
- self.private_key = private_key
- self.certificate_chain = certificate_chain
-
- def wrap(self, sock):
- raise NotImplemented
-
- def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE):
- raise NotImplemented
-
-
-class HTTPServer(object):
- """An HTTP server."""
-
- _bind_addr = "127.0.0.1"
- _interrupt = None
-
- gateway = None
- """A Gateway instance."""
-
- minthreads = None
- """The minimum number of worker threads to create (default 10)."""
-
- maxthreads = None
- """The maximum number of worker threads to create (default -1 = no limit)."""
-
- server_name = None
- """The name of the server; defaults to socket.gethostname()."""
-
- protocol = "HTTP/1.1"
- """The version string to write in the Status-Line of all HTTP responses.
-
- For example, "HTTP/1.1" is the default. This also limits the supported
- features used in the response."""
-
- request_queue_size = 5
- """The 'backlog' arg to socket.listen(); max queued connections (default 5)."""
-
- shutdown_timeout = 5
- """The total time, in seconds, to wait for worker threads to cleanly exit."""
-
- timeout = 10
- """The timeout in seconds for accepted connections (default 10)."""
-
- version = "CherryPy/3.2.0"
- """A version string for the HTTPServer."""
-
- software = None
- """The value to set for the SERVER_SOFTWARE entry in the WSGI environ.
-
- If None, this defaults to ``'%s Server' % self.version``."""
-
- ready = False
- """An internal flag which marks whether the socket is accepting connections."""
-
- max_request_header_size = 0
- """The maximum size, in bytes, for request headers, or 0 for no limit."""
-
- max_request_body_size = 0
- """The maximum size, in bytes, for request bodies, or 0 for no limit."""
-
- nodelay = True
- """If True (the default since 3.1), sets the TCP_NODELAY socket option."""
-
- ConnectionClass = HTTPConnection
- """The class to use for handling HTTP connections."""
-
- ssl_adapter = None
- """An instance of SSLAdapter (or a subclass).
-
- You must have the corresponding SSL driver library installed."""
-
- def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1,
- server_name=None):
- self.bind_addr = bind_addr
- self.gateway = gateway
-
- self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads)
-
- if not server_name:
- server_name = socket.gethostname()
- self.server_name = server_name
- self.clear_stats()
-
- def clear_stats(self):
- self._start_time = None
- self._run_time = 0
- self.stats = {
- 'Enabled': False,
- 'Bind Address': lambda s: repr(self.bind_addr),
- 'Run time': lambda s: (not s['Enabled']) and 0 or self.runtime(),
- 'Accepts': 0,
- 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(),
- 'Queue': lambda s: getattr(self.requests, "qsize", None),
- 'Threads': lambda s: len(getattr(self.requests, "_threads", [])),
- 'Threads Idle': lambda s: getattr(self.requests, "idle", None),
- 'Socket Errors': 0,
- 'Requests': lambda s: (not s['Enabled']) and 0 or sum([w['Requests'](w) for w
- in s['Worker Threads'].values()], 0),
- 'Bytes Read': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Read'](w) for w
- in s['Worker Threads'].values()], 0),
- 'Bytes Written': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Written'](w) for w
- in s['Worker Threads'].values()], 0),
- 'Work Time': lambda s: (not s['Enabled']) and 0 or sum([w['Work Time'](w) for w
- in s['Worker Threads'].values()], 0),
- 'Read Throughput': lambda s: (not s['Enabled']) and 0 or sum(
- [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6)
- for w in s['Worker Threads'].values()], 0),
- 'Write Throughput': lambda s: (not s['Enabled']) and 0 or sum(
- [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6)
- for w in s['Worker Threads'].values()], 0),
- 'Worker Threads': {},
- }
- logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats
-
- def runtime(self):
- if self._start_time is None:
- return self._run_time
- else:
- return self._run_time + (time.time() - self._start_time)
-
- def __str__(self):
- return "%s.%s(%r)" % (self.__module__, self.__class__.__name__,
- self.bind_addr)
-
- def _get_bind_addr(self):
- return self._bind_addr
- def _set_bind_addr(self, value):
- if isinstance(value, tuple) and value[0] in ('', None):
- # Despite the socket module docs, using '' does not
- # allow AI_PASSIVE to work. Passing None instead
- # returns '0.0.0.0' like we want. In other words:
- # host AI_PASSIVE result
- # '' Y 192.168.x.y
- # '' N 192.168.x.y
- # None Y 0.0.0.0
- # None N 127.0.0.1
- # But since you can get the same effect with an explicit
- # '0.0.0.0', we deny both the empty string and None as values.
- raise ValueError("Host values of '' or None are not allowed. "
- "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead "
- "to listen on all active interfaces.")
- self._bind_addr = value
- bind_addr = property(_get_bind_addr, _set_bind_addr,
- doc="""The interface on which to listen for connections.
-
- For TCP sockets, a (host, port) tuple. Host values may be any IPv4
- or IPv6 address, or any valid hostname. The string 'localhost' is a
- synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6).
- The string '0.0.0.0' is a special IPv4 entry meaning "any active
- interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for
- IPv6. The empty string or None are not allowed.
-
- For UNIX sockets, supply the filename as a string.""")
-
- def start(self):
- """Run the server forever."""
- # We don't have to trap KeyboardInterrupt or SystemExit here,
- # because cherrpy.server already does so, calling self.stop() for us.
- # If you're using this server with another framework, you should
- # trap those exceptions in whatever code block calls start().
- self._interrupt = None
-
- if self.software is None:
- self.software = "%s Server" % self.version
-
- # SSL backward compatibility
- if (self.ssl_adapter is None and
- getattr(self, 'ssl_certificate', None) and
- getattr(self, 'ssl_private_key', None)):
- warnings.warn(
- "SSL attributes are deprecated in CherryPy 3.2, and will "
- "be removed in CherryPy 3.3. Use an ssl_adapter attribute "
- "instead.",
- DeprecationWarning
- )
- try:
- from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter
- except ImportError:
- pass
- else:
- self.ssl_adapter = pyOpenSSLAdapter(
- self.ssl_certificate, self.ssl_private_key,
- getattr(self, 'ssl_certificate_chain', None))
-
- # Select the appropriate socket
- if isinstance(self.bind_addr, basestring):
- # AF_UNIX socket
-
- # So we can reuse the socket...
- try: os.unlink(self.bind_addr)
- except: pass
-
- # So everyone can access the socket...
- try: os.chmod(self.bind_addr, 0777)
- except: pass
-
- info = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)]
- else:
- # AF_INET or AF_INET6 socket
- # Get the correct address family for our host (allows IPv6 addresses)
- host, port = self.bind_addr
- try:
- info = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
- socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
- except socket.gaierror:
- if ':' in self.bind_addr[0]:
- info = [(socket.AF_INET6, socket.SOCK_STREAM,
- 0, "", self.bind_addr + (0, 0))]
- else:
- info = [(socket.AF_INET, socket.SOCK_STREAM,
- 0, "", self.bind_addr)]
-
- self.socket = None
- msg = "No socket could be created"
- for res in info:
- af, socktype, proto, canonname, sa = res
- try:
- self.bind(af, socktype, proto)
- except socket.error:
- if self.socket:
- self.socket.close()
- self.socket = None
- continue
- break
- if not self.socket:
- raise socket.error(msg)
-
- # Timeout so KeyboardInterrupt can be caught on Win32
- self.socket.settimeout(1)
- self.socket.listen(self.request_queue_size)
-
- # Create worker threads
- self.requests.start()
-
- self.ready = True
- self._start_time = time.time()
- while self.ready:
- self.tick()
- if self.interrupt:
- while self.interrupt is True:
- # Wait for self.stop() to complete. See _set_interrupt.
- time.sleep(0.1)
- if self.interrupt:
- raise self.interrupt
-
- def bind(self, family, type, proto=0):
- """Create (or recreate) the actual socket object."""
- self.socket = socket.socket(family, type, proto)
- prevent_socket_inheritance(self.socket)
- self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- if self.nodelay and not isinstance(self.bind_addr, str):
- self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-
- if self.ssl_adapter is not None:
- self.socket = self.ssl_adapter.bind(self.socket)
-
- # If listening on the IPV6 any address ('::' = IN6ADDR_ANY),
- # activate dual-stack. See http://www.cherrypy.org/ticket/871.
- if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6
- and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')):
- try:
- self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
- except (AttributeError, socket.error):
- # Apparently, the socket option is not available in
- # this machine's TCP stack
- pass
-
- self.socket.bind(self.bind_addr)
-
- def tick(self):
- """Accept a new connection and put it on the Queue."""
- try:
- s, addr = self.socket.accept()
- if self.stats['Enabled']:
- self.stats['Accepts'] += 1
- if not self.ready:
- return
-
- prevent_socket_inheritance(s)
- if hasattr(s, 'settimeout'):
- s.settimeout(self.timeout)
-
- makefile = CP_fileobject
- ssl_env = {}
- # if ssl cert and key are set, we try to be a secure HTTP server
- if self.ssl_adapter is not None:
- try:
- s, ssl_env = self.ssl_adapter.wrap(s)
- except NoSSLError:
- msg = ("The client sent a plain HTTP request, but "
- "this server only speaks HTTPS on this port.")
- buf = ["%s 400 Bad Request\r\n" % self.protocol,
- "Content-Length: %s\r\n" % len(msg),
- "Content-Type: text/plain\r\n\r\n",
- msg]
-
- wfile = CP_fileobject(s, "wb", DEFAULT_BUFFER_SIZE)
- try:
- wfile.sendall("".join(buf))
- except socket.error, x:
- if x.args[0] not in socket_errors_to_ignore:
- raise
- return
- if not s:
- return
- makefile = self.ssl_adapter.makefile
- # Re-apply our timeout since we may have a new socket object
- if hasattr(s, 'settimeout'):
- s.settimeout(self.timeout)
-
- conn = self.ConnectionClass(self, s, makefile)
-
- if not isinstance(self.bind_addr, basestring):
- # optional values
- # Until we do DNS lookups, omit REMOTE_HOST
- if addr is None: # sometimes this can happen
- # figure out if AF_INET or AF_INET6.
- if len(s.getsockname()) == 2:
- # AF_INET
- addr = ('0.0.0.0', 0)
- else:
- # AF_INET6
- addr = ('::', 0)
- conn.remote_addr = addr[0]
- conn.remote_port = addr[1]
-
- conn.ssl_env = ssl_env
-
- self.requests.put(conn)
- except socket.timeout:
- # The only reason for the timeout in start() is so we can
- # notice keyboard interrupts on Win32, which don't interrupt
- # accept() by default
- return
- except socket.error, x:
- if self.stats['Enabled']:
- self.stats['Socket Errors'] += 1
- if x.args[0] in socket_error_eintr:
- # I *think* this is right. EINTR should occur when a signal
- # is received during the accept() call; all docs say retry
- # the call, and I *think* I'm reading it right that Python
- # will then go ahead and poll for and handle the signal
- # elsewhere. See http://www.cherrypy.org/ticket/707.
- return
- if x.args[0] in socket_errors_nonblocking:
- # Just try again. See http://www.cherrypy.org/ticket/479.
- return
- if x.args[0] in socket_errors_to_ignore:
- # Our socket was closed.
- # See http://www.cherrypy.org/ticket/686.
- return
- raise
-
- def _get_interrupt(self):
- return self._interrupt
- def _set_interrupt(self, interrupt):
- self._interrupt = True
- self.stop()
- self._interrupt = interrupt
- interrupt = property(_get_interrupt, _set_interrupt,
- doc="Set this to an Exception instance to "
- "interrupt the server.")
-
- def stop(self):
- """Gracefully shutdown a server that is serving forever."""
- self.ready = False
- if self._start_time is not None:
- self._run_time += (time.time() - self._start_time)
- self._start_time = None
-
- sock = getattr(self, "socket", None)
- if sock:
- if not isinstance(self.bind_addr, basestring):
- # Touch our own socket to make accept() return immediately.
- try:
- host, port = sock.getsockname()[:2]
- except socket.error, x:
- if x.args[0] not in socket_errors_to_ignore:
- # Changed to use error code and not message
- # See http://www.cherrypy.org/ticket/860.
- raise
- else:
- # Note that we're explicitly NOT using AI_PASSIVE,
- # here, because we want an actual IP to touch.
- # localhost won't work if we've bound to a public IP,
- # but it will if we bound to '0.0.0.0' (INADDR_ANY).
- for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
- socket.SOCK_STREAM):
- af, socktype, proto, canonname, sa = res
- s = None
- try:
- s = socket.socket(af, socktype, proto)
- # See http://groups.google.com/group/cherrypy-users/
- # browse_frm/thread/bbfe5eb39c904fe0
- s.settimeout(1.0)
- s.connect((host, port))
- s.close()
- except socket.error:
- if s:
- s.close()
- if hasattr(sock, "close"):
- sock.close()
- self.socket = None
-
- self.requests.stop(self.shutdown_timeout)
-
-
-class Gateway(object):
-
- def __init__(self, req):
- self.req = req
-
- def respond(self):
- raise NotImplemented
-
-
-# These may either be wsgiserver.SSLAdapter subclasses or the string names
-# of such classes (in which case they will be lazily loaded).
-ssl_adapters = {
- 'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter',
- 'pyopenssl': 'cherrypy.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter',
- }
-
-def get_ssl_adapter_class(name='pyopenssl'):
- adapter = ssl_adapters[name.lower()]
- if isinstance(adapter, basestring):
- last_dot = adapter.rfind(".")
- attr_name = adapter[last_dot + 1:]
- mod_path = adapter[:last_dot]
-
- try:
- mod = sys.modules[mod_path]
- if mod is None:
- raise KeyError()
- except KeyError:
- # The last [''] is important.
- mod = __import__(mod_path, globals(), locals(), [''])
-
- # Let an AttributeError propagate outward.
- try:
- adapter = getattr(mod, attr_name)
- except AttributeError:
- raise AttributeError("'%s' object has no attribute '%s'"
- % (mod_path, attr_name))
-
- return adapter
-
-# -------------------------------- WSGI Stuff -------------------------------- #
-
-
-class CherryPyWSGIServer(HTTPServer):
-
- wsgi_version = (1, 0)
-
- def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None,
- max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5):
- self.requests = ThreadPool(self, min=numthreads or 1, max=max)
- self.wsgi_app = wsgi_app
- self.gateway = wsgi_gateways[self.wsgi_version]
-
- self.bind_addr = bind_addr
- if not server_name:
- server_name = socket.gethostname()
- self.server_name = server_name
- self.request_queue_size = request_queue_size
-
- self.timeout = timeout
- self.shutdown_timeout = shutdown_timeout
- self.clear_stats()
-
- def _get_numthreads(self):
- return self.requests.min
- def _set_numthreads(self, value):
- self.requests.min = value
- numthreads = property(_get_numthreads, _set_numthreads)
-
-
-class WSGIGateway(Gateway):
-
- def __init__(self, req):
- self.req = req
- self.started_response = False
- self.env = self.get_environ()
- self.remaining_bytes_out = None
-
- def get_environ(self):
- """Return a new environ dict targeting the given wsgi.version"""
- raise NotImplemented
-
- def respond(self):
- response = self.req.server.wsgi_app(self.env, self.start_response)
- try:
- for chunk in response:
- # "The start_response callable must not actually transmit
- # the response headers. Instead, it must store them for the
- # server or gateway to transmit only after the first
- # iteration of the application return value that yields
- # a NON-EMPTY string, or upon the application's first
- # invocation of the write() callable." (PEP 333)
- if chunk:
- if isinstance(chunk, unicode):
- chunk = chunk.encode('ISO-8859-1')
- self.write(chunk)
- finally:
- if hasattr(response, "close"):
- response.close()
-
- def start_response(self, status, headers, exc_info = None):
- """WSGI callable to begin the HTTP response."""
- # "The application may call start_response more than once,
- # if and only if the exc_info argument is provided."
- if self.started_response and not exc_info:
- raise AssertionError("WSGI start_response called a second "
- "time with no exc_info.")
- self.started_response = True
-
- # "if exc_info is provided, and the HTTP headers have already been
- # sent, start_response must raise an error, and should raise the
- # exc_info tuple."
- if self.req.sent_headers:
- try:
- raise exc_info[0], exc_info[1], exc_info[2]
- finally:
- exc_info = None
-
- self.req.status = status
- for k, v in headers:
- if not isinstance(k, str):
- raise TypeError("WSGI response header key %r is not a byte string." % k)
- if not isinstance(v, str):
- raise TypeError("WSGI response header value %r is not a byte string." % v)
- if k.lower() == 'content-length':
- self.remaining_bytes_out = int(v)
- self.req.outheaders.extend(headers)
-
- return self.write
-
- def write(self, chunk):
- """WSGI callable to write unbuffered data to the client.
-
- This method is also used internally by start_response (to write
- data from the iterable returned by the WSGI application).
- """
- if not self.started_response:
- raise AssertionError("WSGI write called before start_response.")
-
- chunklen = len(chunk)
- rbo = self.remaining_bytes_out
- if rbo is not None and chunklen > rbo:
- if not self.req.sent_headers:
- # Whew. We can send a 500 to the client.
- self.req.simple_response("500 Internal Server Error",
- "The requested resource returned more bytes than the "
- "declared Content-Length.")
- else:
- # Dang. We have probably already sent data. Truncate the chunk
- # to fit (so the client doesn't hang) and raise an error later.
- chunk = chunk[:rbo]
-
- if not self.req.sent_headers:
- self.req.sent_headers = True
- self.req.send_headers()
-
- self.req.write(chunk)
-
- if rbo is not None:
- rbo -= chunklen
- if rbo < 0:
- raise ValueError(
- "Response body exceeds the declared Content-Length.")
-
-
-class WSGIGateway_10(WSGIGateway):
-
- def get_environ(self):
- """Return a new environ dict targeting the given wsgi.version"""
- req = self.req
- env = {
- # set a non-standard environ entry so the WSGI app can know what
- # the *real* server protocol is (and what features to support).
- # See http://www.faqs.org/rfcs/rfc2145.html.
- 'ACTUAL_SERVER_PROTOCOL': req.server.protocol,
- 'PATH_INFO': req.path,
- 'QUERY_STRING': req.qs,
- 'REMOTE_ADDR': req.conn.remote_addr or '',
- 'REMOTE_PORT': str(req.conn.remote_port or ''),
- 'REQUEST_METHOD': req.method,
- 'REQUEST_URI': req.uri,
- 'SCRIPT_NAME': '',
- 'SERVER_NAME': req.server.server_name,
- # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol.
- 'SERVER_PROTOCOL': req.request_protocol,
- 'SERVER_SOFTWARE': req.server.software,
- 'wsgi.errors': sys.stderr,
- 'wsgi.input': req.rfile,
- 'wsgi.multiprocess': False,
- 'wsgi.multithread': True,
- 'wsgi.run_once': False,
- 'wsgi.url_scheme': req.scheme,
- 'wsgi.version': (1, 0),
- }
-
- if isinstance(req.server.bind_addr, basestring):
- # AF_UNIX. This isn't really allowed by WSGI, which doesn't
- # address unix domain sockets. But it's better than nothing.
- env["SERVER_PORT"] = ""
- else:
- env["SERVER_PORT"] = str(req.server.bind_addr[1])
-
- # Request headers
- for k, v in req.inheaders.iteritems():
- env["HTTP_" + k.upper().replace("-", "_")] = v
-
- # CONTENT_TYPE/CONTENT_LENGTH
- ct = env.pop("HTTP_CONTENT_TYPE", None)
- if ct is not None:
- env["CONTENT_TYPE"] = ct
- cl = env.pop("HTTP_CONTENT_LENGTH", None)
- if cl is not None:
- env["CONTENT_LENGTH"] = cl
-
- if req.conn.ssl_env:
- env.update(req.conn.ssl_env)
-
- return env
-
-
-class WSGIGateway_u0(WSGIGateway_10):
-
- def get_environ(self):
- """Return a new environ dict targeting the given wsgi.version"""
- req = self.req
- env_10 = WSGIGateway_10.get_environ(self)
- env = dict([(k.decode('ISO-8859-1'), v) for k, v in env_10.iteritems()])
- env[u'wsgi.version'] = ('u', 0)
-
- # Request-URI
- env.setdefault(u'wsgi.url_encoding', u'utf-8')
- try:
- for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]:
- env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding'])
- except UnicodeDecodeError:
- # Fall back to latin 1 so apps can transcode if needed.
- env[u'wsgi.url_encoding'] = u'ISO-8859-1'
- for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]:
- env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding'])
-
- for k, v in sorted(env.items()):
- if isinstance(v, str) and k not in ('REQUEST_URI', 'wsgi.input'):
- env[k] = v.decode('ISO-8859-1')
-
- return env
-
-wsgi_gateways = {
- (1, 0): WSGIGateway_10,
- ('u', 0): WSGIGateway_u0,
-}
-
-class WSGIPathInfoDispatcher(object):
- """A WSGI dispatcher for dispatch based on the PATH_INFO.
-
- apps: a dict or list of (path_prefix, app) pairs.
- """
-
- def __init__(self, apps):
- try:
- apps = apps.items()
- except AttributeError:
- pass
-
- # Sort the apps by len(path), descending
- apps.sort(cmp=lambda x,y: cmp(len(x[0]), len(y[0])))
- apps.reverse()
-
- # The path_prefix strings must start, but not end, with a slash.
- # Use "" instead of "/".
- self.apps = [(p.rstrip("/"), a) for p, a in apps]
-
- def __call__(self, environ, start_response):
- path = environ["PATH_INFO"] or "/"
- for p, app in self.apps:
- # The apps list should be sorted by length, descending.
- if path.startswith(p + "/") or path == p:
- environ = environ.copy()
- environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p
- environ["PATH_INFO"] = path[len(p):]
- return app(environ, start_response)
-
- start_response('404 Not Found', [('Content-Type', 'text/plain'),
- ('Content-Length', '0')])
- return ['']
-
+"""A high-speed, production ready, thread pooled, generic HTTP server.
+
+Simplest example on how to use this module directly
+(without using CherryPy's application machinery)::
+
+ from cherrypy import wsgiserver
+
+ def my_crazy_app(environ, start_response):
+ status = '200 OK'
+ response_headers = [('Content-type','text/plain')]
+ start_response(status, response_headers)
+ return ['Hello world!']
+
+ server = wsgiserver.CherryPyWSGIServer(
+ ('0.0.0.0', 8070), my_crazy_app,
+ server_name='www.cherrypy.example')
+ server.start()
+
+The CherryPy WSGI server can serve as many WSGI applications
+as you want in one instance by using a WSGIPathInfoDispatcher::
+
+ d = WSGIPathInfoDispatcher({'/': my_crazy_app, '/blog': my_blog_app})
+ server = wsgiserver.CherryPyWSGIServer(('0.0.0.0', 80), d)
+
+Want SSL support? Just set server.ssl_adapter to an SSLAdapter instance.
+
+This won't call the CherryPy engine (application side) at all, only the
+HTTP server, which is independent from the rest of CherryPy. Don't
+let the name "CherryPyWSGIServer" throw you; the name merely reflects
+its origin, not its coupling.
+
+For those of you wanting to understand internals of this module, here's the
+basic call flow. The server's listening thread runs a very tight loop,
+sticking incoming connections onto a Queue::
+
+ server = CherryPyWSGIServer(...)
+ server.start()
+ while True:
+ tick()
+ # This blocks until a request comes in:
+ child = socket.accept()
+ conn = HTTPConnection(child, ...)
+ server.requests.put(conn)
+
+Worker threads are kept in a pool and poll the Queue, popping off and then
+handling each connection in turn. Each connection can consist of an arbitrary
+number of requests and their responses, so we run a nested loop::
+
+ while True:
+ conn = server.requests.get()
+ conn.communicate()
+ -> while True:
+ req = HTTPRequest(...)
+ req.parse_request()
+ -> # Read the Request-Line, e.g. "GET /page HTTP/1.1"
+ req.rfile.readline()
+ read_headers(req.rfile, req.inheaders)
+ req.respond()
+ -> response = app(...)
+ try:
+ for chunk in response:
+ if chunk:
+ req.write(chunk)
+ finally:
+ if hasattr(response, "close"):
+ response.close()
+ if req.close_connection:
+ return
+"""
+
+CRLF = '\r\n'
+import os
+import Queue
+import re
+quoted_slash = re.compile("(?i)%2F")
+import rfc822
+import socket
+import sys
+if 'win' in sys.platform and not hasattr(socket, 'IPPROTO_IPV6'):
+ socket.IPPROTO_IPV6 = 41
+try:
+ import cStringIO as StringIO
+except ImportError:
+ import StringIO
+DEFAULT_BUFFER_SIZE = -1
+
+_fileobject_uses_str_type = isinstance(socket._fileobject(None)._rbuf, basestring)
+
+import threading
+import time
+import traceback
+def format_exc(limit=None):
+ """Like print_exc() but return a string. Backport for Python 2.3."""
+ try:
+ etype, value, tb = sys.exc_info()
+ return ''.join(traceback.format_exception(etype, value, tb, limit))
+ finally:
+ etype = value = tb = None
+
+
+from urllib import unquote
+from urlparse import urlparse
+import warnings
+
+import errno
+
+def plat_specific_errors(*errnames):
+ """Return error numbers for all errors in errnames on this platform.
+
+ The 'errno' module contains different global constants depending on
+ the specific platform (OS). This function will return the list of
+ numeric values for a given list of potential names.
+ """
+ errno_names = dir(errno)
+ nums = [getattr(errno, k) for k in errnames if k in errno_names]
+ # de-dupe the list
+ return dict.fromkeys(nums).keys()
+
+socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR")
+
+socket_errors_to_ignore = plat_specific_errors(
+ "EPIPE",
+ "EBADF", "WSAEBADF",
+ "ENOTSOCK", "WSAENOTSOCK",
+ "ETIMEDOUT", "WSAETIMEDOUT",
+ "ECONNREFUSED", "WSAECONNREFUSED",
+ "ECONNRESET", "WSAECONNRESET",
+ "ECONNABORTED", "WSAECONNABORTED",
+ "ENETRESET", "WSAENETRESET",
+ "EHOSTDOWN", "EHOSTUNREACH",
+ )
+socket_errors_to_ignore.append("timed out")
+socket_errors_to_ignore.append("The read operation timed out")
+
+socket_errors_nonblocking = plat_specific_errors(
+ 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK')
+
+comma_separated_headers = ['Accept', 'Accept-Charset', 'Accept-Encoding',
+ 'Accept-Language', 'Accept-Ranges', 'Allow', 'Cache-Control',
+ 'Connection', 'Content-Encoding', 'Content-Language', 'Expect',
+ 'If-Match', 'If-None-Match', 'Pragma', 'Proxy-Authenticate', 'TE',
+ 'Trailer', 'Transfer-Encoding', 'Upgrade', 'Vary', 'Via', 'Warning',
+ 'WWW-Authenticate']
+
+
+import logging
+if not hasattr(logging, 'statistics'): logging.statistics = {}
+
+
+def read_headers(rfile, hdict=None):
+ """Read headers from the given stream into the given header dict.
+
+ If hdict is None, a new header dict is created. Returns the populated
+ header dict.
+
+ Headers which are repeated are folded together using a comma if their
+ specification so dictates.
+
+ This function raises ValueError when the read bytes violate the HTTP spec.
+ You should probably return "400 Bad Request" if this happens.
+ """
+ if hdict is None:
+ hdict = {}
+
+ while True:
+ line = rfile.readline()
+ if not line:
+ # No more data--illegal end of headers
+ raise ValueError("Illegal end of headers.")
+
+ if line == CRLF:
+ # Normal end of headers
+ break
+ if not line.endswith(CRLF):
+ raise ValueError("HTTP requires CRLF terminators")
+
+ if line[0] in ' \t':
+ # It's a continuation line.
+ v = line.strip()
+ else:
+ try:
+ k, v = line.split(":", 1)
+ except ValueError:
+ raise ValueError("Illegal header line.")
+ # TODO: what about TE and WWW-Authenticate?
+ k = k.strip().title()
+ v = v.strip()
+ hname = k
+
+ if k in comma_separated_headers:
+ existing = hdict.get(hname)
+ if existing:
+ v = ", ".join((existing, v))
+ hdict[hname] = v
+
+ return hdict
+
+
+class MaxSizeExceeded(Exception):
+ pass
+
+class SizeCheckWrapper(object):
+ """Wraps a file-like object, raising MaxSizeExceeded if too large."""
+
+ def __init__(self, rfile, maxlen):
+ self.rfile = rfile
+ self.maxlen = maxlen
+ self.bytes_read = 0
+
+ def _check_length(self):
+ if self.maxlen and self.bytes_read > self.maxlen:
+ raise MaxSizeExceeded()
+
+ def read(self, size=None):
+ data = self.rfile.read(size)
+ self.bytes_read += len(data)
+ self._check_length()
+ return data
+
+ def readline(self, size=None):
+ if size is not None:
+ data = self.rfile.readline(size)
+ self.bytes_read += len(data)
+ self._check_length()
+ return data
+
+ # User didn't specify a size ...
+ # We read the line in chunks to make sure it's not a 100MB line !
+ res = []
+ while True:
+ data = self.rfile.readline(256)
+ self.bytes_read += len(data)
+ self._check_length()
+ res.append(data)
+ # See http://www.cherrypy.org/ticket/421
+ if len(data) < 256 or data[-1:] == "\n":
+ return ''.join(res)
+
+ def readlines(self, sizehint=0):
+ # Shamelessly stolen from StringIO
+ total = 0
+ lines = []
+ line = self.readline()
+ while line:
+ lines.append(line)
+ total += len(line)
+ if 0 < sizehint <= total:
+ break
+ line = self.readline()
+ return lines
+
+ def close(self):
+ self.rfile.close()
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ data = self.rfile.next()
+ self.bytes_read += len(data)
+ self._check_length()
+ return data
+
+
+class KnownLengthRFile(object):
+ """Wraps a file-like object, returning an empty string when exhausted."""
+
+ def __init__(self, rfile, content_length):
+ self.rfile = rfile
+ self.remaining = content_length
+
+ def read(self, size=None):
+ if self.remaining == 0:
+ return ''
+ if size is None:
+ size = self.remaining
+ else:
+ size = min(size, self.remaining)
+
+ data = self.rfile.read(size)
+ self.remaining -= len(data)
+ return data
+
+ def readline(self, size=None):
+ if self.remaining == 0:
+ return ''
+ if size is None:
+ size = self.remaining
+ else:
+ size = min(size, self.remaining)
+
+ data = self.rfile.readline(size)
+ self.remaining -= len(data)
+ return data
+
+ def readlines(self, sizehint=0):
+ # Shamelessly stolen from StringIO
+ total = 0
+ lines = []
+ line = self.readline(sizehint)
+ while line:
+ lines.append(line)
+ total += len(line)
+ if 0 < sizehint <= total:
+ break
+ line = self.readline(sizehint)
+ return lines
+
+ def close(self):
+ self.rfile.close()
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ data = next(self.rfile)
+ self.remaining -= len(data)
+ return data
+
+
+class ChunkedRFile(object):
+ """Wraps a file-like object, returning an empty string when exhausted.
+
+ This class is intended to provide a conforming wsgi.input value for
+ request entities that have been encoded with the 'chunked' transfer
+ encoding.
+ """
+
+ def __init__(self, rfile, maxlen, bufsize=8192):
+ self.rfile = rfile
+ self.maxlen = maxlen
+ self.bytes_read = 0
+ self.buffer = ''
+ self.bufsize = bufsize
+ self.closed = False
+
+ def _fetch(self):
+ if self.closed:
+ return
+
+ line = self.rfile.readline()
+ self.bytes_read += len(line)
+
+ if self.maxlen and self.bytes_read > self.maxlen:
+ raise MaxSizeExceeded("Request Entity Too Large", self.maxlen)
+
+ line = line.strip().split(";", 1)
+
+ try:
+ chunk_size = line.pop(0)
+ chunk_size = int(chunk_size, 16)
+ except ValueError:
+ raise ValueError("Bad chunked transfer size: " + repr(chunk_size))
+
+ if chunk_size <= 0:
+ self.closed = True
+ return
+
+## if line: chunk_extension = line[0]
+
+ if self.maxlen and self.bytes_read + chunk_size > self.maxlen:
+ raise IOError("Request Entity Too Large")
+
+ chunk = self.rfile.read(chunk_size)
+ self.bytes_read += len(chunk)
+ self.buffer += chunk
+
+ crlf = self.rfile.read(2)
+ if crlf != CRLF:
+ raise ValueError(
+ "Bad chunked transfer coding (expected '\\r\\n', "
+ "got " + repr(crlf) + ")")
+
+ def read(self, size=None):
+ data = ''
+ while True:
+ if size and len(data) >= size:
+ return data
+
+ if not self.buffer:
+ self._fetch()
+ if not self.buffer:
+ # EOF
+ return data
+
+ if size:
+ remaining = size - len(data)
+ data += self.buffer[:remaining]
+ self.buffer = self.buffer[remaining:]
+ else:
+ data += self.buffer
+
+ def readline(self, size=None):
+ data = ''
+ while True:
+ if size and len(data) >= size:
+ return data
+
+ if not self.buffer:
+ self._fetch()
+ if not self.buffer:
+ # EOF
+ return data
+
+ newline_pos = self.buffer.find('\n')
+ if size:
+ if newline_pos == -1:
+ remaining = size - len(data)
+ data += self.buffer[:remaining]
+ self.buffer = self.buffer[remaining:]
+ else:
+ remaining = min(size - len(data), newline_pos)
+ data += self.buffer[:remaining]
+ self.buffer = self.buffer[remaining:]
+ else:
+ if newline_pos == -1:
+ data += self.buffer
+ else:
+ data += self.buffer[:newline_pos]
+ self.buffer = self.buffer[newline_pos:]
+
+ def readlines(self, sizehint=0):
+ # Shamelessly stolen from StringIO
+ total = 0
+ lines = []
+ line = self.readline(sizehint)
+ while line:
+ lines.append(line)
+ total += len(line)
+ if 0 < sizehint <= total:
+ break
+ line = self.readline(sizehint)
+ return lines
+
+ def read_trailer_lines(self):
+ if not self.closed:
+ raise ValueError(
+ "Cannot read trailers until the request body has been read.")
+
+ while True:
+ line = self.rfile.readline()
+ if not line:
+ # No more data--illegal end of headers
+ raise ValueError("Illegal end of headers.")
+
+ self.bytes_read += len(line)
+ if self.maxlen and self.bytes_read > self.maxlen:
+ raise IOError("Request Entity Too Large")
+
+ if line == CRLF:
+ # Normal end of headers
+ break
+ if not line.endswith(CRLF):
+ raise ValueError("HTTP requires CRLF terminators")
+
+ yield line
+
+ def close(self):
+ self.rfile.close()
+
+ def __iter__(self):
+ # Shamelessly stolen from StringIO
+ total = 0
+ line = self.readline(sizehint)
+ while line:
+ yield line
+ total += len(line)
+ if 0 < sizehint <= total:
+ break
+ line = self.readline(sizehint)
+
+
+class HTTPRequest(object):
+ """An HTTP Request (and response).
+
+ A single HTTP connection may consist of multiple request/response pairs.
+ """
+
+ server = None
+ """The HTTPServer object which is receiving this request."""
+
+ conn = None
+ """The HTTPConnection object on which this request connected."""
+
+ inheaders = {}
+ """A dict of request headers."""
+
+ outheaders = []
+ """A list of header tuples to write in the response."""
+
+ ready = False
+ """When True, the request has been parsed and is ready to begin generating
+ the response. When False, signals the calling Connection that the response
+ should not be generated and the connection should close."""
+
+ close_connection = False
+ """Signals the calling Connection that the request should close. This does
+ not imply an error! The client and/or server may each request that the
+ connection be closed."""
+
+ chunked_write = False
+ """If True, output will be encoded with the "chunked" transfer-coding.
+
+ This value is set automatically inside send_headers."""
+
+ def __init__(self, server, conn):
+ self.server= server
+ self.conn = conn
+
+ self.ready = False
+ self.started_request = False
+ self.scheme = "http"
+ if self.server.ssl_adapter is not None:
+ self.scheme = "https"
+ # Use the lowest-common protocol in case read_request_line errors.
+ self.response_protocol = 'HTTP/1.0'
+ self.inheaders = {}
+
+ self.status = ""
+ self.outheaders = []
+ self.sent_headers = False
+ self.close_connection = self.__class__.close_connection
+ self.chunked_read = False
+ self.chunked_write = self.__class__.chunked_write
+
+ def parse_request(self):
+ """Parse the next HTTP request start-line and message-headers."""
+ self.rfile = SizeCheckWrapper(self.conn.rfile,
+ self.server.max_request_header_size)
+ try:
+ self.read_request_line()
+ except MaxSizeExceeded:
+ self.simple_response("414 Request-URI Too Long",
+ "The Request-URI sent with the request exceeds the maximum "
+ "allowed bytes.")
+ return
+
+ try:
+ success = self.read_request_headers()
+ except MaxSizeExceeded:
+ self.simple_response("413 Request Entity Too Large",
+ "The headers sent with the request exceed the maximum "
+ "allowed bytes.")
+ return
+ else:
+ if not success:
+ return
+
+ self.ready = True
+
+ def read_request_line(self):
+ # HTTP/1.1 connections are persistent by default. If a client
+ # requests a page, then idles (leaves the connection open),
+ # then rfile.readline() will raise socket.error("timed out").
+ # Note that it does this based on the value given to settimeout(),
+ # and doesn't need the client to request or acknowledge the close
+ # (although your TCP stack might suffer for it: cf Apache's history
+ # with FIN_WAIT_2).
+ request_line = self.rfile.readline()
+
+ # Set started_request to True so communicate() knows to send 408
+ # from here on out.
+ self.started_request = True
+ if not request_line:
+ # Force self.ready = False so the connection will close.
+ self.ready = False
+ return
+
+ if request_line == CRLF:
+ # RFC 2616 sec 4.1: "...if the server is reading the protocol
+ # stream at the beginning of a message and receives a CRLF
+ # first, it should ignore the CRLF."
+ # But only ignore one leading line! else we enable a DoS.
+ request_line = self.rfile.readline()
+ if not request_line:
+ self.ready = False
+ return
+
+ if not request_line.endswith(CRLF):
+ self.simple_response("400 Bad Request", "HTTP requires CRLF terminators")
+ return
+
+ try:
+ method, uri, req_protocol = request_line.strip().split(" ", 2)
+ rp = int(req_protocol[5]), int(req_protocol[7])
+ except (ValueError, IndexError):
+ self.simple_response("400 Bad Request", "Malformed Request-Line")
+ return
+
+ self.uri = uri
+ self.method = method
+
+ # uri may be an abs_path (including "http://host.domain.tld");
+ scheme, authority, path = self.parse_request_uri(uri)
+ if '#' in path:
+ self.simple_response("400 Bad Request",
+ "Illegal #fragment in Request-URI.")
+ return
+
+ if scheme:
+ self.scheme = scheme
+
+ qs = ''
+ if '?' in path:
+ path, qs = path.split('?', 1)
+
+ # Unquote the path+params (e.g. "/this%20path" -> "/this path").
+ # http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2
+ #
+ # But note that "...a URI must be separated into its components
+ # before the escaped characters within those components can be
+ # safely decoded." http://www.ietf.org/rfc/rfc2396.txt, sec 2.4.2
+ # Therefore, "/this%2Fpath" becomes "/this%2Fpath", not "/this/path".
+ try:
+ atoms = [unquote(x) for x in quoted_slash.split(path)]
+ except ValueError, ex:
+ self.simple_response("400 Bad Request", ex.args[0])
+ return
+ path = "%2F".join(atoms)
+ self.path = path
+
+ # Note that, like wsgiref and most other HTTP servers,
+ # we "% HEX HEX"-unquote the path but not the query string.
+ self.qs = qs
+
+ # Compare request and server HTTP protocol versions, in case our
+ # server does not support the requested protocol. Limit our output
+ # to min(req, server). We want the following output:
+ # request server actual written supported response
+ # protocol protocol response protocol feature set
+ # a 1.0 1.0 1.0 1.0
+ # b 1.0 1.1 1.1 1.0
+ # c 1.1 1.0 1.0 1.0
+ # d 1.1 1.1 1.1 1.1
+ # Notice that, in (b), the response will be "HTTP/1.1" even though
+ # the client only understands 1.0. RFC 2616 10.5.6 says we should
+ # only return 505 if the _major_ version is different.
+ sp = int(self.server.protocol[5]), int(self.server.protocol[7])
+
+ if sp[0] != rp[0]:
+ self.simple_response("505 HTTP Version Not Supported")
+ return
+ self.request_protocol = req_protocol
+ self.response_protocol = "HTTP/%s.%s" % min(rp, sp)
+
+ def read_request_headers(self):
+ """Read self.rfile into self.inheaders. Return success."""
+
+ # then all the http headers
+ try:
+ read_headers(self.rfile, self.inheaders)
+ except ValueError, ex:
+ self.simple_response("400 Bad Request", ex.args[0])
+ return False
+
+ mrbs = self.server.max_request_body_size
+ if mrbs and int(self.inheaders.get("Content-Length", 0)) > mrbs:
+ self.simple_response("413 Request Entity Too Large",
+ "The entity sent with the request exceeds the maximum "
+ "allowed bytes.")
+ return False
+
+ # Persistent connection support
+ if self.response_protocol == "HTTP/1.1":
+ # Both server and client are HTTP/1.1
+ if self.inheaders.get("Connection", "") == "close":
+ self.close_connection = True
+ else:
+ # Either the server or client (or both) are HTTP/1.0
+ if self.inheaders.get("Connection", "") != "Keep-Alive":
+ self.close_connection = True
+
+ # Transfer-Encoding support
+ te = None
+ if self.response_protocol == "HTTP/1.1":
+ te = self.inheaders.get("Transfer-Encoding")
+ if te:
+ te = [x.strip().lower() for x in te.split(",") if x.strip()]
+
+ self.chunked_read = False
+
+ if te:
+ for enc in te:
+ if enc == "chunked":
+ self.chunked_read = True
+ else:
+ # Note that, even if we see "chunked", we must reject
+ # if there is an extension we don't recognize.
+ self.simple_response("501 Unimplemented")
+ self.close_connection = True
+ return False
+
+ # From PEP 333:
+ # "Servers and gateways that implement HTTP 1.1 must provide
+ # transparent support for HTTP 1.1's "expect/continue" mechanism.
+ # This may be done in any of several ways:
+ # 1. Respond to requests containing an Expect: 100-continue request
+ # with an immediate "100 Continue" response, and proceed normally.
+ # 2. Proceed with the request normally, but provide the application
+ # with a wsgi.input stream that will send the "100 Continue"
+ # response if/when the application first attempts to read from
+ # the input stream. The read request must then remain blocked
+ # until the client responds.
+ # 3. Wait until the client decides that the server does not support
+ # expect/continue, and sends the request body on its own.
+ # (This is suboptimal, and is not recommended.)
+ #
+ # We used to do 3, but are now doing 1. Maybe we'll do 2 someday,
+ # but it seems like it would be a big slowdown for such a rare case.
+ if self.inheaders.get("Expect", "") == "100-continue":
+ # Don't use simple_response here, because it emits headers
+ # we don't want. See http://www.cherrypy.org/ticket/951
+ msg = self.server.protocol + " 100 Continue\r\n\r\n"
+ try:
+ self.conn.wfile.sendall(msg)
+ except socket.error, x:
+ if x.args[0] not in socket_errors_to_ignore:
+ raise
+ return True
+
+ def parse_request_uri(self, uri):
+ """Parse a Request-URI into (scheme, authority, path).
+
+ Note that Request-URI's must be one of::
+
+ Request-URI = "*" | absoluteURI | abs_path | authority
+
+ Therefore, a Request-URI which starts with a double forward-slash
+ cannot be a "net_path"::
+
+ net_path = "//" authority [ abs_path ]
+
+ Instead, it must be interpreted as an "abs_path" with an empty first
+ path segment::
+
+ abs_path = "/" path_segments
+ path_segments = segment *( "/" segment )
+ segment = *pchar *( ";" param )
+ param = *pchar
+ """
+ if uri == "*":
+ return None, None, uri
+
+ i = uri.find('://')
+ if i > 0 and '?' not in uri[:i]:
+ # An absoluteURI.
+ # If there's a scheme (and it must be http or https), then:
+ # http_URL = "http:" "//" host [ ":" port ] [ abs_path [ "?" query ]]
+ scheme, remainder = uri[:i].lower(), uri[i + 3:]
+ authority, path = remainder.split("/", 1)
+ return scheme, authority, path
+
+ if uri.startswith('/'):
+ # An abs_path.
+ return None, None, uri
+ else:
+ # An authority.
+ return None, uri, None
+
+ def respond(self):
+ """Call the gateway and write its iterable output."""
+ mrbs = self.server.max_request_body_size
+ if self.chunked_read:
+ self.rfile = ChunkedRFile(self.conn.rfile, mrbs)
+ else:
+ cl = int(self.inheaders.get("Content-Length", 0))
+ if mrbs and mrbs < cl:
+ if not self.sent_headers:
+ self.simple_response("413 Request Entity Too Large",
+ "The entity sent with the request exceeds the maximum "
+ "allowed bytes.")
+ return
+ self.rfile = KnownLengthRFile(self.conn.rfile, cl)
+
+ self.server.gateway(self).respond()
+
+ if (self.ready and not self.sent_headers):
+ self.sent_headers = True
+ self.send_headers()
+ if self.chunked_write:
+ self.conn.wfile.sendall("0\r\n\r\n")
+
+ def simple_response(self, status, msg=""):
+ """Write a simple response back to the client."""
+ status = str(status)
+ buf = [self.server.protocol + " " +
+ status + CRLF,
+ "Content-Length: %s\r\n" % len(msg),
+ "Content-Type: text/plain\r\n"]
+
+ if status[:3] in ("413", "414"):
+ # Request Entity Too Large / Request-URI Too Long
+ self.close_connection = True
+ if self.response_protocol == 'HTTP/1.1':
+ # This will not be true for 414, since read_request_line
+ # usually raises 414 before reading the whole line, and we
+ # therefore cannot know the proper response_protocol.
+ buf.append("Connection: close\r\n")
+ else:
+ # HTTP/1.0 had no 413/414 status nor Connection header.
+ # Emit 400 instead and trust the message body is enough.
+ status = "400 Bad Request"
+
+ buf.append(CRLF)
+ if msg:
+ if isinstance(msg, unicode):
+ msg = msg.encode("ISO-8859-1")
+ buf.append(msg)
+
+ try:
+ self.conn.wfile.sendall("".join(buf))
+ except socket.error, x:
+ if x.args[0] not in socket_errors_to_ignore:
+ raise
+
+ def write(self, chunk):
+ """Write unbuffered data to the client."""
+ if self.chunked_write and chunk:
+ buf = [hex(len(chunk))[2:], CRLF, chunk, CRLF]
+ self.conn.wfile.sendall("".join(buf))
+ else:
+ self.conn.wfile.sendall(chunk)
+
+ def send_headers(self):
+ """Assert, process, and send the HTTP response message-headers.
+
+ You must set self.status, and self.outheaders before calling this.
+ """
+ hkeys = [key.lower() for key, value in self.outheaders]
+ status = int(self.status[:3])
+
+ if status == 413:
+ # Request Entity Too Large. Close conn to avoid garbage.
+ self.close_connection = True
+ elif "content-length" not in hkeys:
+ # "All 1xx (informational), 204 (no content),
+ # and 304 (not modified) responses MUST NOT
+ # include a message-body." So no point chunking.
+ if status < 200 or status in (204, 205, 304):
+ pass
+ else:
+ if (self.response_protocol == 'HTTP/1.1'
+ and self.method != 'HEAD'):
+ # Use the chunked transfer-coding
+ self.chunked_write = True
+ self.outheaders.append(("Transfer-Encoding", "chunked"))
+ else:
+ # Closing the conn is the only way to determine len.
+ self.close_connection = True
+
+ if "connection" not in hkeys:
+ if self.response_protocol == 'HTTP/1.1':
+ # Both server and client are HTTP/1.1 or better
+ if self.close_connection:
+ self.outheaders.append(("Connection", "close"))
+ else:
+ # Server and/or client are HTTP/1.0
+ if not self.close_connection:
+ self.outheaders.append(("Connection", "Keep-Alive"))
+
+ if (not self.close_connection) and (not self.chunked_read):
+ # Read any remaining request body data on the socket.
+ # "If an origin server receives a request that does not include an
+ # Expect request-header field with the "100-continue" expectation,
+ # the request includes a request body, and the server responds
+ # with a final status code before reading the entire request body
+ # from the transport connection, then the server SHOULD NOT close
+ # the transport connection until it has read the entire request,
+ # or until the client closes the connection. Otherwise, the client
+ # might not reliably receive the response message. However, this
+ # requirement is not be construed as preventing a server from
+ # defending itself against denial-of-service attacks, or from
+ # badly broken client implementations."
+ remaining = getattr(self.rfile, 'remaining', 0)
+ if remaining > 0:
+ self.rfile.read(remaining)
+
+ if "date" not in hkeys:
+ self.outheaders.append(("Date", rfc822.formatdate()))
+
+ if "server" not in hkeys:
+ self.outheaders.append(("Server", self.server.server_name))
+
+ buf = [self.server.protocol + " " + self.status + CRLF]
+ for k, v in self.outheaders:
+ buf.append(k + ": " + v + CRLF)
+ buf.append(CRLF)
+ self.conn.wfile.sendall("".join(buf))
+
+
+class NoSSLError(Exception):
+ """Exception raised when a client speaks HTTP to an HTTPS socket."""
+ pass
+
+
+class FatalSSLAlert(Exception):
+ """Exception raised when the SSL implementation signals a fatal alert."""
+ pass
+
+
+class CP_fileobject(socket._fileobject):
+ """Faux file object attached to a socket object."""
+
+ def __init__(self, *args, **kwargs):
+ self.bytes_read = 0
+ self.bytes_written = 0
+ socket._fileobject.__init__(self, *args, **kwargs)
+
+ def sendall(self, data):
+ """Sendall for non-blocking sockets."""
+ while data:
+ try:
+ bytes_sent = self.send(data)
+ data = data[bytes_sent:]
+ except socket.error, e:
+ if e.args[0] not in socket_errors_nonblocking:
+ raise
+
+ def send(self, data):
+ bytes_sent = self._sock.send(data)
+ self.bytes_written += bytes_sent
+ return bytes_sent
+
+ def flush(self):
+ if self._wbuf:
+ buffer = "".join(self._wbuf)
+ self._wbuf = []
+ self.sendall(buffer)
+
+ def recv(self, size):
+ while True:
+ try:
+ data = self._sock.recv(size)
+ self.bytes_read += len(data)
+ return data
+ except socket.error, e:
+ if (e.args[0] not in socket_errors_nonblocking
+ and e.args[0] not in socket_error_eintr):
+ raise
+
+ if not _fileobject_uses_str_type:
+ def read(self, size=-1):
+ # Use max, disallow tiny reads in a loop as they are very inefficient.
+ # We never leave read() with any leftover data from a new recv() call
+ # in our internal buffer.
+ rbufsize = max(self._rbufsize, self.default_bufsize)
+ # Our use of StringIO rather than lists of string objects returned by
+ # recv() minimizes memory usage and fragmentation that occurs when
+ # rbufsize is large compared to the typical return value of recv().
+ buf = self._rbuf
+ buf.seek(0, 2) # seek end
+ if size < 0:
+ # Read until EOF
+ self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
+ while True:
+ data = self.recv(rbufsize)
+ if not data:
+ break
+ buf.write(data)
+ return buf.getvalue()
+ else:
+ # Read until size bytes or EOF seen, whichever comes first
+ buf_len = buf.tell()
+ if buf_len >= size:
+ # Already have size bytes in our buffer? Extract and return.
+ buf.seek(0)
+ rv = buf.read(size)
+ self._rbuf = StringIO.StringIO()
+ self._rbuf.write(buf.read())
+ return rv
+
+ self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
+ while True:
+ left = size - buf_len
+ # recv() will malloc the amount of memory given as its
+ # parameter even though it often returns much less data
+ # than that. The returned data string is short lived
+ # as we copy it into a StringIO and free it. This avoids
+ # fragmentation issues on many platforms.
+ data = self.recv(left)
+ if not data:
+ break
+ n = len(data)
+ if n == size and not buf_len:
+ # Shortcut. Avoid buffer data copies when:
+ # - We have no data in our buffer.
+ # AND
+ # - Our call to recv returned exactly the
+ # number of bytes we were asked to read.
+ return data
+ if n == left:
+ buf.write(data)
+ del data # explicit free
+ break
+ assert n <= left, "recv(%d) returned %d bytes" % (left, n)
+ buf.write(data)
+ buf_len += n
+ del data # explicit free
+ #assert buf_len == buf.tell()
+ return buf.getvalue()
+
+ def readline(self, size=-1):
+ buf = self._rbuf
+ buf.seek(0, 2) # seek end
+ if buf.tell() > 0:
+ # check if we already have it in our buffer
+ buf.seek(0)
+ bline = buf.readline(size)
+ if bline.endswith('\n') or len(bline) == size:
+ self._rbuf = StringIO.StringIO()
+ self._rbuf.write(buf.read())
+ return bline
+ del bline
+ if size < 0:
+ # Read until \n or EOF, whichever comes first
+ if self._rbufsize <= 1:
+ # Speed up unbuffered case
+ buf.seek(0)
+ buffers = [buf.read()]
+ self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
+ data = None
+ recv = self.recv
+ while data != "\n":
+ data = recv(1)
+ if not data:
+ break
+ buffers.append(data)
+ return "".join(buffers)
+
+ buf.seek(0, 2) # seek end
+ self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
+ while True:
+ data = self.recv(self._rbufsize)
+ if not data:
+ break
+ nl = data.find('\n')
+ if nl >= 0:
+ nl += 1
+ buf.write(data[:nl])
+ self._rbuf.write(data[nl:])
+ del data
+ break
+ buf.write(data)
+ return buf.getvalue()
+ else:
+ # Read until size bytes or \n or EOF seen, whichever comes first
+ buf.seek(0, 2) # seek end
+ buf_len = buf.tell()
+ if buf_len >= size:
+ buf.seek(0)
+ rv = buf.read(size)
+ self._rbuf = StringIO.StringIO()
+ self._rbuf.write(buf.read())
+ return rv
+ self._rbuf = StringIO.StringIO() # reset _rbuf. we consume it via buf.
+ while True:
+ data = self.recv(self._rbufsize)
+ if not data:
+ break
+ left = size - buf_len
+ # did we just receive a newline?
+ nl = data.find('\n', 0, left)
+ if nl >= 0:
+ nl += 1
+ # save the excess data to _rbuf
+ self._rbuf.write(data[nl:])
+ if buf_len:
+ buf.write(data[:nl])
+ break
+ else:
+ # Shortcut. Avoid data copy through buf when returning
+ # a substring of our first recv().
+ return data[:nl]
+ n = len(data)
+ if n == size and not buf_len:
+ # Shortcut. Avoid data copy through buf when
+ # returning exactly all of our first recv().
+ return data
+ if n >= left:
+ buf.write(data[:left])
+ self._rbuf.write(data[left:])
+ break
+ buf.write(data)
+ buf_len += n
+ #assert buf_len == buf.tell()
+ return buf.getvalue()
+ else:
+ def read(self, size=-1):
+ if size < 0:
+ # Read until EOF
+ buffers = [self._rbuf]
+ self._rbuf = ""
+ if self._rbufsize <= 1:
+ recv_size = self.default_bufsize
+ else:
+ recv_size = self._rbufsize
+
+ while True:
+ data = self.recv(recv_size)
+ if not data:
+ break
+ buffers.append(data)
+ return "".join(buffers)
+ else:
+ # Read until size bytes or EOF seen, whichever comes first
+ data = self._rbuf
+ buf_len = len(data)
+ if buf_len >= size:
+ self._rbuf = data[size:]
+ return data[:size]
+ buffers = []
+ if data:
+ buffers.append(data)
+ self._rbuf = ""
+ while True:
+ left = size - buf_len
+ recv_size = max(self._rbufsize, left)
+ data = self.recv(recv_size)
+ if not data:
+ break
+ buffers.append(data)
+ n = len(data)
+ if n >= left:
+ self._rbuf = data[left:]
+ buffers[-1] = data[:left]
+ break
+ buf_len += n
+ return "".join(buffers)
+
+ def readline(self, size=-1):
+ data = self._rbuf
+ if size < 0:
+ # Read until \n or EOF, whichever comes first
+ if self._rbufsize <= 1:
+ # Speed up unbuffered case
+ assert data == ""
+ buffers = []
+ while data != "\n":
+ data = self.recv(1)
+ if not data:
+ break
+ buffers.append(data)
+ return "".join(buffers)
+ nl = data.find('\n')
+ if nl >= 0:
+ nl += 1
+ self._rbuf = data[nl:]
+ return data[:nl]
+ buffers = []
+ if data:
+ buffers.append(data)
+ self._rbuf = ""
+ while True:
+ data = self.recv(self._rbufsize)
+ if not data:
+ break
+ buffers.append(data)
+ nl = data.find('\n')
+ if nl >= 0:
+ nl += 1
+ self._rbuf = data[nl:]
+ buffers[-1] = data[:nl]
+ break
+ return "".join(buffers)
+ else:
+ # Read until size bytes or \n or EOF seen, whichever comes first
+ nl = data.find('\n', 0, size)
+ if nl >= 0:
+ nl += 1
+ self._rbuf = data[nl:]
+ return data[:nl]
+ buf_len = len(data)
+ if buf_len >= size:
+ self._rbuf = data[size:]
+ return data[:size]
+ buffers = []
+ if data:
+ buffers.append(data)
+ self._rbuf = ""
+ while True:
+ data = self.recv(self._rbufsize)
+ if not data:
+ break
+ buffers.append(data)
+ left = size - buf_len
+ nl = data.find('\n', 0, left)
+ if nl >= 0:
+ nl += 1
+ self._rbuf = data[nl:]
+ buffers[-1] = data[:nl]
+ break
+ n = len(data)
+ if n >= left:
+ self._rbuf = data[left:]
+ buffers[-1] = data[:left]
+ break
+ buf_len += n
+ return "".join(buffers)
+
+
+class HTTPConnection(object):
+ """An HTTP connection (active socket).
+
+ server: the Server object which received this connection.
+ socket: the raw socket object (usually TCP) for this connection.
+ makefile: a fileobject class for reading from the socket.
+ """
+
+ remote_addr = None
+ remote_port = None
+ ssl_env = None
+ rbufsize = DEFAULT_BUFFER_SIZE
+ wbufsize = DEFAULT_BUFFER_SIZE
+ RequestHandlerClass = HTTPRequest
+
+ def __init__(self, server, sock, makefile=CP_fileobject):
+ self.server = server
+ self.socket = sock
+ self.rfile = makefile(sock, "rb", self.rbufsize)
+ self.wfile = makefile(sock, "wb", self.wbufsize)
+ self.requests_seen = 0
+
+ def communicate(self):
+ """Read each request and respond appropriately."""
+ request_seen = False
+ try:
+ while True:
+ # (re)set req to None so that if something goes wrong in
+ # the RequestHandlerClass constructor, the error doesn't
+ # get written to the previous request.
+ req = None
+ req = self.RequestHandlerClass(self.server, self)
+
+ # This order of operations should guarantee correct pipelining.
+ req.parse_request()
+ if self.server.stats['Enabled']:
+ self.requests_seen += 1
+ if not req.ready:
+ # Something went wrong in the parsing (and the server has
+ # probably already made a simple_response). Return and
+ # let the conn close.
+ return
+
+ request_seen = True
+ req.respond()
+ if req.close_connection:
+ return
+ except socket.error, e:
+ errnum = e.args[0]
+ # sadly SSL sockets return a different (longer) time out string
+ if errnum == 'timed out' or errnum == 'The read operation timed out':
+ # Don't error if we're between requests; only error
+ # if 1) no request has been started at all, or 2) we're
+ # in the middle of a request.
+ # See http://www.cherrypy.org/ticket/853
+ if (not request_seen) or (req and req.started_request):
+ # Don't bother writing the 408 if the response
+ # has already started being written.
+ if req and not req.sent_headers:
+ try:
+ req.simple_response("408 Request Timeout")
+ except FatalSSLAlert:
+ # Close the connection.
+ return
+ elif errnum not in socket_errors_to_ignore:
+ if req and not req.sent_headers:
+ try:
+ req.simple_response("500 Internal Server Error",
+ format_exc())
+ except FatalSSLAlert:
+ # Close the connection.
+ return
+ return
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except FatalSSLAlert:
+ # Close the connection.
+ return
+ except NoSSLError:
+ if req and not req.sent_headers:
+ # Unwrap our wfile
+ self.wfile = CP_fileobject(self.socket._sock, "wb", self.wbufsize)
+ req.simple_response("400 Bad Request",
+ "The client sent a plain HTTP request, but "
+ "this server only speaks HTTPS on this port.")
+ self.linger = True
+ except Exception:
+ if req and not req.sent_headers:
+ try:
+ req.simple_response("500 Internal Server Error", format_exc())
+ except FatalSSLAlert:
+ # Close the connection.
+ return
+
+ linger = False
+
+ def close(self):
+ """Close the socket underlying this connection."""
+ self.rfile.close()
+
+ if not self.linger:
+ # Python's socket module does NOT call close on the kernel socket
+ # when you call socket.close(). We do so manually here because we
+ # want this server to send a FIN TCP segment immediately. Note this
+ # must be called *before* calling socket.close(), because the latter
+ # drops its reference to the kernel socket.
+ if hasattr(self.socket, '_sock'):
+ self.socket._sock.close()
+ self.socket.close()
+ else:
+ # On the other hand, sometimes we want to hang around for a bit
+ # to make sure the client has a chance to read our entire
+ # response. Skipping the close() calls here delays the FIN
+ # packet until the socket object is garbage-collected later.
+ # Someday, perhaps, we'll do the full lingering_close that
+ # Apache does, but not today.
+ pass
+
+
+_SHUTDOWNREQUEST = None
+
+class WorkerThread(threading.Thread):
+ """Thread which continuously polls a Queue for Connection objects.
+
+ Due to the timing issues of polling a Queue, a WorkerThread does not
+ check its own 'ready' flag after it has started. To stop the thread,
+ it is necessary to stick a _SHUTDOWNREQUEST object onto the Queue
+ (one for each running WorkerThread).
+ """
+
+ conn = None
+ """The current connection pulled off the Queue, or None."""
+
+ server = None
+ """The HTTP Server which spawned this thread, and which owns the
+ Queue and is placing active connections into it."""
+
+ ready = False
+ """A simple flag for the calling server to know when this thread
+ has begun polling the Queue."""
+
+
+ def __init__(self, server):
+ self.ready = False
+ self.server = server
+
+ self.requests_seen = 0
+ self.bytes_read = 0
+ self.bytes_written = 0
+ self.start_time = None
+ self.work_time = 0
+ self.stats = {
+ 'Requests': lambda s: self.requests_seen + ((self.start_time is None) and 0 or self.conn.requests_seen),
+ 'Bytes Read': lambda s: self.bytes_read + ((self.start_time is None) and 0 or self.conn.rfile.bytes_read),
+ 'Bytes Written': lambda s: self.bytes_written + ((self.start_time is None) and 0 or self.conn.wfile.bytes_written),
+ 'Work Time': lambda s: self.work_time + ((self.start_time is None) and 0 or time.time() - self.start_time),
+ 'Read Throughput': lambda s: s['Bytes Read'](s) / (s['Work Time'](s) or 1e-6),
+ 'Write Throughput': lambda s: s['Bytes Written'](s) / (s['Work Time'](s) or 1e-6),
+ }
+ threading.Thread.__init__(self)
+
+ def run(self):
+ self.server.stats['Worker Threads'][self.getName()] = self.stats
+ try:
+ self.ready = True
+ while True:
+ conn = self.server.requests.get()
+ if conn is _SHUTDOWNREQUEST:
+ return
+
+ self.conn = conn
+ if self.server.stats['Enabled']:
+ self.start_time = time.time()
+ try:
+ conn.communicate()
+ finally:
+ conn.close()
+ if self.server.stats['Enabled']:
+ self.requests_seen += self.conn.requests_seen
+ self.bytes_read += self.conn.rfile.bytes_read
+ self.bytes_written += self.conn.wfile.bytes_written
+ self.work_time += time.time() - self.start_time
+ self.start_time = None
+ self.conn = None
+ except (KeyboardInterrupt, SystemExit), exc:
+ self.server.interrupt = exc
+
+
+class ThreadPool(object):
+ """A Request Queue for the CherryPyWSGIServer which pools threads.
+
+ ThreadPool objects must provide min, get(), put(obj), start()
+ and stop(timeout) attributes.
+ """
+
+ def __init__(self, server, min=10, max=-1):
+ self.server = server
+ self.min = min
+ self.max = max
+ self._threads = []
+ self._queue = Queue.Queue()
+ self.get = self._queue.get
+
+ def start(self):
+ """Start the pool of threads."""
+ for i in range(self.min):
+ self._threads.append(WorkerThread(self.server))
+ for worker in self._threads:
+ worker.setName("CP Server " + worker.getName())
+ worker.start()
+ for worker in self._threads:
+ while not worker.ready:
+ time.sleep(.1)
+
+ def _get_idle(self):
+ """Number of worker threads which are idle. Read-only."""
+ return len([t for t in self._threads if t.conn is None])
+ idle = property(_get_idle, doc=_get_idle.__doc__)
+
+ def put(self, obj):
+ self._queue.put(obj)
+ if obj is _SHUTDOWNREQUEST:
+ return
+
+ def grow(self, amount):
+ """Spawn new worker threads (not above self.max)."""
+ for i in range(amount):
+ if self.max > 0 and len(self._threads) >= self.max:
+ break
+ worker = WorkerThread(self.server)
+ worker.setName("CP Server " + worker.getName())
+ self._threads.append(worker)
+ worker.start()
+
+ def shrink(self, amount):
+ """Kill off worker threads (not below self.min)."""
+ # Grow/shrink the pool if necessary.
+ # Remove any dead threads from our list
+ for t in self._threads:
+ if not t.isAlive():
+ self._threads.remove(t)
+ amount -= 1
+
+ if amount > 0:
+ for i in range(min(amount, len(self._threads) - self.min)):
+ # Put a number of shutdown requests on the queue equal
+ # to 'amount'. Once each of those is processed by a worker,
+ # that worker will terminate and be culled from our list
+ # in self.put.
+ self._queue.put(_SHUTDOWNREQUEST)
+
+ def stop(self, timeout=5):
+ # Must shut down threads here so the code that calls
+ # this method can know when all threads are stopped.
+ for worker in self._threads:
+ self._queue.put(_SHUTDOWNREQUEST)
+
+ # Don't join currentThread (when stop is called inside a request).
+ current = threading.currentThread()
+ if timeout and timeout >= 0:
+ endtime = time.time() + timeout
+ while self._threads:
+ worker = self._threads.pop()
+ if worker is not current and worker.isAlive():
+ try:
+ if timeout is None or timeout < 0:
+ worker.join()
+ else:
+ remaining_time = endtime - time.time()
+ if remaining_time > 0:
+ worker.join(remaining_time)
+ if worker.isAlive():
+ # We exhausted the timeout.
+ # Forcibly shut down the socket.
+ c = worker.conn
+ if c and not c.rfile.closed:
+ try:
+ c.socket.shutdown(socket.SHUT_RD)
+ except TypeError:
+ # pyOpenSSL sockets don't take an arg
+ c.socket.shutdown()
+ worker.join()
+ except (AssertionError,
+ # Ignore repeated Ctrl-C.
+ # See http://www.cherrypy.org/ticket/691.
+ KeyboardInterrupt), exc1:
+ pass
+
+ def _get_qsize(self):
+ return self._queue.qsize()
+ qsize = property(_get_qsize)
+
+
+
+try:
+ import fcntl
+except ImportError:
+ try:
+ from ctypes import windll, WinError
+ except ImportError:
+ def prevent_socket_inheritance(sock):
+ """Dummy function, since neither fcntl nor ctypes are available."""
+ pass
+ else:
+ def prevent_socket_inheritance(sock):
+ """Mark the given socket fd as non-inheritable (Windows)."""
+ if not windll.kernel32.SetHandleInformation(sock.fileno(), 1, 0):
+ raise WinError()
+else:
+ def prevent_socket_inheritance(sock):
+ """Mark the given socket fd as non-inheritable (POSIX)."""
+ fd = sock.fileno()
+ old_flags = fcntl.fcntl(fd, fcntl.F_GETFD)
+ fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC)
+
+
+class SSLAdapter(object):
+ """Base class for SSL driver library adapters.
+
+ Required methods:
+
+ * ``wrap(sock) -> (wrapped socket, ssl environ dict)``
+ * ``makefile(sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE) -> socket file object``
+ """
+
+ def __init__(self, certificate, private_key, certificate_chain=None):
+ self.certificate = certificate
+ self.private_key = private_key
+ self.certificate_chain = certificate_chain
+
+ def wrap(self, sock):
+ raise NotImplemented
+
+ def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE):
+ raise NotImplemented
+
+
+class HTTPServer(object):
+ """An HTTP server."""
+
+ _bind_addr = "127.0.0.1"
+ _interrupt = None
+
+ gateway = None
+ """A Gateway instance."""
+
+ minthreads = None
+ """The minimum number of worker threads to create (default 10)."""
+
+ maxthreads = None
+ """The maximum number of worker threads to create (default -1 = no limit)."""
+
+ server_name = None
+ """The name of the server; defaults to socket.gethostname()."""
+
+ protocol = "HTTP/1.1"
+ """The version string to write in the Status-Line of all HTTP responses.
+
+ For example, "HTTP/1.1" is the default. This also limits the supported
+ features used in the response."""
+
+ request_queue_size = 5
+ """The 'backlog' arg to socket.listen(); max queued connections (default 5)."""
+
+ shutdown_timeout = 5
+ """The total time, in seconds, to wait for worker threads to cleanly exit."""
+
+ timeout = 10
+ """The timeout in seconds for accepted connections (default 10)."""
+
+ version = "CherryPy/3.2.0"
+ """A version string for the HTTPServer."""
+
+ software = None
+ """The value to set for the SERVER_SOFTWARE entry in the WSGI environ.
+
+ If None, this defaults to ``'%s Server' % self.version``."""
+
+ ready = False
+ """An internal flag which marks whether the socket is accepting connections."""
+
+ max_request_header_size = 0
+ """The maximum size, in bytes, for request headers, or 0 for no limit."""
+
+ max_request_body_size = 0
+ """The maximum size, in bytes, for request bodies, or 0 for no limit."""
+
+ nodelay = True
+ """If True (the default since 3.1), sets the TCP_NODELAY socket option."""
+
+ ConnectionClass = HTTPConnection
+ """The class to use for handling HTTP connections."""
+
+ ssl_adapter = None
+ """An instance of SSLAdapter (or a subclass).
+
+ You must have the corresponding SSL driver library installed."""
+
+ def __init__(self, bind_addr, gateway, minthreads=10, maxthreads=-1,
+ server_name=None):
+ self.bind_addr = bind_addr
+ self.gateway = gateway
+
+ self.requests = ThreadPool(self, min=minthreads or 1, max=maxthreads)
+
+ if not server_name:
+ server_name = socket.gethostname()
+ self.server_name = server_name
+ self.clear_stats()
+
+ def clear_stats(self):
+ self._start_time = None
+ self._run_time = 0
+ self.stats = {
+ 'Enabled': False,
+ 'Bind Address': lambda s: repr(self.bind_addr),
+ 'Run time': lambda s: (not s['Enabled']) and 0 or self.runtime(),
+ 'Accepts': 0,
+ 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(),
+ 'Queue': lambda s: getattr(self.requests, "qsize", None),
+ 'Threads': lambda s: len(getattr(self.requests, "_threads", [])),
+ 'Threads Idle': lambda s: getattr(self.requests, "idle", None),
+ 'Socket Errors': 0,
+ 'Requests': lambda s: (not s['Enabled']) and 0 or sum([w['Requests'](w) for w
+ in s['Worker Threads'].values()], 0),
+ 'Bytes Read': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Read'](w) for w
+ in s['Worker Threads'].values()], 0),
+ 'Bytes Written': lambda s: (not s['Enabled']) and 0 or sum([w['Bytes Written'](w) for w
+ in s['Worker Threads'].values()], 0),
+ 'Work Time': lambda s: (not s['Enabled']) and 0 or sum([w['Work Time'](w) for w
+ in s['Worker Threads'].values()], 0),
+ 'Read Throughput': lambda s: (not s['Enabled']) and 0 or sum(
+ [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6)
+ for w in s['Worker Threads'].values()], 0),
+ 'Write Throughput': lambda s: (not s['Enabled']) and 0 or sum(
+ [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6)
+ for w in s['Worker Threads'].values()], 0),
+ 'Worker Threads': {},
+ }
+ logging.statistics["CherryPy HTTPServer %d" % id(self)] = self.stats
+
+ def runtime(self):
+ if self._start_time is None:
+ return self._run_time
+ else:
+ return self._run_time + (time.time() - self._start_time)
+
+ def __str__(self):
+ return "%s.%s(%r)" % (self.__module__, self.__class__.__name__,
+ self.bind_addr)
+
+ def _get_bind_addr(self):
+ return self._bind_addr
+ def _set_bind_addr(self, value):
+ if isinstance(value, tuple) and value[0] in ('', None):
+ # Despite the socket module docs, using '' does not
+ # allow AI_PASSIVE to work. Passing None instead
+ # returns '0.0.0.0' like we want. In other words:
+ # host AI_PASSIVE result
+ # '' Y 192.168.x.y
+ # '' N 192.168.x.y
+ # None Y 0.0.0.0
+ # None N 127.0.0.1
+ # But since you can get the same effect with an explicit
+ # '0.0.0.0', we deny both the empty string and None as values.
+ raise ValueError("Host values of '' or None are not allowed. "
+ "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead "
+ "to listen on all active interfaces.")
+ self._bind_addr = value
+ bind_addr = property(_get_bind_addr, _set_bind_addr,
+ doc="""The interface on which to listen for connections.
+
+ For TCP sockets, a (host, port) tuple. Host values may be any IPv4
+ or IPv6 address, or any valid hostname. The string 'localhost' is a
+ synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6).
+ The string '0.0.0.0' is a special IPv4 entry meaning "any active
+ interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for
+ IPv6. The empty string or None are not allowed.
+
+ For UNIX sockets, supply the filename as a string.""")
+
+ def start(self):
+ """Run the server forever."""
+ # We don't have to trap KeyboardInterrupt or SystemExit here,
+ # because cherrpy.server already does so, calling self.stop() for us.
+ # If you're using this server with another framework, you should
+ # trap those exceptions in whatever code block calls start().
+ self._interrupt = None
+
+ if self.software is None:
+ self.software = "%s Server" % self.version
+
+ # SSL backward compatibility
+ if (self.ssl_adapter is None and
+ getattr(self, 'ssl_certificate', None) and
+ getattr(self, 'ssl_private_key', None)):
+ warnings.warn(
+ "SSL attributes are deprecated in CherryPy 3.2, and will "
+ "be removed in CherryPy 3.3. Use an ssl_adapter attribute "
+ "instead.",
+ DeprecationWarning
+ )
+ try:
+ from cherrypy.wsgiserver.ssl_pyopenssl import pyOpenSSLAdapter
+ except ImportError:
+ pass
+ else:
+ self.ssl_adapter = pyOpenSSLAdapter(
+ self.ssl_certificate, self.ssl_private_key,
+ getattr(self, 'ssl_certificate_chain', None))
+
+ # Select the appropriate socket
+ if isinstance(self.bind_addr, basestring):
+ # AF_UNIX socket
+
+ # So we can reuse the socket...
+ try: os.unlink(self.bind_addr)
+ except: pass
+
+ # So everyone can access the socket...
+ try: os.chmod(self.bind_addr, 0777)
+ except: pass
+
+ info = [(socket.AF_UNIX, socket.SOCK_STREAM, 0, "", self.bind_addr)]
+ else:
+ # AF_INET or AF_INET6 socket
+ # Get the correct address family for our host (allows IPv6 addresses)
+ host, port = self.bind_addr
+ try:
+ info = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
+ socket.SOCK_STREAM, 0, socket.AI_PASSIVE)
+ except socket.gaierror:
+ if ':' in self.bind_addr[0]:
+ info = [(socket.AF_INET6, socket.SOCK_STREAM,
+ 0, "", self.bind_addr + (0, 0))]
+ else:
+ info = [(socket.AF_INET, socket.SOCK_STREAM,
+ 0, "", self.bind_addr)]
+
+ self.socket = None
+ msg = "No socket could be created"
+ for res in info:
+ af, socktype, proto, canonname, sa = res
+ try:
+ self.bind(af, socktype, proto)
+ except socket.error:
+ if self.socket:
+ self.socket.close()
+ self.socket = None
+ continue
+ break
+ if not self.socket:
+ raise socket.error(msg)
+
+ # Timeout so KeyboardInterrupt can be caught on Win32
+ self.socket.settimeout(1)
+ self.socket.listen(self.request_queue_size)
+
+ # Create worker threads
+ self.requests.start()
+
+ self.ready = True
+ self._start_time = time.time()
+ while self.ready:
+ self.tick()
+ if self.interrupt:
+ while self.interrupt is True:
+ # Wait for self.stop() to complete. See _set_interrupt.
+ time.sleep(0.1)
+ if self.interrupt:
+ raise self.interrupt
+
+ def bind(self, family, type, proto=0):
+ """Create (or recreate) the actual socket object."""
+ self.socket = socket.socket(family, type, proto)
+ prevent_socket_inheritance(self.socket)
+ self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ if self.nodelay and not isinstance(self.bind_addr, str):
+ self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+ if self.ssl_adapter is not None:
+ self.socket = self.ssl_adapter.bind(self.socket)
+
+ # If listening on the IPV6 any address ('::' = IN6ADDR_ANY),
+ # activate dual-stack. See http://www.cherrypy.org/ticket/871.
+ if (hasattr(socket, 'AF_INET6') and family == socket.AF_INET6
+ and self.bind_addr[0] in ('::', '::0', '::0.0.0.0')):
+ try:
+ self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
+ except (AttributeError, socket.error):
+ # Apparently, the socket option is not available in
+ # this machine's TCP stack
+ pass
+
+ self.socket.bind(self.bind_addr)
+
+ def tick(self):
+ """Accept a new connection and put it on the Queue."""
+ try:
+ s, addr = self.socket.accept()
+ if self.stats['Enabled']:
+ self.stats['Accepts'] += 1
+ if not self.ready:
+ return
+
+ prevent_socket_inheritance(s)
+ if hasattr(s, 'settimeout'):
+ s.settimeout(self.timeout)
+
+ makefile = CP_fileobject
+ ssl_env = {}
+ # if ssl cert and key are set, we try to be a secure HTTP server
+ if self.ssl_adapter is not None:
+ try:
+ s, ssl_env = self.ssl_adapter.wrap(s)
+ except NoSSLError:
+ msg = ("The client sent a plain HTTP request, but "
+ "this server only speaks HTTPS on this port.")
+ buf = ["%s 400 Bad Request\r\n" % self.protocol,
+ "Content-Length: %s\r\n" % len(msg),
+ "Content-Type: text/plain\r\n\r\n",
+ msg]
+
+ wfile = CP_fileobject(s, "wb", DEFAULT_BUFFER_SIZE)
+ try:
+ wfile.sendall("".join(buf))
+ except socket.error, x:
+ if x.args[0] not in socket_errors_to_ignore:
+ raise
+ return
+ if not s:
+ return
+ makefile = self.ssl_adapter.makefile
+ # Re-apply our timeout since we may have a new socket object
+ if hasattr(s, 'settimeout'):
+ s.settimeout(self.timeout)
+
+ conn = self.ConnectionClass(self, s, makefile)
+
+ if not isinstance(self.bind_addr, basestring):
+ # optional values
+ # Until we do DNS lookups, omit REMOTE_HOST
+ if addr is None: # sometimes this can happen
+ # figure out if AF_INET or AF_INET6.
+ if len(s.getsockname()) == 2:
+ # AF_INET
+ addr = ('0.0.0.0', 0)
+ else:
+ # AF_INET6
+ addr = ('::', 0)
+ conn.remote_addr = addr[0]
+ conn.remote_port = addr[1]
+
+ conn.ssl_env = ssl_env
+
+ self.requests.put(conn)
+ except socket.timeout:
+ # The only reason for the timeout in start() is so we can
+ # notice keyboard interrupts on Win32, which don't interrupt
+ # accept() by default
+ return
+ except socket.error, x:
+ if self.stats['Enabled']:
+ self.stats['Socket Errors'] += 1
+ if x.args[0] in socket_error_eintr:
+ # I *think* this is right. EINTR should occur when a signal
+ # is received during the accept() call; all docs say retry
+ # the call, and I *think* I'm reading it right that Python
+ # will then go ahead and poll for and handle the signal
+ # elsewhere. See http://www.cherrypy.org/ticket/707.
+ return
+ if x.args[0] in socket_errors_nonblocking:
+ # Just try again. See http://www.cherrypy.org/ticket/479.
+ return
+ if x.args[0] in socket_errors_to_ignore:
+ # Our socket was closed.
+ # See http://www.cherrypy.org/ticket/686.
+ return
+ raise
+
+ def _get_interrupt(self):
+ return self._interrupt
+ def _set_interrupt(self, interrupt):
+ self._interrupt = True
+ self.stop()
+ self._interrupt = interrupt
+ interrupt = property(_get_interrupt, _set_interrupt,
+ doc="Set this to an Exception instance to "
+ "interrupt the server.")
+
+ def stop(self):
+ """Gracefully shutdown a server that is serving forever."""
+ self.ready = False
+ if self._start_time is not None:
+ self._run_time += (time.time() - self._start_time)
+ self._start_time = None
+
+ sock = getattr(self, "socket", None)
+ if sock:
+ if not isinstance(self.bind_addr, basestring):
+ # Touch our own socket to make accept() return immediately.
+ try:
+ host, port = sock.getsockname()[:2]
+ except socket.error, x:
+ if x.args[0] not in socket_errors_to_ignore:
+ # Changed to use error code and not message
+ # See http://www.cherrypy.org/ticket/860.
+ raise
+ else:
+ # Note that we're explicitly NOT using AI_PASSIVE,
+ # here, because we want an actual IP to touch.
+ # localhost won't work if we've bound to a public IP,
+ # but it will if we bound to '0.0.0.0' (INADDR_ANY).
+ for res in socket.getaddrinfo(host, port, socket.AF_UNSPEC,
+ socket.SOCK_STREAM):
+ af, socktype, proto, canonname, sa = res
+ s = None
+ try:
+ s = socket.socket(af, socktype, proto)
+ # See http://groups.google.com/group/cherrypy-users/
+ # browse_frm/thread/bbfe5eb39c904fe0
+ s.settimeout(1.0)
+ s.connect((host, port))
+ s.close()
+ except socket.error:
+ if s:
+ s.close()
+ if hasattr(sock, "close"):
+ sock.close()
+ self.socket = None
+
+ self.requests.stop(self.shutdown_timeout)
+
+
+class Gateway(object):
+
+ def __init__(self, req):
+ self.req = req
+
+ def respond(self):
+ raise NotImplemented
+
+
+# These may either be wsgiserver.SSLAdapter subclasses or the string names
+# of such classes (in which case they will be lazily loaded).
+ssl_adapters = {
+ 'builtin': 'cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter',
+ 'pyopenssl': 'cherrypy.wsgiserver.ssl_pyopenssl.pyOpenSSLAdapter',
+ }
+
+def get_ssl_adapter_class(name='pyopenssl'):
+ adapter = ssl_adapters[name.lower()]
+ if isinstance(adapter, basestring):
+ last_dot = adapter.rfind(".")
+ attr_name = adapter[last_dot + 1:]
+ mod_path = adapter[:last_dot]
+
+ try:
+ mod = sys.modules[mod_path]
+ if mod is None:
+ raise KeyError()
+ except KeyError:
+ # The last [''] is important.
+ mod = __import__(mod_path, globals(), locals(), [''])
+
+ # Let an AttributeError propagate outward.
+ try:
+ adapter = getattr(mod, attr_name)
+ except AttributeError:
+ raise AttributeError("'%s' object has no attribute '%s'"
+ % (mod_path, attr_name))
+
+ return adapter
+
+# -------------------------------- WSGI Stuff -------------------------------- #
+
+
+class CherryPyWSGIServer(HTTPServer):
+
+ wsgi_version = (1, 0)
+
+ def __init__(self, bind_addr, wsgi_app, numthreads=10, server_name=None,
+ max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5):
+ self.requests = ThreadPool(self, min=numthreads or 1, max=max)
+ self.wsgi_app = wsgi_app
+ self.gateway = wsgi_gateways[self.wsgi_version]
+
+ self.bind_addr = bind_addr
+ if not server_name:
+ server_name = socket.gethostname()
+ self.server_name = server_name
+ self.request_queue_size = request_queue_size
+
+ self.timeout = timeout
+ self.shutdown_timeout = shutdown_timeout
+ self.clear_stats()
+
+ def _get_numthreads(self):
+ return self.requests.min
+ def _set_numthreads(self, value):
+ self.requests.min = value
+ numthreads = property(_get_numthreads, _set_numthreads)
+
+
+class WSGIGateway(Gateway):
+
+ def __init__(self, req):
+ self.req = req
+ self.started_response = False
+ self.env = self.get_environ()
+ self.remaining_bytes_out = None
+
+ def get_environ(self):
+ """Return a new environ dict targeting the given wsgi.version"""
+ raise NotImplemented
+
+ def respond(self):
+ response = self.req.server.wsgi_app(self.env, self.start_response)
+ try:
+ for chunk in response:
+ # "The start_response callable must not actually transmit
+ # the response headers. Instead, it must store them for the
+ # server or gateway to transmit only after the first
+ # iteration of the application return value that yields
+ # a NON-EMPTY string, or upon the application's first
+ # invocation of the write() callable." (PEP 333)
+ if chunk:
+ if isinstance(chunk, unicode):
+ chunk = chunk.encode('ISO-8859-1')
+ self.write(chunk)
+ finally:
+ if hasattr(response, "close"):
+ response.close()
+
+ def start_response(self, status, headers, exc_info = None):
+ """WSGI callable to begin the HTTP response."""
+ # "The application may call start_response more than once,
+ # if and only if the exc_info argument is provided."
+ if self.started_response and not exc_info:
+ raise AssertionError("WSGI start_response called a second "
+ "time with no exc_info.")
+ self.started_response = True
+
+ # "if exc_info is provided, and the HTTP headers have already been
+ # sent, start_response must raise an error, and should raise the
+ # exc_info tuple."
+ if self.req.sent_headers:
+ try:
+ raise exc_info[0], exc_info[1], exc_info[2]
+ finally:
+ exc_info = None
+
+ self.req.status = status
+ for k, v in headers:
+ if not isinstance(k, str):
+ raise TypeError("WSGI response header key %r is not a byte string." % k)
+ if not isinstance(v, str):
+ raise TypeError("WSGI response header value %r is not a byte string." % v)
+ if k.lower() == 'content-length':
+ self.remaining_bytes_out = int(v)
+ self.req.outheaders.extend(headers)
+
+ return self.write
+
+ def write(self, chunk):
+ """WSGI callable to write unbuffered data to the client.
+
+ This method is also used internally by start_response (to write
+ data from the iterable returned by the WSGI application).
+ """
+ if not self.started_response:
+ raise AssertionError("WSGI write called before start_response.")
+
+ chunklen = len(chunk)
+ rbo = self.remaining_bytes_out
+ if rbo is not None and chunklen > rbo:
+ if not self.req.sent_headers:
+ # Whew. We can send a 500 to the client.
+ self.req.simple_response("500 Internal Server Error",
+ "The requested resource returned more bytes than the "
+ "declared Content-Length.")
+ else:
+ # Dang. We have probably already sent data. Truncate the chunk
+ # to fit (so the client doesn't hang) and raise an error later.
+ chunk = chunk[:rbo]
+
+ if not self.req.sent_headers:
+ self.req.sent_headers = True
+ self.req.send_headers()
+
+ self.req.write(chunk)
+
+ if rbo is not None:
+ rbo -= chunklen
+ if rbo < 0:
+ raise ValueError(
+ "Response body exceeds the declared Content-Length.")
+
+
+class WSGIGateway_10(WSGIGateway):
+
+ def get_environ(self):
+ """Return a new environ dict targeting the given wsgi.version"""
+ req = self.req
+ env = {
+ # set a non-standard environ entry so the WSGI app can know what
+ # the *real* server protocol is (and what features to support).
+ # See http://www.faqs.org/rfcs/rfc2145.html.
+ 'ACTUAL_SERVER_PROTOCOL': req.server.protocol,
+ 'PATH_INFO': req.path,
+ 'QUERY_STRING': req.qs,
+ 'REMOTE_ADDR': req.conn.remote_addr or '',
+ 'REMOTE_PORT': str(req.conn.remote_port or ''),
+ 'REQUEST_METHOD': req.method,
+ 'REQUEST_URI': req.uri,
+ 'SCRIPT_NAME': '',
+ 'SERVER_NAME': req.server.server_name,
+ # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol.
+ 'SERVER_PROTOCOL': req.request_protocol,
+ 'SERVER_SOFTWARE': req.server.software,
+ 'wsgi.errors': sys.stderr,
+ 'wsgi.input': req.rfile,
+ 'wsgi.multiprocess': False,
+ 'wsgi.multithread': True,
+ 'wsgi.run_once': False,
+ 'wsgi.url_scheme': req.scheme,
+ 'wsgi.version': (1, 0),
+ }
+
+ if isinstance(req.server.bind_addr, basestring):
+ # AF_UNIX. This isn't really allowed by WSGI, which doesn't
+ # address unix domain sockets. But it's better than nothing.
+ env["SERVER_PORT"] = ""
+ else:
+ env["SERVER_PORT"] = str(req.server.bind_addr[1])
+
+ # Request headers
+ for k, v in req.inheaders.iteritems():
+ env["HTTP_" + k.upper().replace("-", "_")] = v
+
+ # CONTENT_TYPE/CONTENT_LENGTH
+ ct = env.pop("HTTP_CONTENT_TYPE", None)
+ if ct is not None:
+ env["CONTENT_TYPE"] = ct
+ cl = env.pop("HTTP_CONTENT_LENGTH", None)
+ if cl is not None:
+ env["CONTENT_LENGTH"] = cl
+
+ if req.conn.ssl_env:
+ env.update(req.conn.ssl_env)
+
+ return env
+
+
+class WSGIGateway_u0(WSGIGateway_10):
+
+ def get_environ(self):
+ """Return a new environ dict targeting the given wsgi.version"""
+ req = self.req
+ env_10 = WSGIGateway_10.get_environ(self)
+ env = dict([(k.decode('ISO-8859-1'), v) for k, v in env_10.iteritems()])
+ env[u'wsgi.version'] = ('u', 0)
+
+ # Request-URI
+ env.setdefault(u'wsgi.url_encoding', u'utf-8')
+ try:
+ for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]:
+ env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding'])
+ except UnicodeDecodeError:
+ # Fall back to latin 1 so apps can transcode if needed.
+ env[u'wsgi.url_encoding'] = u'ISO-8859-1'
+ for key in [u"PATH_INFO", u"SCRIPT_NAME", u"QUERY_STRING"]:
+ env[key] = env_10[str(key)].decode(env[u'wsgi.url_encoding'])
+
+ for k, v in sorted(env.items()):
+ if isinstance(v, str) and k not in ('REQUEST_URI', 'wsgi.input'):
+ env[k] = v.decode('ISO-8859-1')
+
+ return env
+
+wsgi_gateways = {
+ (1, 0): WSGIGateway_10,
+ ('u', 0): WSGIGateway_u0,
+}
+
+class WSGIPathInfoDispatcher(object):
+ """A WSGI dispatcher for dispatch based on the PATH_INFO.
+
+ apps: a dict or list of (path_prefix, app) pairs.
+ """
+
+ def __init__(self, apps):
+ try:
+ apps = apps.items()
+ except AttributeError:
+ pass
+
+ # Sort the apps by len(path), descending
+ apps.sort(cmp=lambda x,y: cmp(len(x[0]), len(y[0])))
+ apps.reverse()
+
+ # The path_prefix strings must start, but not end, with a slash.
+ # Use "" instead of "/".
+ self.apps = [(p.rstrip("/"), a) for p, a in apps]
+
+ def __call__(self, environ, start_response):
+ path = environ["PATH_INFO"] or "/"
+ for p, app in self.apps:
+ # The apps list should be sorted by length, descending.
+ if path.startswith(p + "/") or path == p:
+ environ = environ.copy()
+ environ["SCRIPT_NAME"] = environ["SCRIPT_NAME"] + p
+ environ["PATH_INFO"] = path[len(p):]
+ return app(environ, start_response)
+
+ start_response('404 Not Found', [('Content-Type', 'text/plain'),
+ ('Content-Length', '0')])
+ return ['']
+
diff --git a/web/wsgiserver/ssl_builtin.py b/web/wsgiserver/ssl_builtin.py
index 64c0eeb..c488a3e 100644
--- a/web/wsgiserver/ssl_builtin.py
+++ b/web/wsgiserver/ssl_builtin.py
@@ -1,72 +1,72 @@
-"""A library for integrating Python's builtin ``ssl`` library with CherryPy.
-
-The ssl module must be importable for SSL functionality.
-
-To use this module, set ``CherryPyWSGIServer.ssl_adapter`` to an instance of
-``BuiltinSSLAdapter``.
-"""
-
-try:
- import ssl
-except ImportError:
- ssl = None
-
-from cherrypy import wsgiserver
-
-
-class BuiltinSSLAdapter(wsgiserver.SSLAdapter):
- """A wrapper for integrating Python's builtin ssl module with CherryPy."""
-
- certificate = None
- """The filename of the server SSL certificate."""
-
- private_key = None
- """The filename of the server's private key file."""
-
- def __init__(self, certificate, private_key, certificate_chain=None):
- if ssl is None:
- raise ImportError("You must install the ssl module to use HTTPS.")
- self.certificate = certificate
- self.private_key = private_key
- self.certificate_chain = certificate_chain
-
- def bind(self, sock):
- """Wrap and return the given socket."""
- return sock
-
- def wrap(self, sock):
- """Wrap and return the given socket, plus WSGI environ entries."""
- try:
- s = ssl.wrap_socket(sock, do_handshake_on_connect=True,
- server_side=True, certfile=self.certificate,
- keyfile=self.private_key, ssl_version=ssl.PROTOCOL_SSLv23)
- except ssl.SSLError, e:
- if e.errno == ssl.SSL_ERROR_EOF:
- # This is almost certainly due to the cherrypy engine
- # 'pinging' the socket to assert it's connectable;
- # the 'ping' isn't SSL.
- return None, {}
- elif e.errno == ssl.SSL_ERROR_SSL:
- if e.args[1].endswith('http request'):
- # The client is speaking HTTP to an HTTPS server.
- raise wsgiserver.NoSSLError
- raise
- return s, self.get_environ(s)
-
- # TODO: fill this out more with mod ssl env
- def get_environ(self, sock):
- """Create WSGI environ entries to be merged into each request."""
- cipher = sock.cipher()
- ssl_environ = {
- "wsgi.url_scheme": "https",
- "HTTPS": "on",
- 'SSL_PROTOCOL': cipher[1],
- 'SSL_CIPHER': cipher[0]
-## SSL_VERSION_INTERFACE string The mod_ssl program version
-## SSL_VERSION_LIBRARY string The OpenSSL program version
- }
- return ssl_environ
-
- def makefile(self, sock, mode='r', bufsize=-1):
- return wsgiserver.CP_fileobject(sock, mode, bufsize)
-
+"""A library for integrating Python's builtin ``ssl`` library with CherryPy.
+
+The ssl module must be importable for SSL functionality.
+
+To use this module, set ``CherryPyWSGIServer.ssl_adapter`` to an instance of
+``BuiltinSSLAdapter``.
+"""
+
+try:
+ import ssl
+except ImportError:
+ ssl = None
+
+from cherrypy import wsgiserver
+
+
+class BuiltinSSLAdapter(wsgiserver.SSLAdapter):
+ """A wrapper for integrating Python's builtin ssl module with CherryPy."""
+
+ certificate = None
+ """The filename of the server SSL certificate."""
+
+ private_key = None
+ """The filename of the server's private key file."""
+
+ def __init__(self, certificate, private_key, certificate_chain=None):
+ if ssl is None:
+ raise ImportError("You must install the ssl module to use HTTPS.")
+ self.certificate = certificate
+ self.private_key = private_key
+ self.certificate_chain = certificate_chain
+
+ def bind(self, sock):
+ """Wrap and return the given socket."""
+ return sock
+
+ def wrap(self, sock):
+ """Wrap and return the given socket, plus WSGI environ entries."""
+ try:
+ s = ssl.wrap_socket(sock, do_handshake_on_connect=True,
+ server_side=True, certfile=self.certificate,
+ keyfile=self.private_key, ssl_version=ssl.PROTOCOL_SSLv23)
+ except ssl.SSLError, e:
+ if e.errno == ssl.SSL_ERROR_EOF:
+ # This is almost certainly due to the cherrypy engine
+ # 'pinging' the socket to assert it's connectable;
+ # the 'ping' isn't SSL.
+ return None, {}
+ elif e.errno == ssl.SSL_ERROR_SSL:
+ if e.args[1].endswith('http request'):
+ # The client is speaking HTTP to an HTTPS server.
+ raise wsgiserver.NoSSLError
+ raise
+ return s, self.get_environ(s)
+
+ # TODO: fill this out more with mod ssl env
+ def get_environ(self, sock):
+ """Create WSGI environ entries to be merged into each request."""
+ cipher = sock.cipher()
+ ssl_environ = {
+ "wsgi.url_scheme": "https",
+ "HTTPS": "on",
+ 'SSL_PROTOCOL': cipher[1],
+ 'SSL_CIPHER': cipher[0]
+## SSL_VERSION_INTERFACE string The mod_ssl program version
+## SSL_VERSION_LIBRARY string The OpenSSL program version
+ }
+ return ssl_environ
+
+ def makefile(self, sock, mode='r', bufsize=-1):
+ return wsgiserver.CP_fileobject(sock, mode, bufsize)
+
diff --git a/web/wsgiserver/ssl_pyopenssl.py b/web/wsgiserver/ssl_pyopenssl.py
index f3d9bf5..c6d7401 100644
--- a/web/wsgiserver/ssl_pyopenssl.py
+++ b/web/wsgiserver/ssl_pyopenssl.py
@@ -1,256 +1,256 @@
-"""A library for integrating pyOpenSSL with CherryPy.
-
-The OpenSSL module must be importable for SSL functionality.
-You can obtain it from http://pyopenssl.sourceforge.net/
-
-To use this module, set CherryPyWSGIServer.ssl_adapter to an instance of
-SSLAdapter. There are two ways to use SSL:
-
-Method One
-----------
-
- * ``ssl_adapter.context``: an instance of SSL.Context.
-
-If this is not None, it is assumed to be an SSL.Context instance,
-and will be passed to SSL.Connection on bind(). The developer is
-responsible for forming a valid Context object. This approach is
-to be preferred for more flexibility, e.g. if the cert and key are
-streams instead of files, or need decryption, or SSL.SSLv3_METHOD
-is desired instead of the default SSL.SSLv23_METHOD, etc. Consult
-the pyOpenSSL documentation for complete options.
-
-Method Two (shortcut)
----------------------
-
- * ``ssl_adapter.certificate``: the filename of the server SSL certificate.
- * ``ssl_adapter.private_key``: the filename of the server's private key file.
-
-Both are None by default. If ssl_adapter.context is None, but .private_key
-and .certificate are both given and valid, they will be read, and the
-context will be automatically created from them.
-"""
-
-import socket
-import threading
-import time
-
-from cherrypy import wsgiserver
-
-try:
- from OpenSSL import SSL
- from OpenSSL import crypto
-except ImportError:
- SSL = None
-
-
-class SSL_fileobject(wsgiserver.CP_fileobject):
- """SSL file object attached to a socket object."""
-
- ssl_timeout = 3
- ssl_retry = .01
-
- def _safe_call(self, is_reader, call, *args, **kwargs):
- """Wrap the given call with SSL error-trapping.
-
- is_reader: if False EOF errors will be raised. If True, EOF errors
- will return "" (to emulate normal sockets).
- """
- start = time.time()
- while True:
- try:
- return call(*args, **kwargs)
- except SSL.WantReadError:
- # Sleep and try again. This is dangerous, because it means
- # the rest of the stack has no way of differentiating
- # between a "new handshake" error and "client dropped".
- # Note this isn't an endless loop: there's a timeout below.
- time.sleep(self.ssl_retry)
- except SSL.WantWriteError:
- time.sleep(self.ssl_retry)
- except SSL.SysCallError, e:
- if is_reader and e.args == (-1, 'Unexpected EOF'):
- return ""
-
- errnum = e.args[0]
- if is_reader and errnum in wsgiserver.socket_errors_to_ignore:
- return ""
- raise socket.error(errnum)
- except SSL.Error, e:
- if is_reader and e.args == (-1, 'Unexpected EOF'):
- return ""
-
- thirdarg = None
- try:
- thirdarg = e.args[0][0][2]
- except IndexError:
- pass
-
- if thirdarg == 'http request':
- # The client is talking HTTP to an HTTPS server.
- raise wsgiserver.NoSSLError()
-
- raise wsgiserver.FatalSSLAlert(*e.args)
- except:
- raise
-
- if time.time() - start > self.ssl_timeout:
- raise socket.timeout("timed out")
-
- def recv(self, *args, **kwargs):
- buf = []
- r = super(SSL_fileobject, self).recv
- while True:
- data = self._safe_call(True, r, *args, **kwargs)
- buf.append(data)
- p = self._sock.pending()
- if not p:
- return "".join(buf)
-
- def sendall(self, *args, **kwargs):
- return self._safe_call(False, super(SSL_fileobject, self).sendall,
- *args, **kwargs)
-
- def send(self, *args, **kwargs):
- return self._safe_call(False, super(SSL_fileobject, self).send,
- *args, **kwargs)
-
-
-class SSLConnection:
- """A thread-safe wrapper for an SSL.Connection.
-
- ``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``.
- """
-
- def __init__(self, *args):
- self._ssl_conn = SSL.Connection(*args)
- self._lock = threading.RLock()
-
- for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read',
- 'renegotiate', 'bind', 'listen', 'connect', 'accept',
- 'setblocking', 'fileno', 'close', 'get_cipher_list',
- 'getpeername', 'getsockname', 'getsockopt', 'setsockopt',
- 'makefile', 'get_app_data', 'set_app_data', 'state_string',
- 'sock_shutdown', 'get_peer_certificate', 'want_read',
- 'want_write', 'set_connect_state', 'set_accept_state',
- 'connect_ex', 'sendall', 'settimeout', 'gettimeout'):
- exec("""def %s(self, *args):
- self._lock.acquire()
- try:
- return self._ssl_conn.%s(*args)
- finally:
- self._lock.release()
-""" % (f, f))
-
- def shutdown(self, *args):
- self._lock.acquire()
- try:
- # pyOpenSSL.socket.shutdown takes no args
- return self._ssl_conn.shutdown()
- finally:
- self._lock.release()
-
-
-class pyOpenSSLAdapter(wsgiserver.SSLAdapter):
- """A wrapper for integrating pyOpenSSL with CherryPy."""
-
- context = None
- """An instance of SSL.Context."""
-
- certificate = None
- """The filename of the server SSL certificate."""
-
- private_key = None
- """The filename of the server's private key file."""
-
- certificate_chain = None
- """Optional. The filename of CA's intermediate certificate bundle.
-
- This is needed for cheaper "chained root" SSL certificates, and should be
- left as None if not required."""
-
- def __init__(self, certificate, private_key, certificate_chain=None):
- if SSL is None:
- raise ImportError("You must install pyOpenSSL to use HTTPS.")
-
- self.context = None
- self.certificate = certificate
- self.private_key = private_key
- self.certificate_chain = certificate_chain
- self._environ = None
-
- def bind(self, sock):
- """Wrap and return the given socket."""
- if self.context is None:
- self.context = self.get_context()
- conn = SSLConnection(self.context, sock)
- self._environ = self.get_environ()
- return conn
-
- def wrap(self, sock):
- """Wrap and return the given socket, plus WSGI environ entries."""
- return sock, self._environ.copy()
-
- def get_context(self):
- """Return an SSL.Context from self attributes."""
- # See http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/442473
- c = SSL.Context(SSL.SSLv23_METHOD)
- c.use_privatekey_file(self.private_key)
- if self.certificate_chain:
- c.load_verify_locations(self.certificate_chain)
- c.use_certificate_file(self.certificate)
- return c
-
- def get_environ(self):
- """Return WSGI environ entries to be merged into each request."""
- ssl_environ = {
- "HTTPS": "on",
- # pyOpenSSL doesn't provide access to any of these AFAICT
-## 'SSL_PROTOCOL': 'SSLv2',
-## SSL_CIPHER string The cipher specification name
-## SSL_VERSION_INTERFACE string The mod_ssl program version
-## SSL_VERSION_LIBRARY string The OpenSSL program version
- }
-
- if self.certificate:
- # Server certificate attributes
- cert = open(self.certificate, 'rb').read()
- cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
- ssl_environ.update({
- 'SSL_SERVER_M_VERSION': cert.get_version(),
- 'SSL_SERVER_M_SERIAL': cert.get_serial_number(),
-## 'SSL_SERVER_V_START': Validity of server's certificate (start time),
-## 'SSL_SERVER_V_END': Validity of server's certificate (end time),
- })
-
- for prefix, dn in [("I", cert.get_issuer()),
- ("S", cert.get_subject())]:
- # X509Name objects don't seem to have a way to get the
- # complete DN string. Use str() and slice it instead,
- # because str(dn) == ""
- dnstr = str(dn)[18:-2]
-
- wsgikey = 'SSL_SERVER_%s_DN' % prefix
- ssl_environ[wsgikey] = dnstr
-
- # The DN should be of the form: /k1=v1/k2=v2, but we must allow
- # for any value to contain slashes itself (in a URL).
- while dnstr:
- pos = dnstr.rfind("=")
- dnstr, value = dnstr[:pos], dnstr[pos + 1:]
- pos = dnstr.rfind("/")
- dnstr, key = dnstr[:pos], dnstr[pos + 1:]
- if key and value:
- wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key)
- ssl_environ[wsgikey] = value
-
- return ssl_environ
-
- def makefile(self, sock, mode='r', bufsize=-1):
- if SSL and isinstance(sock, SSL.ConnectionType):
- timeout = sock.gettimeout()
- f = SSL_fileobject(sock, mode, bufsize)
- f.ssl_timeout = timeout
- return f
- else:
- return wsgiserver.CP_fileobject(sock, mode, bufsize)
-
+"""A library for integrating pyOpenSSL with CherryPy.
+
+The OpenSSL module must be importable for SSL functionality.
+You can obtain it from http://pyopenssl.sourceforge.net/
+
+To use this module, set CherryPyWSGIServer.ssl_adapter to an instance of
+SSLAdapter. There are two ways to use SSL:
+
+Method One
+----------
+
+ * ``ssl_adapter.context``: an instance of SSL.Context.
+
+If this is not None, it is assumed to be an SSL.Context instance,
+and will be passed to SSL.Connection on bind(). The developer is
+responsible for forming a valid Context object. This approach is
+to be preferred for more flexibility, e.g. if the cert and key are
+streams instead of files, or need decryption, or SSL.SSLv3_METHOD
+is desired instead of the default SSL.SSLv23_METHOD, etc. Consult
+the pyOpenSSL documentation for complete options.
+
+Method Two (shortcut)
+---------------------
+
+ * ``ssl_adapter.certificate``: the filename of the server SSL certificate.
+ * ``ssl_adapter.private_key``: the filename of the server's private key file.
+
+Both are None by default. If ssl_adapter.context is None, but .private_key
+and .certificate are both given and valid, they will be read, and the
+context will be automatically created from them.
+"""
+
+import socket
+import threading
+import time
+
+from cherrypy import wsgiserver
+
+try:
+ from OpenSSL import SSL
+ from OpenSSL import crypto
+except ImportError:
+ SSL = None
+
+
+class SSL_fileobject(wsgiserver.CP_fileobject):
+ """SSL file object attached to a socket object."""
+
+ ssl_timeout = 3
+ ssl_retry = .01
+
+ def _safe_call(self, is_reader, call, *args, **kwargs):
+ """Wrap the given call with SSL error-trapping.
+
+ is_reader: if False EOF errors will be raised. If True, EOF errors
+ will return "" (to emulate normal sockets).
+ """
+ start = time.time()
+ while True:
+ try:
+ return call(*args, **kwargs)
+ except SSL.WantReadError:
+ # Sleep and try again. This is dangerous, because it means
+ # the rest of the stack has no way of differentiating
+ # between a "new handshake" error and "client dropped".
+ # Note this isn't an endless loop: there's a timeout below.
+ time.sleep(self.ssl_retry)
+ except SSL.WantWriteError:
+ time.sleep(self.ssl_retry)
+ except SSL.SysCallError, e:
+ if is_reader and e.args == (-1, 'Unexpected EOF'):
+ return ""
+
+ errnum = e.args[0]
+ if is_reader and errnum in wsgiserver.socket_errors_to_ignore:
+ return ""
+ raise socket.error(errnum)
+ except SSL.Error, e:
+ if is_reader and e.args == (-1, 'Unexpected EOF'):
+ return ""
+
+ thirdarg = None
+ try:
+ thirdarg = e.args[0][0][2]
+ except IndexError:
+ pass
+
+ if thirdarg == 'http request':
+ # The client is talking HTTP to an HTTPS server.
+ raise wsgiserver.NoSSLError()
+
+ raise wsgiserver.FatalSSLAlert(*e.args)
+ except:
+ raise
+
+ if time.time() - start > self.ssl_timeout:
+ raise socket.timeout("timed out")
+
+ def recv(self, *args, **kwargs):
+ buf = []
+ r = super(SSL_fileobject, self).recv
+ while True:
+ data = self._safe_call(True, r, *args, **kwargs)
+ buf.append(data)
+ p = self._sock.pending()
+ if not p:
+ return "".join(buf)
+
+ def sendall(self, *args, **kwargs):
+ return self._safe_call(False, super(SSL_fileobject, self).sendall,
+ *args, **kwargs)
+
+ def send(self, *args, **kwargs):
+ return self._safe_call(False, super(SSL_fileobject, self).send,
+ *args, **kwargs)
+
+
+class SSLConnection:
+ """A thread-safe wrapper for an SSL.Connection.
+
+ ``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``.
+ """
+
+ def __init__(self, *args):
+ self._ssl_conn = SSL.Connection(*args)
+ self._lock = threading.RLock()
+
+ for f in ('get_context', 'pending', 'send', 'write', 'recv', 'read',
+ 'renegotiate', 'bind', 'listen', 'connect', 'accept',
+ 'setblocking', 'fileno', 'close', 'get_cipher_list',
+ 'getpeername', 'getsockname', 'getsockopt', 'setsockopt',
+ 'makefile', 'get_app_data', 'set_app_data', 'state_string',
+ 'sock_shutdown', 'get_peer_certificate', 'want_read',
+ 'want_write', 'set_connect_state', 'set_accept_state',
+ 'connect_ex', 'sendall', 'settimeout', 'gettimeout'):
+ exec("""def %s(self, *args):
+ self._lock.acquire()
+ try:
+ return self._ssl_conn.%s(*args)
+ finally:
+ self._lock.release()
+""" % (f, f))
+
+ def shutdown(self, *args):
+ self._lock.acquire()
+ try:
+ # pyOpenSSL.socket.shutdown takes no args
+ return self._ssl_conn.shutdown()
+ finally:
+ self._lock.release()
+
+
+class pyOpenSSLAdapter(wsgiserver.SSLAdapter):
+ """A wrapper for integrating pyOpenSSL with CherryPy."""
+
+ context = None
+ """An instance of SSL.Context."""
+
+ certificate = None
+ """The filename of the server SSL certificate."""
+
+ private_key = None
+ """The filename of the server's private key file."""
+
+ certificate_chain = None
+ """Optional. The filename of CA's intermediate certificate bundle.
+
+ This is needed for cheaper "chained root" SSL certificates, and should be
+ left as None if not required."""
+
+ def __init__(self, certificate, private_key, certificate_chain=None):
+ if SSL is None:
+ raise ImportError("You must install pyOpenSSL to use HTTPS.")
+
+ self.context = None
+ self.certificate = certificate
+ self.private_key = private_key
+ self.certificate_chain = certificate_chain
+ self._environ = None
+
+ def bind(self, sock):
+ """Wrap and return the given socket."""
+ if self.context is None:
+ self.context = self.get_context()
+ conn = SSLConnection(self.context, sock)
+ self._environ = self.get_environ()
+ return conn
+
+ def wrap(self, sock):
+ """Wrap and return the given socket, plus WSGI environ entries."""
+ return sock, self._environ.copy()
+
+ def get_context(self):
+ """Return an SSL.Context from self attributes."""
+ # See http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/442473
+ c = SSL.Context(SSL.SSLv23_METHOD)
+ c.use_privatekey_file(self.private_key)
+ if self.certificate_chain:
+ c.load_verify_locations(self.certificate_chain)
+ c.use_certificate_file(self.certificate)
+ return c
+
+ def get_environ(self):
+ """Return WSGI environ entries to be merged into each request."""
+ ssl_environ = {
+ "HTTPS": "on",
+ # pyOpenSSL doesn't provide access to any of these AFAICT
+## 'SSL_PROTOCOL': 'SSLv2',
+## SSL_CIPHER string The cipher specification name
+## SSL_VERSION_INTERFACE string The mod_ssl program version
+## SSL_VERSION_LIBRARY string The OpenSSL program version
+ }
+
+ if self.certificate:
+ # Server certificate attributes
+ cert = open(self.certificate, 'rb').read()
+ cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
+ ssl_environ.update({
+ 'SSL_SERVER_M_VERSION': cert.get_version(),
+ 'SSL_SERVER_M_SERIAL': cert.get_serial_number(),
+## 'SSL_SERVER_V_START': Validity of server's certificate (start time),
+## 'SSL_SERVER_V_END': Validity of server's certificate (end time),
+ })
+
+ for prefix, dn in [("I", cert.get_issuer()),
+ ("S", cert.get_subject())]:
+ # X509Name objects don't seem to have a way to get the
+ # complete DN string. Use str() and slice it instead,
+ # because str(dn) == ""
+ dnstr = str(dn)[18:-2]
+
+ wsgikey = 'SSL_SERVER_%s_DN' % prefix
+ ssl_environ[wsgikey] = dnstr
+
+ # The DN should be of the form: /k1=v1/k2=v2, but we must allow
+ # for any value to contain slashes itself (in a URL).
+ while dnstr:
+ pos = dnstr.rfind("=")
+ dnstr, value = dnstr[:pos], dnstr[pos + 1:]
+ pos = dnstr.rfind("/")
+ dnstr, key = dnstr[:pos], dnstr[pos + 1:]
+ if key and value:
+ wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key)
+ ssl_environ[wsgikey] = value
+
+ return ssl_environ
+
+ def makefile(self, sock, mode='r', bufsize=-1):
+ if SSL and isinstance(sock, SSL.ConnectionType):
+ timeout = sock.gettimeout()
+ f = SSL_fileobject(sock, mode, bufsize)
+ f.ssl_timeout = timeout
+ return f
+ else:
+ return wsgiserver.CP_fileobject(sock, mode, bufsize)
+