merge conflicts

This commit is contained in:
Alexandre Van de Sande 2015-02-16 14:43:30 +01:00
commit 3068e2688d
769 changed files with 442356 additions and 2884 deletions

3
.gitignore vendored
View File

@ -17,3 +17,6 @@
*~
.project
.settings
cmd/ethereum/ethereum
cmd/mist/mist

View File

@ -12,8 +12,9 @@ install:
- if ! go get code.google.com/p/go.tools/cmd/cover; then go get golang.org/x/tools/cmd/cover; fi
- go get github.com/mattn/goveralls
- go get gopkg.in/check.v1
- DEPS=$(go list -f '{{.Imports}}' ./... | sed -e 's/\[//g' | sed -e 's/\]//g' | sed -e 's/C //g'); if [ "$DEPS" ]; then go get -d -v $DEPS; fi
- go get github.com/tools/godep
before_script:
- godep restore
- gofmt -l -w .
- goimports -l -w .
- golint .

View File

@ -1,27 +1,26 @@
FROM ubuntu:14.04
FROM ubuntu:14.04.1
## Environment setup
ENV HOME /root
ENV GOPATH /root/go
ENV PATH /golang/bin:/root/go/bin:/usr/local/go/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games
ENV PKG_CONFIG_PATH /opt/qt54/lib/pkgconfig
ENV PATH /root/go/bin:/usr/local/go/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games
RUN mkdir -p /root/go
ENV DEBIAN_FRONTEND noninteractive
## Install base dependencies
RUN apt-get update && apt-get upgrade -y
RUN apt-get install -y git mercurial build-essential software-properties-common pkg-config libgmp3-dev libreadline6-dev libpcre3-dev libpcre++-dev mesa-common-dev libglu1-mesa-dev
RUN apt-get install -y git mercurial build-essential software-properties-common wget pkg-config libgmp3-dev libreadline6-dev libpcre3-dev libpcre++-dev
## Install Qt5.4 dependencies from PPA
RUN add-apt-repository ppa:beineri/opt-qt54-trusty -y
RUN apt-get update -y
RUN apt-get install -y qt54quickcontrols qt54webengine
## Install Qt5.4
# RUN add-apt-repository ppa:beineri/opt-qt54-trusty -y
# RUN apt-get update -y
# RUN apt-get install -y qt54quickcontrols qt54webengine mesa-common-dev libglu1-mesa-dev
# ENV PKG_CONFIG_PATH /opt/qt54/lib/pkgconfig
## Build and install latest Go
RUN git clone https://go.googlesource.com/go golang
RUN cd golang && git checkout go1.4.1
RUN cd golang/src && ./make.bash && go version
# Install Golang
RUN wget https://storage.googleapis.com/golang/go1.4.1.linux-amd64.tar.gz
RUN tar -C /usr/local -xzf go*.tar.gz && go version
# this is a workaround, to make sure that docker's cache is invalidated whenever the git repo changes
ADD https://api.github.com/repos/ethereum/go-ethereum/git/refs/heads/develop file_does_not_exist

130
Godeps/Godeps.json generated Normal file
View File

@ -0,0 +1,130 @@
{
"ImportPath": "github.com/ethereum/go-ethereum",
"GoVersion": "go1.4",
"Packages": [
"./..."
],
"Deps": [
{
"ImportPath": "bitbucket.org/kardianos/osext",
"Comment": "null-13",
"Rev": "5d3ddcf53a508cc2f7404eaebf546ef2cb5cdb6e"
},
{
"ImportPath": "code.google.com/p/go-uuid/uuid",
"Comment": "null-12",
"Rev": "7dda39b2e7d5e265014674c5af696ba4186679e9"
},
{
"ImportPath": "code.google.com/p/go.crypto/pbkdf2",
"Comment": "null-236",
"Rev": "69e2a90ed92d03812364aeb947b7068dc42e561e"
},
{
"ImportPath": "code.google.com/p/go.crypto/ripemd160",
"Comment": "null-236",
"Rev": "69e2a90ed92d03812364aeb947b7068dc42e561e"
},
{
"ImportPath": "code.google.com/p/go.crypto/scrypt",
"Comment": "null-236",
"Rev": "69e2a90ed92d03812364aeb947b7068dc42e561e"
},
{
"ImportPath": "code.google.com/p/go.net/websocket",
"Comment": "null-173",
"Rev": "4231557d7c726df4cf9a4e8cdd8a417c8c200bdb"
},
{
"ImportPath": "code.google.com/p/snappy-go/snappy",
"Comment": "null-15",
"Rev": "12e4b4183793ac4b061921e7980845e750679fd0"
},
{
"ImportPath": "github.com/ethereum/serpent-go",
"Rev": "5767a0dbd759d313df3f404dadb7f98d7ab51443"
},
{
"ImportPath": "github.com/fjl/goupnp",
"Rev": "fa95df6feb61e136b499d01711fcd410ccaf20c1"
},
{
"ImportPath": "github.com/howeyc/fsnotify",
"Comment": "v0.9.0-11-g6b1ef89",
"Rev": "6b1ef893dc11e0447abda6da20a5203481878dda"
},
{
"ImportPath": "github.com/jackpal/go-nat-pmp",
"Rev": "a45aa3d54aef73b504e15eb71bea0e5565b5e6e1"
},
{
"ImportPath": "github.com/obscuren/ecies",
"Rev": "d899334bba7bf4a157cab19d8ad836dcb1de0c34"
},
{
"ImportPath": "github.com/obscuren/otto",
"Rev": "cf13cc4228c5e5ce0fe27a7aea90bc10091c4f19"
},
{
"ImportPath": "github.com/obscuren/qml",
"Rev": "807b51d4104231784fa5e336ccd26d61759a3cb2"
},
{
"ImportPath": "github.com/rakyll/globalconf",
"Rev": "415abc325023f1a00cd2d9fa512e0e71745791a2"
},
{
"ImportPath": "github.com/rakyll/goini",
"Rev": "907cca0f578a5316fb864ec6992dc3d9730ec58c"
},
{
"ImportPath": "github.com/robertkrimen/otto/ast",
"Rev": "dea31a3d392779af358ec41f77a07fcc7e9d04ba"
},
{
"ImportPath": "github.com/robertkrimen/otto/dbg",
"Rev": "dea31a3d392779af358ec41f77a07fcc7e9d04ba"
},
{
"ImportPath": "github.com/robertkrimen/otto/file",
"Rev": "dea31a3d392779af358ec41f77a07fcc7e9d04ba"
},
{
"ImportPath": "github.com/robertkrimen/otto/parser",
"Rev": "dea31a3d392779af358ec41f77a07fcc7e9d04ba"
},
{
"ImportPath": "github.com/robertkrimen/otto/registry",
"Rev": "dea31a3d392779af358ec41f77a07fcc7e9d04ba"
},
{
"ImportPath": "github.com/robertkrimen/otto/token",
"Rev": "dea31a3d392779af358ec41f77a07fcc7e9d04ba"
},
{
"ImportPath": "github.com/syndtr/goleveldb/leveldb",
"Rev": "832fa7ed4d28545eab80f19e1831fc004305cade"
},
{
"ImportPath": "golang.org/x/crypto/pbkdf2",
"Rev": "4ed45ec682102c643324fae5dff8dab085b6c300"
},
{
"ImportPath": "gopkg.in/check.v1",
"Rev": "64131543e7896d5bcc6bd5a76287eb75ea96c673"
},
{
"ImportPath": "gopkg.in/fatih/set.v0",
"Comment": "v0.1.0-3-g27c4092",
"Rev": "27c40922c40b43fe04554d8223a402af3ea333f3"
},
{
"ImportPath": "gopkg.in/qml.v1/cdata",
"Rev": "1116cb9cd8dee23f8d444ded354eb53122739f99"
},
{
"ImportPath": "gopkg.in/qml.v1/gl/glbase",
"Rev": "1116cb9cd8dee23f8d444ded354eb53122739f99"
}
]
}

5
Godeps/Readme generated Normal file
View File

@ -0,0 +1,5 @@
This directory tree is generated automatically by godep.
Please do not edit.
See https://github.com/tools/godep for more information.

2
Godeps/_workspace/.gitignore generated vendored Normal file
View File

@ -0,0 +1,2 @@
/pkg
/bin

View File

@ -0,0 +1,20 @@
Copyright (c) 2012 Daniel Theophanes
This software is provided 'as-is', without any express or implied
warranty. In no event will the authors be held liable for any damages
arising from the use of this software.
Permission is granted to anyone to use this software for any purpose,
including commercial applications, and to alter it and redistribute it
freely, subject to the following restrictions:
1. The origin of this software must not be misrepresented; you must not
claim that you wrote the original software. If you use this software
in a product, an acknowledgment in the product documentation would be
appreciated but is not required.
2. Altered source versions must be plainly marked as such, and must not be
misrepresented as being the original software.
3. This notice may not be removed or altered from any source
distribution.

View File

@ -0,0 +1,32 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Extensions to the standard "os" package.
package osext
import "path/filepath"
// Executable returns an absolute path that can be used to
// re-invoke the current program.
// It may not be valid after the current program exits.
func Executable() (string, error) {
p, err := executable()
return filepath.Clean(p), err
}
// Returns same path as Executable, returns just the folder
// path. Excludes the executable name.
func ExecutableFolder() (string, error) {
p, err := Executable()
if err != nil {
return "", err
}
folder, _ := filepath.Split(p)
return folder, nil
}
// Depricated. Same as Executable().
func GetExePath() (exePath string, err error) {
return Executable()
}

View File

@ -0,0 +1,20 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package osext
import (
"syscall"
"os"
"strconv"
)
func executable() (string, error) {
f, err := os.Open("/proc/" + strconv.Itoa(os.Getpid()) + "/text")
if err != nil {
return "", err
}
defer f.Close()
return syscall.Fd2path(int(f.Fd()))
}

View File

@ -0,0 +1,25 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build linux netbsd openbsd
package osext
import (
"errors"
"os"
"runtime"
)
func executable() (string, error) {
switch runtime.GOOS {
case "linux":
return os.Readlink("/proc/self/exe")
case "netbsd":
return os.Readlink("/proc/curproc/exe")
case "openbsd":
return os.Readlink("/proc/curproc/file")
}
return "", errors.New("ExecPath not implemented for " + runtime.GOOS)
}

View File

@ -0,0 +1,79 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin freebsd
package osext
import (
"os"
"path/filepath"
"runtime"
"syscall"
"unsafe"
)
var initCwd, initCwdErr = os.Getwd()
func executable() (string, error) {
var mib [4]int32
switch runtime.GOOS {
case "freebsd":
mib = [4]int32{1 /* CTL_KERN */, 14 /* KERN_PROC */, 12 /* KERN_PROC_PATHNAME */, -1}
case "darwin":
mib = [4]int32{1 /* CTL_KERN */, 38 /* KERN_PROCARGS */, int32(os.Getpid()), -1}
}
n := uintptr(0)
// Get length.
_, _, errNum := syscall.Syscall6(syscall.SYS___SYSCTL, uintptr(unsafe.Pointer(&mib[0])), 4, 0, uintptr(unsafe.Pointer(&n)), 0, 0)
if errNum != 0 {
return "", errNum
}
if n == 0 { // This shouldn't happen.
return "", nil
}
buf := make([]byte, n)
_, _, errNum = syscall.Syscall6(syscall.SYS___SYSCTL, uintptr(unsafe.Pointer(&mib[0])), 4, uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&n)), 0, 0)
if errNum != 0 {
return "", errNum
}
if n == 0 { // This shouldn't happen.
return "", nil
}
for i, v := range buf {
if v == 0 {
buf = buf[:i]
break
}
}
var err error
execPath := string(buf)
// execPath will not be empty due to above checks.
// Try to get the absolute path if the execPath is not rooted.
if execPath[0] != '/' {
execPath, err = getAbs(execPath)
if err != nil {
return execPath, err
}
}
// For darwin KERN_PROCARGS may return the path to a symlink rather than the
// actual executable.
if runtime.GOOS == "darwin" {
if execPath, err = filepath.EvalSymlinks(execPath); err != nil {
return execPath, err
}
}
return execPath, nil
}
func getAbs(execPath string) (string, error) {
if initCwdErr != nil {
return execPath, initCwdErr
}
// The execPath may begin with a "../" or a "./" so clean it first.
// Join the two paths, trailing and starting slashes undetermined, so use
// the generic Join function.
return filepath.Join(initCwd, filepath.Clean(execPath)), nil
}

View File

@ -0,0 +1,79 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin linux freebsd netbsd windows
package osext
import (
"fmt"
"os"
oexec "os/exec"
"path/filepath"
"runtime"
"testing"
)
const execPath_EnvVar = "OSTEST_OUTPUT_EXECPATH"
func TestExecPath(t *testing.T) {
ep, err := Executable()
if err != nil {
t.Fatalf("ExecPath failed: %v", err)
}
// we want fn to be of the form "dir/prog"
dir := filepath.Dir(filepath.Dir(ep))
fn, err := filepath.Rel(dir, ep)
if err != nil {
t.Fatalf("filepath.Rel: %v", err)
}
cmd := &oexec.Cmd{}
// make child start with a relative program path
cmd.Dir = dir
cmd.Path = fn
// forge argv[0] for child, so that we can verify we could correctly
// get real path of the executable without influenced by argv[0].
cmd.Args = []string{"-", "-test.run=XXXX"}
cmd.Env = []string{fmt.Sprintf("%s=1", execPath_EnvVar)}
out, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("exec(self) failed: %v", err)
}
outs := string(out)
if !filepath.IsAbs(outs) {
t.Fatalf("Child returned %q, want an absolute path", out)
}
if !sameFile(outs, ep) {
t.Fatalf("Child returned %q, not the same file as %q", out, ep)
}
}
func sameFile(fn1, fn2 string) bool {
fi1, err := os.Stat(fn1)
if err != nil {
return false
}
fi2, err := os.Stat(fn2)
if err != nil {
return false
}
return os.SameFile(fi1, fi2)
}
func init() {
if e := os.Getenv(execPath_EnvVar); e != "" {
// first chdir to another path
dir := "/"
if runtime.GOOS == "windows" {
dir = filepath.VolumeName(".")
}
os.Chdir(dir)
if ep, err := Executable(); err != nil {
fmt.Fprint(os.Stderr, "ERROR: ", err)
} else {
fmt.Fprint(os.Stderr, ep)
}
os.Exit(0)
}
}

View File

@ -0,0 +1,34 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package osext
import (
"syscall"
"unicode/utf16"
"unsafe"
)
var (
kernel = syscall.MustLoadDLL("kernel32.dll")
getModuleFileNameProc = kernel.MustFindProc("GetModuleFileNameW")
)
// GetModuleFileName() with hModule = NULL
func executable() (exePath string, err error) {
return getModuleFileName()
}
func getModuleFileName() (string, error) {
var n uint32
b := make([]uint16, syscall.MAX_PATH)
size := uint32(len(b))
r0, _, e1 := getModuleFileNameProc.Call(0, uintptr(unsafe.Pointer(&b[0])), uintptr(size))
n = uint32(r0)
if n == 0 {
return "", e1
}
return string(utf16.Decode(b[0:n])), nil
}

View File

@ -0,0 +1,27 @@
Copyright (c) 2009 Google Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,84 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"fmt"
"os"
)
// A Domain represents a Version 2 domain
type Domain byte
// Domain constants for DCE Security (Version 2) UUIDs.
const (
Person = Domain(0)
Group = Domain(1)
Org = Domain(2)
)
// NewDCESecurity returns a DCE Security (Version 2) UUID.
//
// The domain should be one of Person, Group or Org.
// On a POSIX system the id should be the users UID for the Person
// domain and the users GID for the Group. The meaning of id for
// the domain Org or on non-POSIX systems is site defined.
//
// For a given domain/id pair the same token may be returned for up to
// 7 minutes and 10 seconds.
func NewDCESecurity(domain Domain, id uint32) UUID {
uuid := NewUUID()
if uuid != nil {
uuid[6] = (uuid[6] & 0x0f) | 0x20 // Version 2
uuid[9] = byte(domain)
binary.BigEndian.PutUint32(uuid[0:], id)
}
return uuid
}
// NewDCEPerson returns a DCE Security (Version 2) UUID in the person
// domain with the id returned by os.Getuid.
//
// NewDCEPerson(Person, uint32(os.Getuid()))
func NewDCEPerson() UUID {
return NewDCESecurity(Person, uint32(os.Getuid()))
}
// NewDCEGroup returns a DCE Security (Version 2) UUID in the group
// domain with the id returned by os.Getgid.
//
// NewDCEGroup(Group, uint32(os.Getgid()))
func NewDCEGroup() UUID {
return NewDCESecurity(Group, uint32(os.Getgid()))
}
// Domain returns the domain for a Version 2 UUID or false.
func (uuid UUID) Domain() (Domain, bool) {
if v, _ := uuid.Version(); v != 2 {
return 0, false
}
return Domain(uuid[9]), true
}
// Id returns the id for a Version 2 UUID or false.
func (uuid UUID) Id() (uint32, bool) {
if v, _ := uuid.Version(); v != 2 {
return 0, false
}
return binary.BigEndian.Uint32(uuid[0:4]), true
}
func (d Domain) String() string {
switch d {
case Person:
return "Person"
case Group:
return "Group"
case Org:
return "Org"
}
return fmt.Sprintf("Domain%d", int(d))
}

View File

@ -0,0 +1,8 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The uuid package generates and inspects UUIDs.
//
// UUIDs are based on RFC 4122 and DCE 1.1: Authentication and Security Services.
package uuid

View File

@ -0,0 +1,53 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"crypto/md5"
"crypto/sha1"
"hash"
)
// Well known Name Space IDs and UUIDs
var (
NameSpace_DNS = Parse("6ba7b810-9dad-11d1-80b4-00c04fd430c8")
NameSpace_URL = Parse("6ba7b811-9dad-11d1-80b4-00c04fd430c8")
NameSpace_OID = Parse("6ba7b812-9dad-11d1-80b4-00c04fd430c8")
NameSpace_X500 = Parse("6ba7b814-9dad-11d1-80b4-00c04fd430c8")
NIL = Parse("00000000-0000-0000-0000-000000000000")
)
// NewHash returns a new UUID dervied from the hash of space concatenated with
// data generated by h. The hash should be at least 16 byte in length. The
// first 16 bytes of the hash are used to form the UUID. The version of the
// UUID will be the lower 4 bits of version. NewHash is used to implement
// NewMD5 and NewSHA1.
func NewHash(h hash.Hash, space UUID, data []byte, version int) UUID {
h.Reset()
h.Write(space)
h.Write([]byte(data))
s := h.Sum(nil)
uuid := make([]byte, 16)
copy(uuid, s)
uuid[6] = (uuid[6] & 0x0f) | uint8((version&0xf)<<4)
uuid[8] = (uuid[8] & 0x3f) | 0x80 // RFC 4122 variant
return uuid
}
// NewMD5 returns a new MD5 (Version 3) UUID based on the
// supplied name space and data.
//
// NewHash(md5.New(), space, data, 3)
func NewMD5(space UUID, data []byte) UUID {
return NewHash(md5.New(), space, data, 3)
}
// NewSHA1 returns a new SHA1 (Version 5) UUID based on the
// supplied name space and data.
//
// NewHash(sha1.New(), space, data, 5)
func NewSHA1(space UUID, data []byte) UUID {
return NewHash(sha1.New(), space, data, 5)
}

View File

@ -0,0 +1,101 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import "net"
var (
interfaces []net.Interface // cached list of interfaces
ifname string // name of interface being used
nodeID []byte // hardware for version 1 UUIDs
)
// NodeInterface returns the name of the interface from which the NodeID was
// derived. The interface "user" is returned if the NodeID was set by
// SetNodeID.
func NodeInterface() string {
return ifname
}
// SetNodeInterface selects the hardware address to be used for Version 1 UUIDs.
// If name is "" then the first usable interface found will be used or a random
// Node ID will be generated. If a named interface cannot be found then false
// is returned.
//
// SetNodeInterface never fails when name is "".
func SetNodeInterface(name string) bool {
if interfaces == nil {
var err error
interfaces, err = net.Interfaces()
if err != nil && name != "" {
return false
}
}
for _, ifs := range interfaces {
if len(ifs.HardwareAddr) >= 6 && (name == "" || name == ifs.Name) {
if setNodeID(ifs.HardwareAddr) {
ifname = ifs.Name
return true
}
}
}
// We found no interfaces with a valid hardware address. If name
// does not specify a specific interface generate a random Node ID
// (section 4.1.6)
if name == "" {
if nodeID == nil {
nodeID = make([]byte, 6)
}
randomBits(nodeID)
return true
}
return false
}
// NodeID returns a slice of a copy of the current Node ID, setting the Node ID
// if not already set.
func NodeID() []byte {
if nodeID == nil {
SetNodeInterface("")
}
nid := make([]byte, 6)
copy(nid, nodeID)
return nid
}
// SetNodeID sets the Node ID to be used for Version 1 UUIDs. The first 6 bytes
// of id are used. If id is less than 6 bytes then false is returned and the
// Node ID is not set.
func SetNodeID(id []byte) bool {
if setNodeID(id) {
ifname = "user"
return true
}
return false
}
func setNodeID(id []byte) bool {
if len(id) < 6 {
return false
}
if nodeID == nil {
nodeID = make([]byte, 6)
}
copy(nodeID, id)
return true
}
// NodeID returns the 6 byte node id encoded in uuid. It returns nil if uuid is
// not valid. The NodeID is only well defined for version 1 and 2 UUIDs.
func (uuid UUID) NodeID() []byte {
if len(uuid) != 16 {
return nil
}
node := make([]byte, 6)
copy(node, uuid[10:])
return node
}

View File

@ -0,0 +1,132 @@
// Copyright 2014 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
"sync"
"time"
)
// A Time represents a time as the number of 100's of nanoseconds since 15 Oct
// 1582.
type Time int64
const (
lillian = 2299160 // Julian day of 15 Oct 1582
unix = 2440587 // Julian day of 1 Jan 1970
epoch = unix - lillian // Days between epochs
g1582 = epoch * 86400 // seconds between epochs
g1582ns100 = g1582 * 10000000 // 100s of a nanoseconds between epochs
)
var (
mu sync.Mutex
lasttime uint64 // last time we returned
clock_seq uint16 // clock sequence for this run
timeNow = time.Now // for testing
)
// UnixTime converts t the number of seconds and nanoseconds using the Unix
// epoch of 1 Jan 1970.
func (t Time) UnixTime() (sec, nsec int64) {
sec = int64(t - g1582ns100)
nsec = (sec % 10000000) * 100
sec /= 10000000
return sec, nsec
}
// GetTime returns the current Time (100s of nanoseconds since 15 Oct 1582) and
// adjusts the clock sequence as needed. An error is returned if the current
// time cannot be determined.
func GetTime() (Time, error) {
defer mu.Unlock()
mu.Lock()
return getTime()
}
func getTime() (Time, error) {
t := timeNow()
// If we don't have a clock sequence already, set one.
if clock_seq == 0 {
setClockSequence(-1)
}
now := uint64(t.UnixNano()/100) + g1582ns100
// If time has gone backwards with this clock sequence then we
// increment the clock sequence
if now <= lasttime {
clock_seq = ((clock_seq + 1) & 0x3fff) | 0x8000
}
lasttime = now
return Time(now), nil
}
// ClockSequence returns the current clock sequence, generating one if not
// already set. The clock sequence is only used for Version 1 UUIDs.
//
// The uuid package does not use global static storage for the clock sequence or
// the last time a UUID was generated. Unless SetClockSequence a new random
// clock sequence is generated the first time a clock sequence is requested by
// ClockSequence, GetTime, or NewUUID. (section 4.2.1.1) sequence is generated
// for
func ClockSequence() int {
defer mu.Unlock()
mu.Lock()
return clockSequence()
}
func clockSequence() int {
if clock_seq == 0 {
setClockSequence(-1)
}
return int(clock_seq & 0x3fff)
}
// SetClockSeq sets the clock sequence to the lower 14 bits of seq. Setting to
// -1 causes a new sequence to be generated.
func SetClockSequence(seq int) {
defer mu.Unlock()
mu.Lock()
setClockSequence(seq)
}
func setClockSequence(seq int) {
if seq == -1 {
var b [2]byte
randomBits(b[:]) // clock sequence
seq = int(b[0])<<8 | int(b[1])
}
old_seq := clock_seq
clock_seq = uint16(seq&0x3fff) | 0x8000 // Set our variant
if old_seq != clock_seq {
lasttime = 0
}
}
// Time returns the time in 100s of nanoseconds since 15 Oct 1582 encoded in
// uuid. It returns false if uuid is not valid. The time is only well defined
// for version 1 and 2 UUIDs.
func (uuid UUID) Time() (Time, bool) {
if len(uuid) != 16 {
return 0, false
}
time := int64(binary.BigEndian.Uint32(uuid[0:4]))
time |= int64(binary.BigEndian.Uint16(uuid[4:6])) << 32
time |= int64(binary.BigEndian.Uint16(uuid[6:8])&0xfff) << 48
return Time(time), true
}
// ClockSequence returns the clock sequence encoded in uuid. It returns false
// if uuid is not valid. The clock sequence is only well defined for version 1
// and 2 UUIDs.
func (uuid UUID) ClockSequence() (int, bool) {
if len(uuid) != 16 {
return 0, false
}
return int(binary.BigEndian.Uint16(uuid[8:10])) & 0x3fff, true
}

View File

@ -0,0 +1,43 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"io"
)
// randomBits completely fills slice b with random data.
func randomBits(b []byte) {
if _, err := io.ReadFull(rander, b); err != nil {
panic(err.Error()) // rand should never fail
}
}
// xvalues returns the value of a byte as a hexadecimal digit or 255.
var xvalues = []byte{
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 10, 11, 12, 13, 14, 15, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
}
// xtob converts the the first two hex bytes of x into a byte.
func xtob(x string) (byte, bool) {
b1 := xvalues[x[0]]
b2 := xvalues[x[1]]
return (b1 << 4) | b2, b1 != 255 && b2 != 255
}

View File

@ -0,0 +1,163 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"bytes"
"crypto/rand"
"fmt"
"io"
"strings"
)
// A UUID is a 128 bit (16 byte) Universal Unique IDentifier as defined in RFC
// 4122.
type UUID []byte
// A Version represents a UUIDs version.
type Version byte
// A Variant represents a UUIDs variant.
type Variant byte
// Constants returned by Variant.
const (
Invalid = Variant(iota) // Invalid UUID
RFC4122 // The variant specified in RFC4122
Reserved // Reserved, NCS backward compatibility.
Microsoft // Reserved, Microsoft Corporation backward compatibility.
Future // Reserved for future definition.
)
var rander = rand.Reader // random function
// New returns a new random (version 4) UUID as a string. It is a convenience
// function for NewRandom().String().
func New() string {
return NewRandom().String()
}
// Parse decodes s into a UUID or returns nil. Both the UUID form of
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded.
func Parse(s string) UUID {
if len(s) == 36+9 {
if strings.ToLower(s[:9]) != "urn:uuid:" {
return nil
}
s = s[9:]
} else if len(s) != 36 {
return nil
}
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return nil
}
uuid := make([]byte, 16)
for i, x := range []int{
0, 2, 4, 6,
9, 11,
14, 16,
19, 21,
24, 26, 28, 30, 32, 34} {
if v, ok := xtob(s[x:]); !ok {
return nil
} else {
uuid[i] = v
}
}
return uuid
}
// Equal returns true if uuid1 and uuid2 are equal.
func Equal(uuid1, uuid2 UUID) bool {
return bytes.Equal(uuid1, uuid2)
}
// String returns the string form of uuid, xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
// , or "" if uuid is invalid.
func (uuid UUID) String() string {
if uuid == nil || len(uuid) != 16 {
return ""
}
b := []byte(uuid)
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
b[:4], b[4:6], b[6:8], b[8:10], b[10:])
}
// URN returns the RFC 2141 URN form of uuid,
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx, or "" if uuid is invalid.
func (uuid UUID) URN() string {
if uuid == nil || len(uuid) != 16 {
return ""
}
b := []byte(uuid)
return fmt.Sprintf("urn:uuid:%08x-%04x-%04x-%04x-%012x",
b[:4], b[4:6], b[6:8], b[8:10], b[10:])
}
// Variant returns the variant encoded in uuid. It returns Invalid if
// uuid is invalid.
func (uuid UUID) Variant() Variant {
if len(uuid) != 16 {
return Invalid
}
switch {
case (uuid[8] & 0xc0) == 0x80:
return RFC4122
case (uuid[8] & 0xe0) == 0xc0:
return Microsoft
case (uuid[8] & 0xe0) == 0xe0:
return Future
default:
return Reserved
}
panic("unreachable")
}
// Version returns the verison of uuid. It returns false if uuid is not
// valid.
func (uuid UUID) Version() (Version, bool) {
if len(uuid) != 16 {
return 0, false
}
return Version(uuid[6] >> 4), true
}
func (v Version) String() string {
if v > 15 {
return fmt.Sprintf("BAD_VERSION_%d", v)
}
return fmt.Sprintf("VERSION_%d", v)
}
func (v Variant) String() string {
switch v {
case RFC4122:
return "RFC4122"
case Reserved:
return "Reserved"
case Microsoft:
return "Microsoft"
case Future:
return "Future"
case Invalid:
return "Invalid"
}
return fmt.Sprintf("BadVariant%d", int(v))
}
// SetRand sets the random number generator to r, which implents io.Reader.
// If r.Read returns an error when the package requests random data then
// a panic will be issued.
//
// Calling SetRand with nil sets the random number generator to the default
// generator.
func SetRand(r io.Reader) {
if r == nil {
rander = rand.Reader
return
}
rander = r
}

View File

@ -0,0 +1,390 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"bytes"
"fmt"
"os"
"strings"
"testing"
"time"
)
type test struct {
in string
version Version
variant Variant
isuuid bool
}
var tests = []test{
{"f47ac10b-58cc-0372-8567-0e02b2c3d479", 0, RFC4122, true},
{"f47ac10b-58cc-1372-8567-0e02b2c3d479", 1, RFC4122, true},
{"f47ac10b-58cc-2372-8567-0e02b2c3d479", 2, RFC4122, true},
{"f47ac10b-58cc-3372-8567-0e02b2c3d479", 3, RFC4122, true},
{"f47ac10b-58cc-4372-8567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-5372-8567-0e02b2c3d479", 5, RFC4122, true},
{"f47ac10b-58cc-6372-8567-0e02b2c3d479", 6, RFC4122, true},
{"f47ac10b-58cc-7372-8567-0e02b2c3d479", 7, RFC4122, true},
{"f47ac10b-58cc-8372-8567-0e02b2c3d479", 8, RFC4122, true},
{"f47ac10b-58cc-9372-8567-0e02b2c3d479", 9, RFC4122, true},
{"f47ac10b-58cc-a372-8567-0e02b2c3d479", 10, RFC4122, true},
{"f47ac10b-58cc-b372-8567-0e02b2c3d479", 11, RFC4122, true},
{"f47ac10b-58cc-c372-8567-0e02b2c3d479", 12, RFC4122, true},
{"f47ac10b-58cc-d372-8567-0e02b2c3d479", 13, RFC4122, true},
{"f47ac10b-58cc-e372-8567-0e02b2c3d479", 14, RFC4122, true},
{"f47ac10b-58cc-f372-8567-0e02b2c3d479", 15, RFC4122, true},
{"urn:uuid:f47ac10b-58cc-4372-0567-0e02b2c3d479", 4, Reserved, true},
{"URN:UUID:f47ac10b-58cc-4372-0567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-0567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-1567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-2567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-3567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-4567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-5567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-6567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-7567-0e02b2c3d479", 4, Reserved, true},
{"f47ac10b-58cc-4372-8567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-4372-9567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-4372-a567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-4372-b567-0e02b2c3d479", 4, RFC4122, true},
{"f47ac10b-58cc-4372-c567-0e02b2c3d479", 4, Microsoft, true},
{"f47ac10b-58cc-4372-d567-0e02b2c3d479", 4, Microsoft, true},
{"f47ac10b-58cc-4372-e567-0e02b2c3d479", 4, Future, true},
{"f47ac10b-58cc-4372-f567-0e02b2c3d479", 4, Future, true},
{"f47ac10b158cc-5372-a567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc25372-a567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-53723a567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-5372-a56740e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-5372-a567-0e02-2c3d479", 0, Invalid, false},
{"g47ac10b-58cc-4372-a567-0e02b2c3d479", 0, Invalid, false},
}
var constants = []struct {
c interface{}
name string
}{
{Person, "Person"},
{Group, "Group"},
{Org, "Org"},
{Invalid, "Invalid"},
{RFC4122, "RFC4122"},
{Reserved, "Reserved"},
{Microsoft, "Microsoft"},
{Future, "Future"},
{Domain(17), "Domain17"},
{Variant(42), "BadVariant42"},
}
func testTest(t *testing.T, in string, tt test) {
uuid := Parse(in)
if ok := (uuid != nil); ok != tt.isuuid {
t.Errorf("Parse(%s) got %v expected %v\b", in, ok, tt.isuuid)
}
if uuid == nil {
return
}
if v := uuid.Variant(); v != tt.variant {
t.Errorf("Variant(%s) got %d expected %d\b", in, v, tt.variant)
}
if v, _ := uuid.Version(); v != tt.version {
t.Errorf("Version(%s) got %d expected %d\b", in, v, tt.version)
}
}
func TestUUID(t *testing.T) {
for _, tt := range tests {
testTest(t, tt.in, tt)
testTest(t, strings.ToUpper(tt.in), tt)
}
}
func TestConstants(t *testing.T) {
for x, tt := range constants {
v, ok := tt.c.(fmt.Stringer)
if !ok {
t.Errorf("%x: %v: not a stringer", x, v)
} else if s := v.String(); s != tt.name {
v, _ := tt.c.(int)
t.Errorf("%x: Constant %T:%d gives %q, expected %q\n", x, tt.c, v, s, tt.name)
}
}
}
func TestRandomUUID(t *testing.T) {
m := make(map[string]bool)
for x := 1; x < 32; x++ {
uuid := NewRandom()
s := uuid.String()
if m[s] {
t.Errorf("NewRandom returned duplicated UUID %s\n", s)
}
m[s] = true
if v, _ := uuid.Version(); v != 4 {
t.Errorf("Random UUID of version %s\n", v)
}
if uuid.Variant() != RFC4122 {
t.Errorf("Random UUID is variant %d\n", uuid.Variant())
}
}
}
func TestNew(t *testing.T) {
m := make(map[string]bool)
for x := 1; x < 32; x++ {
s := New()
if m[s] {
t.Errorf("New returned duplicated UUID %s\n", s)
}
m[s] = true
uuid := Parse(s)
if uuid == nil {
t.Errorf("New returned %q which does not decode\n", s)
continue
}
if v, _ := uuid.Version(); v != 4 {
t.Errorf("Random UUID of version %s\n", v)
}
if uuid.Variant() != RFC4122 {
t.Errorf("Random UUID is variant %d\n", uuid.Variant())
}
}
}
func clockSeq(t *testing.T, uuid UUID) int {
seq, ok := uuid.ClockSequence()
if !ok {
t.Fatalf("%s: invalid clock sequence\n", uuid)
}
return seq
}
func TestClockSeq(t *testing.T) {
// Fake time.Now for this test to return a monotonically advancing time; restore it at end.
defer func(orig func() time.Time) { timeNow = orig }(timeNow)
monTime := time.Now()
timeNow = func() time.Time {
monTime = monTime.Add(1 * time.Second)
return monTime
}
SetClockSequence(-1)
uuid1 := NewUUID()
uuid2 := NewUUID()
if clockSeq(t, uuid1) != clockSeq(t, uuid2) {
t.Errorf("clock sequence %d != %d\n", clockSeq(t, uuid1), clockSeq(t, uuid2))
}
SetClockSequence(-1)
uuid2 = NewUUID()
// Just on the very off chance we generated the same sequence
// two times we try again.
if clockSeq(t, uuid1) == clockSeq(t, uuid2) {
SetClockSequence(-1)
uuid2 = NewUUID()
}
if clockSeq(t, uuid1) == clockSeq(t, uuid2) {
t.Errorf("Duplicate clock sequence %d\n", clockSeq(t, uuid1))
}
SetClockSequence(0x1234)
uuid1 = NewUUID()
if seq := clockSeq(t, uuid1); seq != 0x1234 {
t.Errorf("%s: expected seq 0x1234 got 0x%04x\n", uuid1, seq)
}
}
func TestCoding(t *testing.T) {
text := "7d444840-9dc0-11d1-b245-5ffdce74fad2"
urn := "urn:uuid:7d444840-9dc0-11d1-b245-5ffdce74fad2"
data := UUID{
0x7d, 0x44, 0x48, 0x40,
0x9d, 0xc0,
0x11, 0xd1,
0xb2, 0x45,
0x5f, 0xfd, 0xce, 0x74, 0xfa, 0xd2,
}
if v := data.String(); v != text {
t.Errorf("%x: encoded to %s, expected %s\n", data, v, text)
}
if v := data.URN(); v != urn {
t.Errorf("%x: urn is %s, expected %s\n", data, v, urn)
}
uuid := Parse(text)
if !Equal(uuid, data) {
t.Errorf("%s: decoded to %s, expected %s\n", text, uuid, data)
}
}
func TestVersion1(t *testing.T) {
uuid1 := NewUUID()
uuid2 := NewUUID()
if Equal(uuid1, uuid2) {
t.Errorf("%s:duplicate uuid\n", uuid1)
}
if v, _ := uuid1.Version(); v != 1 {
t.Errorf("%s: version %s expected 1\n", uuid1, v)
}
if v, _ := uuid2.Version(); v != 1 {
t.Errorf("%s: version %s expected 1\n", uuid2, v)
}
n1 := uuid1.NodeID()
n2 := uuid2.NodeID()
if !bytes.Equal(n1, n2) {
t.Errorf("Different nodes %x != %x\n", n1, n2)
}
t1, ok := uuid1.Time()
if !ok {
t.Errorf("%s: invalid time\n", uuid1)
}
t2, ok := uuid2.Time()
if !ok {
t.Errorf("%s: invalid time\n", uuid2)
}
q1, ok := uuid1.ClockSequence()
if !ok {
t.Errorf("%s: invalid clock sequence\n", uuid1)
}
q2, ok := uuid2.ClockSequence()
if !ok {
t.Errorf("%s: invalid clock sequence", uuid2)
}
switch {
case t1 == t2 && q1 == q2:
t.Errorf("time stopped\n")
case t1 > t2 && q1 == q2:
t.Errorf("time reversed\n")
case t1 < t2 && q1 != q2:
t.Errorf("clock sequence chaned unexpectedly\n")
}
}
func TestNodeAndTime(t *testing.T) {
// Time is February 5, 1998 12:30:23.136364800 AM GMT
uuid := Parse("7d444840-9dc0-11d1-b245-5ffdce74fad2")
node := []byte{0x5f, 0xfd, 0xce, 0x74, 0xfa, 0xd2}
ts, ok := uuid.Time()
if ok {
c := time.Unix(ts.UnixTime())
want := time.Date(1998, 2, 5, 0, 30, 23, 136364800, time.UTC)
if !c.Equal(want) {
t.Errorf("Got time %v, want %v", c, want)
}
} else {
t.Errorf("%s: bad time\n", uuid)
}
if !bytes.Equal(node, uuid.NodeID()) {
t.Errorf("Expected node %v got %v\n", node, uuid.NodeID())
}
}
func TestMD5(t *testing.T) {
uuid := NewMD5(NameSpace_DNS, []byte("python.org")).String()
want := "6fa459ea-ee8a-3ca4-894e-db77e160355e"
if uuid != want {
t.Errorf("MD5: got %q expected %q\n", uuid, want)
}
}
func TestSHA1(t *testing.T) {
uuid := NewSHA1(NameSpace_DNS, []byte("python.org")).String()
want := "886313e1-3b8a-5372-9b90-0c9aee199e5d"
if uuid != want {
t.Errorf("SHA1: got %q expected %q\n", uuid, want)
}
}
func TestNodeID(t *testing.T) {
nid := []byte{1, 2, 3, 4, 5, 6}
SetNodeInterface("")
s := NodeInterface()
if s == "" || s == "user" {
t.Errorf("NodeInterface %q after SetInteface\n", s)
}
node1 := NodeID()
if node1 == nil {
t.Errorf("NodeID nil after SetNodeInterface\n", s)
}
SetNodeID(nid)
s = NodeInterface()
if s != "user" {
t.Errorf("Expected NodeInterface %q got %q\n", "user", s)
}
node2 := NodeID()
if node2 == nil {
t.Errorf("NodeID nil after SetNodeID\n", s)
}
if bytes.Equal(node1, node2) {
t.Errorf("NodeID not changed after SetNodeID\n", s)
} else if !bytes.Equal(nid, node2) {
t.Errorf("NodeID is %x, expected %x\n", node2, nid)
}
}
func testDCE(t *testing.T, name string, uuid UUID, domain Domain, id uint32) {
if uuid == nil {
t.Errorf("%s failed\n", name)
return
}
if v, _ := uuid.Version(); v != 2 {
t.Errorf("%s: %s: expected version 2, got %s\n", name, uuid, v)
return
}
if v, ok := uuid.Domain(); !ok || v != domain {
if !ok {
t.Errorf("%s: %d: Domain failed\n", name, uuid)
} else {
t.Errorf("%s: %s: expected domain %d, got %d\n", name, uuid, domain, v)
}
}
if v, ok := uuid.Id(); !ok || v != id {
if !ok {
t.Errorf("%s: %d: Id failed\n", name, uuid)
} else {
t.Errorf("%s: %s: expected id %d, got %d\n", name, uuid, id, v)
}
}
}
func TestDCE(t *testing.T) {
testDCE(t, "NewDCESecurity", NewDCESecurity(42, 12345678), 42, 12345678)
testDCE(t, "NewDCEPerson", NewDCEPerson(), Person, uint32(os.Getuid()))
testDCE(t, "NewDCEGroup", NewDCEGroup(), Group, uint32(os.Getgid()))
}
type badRand struct{}
func (r badRand) Read(buf []byte) (int, error) {
for i, _ := range buf {
buf[i] = byte(i)
}
return len(buf), nil
}
func TestBadRand(t *testing.T) {
SetRand(badRand{})
uuid1 := New()
uuid2 := New()
if uuid1 != uuid2 {
t.Errorf("execpted duplicates, got %q and %q\n", uuid1, uuid2)
}
SetRand(nil)
uuid1 = New()
uuid2 = New()
if uuid1 == uuid2 {
t.Errorf("unexecpted duplicates, got %q\n", uuid1)
}
}

View File

@ -0,0 +1,41 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
import (
"encoding/binary"
)
// NewUUID returns a Version 1 UUID based on the current NodeID and clock
// sequence, and the current time. If the NodeID has not been set by SetNodeID
// or SetNodeInterface then it will be set automatically. If the NodeID cannot
// be set NewUUID returns nil. If clock sequence has not been set by
// SetClockSequence then it will be set automatically. If GetTime fails to
// return the current NewUUID returns nil.
func NewUUID() UUID {
if nodeID == nil {
SetNodeInterface("")
}
now, err := GetTime()
if err != nil {
return nil
}
uuid := make([]byte, 16)
time_low := uint32(now & 0xffffffff)
time_mid := uint16((now >> 32) & 0xffff)
time_hi := uint16((now >> 48) & 0x0fff)
time_hi |= 0x1000 // Version 1
binary.BigEndian.PutUint32(uuid[0:], time_low)
binary.BigEndian.PutUint16(uuid[4:], time_mid)
binary.BigEndian.PutUint16(uuid[6:], time_hi)
binary.BigEndian.PutUint16(uuid[8:], clock_seq)
copy(uuid[10:], nodeID)
return uuid
}

View File

@ -0,0 +1,25 @@
// Copyright 2011 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package uuid
// Random returns a Random (Version 4) UUID or panics.
//
// The strength of the UUIDs is based on the strength of the crypto/rand
// package.
//
// A note about uniqueness derived from from the UUID Wikipedia entry:
//
// Randomly generated UUIDs have 122 random bits. One's annual risk of being
// hit by a meteorite is estimated to be one chance in 17 billion, that
// means the probability is about 0.00000000006 (6 × 1011),
// equivalent to the odds of creating a few tens of trillions of UUIDs in a
// year and having one duplicate.
func NewRandom() UUID {
uuid := make([]byte, 16)
randomBits([]byte(uuid))
uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4
uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10
return uuid
}

View File

@ -0,0 +1,77 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package pbkdf2 implements the key derivation function PBKDF2 as defined in RFC
2898 / PKCS #5 v2.0.
A key derivation function is useful when encrypting data based on a password
or any other not-fully-random data. It uses a pseudorandom function to derive
a secure encryption key based on the password.
While v2.0 of the standard defines only one pseudorandom function to use,
HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS Approved
Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for HMAC. To
choose, you can pass the `New` functions from the different SHA packages to
pbkdf2.Key.
*/
package pbkdf2
import (
"crypto/hmac"
"hash"
)
// Key derives a key from the password, salt and iteration count, returning a
// []byte of length keylen that can be used as cryptographic key. The key is
// derived based on the method described as PBKDF2 with the HMAC variant using
// the supplied hash function.
//
// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you
// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by
// doing:
//
// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New)
//
// Remember to get a good random salt. At least 8 bytes is recommended by the
// RFC.
//
// Using a higher iteration count will increase the cost of an exhaustive
// search but will also make derivation proportionally slower.
func Key(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte {
prf := hmac.New(h, password)
hashLen := prf.Size()
numBlocks := (keyLen + hashLen - 1) / hashLen
var buf [4]byte
dk := make([]byte, 0, numBlocks*hashLen)
U := make([]byte, hashLen)
for block := 1; block <= numBlocks; block++ {
// N.B.: || means concatenation, ^ means XOR
// for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter
// U_1 = PRF(password, salt || uint(i))
prf.Reset()
prf.Write(salt)
buf[0] = byte(block >> 24)
buf[1] = byte(block >> 16)
buf[2] = byte(block >> 8)
buf[3] = byte(block)
prf.Write(buf[:4])
dk = prf.Sum(dk)
T := dk[len(dk)-hashLen:]
copy(U, T)
// U_n = PRF(password, U_(n-1))
for n := 2; n <= iter; n++ {
prf.Reset()
prf.Write(U)
U = U[:0]
U = prf.Sum(U)
for x := range U {
T[x] ^= U[x]
}
}
}
return dk[:keyLen]
}

View File

@ -0,0 +1,157 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package pbkdf2
import (
"bytes"
"crypto/sha1"
"crypto/sha256"
"hash"
"testing"
)
type testVector struct {
password string
salt string
iter int
output []byte
}
// Test vectors from RFC 6070, http://tools.ietf.org/html/rfc6070
var sha1TestVectors = []testVector{
{
"password",
"salt",
1,
[]byte{
0x0c, 0x60, 0xc8, 0x0f, 0x96, 0x1f, 0x0e, 0x71,
0xf3, 0xa9, 0xb5, 0x24, 0xaf, 0x60, 0x12, 0x06,
0x2f, 0xe0, 0x37, 0xa6,
},
},
{
"password",
"salt",
2,
[]byte{
0xea, 0x6c, 0x01, 0x4d, 0xc7, 0x2d, 0x6f, 0x8c,
0xcd, 0x1e, 0xd9, 0x2a, 0xce, 0x1d, 0x41, 0xf0,
0xd8, 0xde, 0x89, 0x57,
},
},
{
"password",
"salt",
4096,
[]byte{
0x4b, 0x00, 0x79, 0x01, 0xb7, 0x65, 0x48, 0x9a,
0xbe, 0xad, 0x49, 0xd9, 0x26, 0xf7, 0x21, 0xd0,
0x65, 0xa4, 0x29, 0xc1,
},
},
// // This one takes too long
// {
// "password",
// "salt",
// 16777216,
// []byte{
// 0xee, 0xfe, 0x3d, 0x61, 0xcd, 0x4d, 0xa4, 0xe4,
// 0xe9, 0x94, 0x5b, 0x3d, 0x6b, 0xa2, 0x15, 0x8c,
// 0x26, 0x34, 0xe9, 0x84,
// },
// },
{
"passwordPASSWORDpassword",
"saltSALTsaltSALTsaltSALTsaltSALTsalt",
4096,
[]byte{
0x3d, 0x2e, 0xec, 0x4f, 0xe4, 0x1c, 0x84, 0x9b,
0x80, 0xc8, 0xd8, 0x36, 0x62, 0xc0, 0xe4, 0x4a,
0x8b, 0x29, 0x1a, 0x96, 0x4c, 0xf2, 0xf0, 0x70,
0x38,
},
},
{
"pass\000word",
"sa\000lt",
4096,
[]byte{
0x56, 0xfa, 0x6a, 0xa7, 0x55, 0x48, 0x09, 0x9d,
0xcc, 0x37, 0xd7, 0xf0, 0x34, 0x25, 0xe0, 0xc3,
},
},
}
// Test vectors from
// http://stackoverflow.com/questions/5130513/pbkdf2-hmac-sha2-test-vectors
var sha256TestVectors = []testVector{
{
"password",
"salt",
1,
[]byte{
0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c,
0x43, 0xe7, 0x22, 0x52, 0x56, 0xc4, 0xf8, 0x37,
0xa8, 0x65, 0x48, 0xc9,
},
},
{
"password",
"salt",
2,
[]byte{
0xae, 0x4d, 0x0c, 0x95, 0xaf, 0x6b, 0x46, 0xd3,
0x2d, 0x0a, 0xdf, 0xf9, 0x28, 0xf0, 0x6d, 0xd0,
0x2a, 0x30, 0x3f, 0x8e,
},
},
{
"password",
"salt",
4096,
[]byte{
0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41,
0xaa, 0x53, 0x0d, 0xb6, 0x84, 0x5c, 0x4c, 0x8d,
0x96, 0x28, 0x93, 0xa0,
},
},
{
"passwordPASSWORDpassword",
"saltSALTsaltSALTsaltSALTsaltSALTsalt",
4096,
[]byte{
0x34, 0x8c, 0x89, 0xdb, 0xcb, 0xd3, 0x2b, 0x2f,
0x32, 0xd8, 0x14, 0xb8, 0x11, 0x6e, 0x84, 0xcf,
0x2b, 0x17, 0x34, 0x7e, 0xbc, 0x18, 0x00, 0x18,
0x1c,
},
},
{
"pass\000word",
"sa\000lt",
4096,
[]byte{
0x89, 0xb6, 0x9d, 0x05, 0x16, 0xf8, 0x29, 0x89,
0x3c, 0x69, 0x62, 0x26, 0x65, 0x0a, 0x86, 0x87,
},
},
}
func testHash(t *testing.T, h func() hash.Hash, hashName string, vectors []testVector) {
for i, v := range vectors {
o := Key([]byte(v.password), []byte(v.salt), v.iter, len(v.output), h)
if !bytes.Equal(o, v.output) {
t.Errorf("%s %d: expected %x, got %x", hashName, i, v.output, o)
}
}
}
func TestWithHMACSHA1(t *testing.T) {
testHash(t, sha1.New, "SHA1", sha1TestVectors)
}
func TestWithHMACSHA256(t *testing.T) {
testHash(t, sha256.New, "SHA256", sha256TestVectors)
}

View File

@ -0,0 +1,120 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package ripemd160 implements the RIPEMD-160 hash algorithm.
package ripemd160
// RIPEMD-160 is designed by by Hans Dobbertin, Antoon Bosselaers, and Bart
// Preneel with specifications available at:
// http://homes.esat.kuleuven.be/~cosicart/pdf/AB-9601/AB-9601.pdf.
import (
"crypto"
"hash"
)
func init() {
crypto.RegisterHash(crypto.RIPEMD160, New)
}
// The size of the checksum in bytes.
const Size = 20
// The block size of the hash algorithm in bytes.
const BlockSize = 64
const (
_s0 = 0x67452301
_s1 = 0xefcdab89
_s2 = 0x98badcfe
_s3 = 0x10325476
_s4 = 0xc3d2e1f0
)
// digest represents the partial evaluation of a checksum.
type digest struct {
s [5]uint32 // running context
x [BlockSize]byte // temporary buffer
nx int // index into x
tc uint64 // total count of bytes processed
}
func (d *digest) Reset() {
d.s[0], d.s[1], d.s[2], d.s[3], d.s[4] = _s0, _s1, _s2, _s3, _s4
d.nx = 0
d.tc = 0
}
// New returns a new hash.Hash computing the checksum.
func New() hash.Hash {
result := new(digest)
result.Reset()
return result
}
func (d *digest) Size() int { return Size }
func (d *digest) BlockSize() int { return BlockSize }
func (d *digest) Write(p []byte) (nn int, err error) {
nn = len(p)
d.tc += uint64(nn)
if d.nx > 0 {
n := len(p)
if n > BlockSize-d.nx {
n = BlockSize - d.nx
}
for i := 0; i < n; i++ {
d.x[d.nx+i] = p[i]
}
d.nx += n
if d.nx == BlockSize {
_Block(d, d.x[0:])
d.nx = 0
}
p = p[n:]
}
n := _Block(d, p)
p = p[n:]
if len(p) > 0 {
d.nx = copy(d.x[:], p)
}
return
}
func (d0 *digest) Sum(in []byte) []byte {
// Make a copy of d0 so that caller can keep writing and summing.
d := *d0
// Padding. Add a 1 bit and 0 bits until 56 bytes mod 64.
tc := d.tc
var tmp [64]byte
tmp[0] = 0x80
if tc%64 < 56 {
d.Write(tmp[0 : 56-tc%64])
} else {
d.Write(tmp[0 : 64+56-tc%64])
}
// Length in bits.
tc <<= 3
for i := uint(0); i < 8; i++ {
tmp[i] = byte(tc >> (8 * i))
}
d.Write(tmp[0:8])
if d.nx != 0 {
panic("d.nx != 0")
}
var digest [Size]byte
for i, s := range d.s {
digest[i*4] = byte(s)
digest[i*4+1] = byte(s >> 8)
digest[i*4+2] = byte(s >> 16)
digest[i*4+3] = byte(s >> 24)
}
return append(in, digest[:]...)
}

View File

@ -0,0 +1,64 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ripemd160
// Test vectors are from:
// http://homes.esat.kuleuven.be/~bosselae/ripemd160.html
import (
"fmt"
"io"
"testing"
)
type mdTest struct {
out string
in string
}
var vectors = [...]mdTest{
{"9c1185a5c5e9fc54612808977ee8f548b2258d31", ""},
{"0bdc9d2d256b3ee9daae347be6f4dc835a467ffe", "a"},
{"8eb208f7e05d987a9b044a8e98c6b087f15a0bfc", "abc"},
{"5d0689ef49d2fae572b881b123a85ffa21595f36", "message digest"},
{"f71c27109c692c1b56bbdceb5b9d2865b3708dbc", "abcdefghijklmnopqrstuvwxyz"},
{"12a053384a9c0c88e405a06c27dcf49ada62eb2b", "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"},
{"b0e20b6e3116640286ed3a87a5713079b21f5189", "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"},
{"9b752e45573d4b39f4dbd3323cab82bf63326bfb", "12345678901234567890123456789012345678901234567890123456789012345678901234567890"},
}
func TestVectors(t *testing.T) {
for i := 0; i < len(vectors); i++ {
tv := vectors[i]
md := New()
for j := 0; j < 3; j++ {
if j < 2 {
io.WriteString(md, tv.in)
} else {
io.WriteString(md, tv.in[0:len(tv.in)/2])
md.Sum(nil)
io.WriteString(md, tv.in[len(tv.in)/2:])
}
s := fmt.Sprintf("%x", md.Sum(nil))
if s != tv.out {
t.Fatalf("RIPEMD-160[%d](%s) = %s, expected %s", j, tv.in, s, tv.out)
}
md.Reset()
}
}
}
func TestMillionA(t *testing.T) {
md := New()
for i := 0; i < 100000; i++ {
io.WriteString(md, "aaaaaaaaaa")
}
out := "52783243c1697bdbe16d37f97f68f08325dc1528"
s := fmt.Sprintf("%x", md.Sum(nil))
if s != out {
t.Fatalf("RIPEMD-160 (1 million 'a') = %s, expected %s", s, out)
}
md.Reset()
}

View File

@ -0,0 +1,161 @@
// Copyright 2010 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// RIPEMD-160 block step.
// In its own file so that a faster assembly or C version
// can be substituted easily.
package ripemd160
// work buffer indices and roll amounts for one line
var _n = [80]uint{
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
7, 4, 13, 1, 10, 6, 15, 3, 12, 0, 9, 5, 2, 14, 11, 8,
3, 10, 14, 4, 9, 15, 8, 1, 2, 7, 0, 6, 13, 11, 5, 12,
1, 9, 11, 10, 0, 8, 12, 4, 13, 3, 7, 15, 14, 5, 6, 2,
4, 0, 5, 9, 7, 12, 2, 10, 14, 1, 3, 8, 11, 6, 15, 13,
}
var _r = [80]uint{
11, 14, 15, 12, 5, 8, 7, 9, 11, 13, 14, 15, 6, 7, 9, 8,
7, 6, 8, 13, 11, 9, 7, 15, 7, 12, 15, 9, 11, 7, 13, 12,
11, 13, 6, 7, 14, 9, 13, 15, 14, 8, 13, 6, 5, 12, 7, 5,
11, 12, 14, 15, 14, 15, 9, 8, 9, 14, 5, 6, 8, 6, 5, 12,
9, 15, 5, 11, 6, 8, 13, 12, 5, 12, 13, 14, 11, 8, 5, 6,
}
// same for the other parallel one
var n_ = [80]uint{
5, 14, 7, 0, 9, 2, 11, 4, 13, 6, 15, 8, 1, 10, 3, 12,
6, 11, 3, 7, 0, 13, 5, 10, 14, 15, 8, 12, 4, 9, 1, 2,
15, 5, 1, 3, 7, 14, 6, 9, 11, 8, 12, 2, 10, 0, 4, 13,
8, 6, 4, 1, 3, 11, 15, 0, 5, 12, 2, 13, 9, 7, 10, 14,
12, 15, 10, 4, 1, 5, 8, 7, 6, 2, 13, 14, 0, 3, 9, 11,
}
var r_ = [80]uint{
8, 9, 9, 11, 13, 15, 15, 5, 7, 7, 8, 11, 14, 14, 12, 6,
9, 13, 15, 7, 12, 8, 9, 11, 7, 7, 12, 7, 6, 15, 13, 11,
9, 7, 15, 11, 8, 6, 6, 14, 12, 13, 5, 14, 13, 13, 7, 5,
15, 5, 8, 11, 14, 14, 6, 14, 6, 9, 12, 9, 12, 5, 15, 8,
8, 5, 12, 9, 12, 5, 14, 6, 8, 13, 6, 5, 15, 13, 11, 11,
}
func _Block(md *digest, p []byte) int {
n := 0
var x [16]uint32
var alpha, beta uint32
for len(p) >= BlockSize {
a, b, c, d, e := md.s[0], md.s[1], md.s[2], md.s[3], md.s[4]
aa, bb, cc, dd, ee := a, b, c, d, e
j := 0
for i := 0; i < 16; i++ {
x[i] = uint32(p[j]) | uint32(p[j+1])<<8 | uint32(p[j+2])<<16 | uint32(p[j+3])<<24
j += 4
}
// round 1
i := 0
for i < 16 {
alpha = a + (b ^ c ^ d) + x[_n[i]]
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb ^ (cc | ^dd)) + x[n_[i]] + 0x50a28be6
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// round 2
for i < 32 {
alpha = a + (b&c | ^b&d) + x[_n[i]] + 0x5a827999
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb&dd | cc&^dd) + x[n_[i]] + 0x5c4dd124
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// round 3
for i < 48 {
alpha = a + (b | ^c ^ d) + x[_n[i]] + 0x6ed9eba1
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb | ^cc ^ dd) + x[n_[i]] + 0x6d703ef3
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// round 4
for i < 64 {
alpha = a + (b&d | c&^d) + x[_n[i]] + 0x8f1bbcdc
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb&cc | ^bb&dd) + x[n_[i]] + 0x7a6d76e9
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// round 5
for i < 80 {
alpha = a + (b ^ (c | ^d)) + x[_n[i]] + 0xa953fd4e
s := _r[i]
alpha = (alpha<<s | alpha>>(32-s)) + e
beta = c<<10 | c>>22
a, b, c, d, e = e, alpha, b, beta, d
// parallel line
alpha = aa + (bb ^ cc ^ dd) + x[n_[i]]
s = r_[i]
alpha = (alpha<<s | alpha>>(32-s)) + ee
beta = cc<<10 | cc>>22
aa, bb, cc, dd, ee = ee, alpha, bb, beta, dd
i++
}
// combine results
dd += c + md.s[1]
md.s[1] = md.s[2] + d + ee
md.s[2] = md.s[3] + e + aa
md.s[3] = md.s[4] + a + bb
md.s[4] = md.s[0] + b + cc
md.s[0] = dd
p = p[BlockSize:]
n += BlockSize
}
return n
}

View File

@ -0,0 +1,243 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package scrypt implements the scrypt key derivation function as defined in
// Colin Percival's paper "Stronger Key Derivation via Sequential Memory-Hard
// Functions" (http://www.tarsnap.com/scrypt/scrypt.pdf).
package scrypt
import (
"crypto/sha256"
"errors"
"golang.org/x/crypto/pbkdf2"
)
const maxInt = int(^uint(0) >> 1)
// blockCopy copies n numbers from src into dst.
func blockCopy(dst, src []uint32, n int) {
copy(dst, src[:n])
}
// blockXOR XORs numbers from dst with n numbers from src.
func blockXOR(dst, src []uint32, n int) {
for i, v := range src[:n] {
dst[i] ^= v
}
}
// salsaXOR applies Salsa20/8 to the XOR of 16 numbers from tmp and in,
// and puts the result into both both tmp and out.
func salsaXOR(tmp *[16]uint32, in, out []uint32) {
w0 := tmp[0] ^ in[0]
w1 := tmp[1] ^ in[1]
w2 := tmp[2] ^ in[2]
w3 := tmp[3] ^ in[3]
w4 := tmp[4] ^ in[4]
w5 := tmp[5] ^ in[5]
w6 := tmp[6] ^ in[6]
w7 := tmp[7] ^ in[7]
w8 := tmp[8] ^ in[8]
w9 := tmp[9] ^ in[9]
w10 := tmp[10] ^ in[10]
w11 := tmp[11] ^ in[11]
w12 := tmp[12] ^ in[12]
w13 := tmp[13] ^ in[13]
w14 := tmp[14] ^ in[14]
w15 := tmp[15] ^ in[15]
x0, x1, x2, x3, x4, x5, x6, x7, x8 := w0, w1, w2, w3, w4, w5, w6, w7, w8
x9, x10, x11, x12, x13, x14, x15 := w9, w10, w11, w12, w13, w14, w15
for i := 0; i < 8; i += 2 {
u := x0 + x12
x4 ^= u<<7 | u>>(32-7)
u = x4 + x0
x8 ^= u<<9 | u>>(32-9)
u = x8 + x4
x12 ^= u<<13 | u>>(32-13)
u = x12 + x8
x0 ^= u<<18 | u>>(32-18)
u = x5 + x1
x9 ^= u<<7 | u>>(32-7)
u = x9 + x5
x13 ^= u<<9 | u>>(32-9)
u = x13 + x9
x1 ^= u<<13 | u>>(32-13)
u = x1 + x13
x5 ^= u<<18 | u>>(32-18)
u = x10 + x6
x14 ^= u<<7 | u>>(32-7)
u = x14 + x10
x2 ^= u<<9 | u>>(32-9)
u = x2 + x14
x6 ^= u<<13 | u>>(32-13)
u = x6 + x2
x10 ^= u<<18 | u>>(32-18)
u = x15 + x11
x3 ^= u<<7 | u>>(32-7)
u = x3 + x15
x7 ^= u<<9 | u>>(32-9)
u = x7 + x3
x11 ^= u<<13 | u>>(32-13)
u = x11 + x7
x15 ^= u<<18 | u>>(32-18)
u = x0 + x3
x1 ^= u<<7 | u>>(32-7)
u = x1 + x0
x2 ^= u<<9 | u>>(32-9)
u = x2 + x1
x3 ^= u<<13 | u>>(32-13)
u = x3 + x2
x0 ^= u<<18 | u>>(32-18)
u = x5 + x4
x6 ^= u<<7 | u>>(32-7)
u = x6 + x5
x7 ^= u<<9 | u>>(32-9)
u = x7 + x6
x4 ^= u<<13 | u>>(32-13)
u = x4 + x7
x5 ^= u<<18 | u>>(32-18)
u = x10 + x9
x11 ^= u<<7 | u>>(32-7)
u = x11 + x10
x8 ^= u<<9 | u>>(32-9)
u = x8 + x11
x9 ^= u<<13 | u>>(32-13)
u = x9 + x8
x10 ^= u<<18 | u>>(32-18)
u = x15 + x14
x12 ^= u<<7 | u>>(32-7)
u = x12 + x15
x13 ^= u<<9 | u>>(32-9)
u = x13 + x12
x14 ^= u<<13 | u>>(32-13)
u = x14 + x13
x15 ^= u<<18 | u>>(32-18)
}
x0 += w0
x1 += w1
x2 += w2
x3 += w3
x4 += w4
x5 += w5
x6 += w6
x7 += w7
x8 += w8
x9 += w9
x10 += w10
x11 += w11
x12 += w12
x13 += w13
x14 += w14
x15 += w15
out[0], tmp[0] = x0, x0
out[1], tmp[1] = x1, x1
out[2], tmp[2] = x2, x2
out[3], tmp[3] = x3, x3
out[4], tmp[4] = x4, x4
out[5], tmp[5] = x5, x5
out[6], tmp[6] = x6, x6
out[7], tmp[7] = x7, x7
out[8], tmp[8] = x8, x8
out[9], tmp[9] = x9, x9
out[10], tmp[10] = x10, x10
out[11], tmp[11] = x11, x11
out[12], tmp[12] = x12, x12
out[13], tmp[13] = x13, x13
out[14], tmp[14] = x14, x14
out[15], tmp[15] = x15, x15
}
func blockMix(tmp *[16]uint32, in, out []uint32, r int) {
blockCopy(tmp[:], in[(2*r-1)*16:], 16)
for i := 0; i < 2*r; i += 2 {
salsaXOR(tmp, in[i*16:], out[i*8:])
salsaXOR(tmp, in[i*16+16:], out[i*8+r*16:])
}
}
func integer(b []uint32, r int) uint64 {
j := (2*r - 1) * 16
return uint64(b[j]) | uint64(b[j+1])<<32
}
func smix(b []byte, r, N int, v, xy []uint32) {
var tmp [16]uint32
x := xy
y := xy[32*r:]
j := 0
for i := 0; i < 32*r; i++ {
x[i] = uint32(b[j]) | uint32(b[j+1])<<8 | uint32(b[j+2])<<16 | uint32(b[j+3])<<24
j += 4
}
for i := 0; i < N; i += 2 {
blockCopy(v[i*(32*r):], x, 32*r)
blockMix(&tmp, x, y, r)
blockCopy(v[(i+1)*(32*r):], y, 32*r)
blockMix(&tmp, y, x, r)
}
for i := 0; i < N; i += 2 {
j := int(integer(x, r) & uint64(N-1))
blockXOR(x, v[j*(32*r):], 32*r)
blockMix(&tmp, x, y, r)
j = int(integer(y, r) & uint64(N-1))
blockXOR(y, v[j*(32*r):], 32*r)
blockMix(&tmp, y, x, r)
}
j = 0
for _, v := range x[:32*r] {
b[j+0] = byte(v >> 0)
b[j+1] = byte(v >> 8)
b[j+2] = byte(v >> 16)
b[j+3] = byte(v >> 24)
j += 4
}
}
// Key derives a key from the password, salt, and cost parameters, returning
// a byte slice of length keyLen that can be used as cryptographic key.
//
// N is a CPU/memory cost parameter, which must be a power of two greater than 1.
// r and p must satisfy r * p < 2³⁰. If the parameters do not satisfy the
// limits, the function returns a nil byte slice and an error.
//
// For example, you can get a derived key for e.g. AES-256 (which needs a
// 32-byte key) by doing:
//
// dk := scrypt.Key([]byte("some password"), salt, 16384, 8, 1, 32)
//
// The recommended parameters for interactive logins as of 2009 are N=16384,
// r=8, p=1. They should be increased as memory latency and CPU parallelism
// increases. Remember to get a good random salt.
func Key(password, salt []byte, N, r, p, keyLen int) ([]byte, error) {
if N <= 1 || N&(N-1) != 0 {
return nil, errors.New("scrypt: N must be > 1 and a power of 2")
}
if uint64(r)*uint64(p) >= 1<<30 || r > maxInt/128/p || r > maxInt/256 || N > maxInt/128/r {
return nil, errors.New("scrypt: parameters are too large")
}
xy := make([]uint32, 64*r)
v := make([]uint32, 32*N*r)
b := pbkdf2.Key(password, salt, 1, p*128*r, sha256.New)
for i := 0; i < p; i++ {
smix(b[i*128*r:], r, N, v, xy)
}
return pbkdf2.Key(password, b, 1, keyLen, sha256.New), nil
}

View File

@ -0,0 +1,160 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package scrypt
import (
"bytes"
"testing"
)
type testVector struct {
password string
salt string
N, r, p int
output []byte
}
var good = []testVector{
{
"password",
"salt",
2, 10, 10,
[]byte{
0x48, 0x2c, 0x85, 0x8e, 0x22, 0x90, 0x55, 0xe6, 0x2f,
0x41, 0xe0, 0xec, 0x81, 0x9a, 0x5e, 0xe1, 0x8b, 0xdb,
0x87, 0x25, 0x1a, 0x53, 0x4f, 0x75, 0xac, 0xd9, 0x5a,
0xc5, 0xe5, 0xa, 0xa1, 0x5f,
},
},
{
"password",
"salt",
16, 100, 100,
[]byte{
0x88, 0xbd, 0x5e, 0xdb, 0x52, 0xd1, 0xdd, 0x0, 0x18,
0x87, 0x72, 0xad, 0x36, 0x17, 0x12, 0x90, 0x22, 0x4e,
0x74, 0x82, 0x95, 0x25, 0xb1, 0x8d, 0x73, 0x23, 0xa5,
0x7f, 0x91, 0x96, 0x3c, 0x37,
},
},
{
"this is a long \000 password",
"and this is a long \000 salt",
16384, 8, 1,
[]byte{
0xc3, 0xf1, 0x82, 0xee, 0x2d, 0xec, 0x84, 0x6e, 0x70,
0xa6, 0x94, 0x2f, 0xb5, 0x29, 0x98, 0x5a, 0x3a, 0x09,
0x76, 0x5e, 0xf0, 0x4c, 0x61, 0x29, 0x23, 0xb1, 0x7f,
0x18, 0x55, 0x5a, 0x37, 0x07, 0x6d, 0xeb, 0x2b, 0x98,
0x30, 0xd6, 0x9d, 0xe5, 0x49, 0x26, 0x51, 0xe4, 0x50,
0x6a, 0xe5, 0x77, 0x6d, 0x96, 0xd4, 0x0f, 0x67, 0xaa,
0xee, 0x37, 0xe1, 0x77, 0x7b, 0x8a, 0xd5, 0xc3, 0x11,
0x14, 0x32, 0xbb, 0x3b, 0x6f, 0x7e, 0x12, 0x64, 0x40,
0x18, 0x79, 0xe6, 0x41, 0xae,
},
},
{
"p",
"s",
2, 1, 1,
[]byte{
0x48, 0xb0, 0xd2, 0xa8, 0xa3, 0x27, 0x26, 0x11, 0x98,
0x4c, 0x50, 0xeb, 0xd6, 0x30, 0xaf, 0x52,
},
},
{
"",
"",
16, 1, 1,
[]byte{
0x77, 0xd6, 0x57, 0x62, 0x38, 0x65, 0x7b, 0x20, 0x3b,
0x19, 0xca, 0x42, 0xc1, 0x8a, 0x04, 0x97, 0xf1, 0x6b,
0x48, 0x44, 0xe3, 0x07, 0x4a, 0xe8, 0xdf, 0xdf, 0xfa,
0x3f, 0xed, 0xe2, 0x14, 0x42, 0xfc, 0xd0, 0x06, 0x9d,
0xed, 0x09, 0x48, 0xf8, 0x32, 0x6a, 0x75, 0x3a, 0x0f,
0xc8, 0x1f, 0x17, 0xe8, 0xd3, 0xe0, 0xfb, 0x2e, 0x0d,
0x36, 0x28, 0xcf, 0x35, 0xe2, 0x0c, 0x38, 0xd1, 0x89,
0x06,
},
},
{
"password",
"NaCl",
1024, 8, 16,
[]byte{
0xfd, 0xba, 0xbe, 0x1c, 0x9d, 0x34, 0x72, 0x00, 0x78,
0x56, 0xe7, 0x19, 0x0d, 0x01, 0xe9, 0xfe, 0x7c, 0x6a,
0xd7, 0xcb, 0xc8, 0x23, 0x78, 0x30, 0xe7, 0x73, 0x76,
0x63, 0x4b, 0x37, 0x31, 0x62, 0x2e, 0xaf, 0x30, 0xd9,
0x2e, 0x22, 0xa3, 0x88, 0x6f, 0xf1, 0x09, 0x27, 0x9d,
0x98, 0x30, 0xda, 0xc7, 0x27, 0xaf, 0xb9, 0x4a, 0x83,
0xee, 0x6d, 0x83, 0x60, 0xcb, 0xdf, 0xa2, 0xcc, 0x06,
0x40,
},
},
{
"pleaseletmein", "SodiumChloride",
16384, 8, 1,
[]byte{
0x70, 0x23, 0xbd, 0xcb, 0x3a, 0xfd, 0x73, 0x48, 0x46,
0x1c, 0x06, 0xcd, 0x81, 0xfd, 0x38, 0xeb, 0xfd, 0xa8,
0xfb, 0xba, 0x90, 0x4f, 0x8e, 0x3e, 0xa9, 0xb5, 0x43,
0xf6, 0x54, 0x5d, 0xa1, 0xf2, 0xd5, 0x43, 0x29, 0x55,
0x61, 0x3f, 0x0f, 0xcf, 0x62, 0xd4, 0x97, 0x05, 0x24,
0x2a, 0x9a, 0xf9, 0xe6, 0x1e, 0x85, 0xdc, 0x0d, 0x65,
0x1e, 0x40, 0xdf, 0xcf, 0x01, 0x7b, 0x45, 0x57, 0x58,
0x87,
},
},
/*
// Disabled: needs 1 GiB RAM and takes too long for a simple test.
{
"pleaseletmein", "SodiumChloride",
1048576, 8, 1,
[]byte{
0x21, 0x01, 0xcb, 0x9b, 0x6a, 0x51, 0x1a, 0xae, 0xad,
0xdb, 0xbe, 0x09, 0xcf, 0x70, 0xf8, 0x81, 0xec, 0x56,
0x8d, 0x57, 0x4a, 0x2f, 0xfd, 0x4d, 0xab, 0xe5, 0xee,
0x98, 0x20, 0xad, 0xaa, 0x47, 0x8e, 0x56, 0xfd, 0x8f,
0x4b, 0xa5, 0xd0, 0x9f, 0xfa, 0x1c, 0x6d, 0x92, 0x7c,
0x40, 0xf4, 0xc3, 0x37, 0x30, 0x40, 0x49, 0xe8, 0xa9,
0x52, 0xfb, 0xcb, 0xf4, 0x5c, 0x6f, 0xa7, 0x7a, 0x41,
0xa4,
},
},
*/
}
var bad = []testVector{
{"p", "s", 0, 1, 1, nil}, // N == 0
{"p", "s", 1, 1, 1, nil}, // N == 1
{"p", "s", 7, 8, 1, nil}, // N is not power of 2
{"p", "s", 16, maxInt / 2, maxInt / 2, nil}, // p * r too large
}
func TestKey(t *testing.T) {
for i, v := range good {
k, err := Key([]byte(v.password), []byte(v.salt), v.N, v.r, v.p, len(v.output))
if err != nil {
t.Errorf("%d: got unexpected error: %s", i, err)
}
if !bytes.Equal(k, v.output) {
t.Errorf("%d: expected %x, got %x", i, v.output, k)
}
}
for i, v := range bad {
_, err := Key([]byte(v.password), []byte(v.salt), v.N, v.r, v.p, 32)
if err == nil {
t.Errorf("%d: expected error, got nil", i)
}
}
}
func BenchmarkKey(b *testing.B) {
for i := 0; i < b.N; i++ {
Key([]byte("password"), []byte("salt"), 16384, 8, 1, 64)
}
}

View File

@ -0,0 +1,98 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"crypto/tls"
"io"
"net"
"net/http"
"net/url"
)
// DialError is an error that occurs while dialling a websocket server.
type DialError struct {
*Config
Err error
}
func (e *DialError) Error() string {
return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
}
// NewConfig creates a new WebSocket config for client connection.
func NewConfig(server, origin string) (config *Config, err error) {
config = new(Config)
config.Version = ProtocolVersionHybi13
config.Location, err = url.ParseRequestURI(server)
if err != nil {
return
}
config.Origin, err = url.ParseRequestURI(origin)
if err != nil {
return
}
config.Header = http.Header(make(map[string][]string))
return
}
// NewClient creates a new WebSocket client connection over rwc.
func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
err = hybiClientHandshake(config, br, bw)
if err != nil {
return
}
buf := bufio.NewReadWriter(br, bw)
ws = newHybiClientConn(config, buf, rwc)
return
}
// Dial opens a new client connection to a WebSocket.
func Dial(url_, protocol, origin string) (ws *Conn, err error) {
config, err := NewConfig(url_, origin)
if err != nil {
return nil, err
}
if protocol != "" {
config.Protocol = []string{protocol}
}
return DialConfig(config)
}
// DialConfig opens a new client connection to a WebSocket with a config.
func DialConfig(config *Config) (ws *Conn, err error) {
var client net.Conn
if config.Location == nil {
return nil, &DialError{config, ErrBadWebSocketLocation}
}
if config.Origin == nil {
return nil, &DialError{config, ErrBadWebSocketOrigin}
}
switch config.Location.Scheme {
case "ws":
client, err = net.Dial("tcp", config.Location.Host)
case "wss":
client, err = tls.Dial("tcp", config.Location.Host, config.TlsConfig)
default:
err = ErrBadScheme
}
if err != nil {
goto Error
}
ws, err = NewClient(config, client)
if err != nil {
goto Error
}
return
Error:
return nil, &DialError{config, err}
}

View File

@ -0,0 +1,31 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
import (
"fmt"
"log"
"code.google.com/p/go.net/websocket"
)
// This example demonstrates a trivial client.
func ExampleDial() {
origin := "http://localhost/"
url := "ws://localhost:12345/ws"
ws, err := websocket.Dial(url, "", origin)
if err != nil {
log.Fatal(err)
}
if _, err := ws.Write([]byte("hello, world!\n")); err != nil {
log.Fatal(err)
}
var msg = make([]byte, 512)
var n int
if n, err = ws.Read(msg); err != nil {
log.Fatal(err)
}
fmt.Printf("Received: %s.\n", msg[:n])
}

View File

@ -0,0 +1,26 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket_test
import (
"io"
"net/http"
"code.google.com/p/go.net/websocket"
)
// Echo the data received on the WebSocket.
func EchoServer(ws *websocket.Conn) {
io.Copy(ws, ws)
}
// This example demonstrates a trivial echo server.
func ExampleHandler() {
http.Handle("/echo", websocket.Handler(EchoServer))
err := http.ListenAndServe(":12345", nil)
if err != nil {
panic("ListenAndServe: " + err.Error())
}
}

View File

@ -0,0 +1,564 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
// This file implements a protocol of hybi draft.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
const (
websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
closeStatusNormal = 1000
closeStatusGoingAway = 1001
closeStatusProtocolError = 1002
closeStatusUnsupportedData = 1003
closeStatusFrameTooLarge = 1004
closeStatusNoStatusRcvd = 1005
closeStatusAbnormalClosure = 1006
closeStatusBadMessageData = 1007
closeStatusPolicyViolation = 1008
closeStatusTooBigData = 1009
closeStatusExtensionMismatch = 1010
maxControlFramePayloadLength = 125
)
var (
ErrBadMaskingKey = &ProtocolError{"bad masking key"}
ErrBadPongMessage = &ProtocolError{"bad pong message"}
ErrBadClosingStatus = &ProtocolError{"bad closing status"}
ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
ErrNotImplemented = &ProtocolError{"not implemented"}
handshakeHeader = map[string]bool{
"Host": true,
"Upgrade": true,
"Connection": true,
"Sec-Websocket-Key": true,
"Sec-Websocket-Origin": true,
"Sec-Websocket-Version": true,
"Sec-Websocket-Protocol": true,
"Sec-Websocket-Accept": true,
}
)
// A hybiFrameHeader is a frame header as defined in hybi draft.
type hybiFrameHeader struct {
Fin bool
Rsv [3]bool
OpCode byte
Length int64
MaskingKey []byte
data *bytes.Buffer
}
// A hybiFrameReader is a reader for hybi frame.
type hybiFrameReader struct {
reader io.Reader
header hybiFrameHeader
pos int64
length int
}
func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
n, err = frame.reader.Read(msg)
if err != nil {
return 0, err
}
if frame.header.MaskingKey != nil {
for i := 0; i < n; i++ {
msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
frame.pos++
}
}
return n, err
}
func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
func (frame *hybiFrameReader) HeaderReader() io.Reader {
if frame.header.data == nil {
return nil
}
if frame.header.data.Len() == 0 {
return nil
}
return frame.header.data
}
func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
func (frame *hybiFrameReader) Len() (n int) { return frame.length }
// A hybiFrameReaderFactory creates new frame reader based on its frame type.
type hybiFrameReaderFactory struct {
*bufio.Reader
}
// NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
// See Section 5.2 Base Framing protocol for detail.
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
hybiFrame := new(hybiFrameReader)
frame = hybiFrame
var header []byte
var b byte
// First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
for i := 0; i < 3; i++ {
j := uint(6 - i)
hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
}
hybiFrame.header.OpCode = header[0] & 0x0f
// Second byte. Mask/Payload len(7bits)
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
mask := (b & 0x80) != 0
b &= 0x7f
lengthFields := 0
switch {
case b <= 125: // Payload length 7bits.
hybiFrame.header.Length = int64(b)
case b == 126: // Payload length 7+16bits
lengthFields = 2
case b == 127: // Payload length 7+64bits
lengthFields = 8
}
for i := 0; i < lengthFields; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
}
if mask {
// Masking key. 4 bytes.
for i := 0; i < 4; i++ {
b, err = buf.ReadByte()
if err != nil {
return
}
header = append(header, b)
hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
}
}
hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
hybiFrame.header.data = bytes.NewBuffer(header)
hybiFrame.length = len(header) + int(hybiFrame.header.Length)
return
}
// A HybiFrameWriter is a writer for hybi frame.
type hybiFrameWriter struct {
writer *bufio.Writer
header *hybiFrameHeader
}
func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
var header []byte
var b byte
if frame.header.Fin {
b |= 0x80
}
for i := 0; i < 3; i++ {
if frame.header.Rsv[i] {
j := uint(6 - i)
b |= 1 << j
}
}
b |= frame.header.OpCode
header = append(header, b)
if frame.header.MaskingKey != nil {
b = 0x80
} else {
b = 0
}
lengthFields := 0
length := len(msg)
switch {
case length <= 125:
b |= byte(length)
case length < 65536:
b |= 126
lengthFields = 2
default:
b |= 127
lengthFields = 8
}
header = append(header, b)
for i := 0; i < lengthFields; i++ {
j := uint((lengthFields - i - 1) * 8)
b = byte((length >> j) & 0xff)
header = append(header, b)
}
if frame.header.MaskingKey != nil {
if len(frame.header.MaskingKey) != 4 {
return 0, ErrBadMaskingKey
}
header = append(header, frame.header.MaskingKey...)
frame.writer.Write(header)
data := make([]byte, length)
for i := range data {
data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
}
frame.writer.Write(data)
err = frame.writer.Flush()
return length, err
}
frame.writer.Write(header)
frame.writer.Write(msg)
err = frame.writer.Flush()
return length, err
}
func (frame *hybiFrameWriter) Close() error { return nil }
type hybiFrameWriterFactory struct {
*bufio.Writer
needMaskingKey bool
}
func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
if buf.needMaskingKey {
frameHeader.MaskingKey, err = generateMaskingKey()
if err != nil {
return nil, err
}
}
return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
}
type hybiFrameHandler struct {
conn *Conn
payloadType byte
}
func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (r frameReader, err error) {
if handler.conn.IsServerConn() {
// The client MUST mask all frames sent to the server.
if frame.(*hybiFrameReader).header.MaskingKey == nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
} else {
// The server MUST NOT mask all frames.
if frame.(*hybiFrameReader).header.MaskingKey != nil {
handler.WriteClose(closeStatusProtocolError)
return nil, io.EOF
}
}
if header := frame.HeaderReader(); header != nil {
io.Copy(ioutil.Discard, header)
}
switch frame.PayloadType() {
case ContinuationFrame:
frame.(*hybiFrameReader).header.OpCode = handler.payloadType
case TextFrame, BinaryFrame:
handler.payloadType = frame.PayloadType()
case CloseFrame:
return nil, io.EOF
case PingFrame:
pingMsg := make([]byte, maxControlFramePayloadLength)
n, err := io.ReadFull(frame, pingMsg)
if err != nil && err != io.ErrUnexpectedEOF {
return nil, err
}
io.Copy(ioutil.Discard, frame)
n, err = handler.WritePong(pingMsg[:n])
if err != nil {
return nil, err
}
return nil, nil
case PongFrame:
return nil, ErrNotImplemented
}
return frame, nil
}
func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
if err != nil {
return err
}
msg := make([]byte, 2)
binary.BigEndian.PutUint16(msg, uint16(status))
_, err = w.Write(msg)
w.Close()
return err
}
func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
handler.conn.wio.Lock()
defer handler.conn.wio.Unlock()
w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
return n, err
}
// newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
if buf == nil {
br := bufio.NewReader(rwc)
bw := bufio.NewWriter(rwc)
buf = bufio.NewReadWriter(br, bw)
}
ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
frameWriterFactory: hybiFrameWriterFactory{
buf.Writer, request == nil},
PayloadType: TextFrame,
defaultCloseStatus: closeStatusNormal}
ws.frameHandler = &hybiFrameHandler{conn: ws}
return ws
}
// generateMaskingKey generates a masking key for a frame.
func generateMaskingKey() (maskingKey []byte, err error) {
maskingKey = make([]byte, 4)
if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
return
}
return
}
// generateNonce generates a nonce consisting of a randomly selected 16-byte
// value that has been base64-encoded.
func generateNonce() (nonce []byte) {
key := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
panic(err)
}
nonce = make([]byte, 24)
base64.StdEncoding.Encode(nonce, key)
return
}
// getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
// the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
func getNonceAccept(nonce []byte) (expected []byte, err error) {
h := sha1.New()
if _, err = h.Write(nonce); err != nil {
return
}
if _, err = h.Write([]byte(websocketGUID)); err != nil {
return
}
expected = make([]byte, 28)
base64.StdEncoding.Encode(expected, h.Sum(nil))
return
}
// Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
bw.WriteString("Host: " + config.Location.Host + "\r\n")
bw.WriteString("Upgrade: websocket\r\n")
bw.WriteString("Connection: Upgrade\r\n")
nonce := generateNonce()
if config.handshakeData != nil {
nonce = []byte(config.handshakeData["key"])
}
bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
if config.Version != ProtocolVersionHybi13 {
return ErrBadProtocolVersion
}
bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
if len(config.Protocol) > 0 {
bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
err = config.Header.WriteSubset(bw, handshakeHeader)
if err != nil {
return err
}
bw.WriteString("\r\n")
if err = bw.Flush(); err != nil {
return err
}
resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
if err != nil {
return err
}
if resp.StatusCode != 101 {
return ErrBadStatus
}
if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
return ErrBadUpgrade
}
expectedAccept, err := getNonceAccept(nonce)
if err != nil {
return err
}
if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
return ErrChallengeResponse
}
if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
return ErrUnsupportedExtensions
}
offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
if offeredProtocol != "" {
protocolMatched := false
for i := 0; i < len(config.Protocol); i++ {
if config.Protocol[i] == offeredProtocol {
protocolMatched = true
break
}
}
if !protocolMatched {
return ErrBadWebSocketProtocol
}
config.Protocol = []string{offeredProtocol}
}
return nil
}
// newHybiClientConn creates a client WebSocket connection after handshake.
func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
return newHybiConn(config, buf, rwc, nil)
}
// A HybiServerHandshaker performs a server handshake using hybi draft protocol.
type hybiServerHandshaker struct {
*Config
accept []byte
}
func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
c.Version = ProtocolVersionHybi13
if req.Method != "GET" {
return http.StatusMethodNotAllowed, ErrBadRequestMethod
}
// HTTP version can be safely ignored.
if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
!strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
return http.StatusBadRequest, ErrNotWebSocket
}
key := req.Header.Get("Sec-Websocket-Key")
if key == "" {
return http.StatusBadRequest, ErrChallengeResponse
}
version := req.Header.Get("Sec-Websocket-Version")
switch version {
case "13":
c.Version = ProtocolVersionHybi13
default:
return http.StatusBadRequest, ErrBadWebSocketVersion
}
var scheme string
if req.TLS != nil {
scheme = "wss"
} else {
scheme = "ws"
}
c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
if err != nil {
return http.StatusBadRequest, err
}
protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
if protocol != "" {
protocols := strings.Split(protocol, ",")
for i := 0; i < len(protocols); i++ {
c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
}
}
c.accept, err = getNonceAccept([]byte(key))
if err != nil {
return http.StatusInternalServerError, err
}
return http.StatusSwitchingProtocols, nil
}
// Origin parses Origin header in "req".
// If origin is "null", returns (nil, nil).
func Origin(config *Config, req *http.Request) (*url.URL, error) {
var origin string
switch config.Version {
case ProtocolVersionHybi13:
origin = req.Header.Get("Origin")
}
if origin == "null" {
return nil, nil
}
return url.ParseRequestURI(origin)
}
func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
if len(c.Protocol) > 0 {
if len(c.Protocol) != 1 {
// You need choose a Protocol in Handshake func in Server.
return ErrBadWebSocketProtocol
}
}
buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
buf.WriteString("Upgrade: websocket\r\n")
buf.WriteString("Connection: Upgrade\r\n")
buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
if len(c.Protocol) > 0 {
buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
}
// TODO(ukai): send Sec-WebSocket-Extensions.
if c.Header != nil {
err := c.Header.WriteSubset(buf, handshakeHeader)
if err != nil {
return err
}
}
buf.WriteString("\r\n")
return buf.Flush()
}
func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiServerConn(c.Config, buf, rwc, request)
}
// newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
return newHybiConn(config, buf, rwc, request)
}

View File

@ -0,0 +1,590 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"testing"
)
// Test the getNonceAccept function with values in
// http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
func TestSecWebSocketAccept(t *testing.T) {
nonce := []byte("dGhlIHNhbXBsZSBub25jZQ==")
expected := []byte("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
accept, err := getNonceAccept(nonce)
if err != nil {
t.Errorf("getNonceAccept: returned error %v", err)
return
}
if !bytes.Equal(expected, accept) {
t.Errorf("getNonceAccept: expected %q got %q", expected, accept)
}
}
func TestHybiClientHandshake(t *testing.T) {
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
`))
var err error
config := new(Config)
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
if err != nil {
t.Fatal("location url", err)
}
config.Origin, err = url.ParseRequestURI("http://example.com")
if err != nil {
t.Fatal("origin url", err)
}
config.Protocol = append(config.Protocol, "chat")
config.Protocol = append(config.Protocol, "superchat")
config.Version = ProtocolVersionHybi13
config.handshakeData = map[string]string{
"key": "dGhlIHNhbXBsZSBub25jZQ==",
}
err = hybiClientHandshake(config, br, bw)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
req, err := http.ReadRequest(bufio.NewReader(b))
if err != nil {
t.Fatalf("read request: %v", err)
}
if req.Method != "GET" {
t.Errorf("request method expected GET, but got %q", req.Method)
}
if req.URL.Path != "/chat" {
t.Errorf("request path expected /chat, but got %q", req.URL.Path)
}
if req.Proto != "HTTP/1.1" {
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
}
if req.Host != "server.example.com" {
t.Errorf("request Host expected server.example.com, but got %v", req.Host)
}
var expectedHeader = map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Key": config.handshakeData["key"],
"Origin": config.Origin.String(),
"Sec-Websocket-Protocol": "chat, superchat",
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
}
for k, v := range expectedHeader {
if req.Header.Get(k) != v {
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
}
}
}
func TestHybiClientHandshakeWithHeader(t *testing.T) {
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
br := bufio.NewReader(strings.NewReader(`HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
Sec-WebSocket-Protocol: chat
`))
var err error
config := new(Config)
config.Location, err = url.ParseRequestURI("ws://server.example.com/chat")
if err != nil {
t.Fatal("location url", err)
}
config.Origin, err = url.ParseRequestURI("http://example.com")
if err != nil {
t.Fatal("origin url", err)
}
config.Protocol = append(config.Protocol, "chat")
config.Protocol = append(config.Protocol, "superchat")
config.Version = ProtocolVersionHybi13
config.Header = http.Header(make(map[string][]string))
config.Header.Add("User-Agent", "test")
config.handshakeData = map[string]string{
"key": "dGhlIHNhbXBsZSBub25jZQ==",
}
err = hybiClientHandshake(config, br, bw)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
req, err := http.ReadRequest(bufio.NewReader(b))
if err != nil {
t.Fatalf("read request: %v", err)
}
if req.Method != "GET" {
t.Errorf("request method expected GET, but got %q", req.Method)
}
if req.URL.Path != "/chat" {
t.Errorf("request path expected /chat, but got %q", req.URL.Path)
}
if req.Proto != "HTTP/1.1" {
t.Errorf("request proto expected HTTP/1.1, but got %q", req.Proto)
}
if req.Host != "server.example.com" {
t.Errorf("request Host expected server.example.com, but got %v", req.Host)
}
var expectedHeader = map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-Websocket-Key": config.handshakeData["key"],
"Origin": config.Origin.String(),
"Sec-Websocket-Protocol": "chat, superchat",
"Sec-Websocket-Version": fmt.Sprintf("%d", ProtocolVersionHybi13),
"User-Agent": "test",
}
for k, v := range expectedHeader {
if req.Header.Get(k) != v {
t.Errorf(fmt.Sprintf("%s expected %q but got %q", k, v, req.Header.Get(k)))
}
}
}
func TestHybiServerHandshake(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 13
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
if code != http.StatusSwitchingProtocols {
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code)
}
expectedProtocols := []string{"chat", "superchat"}
if fmt.Sprintf("%v", config.Protocol) != fmt.Sprintf("%v", expectedProtocols) {
t.Errorf("protocol expected %q but got %q", expectedProtocols, config.Protocol)
}
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
config.Protocol = config.Protocol[:1]
err = handshaker.AcceptHandshake(bw)
if err != nil {
t.Errorf("handshake response failed: %v", err)
}
expectedResponse := strings.Join([]string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"Sec-WebSocket-Protocol: chat",
"", ""}, "\r\n")
if b.String() != expectedResponse {
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String())
}
}
func TestHybiServerHandshakeNoSubProtocol(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Version: 13
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
if code != http.StatusSwitchingProtocols {
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code)
}
if len(config.Protocol) != 0 {
t.Errorf("len(config.Protocol) expected 0, but got %q", len(config.Protocol))
}
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
err = handshaker.AcceptHandshake(bw)
if err != nil {
t.Errorf("handshake response failed: %v", err)
}
expectedResponse := strings.Join([]string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"", ""}, "\r\n")
if b.String() != expectedResponse {
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String())
}
}
func TestHybiServerHandshakeHybiBadVersion(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 9
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != ErrBadWebSocketVersion {
t.Errorf("handshake expected err %q but got %q", ErrBadWebSocketVersion, err)
}
if code != http.StatusBadRequest {
t.Errorf("status expected %q but got %q", http.StatusBadRequest, code)
}
}
func testHybiFrame(t *testing.T, testHeader, testPayload, testMaskedPayload []byte, frameHeader *hybiFrameHeader) {
b := bytes.NewBuffer([]byte{})
frameWriterFactory := &hybiFrameWriterFactory{bufio.NewWriter(b), false}
w, _ := frameWriterFactory.NewFrameWriter(TextFrame)
w.(*hybiFrameWriter).header = frameHeader
_, err := w.Write(testPayload)
w.Close()
if err != nil {
t.Errorf("Write error %q", err)
}
var expectedFrame []byte
expectedFrame = append(expectedFrame, testHeader...)
expectedFrame = append(expectedFrame, testMaskedPayload...)
if !bytes.Equal(expectedFrame, b.Bytes()) {
t.Errorf("frame expected %q got %q", expectedFrame, b.Bytes())
}
frameReaderFactory := &hybiFrameReaderFactory{bufio.NewReader(b)}
r, err := frameReaderFactory.NewFrameReader()
if err != nil {
t.Errorf("Read error %q", err)
}
if header := r.HeaderReader(); header == nil {
t.Errorf("no header")
} else {
actualHeader := make([]byte, r.Len())
n, err := header.Read(actualHeader)
if err != nil {
t.Errorf("Read header error %q", err)
} else {
if n < len(testHeader) {
t.Errorf("header too short %q got %q", testHeader, actualHeader[:n])
}
if !bytes.Equal(testHeader, actualHeader[:n]) {
t.Errorf("header expected %q got %q", testHeader, actualHeader[:n])
}
}
}
if trailer := r.TrailerReader(); trailer != nil {
t.Errorf("unexpected trailer %q", trailer)
}
frame := r.(*hybiFrameReader)
if frameHeader.Fin != frame.header.Fin ||
frameHeader.OpCode != frame.header.OpCode ||
len(testPayload) != int(frame.header.Length) {
t.Errorf("mismatch %v (%d) vs %v", frameHeader, len(testPayload), frame)
}
payload := make([]byte, len(testPayload))
_, err = r.Read(payload)
if err != nil {
t.Errorf("read %v", err)
}
if !bytes.Equal(testPayload, payload) {
t.Errorf("payload %q vs %q", testPayload, payload)
}
}
func TestHybiShortTextFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame}
payload := []byte("hello")
testHybiFrame(t, []byte{0x81, 0x05}, payload, payload, frameHeader)
payload = make([]byte, 125)
testHybiFrame(t, []byte{0x81, 125}, payload, payload, frameHeader)
}
func TestHybiShortMaskedTextFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame,
MaskingKey: []byte{0xcc, 0x55, 0x80, 0x20}}
payload := []byte("hello")
maskedPayload := []byte{0xa4, 0x30, 0xec, 0x4c, 0xa3}
header := []byte{0x81, 0x85}
header = append(header, frameHeader.MaskingKey...)
testHybiFrame(t, header, payload, maskedPayload, frameHeader)
}
func TestHybiShortBinaryFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: BinaryFrame}
payload := []byte("hello")
testHybiFrame(t, []byte{0x82, 0x05}, payload, payload, frameHeader)
payload = make([]byte, 125)
testHybiFrame(t, []byte{0x82, 125}, payload, payload, frameHeader)
}
func TestHybiControlFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: PingFrame}
payload := []byte("hello")
testHybiFrame(t, []byte{0x89, 0x05}, payload, payload, frameHeader)
frameHeader = &hybiFrameHeader{Fin: true, OpCode: PongFrame}
testHybiFrame(t, []byte{0x8A, 0x05}, payload, payload, frameHeader)
frameHeader = &hybiFrameHeader{Fin: true, OpCode: CloseFrame}
payload = []byte{0x03, 0xe8} // 1000
testHybiFrame(t, []byte{0x88, 0x02}, payload, payload, frameHeader)
}
func TestHybiLongFrame(t *testing.T) {
frameHeader := &hybiFrameHeader{Fin: true, OpCode: TextFrame}
payload := make([]byte, 126)
testHybiFrame(t, []byte{0x81, 126, 0x00, 126}, payload, payload, frameHeader)
payload = make([]byte, 65535)
testHybiFrame(t, []byte{0x81, 126, 0xff, 0xff}, payload, payload, frameHeader)
payload = make([]byte, 65536)
testHybiFrame(t, []byte{0x81, 127, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00}, payload, payload, frameHeader)
}
func TestHybiClientRead(t *testing.T) {
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o',
0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping
0x81, 0x05, 'w', 'o', 'r', 'l', 'd'}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil)
msg := make([]byte, 512)
n, err := conn.Read(msg)
if err != nil {
t.Errorf("read 1st frame, error %q", err)
}
if n != 5 {
t.Errorf("read 1st frame, expect 5, got %d", n)
}
if !bytes.Equal(wireData[2:7], msg[:n]) {
t.Errorf("read 1st frame %v, got %v", wireData[2:7], msg[:n])
}
n, err = conn.Read(msg)
if err != nil {
t.Errorf("read 2nd frame, error %q", err)
}
if n != 5 {
t.Errorf("read 2nd frame, expect 5, got %d", n)
}
if !bytes.Equal(wireData[16:21], msg[:n]) {
t.Errorf("read 2nd frame %v, got %v", wireData[16:21], msg[:n])
}
n, err = conn.Read(msg)
if err == nil {
t.Errorf("read not EOF")
}
if n != 0 {
t.Errorf("expect read 0, got %d", n)
}
}
func TestHybiShortRead(t *testing.T) {
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o',
0x89, 0x05, 'h', 'e', 'l', 'l', 'o', // ping
0x81, 0x05, 'w', 'o', 'r', 'l', 'd'}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil)
step := 0
pos := 0
expectedPos := []int{2, 5, 16, 19}
expectedLen := []int{3, 2, 3, 2}
for {
msg := make([]byte, 3)
n, err := conn.Read(msg)
if step >= len(expectedPos) {
if err == nil {
t.Errorf("read not EOF")
}
if n != 0 {
t.Errorf("expect read 0, got %d", n)
}
return
}
pos = expectedPos[step]
endPos := pos + expectedLen[step]
if err != nil {
t.Errorf("read from %d, got error %q", pos, err)
return
}
if n != endPos-pos {
t.Errorf("read from %d, expect %d, got %d", pos, endPos-pos, n)
}
if !bytes.Equal(wireData[pos:endPos], msg[:n]) {
t.Errorf("read from %d, frame %v, got %v", pos, wireData[pos:endPos], msg[:n])
}
step++
}
}
func TestHybiServerRead(t *testing.T) {
wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20,
0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello
0x89, 0x85, 0xcc, 0x55, 0x80, 0x20,
0xa4, 0x30, 0xec, 0x4c, 0xa3, // ping: hello
0x81, 0x85, 0xed, 0x83, 0xb4, 0x24,
0x9a, 0xec, 0xc6, 0x48, 0x89, // world
}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request))
expected := [][]byte{[]byte("hello"), []byte("world")}
msg := make([]byte, 512)
n, err := conn.Read(msg)
if err != nil {
t.Errorf("read 1st frame, error %q", err)
}
if n != 5 {
t.Errorf("read 1st frame, expect 5, got %d", n)
}
if !bytes.Equal(expected[0], msg[:n]) {
t.Errorf("read 1st frame %q, got %q", expected[0], msg[:n])
}
n, err = conn.Read(msg)
if err != nil {
t.Errorf("read 2nd frame, error %q", err)
}
if n != 5 {
t.Errorf("read 2nd frame, expect 5, got %d", n)
}
if !bytes.Equal(expected[1], msg[:n]) {
t.Errorf("read 2nd frame %q, got %q", expected[1], msg[:n])
}
n, err = conn.Read(msg)
if err == nil {
t.Errorf("read not EOF")
}
if n != 0 {
t.Errorf("expect read 0, got %d", n)
}
}
func TestHybiServerReadWithoutMasking(t *testing.T) {
wireData := []byte{0x81, 0x05, 'h', 'e', 'l', 'l', 'o'}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, new(http.Request))
// server MUST close the connection upon receiving a non-masked frame.
msg := make([]byte, 512)
_, err := conn.Read(msg)
if err != io.EOF {
t.Errorf("read 1st frame, expect %q, but got %q", io.EOF, err)
}
}
func TestHybiClientReadWithMasking(t *testing.T) {
wireData := []byte{0x81, 0x85, 0xcc, 0x55, 0x80, 0x20,
0xa4, 0x30, 0xec, 0x4c, 0xa3, // hello
}
br := bufio.NewReader(bytes.NewBuffer(wireData))
bw := bufio.NewWriter(bytes.NewBuffer([]byte{}))
conn := newHybiConn(newConfig(t, "/"), bufio.NewReadWriter(br, bw), nil, nil)
// client MUST close the connection upon receiving a masked frame.
msg := make([]byte, 512)
_, err := conn.Read(msg)
if err != io.EOF {
t.Errorf("read 1st frame, expect %q, but got %q", io.EOF, err)
}
}
// Test the hybiServerHandshaker supports firefox implementation and
// checks Connection request header include (but it's not necessary
// equal to) "upgrade"
func TestHybiServerFirefoxHandshake(t *testing.T) {
config := new(Config)
handshaker := &hybiServerHandshaker{Config: config}
br := bufio.NewReader(strings.NewReader(`GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: keep-alive, upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 13
`))
req, err := http.ReadRequest(br)
if err != nil {
t.Fatal("request", err)
}
code, err := handshaker.ReadHandshake(br, req)
if err != nil {
t.Errorf("handshake failed: %v", err)
}
if code != http.StatusSwitchingProtocols {
t.Errorf("status expected %q but got %q", http.StatusSwitchingProtocols, code)
}
b := bytes.NewBuffer([]byte{})
bw := bufio.NewWriter(b)
config.Protocol = []string{"chat"}
err = handshaker.AcceptHandshake(bw)
if err != nil {
t.Errorf("handshake response failed: %v", err)
}
expectedResponse := strings.Join([]string{
"HTTP/1.1 101 Switching Protocols",
"Upgrade: websocket",
"Connection: Upgrade",
"Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
"Sec-WebSocket-Protocol: chat",
"", ""}, "\r\n")
if b.String() != expectedResponse {
t.Errorf("handshake expected %q but got %q", expectedResponse, b.String())
}
}

View File

@ -0,0 +1,114 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"fmt"
"io"
"net/http"
)
func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
var hs serverHandshaker = &hybiServerHandshaker{Config: config}
code, err := hs.ReadHandshake(buf.Reader, req)
if err == ErrBadWebSocketVersion {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if err != nil {
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.WriteString(err.Error())
buf.Flush()
return
}
if handshake != nil {
err = handshake(config, req)
if err != nil {
code = http.StatusForbidden
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
}
err = hs.AcceptHandshake(buf.Writer)
if err != nil {
code = http.StatusBadRequest
fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
buf.WriteString("\r\n")
buf.Flush()
return
}
conn = hs.NewServerConn(buf, rwc, req)
return
}
// Server represents a server of a WebSocket.
type Server struct {
// Config is a WebSocket configuration for new WebSocket connection.
Config
// Handshake is an optional function in WebSocket handshake.
// For example, you can check, or don't check Origin header.
// Another example, you can select config.Protocol.
Handshake func(*Config, *http.Request) error
// Handler handles a WebSocket connection.
Handler
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s.serveWebSocket(w, req)
}
func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
rwc, buf, err := w.(http.Hijacker).Hijack()
if err != nil {
panic("Hijack failed: " + err.Error())
return
}
// The server should abort the WebSocket connection if it finds
// the client did not send a handshake that matches with protocol
// specification.
defer rwc.Close()
conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
if err != nil {
return
}
if conn == nil {
panic("unexpected nil conn")
}
s.Handler(conn)
}
// Handler is a simple interface to a WebSocket browser client.
// It checks if Origin header is valid URL by default.
// You might want to verify websocket.Conn.Config().Origin in the func.
// If you use Server instead of Handler, you could call websocket.Origin and
// check the origin in your Handshake func. So, if you want to accept
// non-browser client, which doesn't send Origin header, you could use Server
//. that doesn't check origin in its Handshake.
type Handler func(*Conn)
func checkOrigin(config *Config, req *http.Request) (err error) {
config.Origin, err = Origin(config, req)
if err == nil && config.Origin == nil {
return fmt.Errorf("null origin")
}
return err
}
// ServeHTTP implements the http.Handler interface for a WebSocket
func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
s := Server{Handler: h, Handshake: checkOrigin}
s.serveWebSocket(w, req)
}

View File

@ -0,0 +1,411 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements a client and server for the WebSocket protocol
// as specified in RFC 6455.
package websocket
import (
"bufio"
"crypto/tls"
"encoding/json"
"errors"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"sync"
"time"
)
const (
ProtocolVersionHybi13 = 13
ProtocolVersionHybi = ProtocolVersionHybi13
SupportedProtocolVersion = "13"
ContinuationFrame = 0
TextFrame = 1
BinaryFrame = 2
CloseFrame = 8
PingFrame = 9
PongFrame = 10
UnknownFrame = 255
)
// ProtocolError represents WebSocket protocol errors.
type ProtocolError struct {
ErrorString string
}
func (err *ProtocolError) Error() string { return err.ErrorString }
var (
ErrBadProtocolVersion = &ProtocolError{"bad protocol version"}
ErrBadScheme = &ProtocolError{"bad scheme"}
ErrBadStatus = &ProtocolError{"bad status"}
ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"}
ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"}
ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
ErrBadWebSocketVersion = &ProtocolError{"missing or bad WebSocket Version"}
ErrChallengeResponse = &ProtocolError{"mismatch challenge/response"}
ErrBadFrame = &ProtocolError{"bad frame"}
ErrBadFrameBoundary = &ProtocolError{"not on frame boundary"}
ErrNotWebSocket = &ProtocolError{"not websocket protocol"}
ErrBadRequestMethod = &ProtocolError{"bad method"}
ErrNotSupported = &ProtocolError{"not supported"}
)
// Addr is an implementation of net.Addr for WebSocket.
type Addr struct {
*url.URL
}
// Network returns the network type for a WebSocket, "websocket".
func (addr *Addr) Network() string { return "websocket" }
// Config is a WebSocket configuration
type Config struct {
// A WebSocket server address.
Location *url.URL
// A Websocket client origin.
Origin *url.URL
// WebSocket subprotocols.
Protocol []string
// WebSocket protocol version.
Version int
// TLS config for secure WebSocket (wss).
TlsConfig *tls.Config
// Additional header fields to be sent in WebSocket opening handshake.
Header http.Header
handshakeData map[string]string
}
// serverHandshaker is an interface to handle WebSocket server side handshake.
type serverHandshaker interface {
// ReadHandshake reads handshake request message from client.
// Returns http response code and error if any.
ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
// AcceptHandshake accepts the client handshake request and sends
// handshake response back to client.
AcceptHandshake(buf *bufio.Writer) (err error)
// NewServerConn creates a new WebSocket connection.
NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
}
// frameReader is an interface to read a WebSocket frame.
type frameReader interface {
// Reader is to read payload of the frame.
io.Reader
// PayloadType returns payload type.
PayloadType() byte
// HeaderReader returns a reader to read header of the frame.
HeaderReader() io.Reader
// TrailerReader returns a reader to read trailer of the frame.
// If it returns nil, there is no trailer in the frame.
TrailerReader() io.Reader
// Len returns total length of the frame, including header and trailer.
Len() int
}
// frameReaderFactory is an interface to creates new frame reader.
type frameReaderFactory interface {
NewFrameReader() (r frameReader, err error)
}
// frameWriter is an interface to write a WebSocket frame.
type frameWriter interface {
// Writer is to write payload of the frame.
io.WriteCloser
}
// frameWriterFactory is an interface to create new frame writer.
type frameWriterFactory interface {
NewFrameWriter(payloadType byte) (w frameWriter, err error)
}
type frameHandler interface {
HandleFrame(frame frameReader) (r frameReader, err error)
WriteClose(status int) (err error)
}
// Conn represents a WebSocket connection.
type Conn struct {
config *Config
request *http.Request
buf *bufio.ReadWriter
rwc io.ReadWriteCloser
rio sync.Mutex
frameReaderFactory
frameReader
wio sync.Mutex
frameWriterFactory
frameHandler
PayloadType byte
defaultCloseStatus int
}
// Read implements the io.Reader interface:
// it reads data of a frame from the WebSocket connection.
// if msg is not large enough for the frame data, it fills the msg and next Read
// will read the rest of the frame data.
// it reads Text frame or Binary frame.
func (ws *Conn) Read(msg []byte) (n int, err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
again:
if ws.frameReader == nil {
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return 0, err
}
ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return 0, err
}
if ws.frameReader == nil {
goto again
}
}
n, err = ws.frameReader.Read(msg)
if err == io.EOF {
if trailer := ws.frameReader.TrailerReader(); trailer != nil {
io.Copy(ioutil.Discard, trailer)
}
ws.frameReader = nil
goto again
}
return n, err
}
// Write implements the io.Writer interface:
// it writes data as a frame to the WebSocket connection.
func (ws *Conn) Write(msg []byte) (n int, err error) {
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
if err != nil {
return 0, err
}
n, err = w.Write(msg)
w.Close()
if err != nil {
return n, err
}
return n, err
}
// Close implements the io.Closer interface.
func (ws *Conn) Close() error {
err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
if err != nil {
return err
}
return ws.rwc.Close()
}
func (ws *Conn) IsClientConn() bool { return ws.request == nil }
func (ws *Conn) IsServerConn() bool { return ws.request != nil }
// LocalAddr returns the WebSocket Origin for the connection for client, or
// the WebSocket location for server.
func (ws *Conn) LocalAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Origin}
}
return &Addr{ws.config.Location}
}
// RemoteAddr returns the WebSocket location for the connection for client, or
// the Websocket Origin for server.
func (ws *Conn) RemoteAddr() net.Addr {
if ws.IsClientConn() {
return &Addr{ws.config.Location}
}
return &Addr{ws.config.Origin}
}
var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
// SetDeadline sets the connection's network read & write deadlines.
func (ws *Conn) SetDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetDeadline(t)
}
return errSetDeadline
}
// SetReadDeadline sets the connection's network read deadline.
func (ws *Conn) SetReadDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetReadDeadline(t)
}
return errSetDeadline
}
// SetWriteDeadline sets the connection's network write deadline.
func (ws *Conn) SetWriteDeadline(t time.Time) error {
if conn, ok := ws.rwc.(net.Conn); ok {
return conn.SetWriteDeadline(t)
}
return errSetDeadline
}
// Config returns the WebSocket config.
func (ws *Conn) Config() *Config { return ws.config }
// Request returns the http request upgraded to the WebSocket.
// It is nil for client side.
func (ws *Conn) Request() *http.Request { return ws.request }
// Codec represents a symmetric pair of functions that implement a codec.
type Codec struct {
Marshal func(v interface{}) (data []byte, payloadType byte, err error)
Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
}
// Send sends v marshaled by cd.Marshal as single frame to ws.
func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
data, payloadType, err := cd.Marshal(v)
if err != nil {
return err
}
ws.wio.Lock()
defer ws.wio.Unlock()
w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
if err != nil {
return err
}
_, err = w.Write(data)
w.Close()
return err
}
// Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores in v.
func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
ws.rio.Lock()
defer ws.rio.Unlock()
if ws.frameReader != nil {
_, err = io.Copy(ioutil.Discard, ws.frameReader)
if err != nil {
return err
}
ws.frameReader = nil
}
again:
frame, err := ws.frameReaderFactory.NewFrameReader()
if err != nil {
return err
}
frame, err = ws.frameHandler.HandleFrame(frame)
if err != nil {
return err
}
if frame == nil {
goto again
}
payloadType := frame.PayloadType()
data, err := ioutil.ReadAll(frame)
if err != nil {
return err
}
return cd.Unmarshal(data, payloadType, v)
}
func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
switch data := v.(type) {
case string:
return []byte(data), TextFrame, nil
case []byte:
return data, BinaryFrame, nil
}
return nil, UnknownFrame, ErrNotSupported
}
func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
switch data := v.(type) {
case *string:
*data = string(msg)
return nil
case *[]byte:
*data = msg
return nil
}
return ErrNotSupported
}
/*
Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
To send/receive text frame, use string type.
To send/receive binary frame, use []byte type.
Trivial usage:
import "websocket"
// receive text frame
var message string
websocket.Message.Receive(ws, &message)
// send text frame
message = "hello"
websocket.Message.Send(ws, message)
// receive binary frame
var data []byte
websocket.Message.Receive(ws, &data)
// send binary frame
data = []byte{0, 1, 2}
websocket.Message.Send(ws, data)
*/
var Message = Codec{marshal, unmarshal}
func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
msg, err = json.Marshal(v)
return msg, TextFrame, err
}
func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
return json.Unmarshal(msg, v)
}
/*
JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
Trivial usage:
import "websocket"
type T struct {
Msg string
Count int
}
// receive JSON type T
var data T
websocket.JSON.Receive(ws, &data)
// send JSON type T
websocket.JSON.Send(ws, data)
*/
var JSON = Codec{jsonMarshal, jsonUnmarshal}

View File

@ -0,0 +1,341 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync"
"testing"
)
var serverAddr string
var once sync.Once
func echoServer(ws *Conn) { io.Copy(ws, ws) }
type Count struct {
S string
N int
}
func countServer(ws *Conn) {
for {
var count Count
err := JSON.Receive(ws, &count)
if err != nil {
return
}
count.N++
count.S = strings.Repeat(count.S, count.N)
err = JSON.Send(ws, count)
if err != nil {
return
}
}
}
func subProtocolHandshake(config *Config, req *http.Request) error {
for _, proto := range config.Protocol {
if proto == "chat" {
config.Protocol = []string{proto}
return nil
}
}
return ErrBadWebSocketProtocol
}
func subProtoServer(ws *Conn) {
for _, proto := range ws.Config().Protocol {
io.WriteString(ws, proto)
}
}
func startServer() {
http.Handle("/echo", Handler(echoServer))
http.Handle("/count", Handler(countServer))
subproto := Server{
Handshake: subProtocolHandshake,
Handler: Handler(subProtoServer),
}
http.Handle("/subproto", subproto)
server := httptest.NewServer(nil)
serverAddr = server.Listener.Addr().String()
log.Print("Test WebSocket server listening on ", serverAddr)
}
func newConfig(t *testing.T, path string) *Config {
config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
return config
}
func TestEcho(t *testing.T) {
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/echo"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
msg := []byte("hello, world\n")
if _, err := conn.Write(msg); err != nil {
t.Errorf("Write: %v", err)
}
var actual_msg = make([]byte, 512)
n, err := conn.Read(actual_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
actual_msg = actual_msg[0:n]
if !bytes.Equal(msg, actual_msg) {
t.Errorf("Echo: expected %q got %q", msg, actual_msg)
}
conn.Close()
}
func TestAddr(t *testing.T) {
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/echo"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
ra := conn.RemoteAddr().String()
if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
t.Errorf("Bad remote addr: %v", ra)
}
la := conn.LocalAddr().String()
if !strings.HasPrefix(la, "http://") {
t.Errorf("Bad local addr: %v", la)
}
conn.Close()
}
func TestCount(t *testing.T) {
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/count"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
var count Count
count.S = "hello"
if err := JSON.Send(conn, count); err != nil {
t.Errorf("Write: %v", err)
}
if err := JSON.Receive(conn, &count); err != nil {
t.Errorf("Read: %v", err)
}
if count.N != 1 {
t.Errorf("count: expected %d got %d", 1, count.N)
}
if count.S != "hello" {
t.Errorf("count: expected %q got %q", "hello", count.S)
}
if err := JSON.Send(conn, count); err != nil {
t.Errorf("Write: %v", err)
}
if err := JSON.Receive(conn, &count); err != nil {
t.Errorf("Read: %v", err)
}
if count.N != 2 {
t.Errorf("count: expected %d got %d", 2, count.N)
}
if count.S != "hellohello" {
t.Errorf("count: expected %q got %q", "hellohello", count.S)
}
conn.Close()
}
func TestWithQuery(t *testing.T) {
once.Do(startServer)
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
config := newConfig(t, "/echo")
config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
if err != nil {
t.Fatal("location url", err)
}
ws, err := NewClient(config, client)
if err != nil {
t.Errorf("WebSocket handshake: %v", err)
return
}
ws.Close()
}
func testWithProtocol(t *testing.T, subproto []string) (string, error) {
once.Do(startServer)
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
config := newConfig(t, "/subproto")
config.Protocol = subproto
ws, err := NewClient(config, client)
if err != nil {
return "", err
}
msg := make([]byte, 16)
n, err := ws.Read(msg)
if err != nil {
return "", err
}
ws.Close()
return string(msg[:n]), nil
}
func TestWithProtocol(t *testing.T) {
proto, err := testWithProtocol(t, []string{"chat"})
if err != nil {
t.Errorf("SubProto: unexpected error: %v", err)
}
if proto != "chat" {
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
}
}
func TestWithTwoProtocol(t *testing.T) {
proto, err := testWithProtocol(t, []string{"test", "chat"})
if err != nil {
t.Errorf("SubProto: unexpected error: %v", err)
}
if proto != "chat" {
t.Errorf("SubProto: expected %q, got %q", "chat", proto)
}
}
func TestWithBadProtocol(t *testing.T) {
_, err := testWithProtocol(t, []string{"test"})
if err != ErrBadStatus {
t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
}
}
func TestHTTP(t *testing.T) {
once.Do(startServer)
// If the client did not send a handshake that matches the protocol
// specification, the server MUST return an HTTP response with an
// appropriate error code (such as 400 Bad Request)
resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
if err != nil {
t.Errorf("Get: error %#v", err)
return
}
if resp == nil {
t.Error("Get: resp is null")
return
}
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
}
}
func TestTrailingSpaces(t *testing.T) {
// http://code.google.com/p/go/issues/detail?id=955
// The last runs of this create keys with trailing spaces that should not be
// generated by the client.
once.Do(startServer)
config := newConfig(t, "/echo")
for i := 0; i < 30; i++ {
// body
ws, err := DialConfig(config)
if err != nil {
t.Errorf("Dial #%d failed: %v", i, err)
break
}
ws.Close()
}
}
func TestDialConfigBadVersion(t *testing.T) {
once.Do(startServer)
config := newConfig(t, "/echo")
config.Version = 1234
_, err := DialConfig(config)
if dialerr, ok := err.(*DialError); ok {
if dialerr.Err != ErrBadProtocolVersion {
t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
}
}
}
func TestSmallBuffer(t *testing.T) {
// http://code.google.com/p/go/issues/detail?id=1145
// Read should be able to handle reading a fragment of a frame.
once.Do(startServer)
// websocket.Dial()
client, err := net.Dial("tcp", serverAddr)
if err != nil {
t.Fatal("dialing", err)
}
conn, err := NewClient(newConfig(t, "/echo"), client)
if err != nil {
t.Errorf("WebSocket handshake error: %v", err)
return
}
msg := []byte("hello, world\n")
if _, err := conn.Write(msg); err != nil {
t.Errorf("Write: %v", err)
}
var small_msg = make([]byte, 8)
n, err := conn.Read(small_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
if !bytes.Equal(msg[:len(small_msg)], small_msg) {
t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
}
var second_msg = make([]byte, len(msg))
n, err = conn.Read(second_msg)
if err != nil {
t.Errorf("Read: %v", err)
}
second_msg = second_msg[0:n]
if !bytes.Equal(msg[len(small_msg):], second_msg) {
t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
}
conn.Close()
}

View File

@ -0,0 +1,124 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
"errors"
)
// ErrCorrupt reports that the input is invalid.
var ErrCorrupt = errors.New("snappy: corrupt input")
// DecodedLen returns the length of the decoded block.
func DecodedLen(src []byte) (int, error) {
v, _, err := decodedLen(src)
return v, err
}
// decodedLen returns the length of the decoded block and the number of bytes
// that the length header occupied.
func decodedLen(src []byte) (blockLen, headerLen int, err error) {
v, n := binary.Uvarint(src)
if n == 0 {
return 0, 0, ErrCorrupt
}
if uint64(int(v)) != v {
return 0, 0, errors.New("snappy: decoded block is too large")
}
return int(v), n, nil
}
// Decode returns the decoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire decoded block.
// Otherwise, a newly allocated slice will be returned.
// It is valid to pass a nil dst.
func Decode(dst, src []byte) ([]byte, error) {
dLen, s, err := decodedLen(src)
if err != nil {
return nil, err
}
if len(dst) < dLen {
dst = make([]byte, dLen)
}
var d, offset, length int
for s < len(src) {
switch src[s] & 0x03 {
case tagLiteral:
x := uint(src[s] >> 2)
switch {
case x < 60:
s += 1
case x == 60:
s += 2
if s > len(src) {
return nil, ErrCorrupt
}
x = uint(src[s-1])
case x == 61:
s += 3
if s > len(src) {
return nil, ErrCorrupt
}
x = uint(src[s-2]) | uint(src[s-1])<<8
case x == 62:
s += 4
if s > len(src) {
return nil, ErrCorrupt
}
x = uint(src[s-3]) | uint(src[s-2])<<8 | uint(src[s-1])<<16
case x == 63:
s += 5
if s > len(src) {
return nil, ErrCorrupt
}
x = uint(src[s-4]) | uint(src[s-3])<<8 | uint(src[s-2])<<16 | uint(src[s-1])<<24
}
length = int(x + 1)
if length <= 0 {
return nil, errors.New("snappy: unsupported literal length")
}
if length > len(dst)-d || length > len(src)-s {
return nil, ErrCorrupt
}
copy(dst[d:], src[s:s+length])
d += length
s += length
continue
case tagCopy1:
s += 2
if s > len(src) {
return nil, ErrCorrupt
}
length = 4 + int(src[s-2])>>2&0x7
offset = int(src[s-2])&0xe0<<3 | int(src[s-1])
case tagCopy2:
s += 3
if s > len(src) {
return nil, ErrCorrupt
}
length = 1 + int(src[s-3])>>2
offset = int(src[s-2]) | int(src[s-1])<<8
case tagCopy4:
return nil, errors.New("snappy: unsupported COPY_4 tag")
}
end := d + length
if offset > d || end > len(dst) {
return nil, ErrCorrupt
}
for ; d < end; d++ {
dst[d] = dst[d-offset]
}
}
if d != dLen {
return nil, ErrCorrupt
}
return dst[:d], nil
}

View File

@ -0,0 +1,174 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"encoding/binary"
)
// We limit how far copy back-references can go, the same as the C++ code.
const maxOffset = 1 << 15
// emitLiteral writes a literal chunk and returns the number of bytes written.
func emitLiteral(dst, lit []byte) int {
i, n := 0, uint(len(lit)-1)
switch {
case n < 60:
dst[0] = uint8(n)<<2 | tagLiteral
i = 1
case n < 1<<8:
dst[0] = 60<<2 | tagLiteral
dst[1] = uint8(n)
i = 2
case n < 1<<16:
dst[0] = 61<<2 | tagLiteral
dst[1] = uint8(n)
dst[2] = uint8(n >> 8)
i = 3
case n < 1<<24:
dst[0] = 62<<2 | tagLiteral
dst[1] = uint8(n)
dst[2] = uint8(n >> 8)
dst[3] = uint8(n >> 16)
i = 4
case int64(n) < 1<<32:
dst[0] = 63<<2 | tagLiteral
dst[1] = uint8(n)
dst[2] = uint8(n >> 8)
dst[3] = uint8(n >> 16)
dst[4] = uint8(n >> 24)
i = 5
default:
panic("snappy: source buffer is too long")
}
if copy(dst[i:], lit) != len(lit) {
panic("snappy: destination buffer is too short")
}
return i + len(lit)
}
// emitCopy writes a copy chunk and returns the number of bytes written.
func emitCopy(dst []byte, offset, length int) int {
i := 0
for length > 0 {
x := length - 4
if 0 <= x && x < 1<<3 && offset < 1<<11 {
dst[i+0] = uint8(offset>>8)&0x07<<5 | uint8(x)<<2 | tagCopy1
dst[i+1] = uint8(offset)
i += 2
break
}
x = length
if x > 1<<6 {
x = 1 << 6
}
dst[i+0] = uint8(x-1)<<2 | tagCopy2
dst[i+1] = uint8(offset)
dst[i+2] = uint8(offset >> 8)
i += 3
length -= x
}
return i
}
// Encode returns the encoded form of src. The returned slice may be a sub-
// slice of dst if dst was large enough to hold the entire encoded block.
// Otherwise, a newly allocated slice will be returned.
// It is valid to pass a nil dst.
func Encode(dst, src []byte) ([]byte, error) {
if n := MaxEncodedLen(len(src)); len(dst) < n {
dst = make([]byte, n)
}
// The block starts with the varint-encoded length of the decompressed bytes.
d := binary.PutUvarint(dst, uint64(len(src)))
// Return early if src is short.
if len(src) <= 4 {
if len(src) != 0 {
d += emitLiteral(dst[d:], src)
}
return dst[:d], nil
}
// Initialize the hash table. Its size ranges from 1<<8 to 1<<14 inclusive.
const maxTableSize = 1 << 14
shift, tableSize := uint(32-8), 1<<8
for tableSize < maxTableSize && tableSize < len(src) {
shift--
tableSize *= 2
}
var table [maxTableSize]int
// Iterate over the source bytes.
var (
s int // The iterator position.
t int // The last position with the same hash as s.
lit int // The start position of any pending literal bytes.
)
for s+3 < len(src) {
// Update the hash table.
b0, b1, b2, b3 := src[s], src[s+1], src[s+2], src[s+3]
h := uint32(b0) | uint32(b1)<<8 | uint32(b2)<<16 | uint32(b3)<<24
p := &table[(h*0x1e35a7bd)>>shift]
// We need to to store values in [-1, inf) in table. To save
// some initialization time, (re)use the table's zero value
// and shift the values against this zero: add 1 on writes,
// subtract 1 on reads.
t, *p = *p-1, s+1
// If t is invalid or src[s:s+4] differs from src[t:t+4], accumulate a literal byte.
if t < 0 || s-t >= maxOffset || b0 != src[t] || b1 != src[t+1] || b2 != src[t+2] || b3 != src[t+3] {
s++
continue
}
// Otherwise, we have a match. First, emit any pending literal bytes.
if lit != s {
d += emitLiteral(dst[d:], src[lit:s])
}
// Extend the match to be as long as possible.
s0 := s
s, t = s+4, t+4
for s < len(src) && src[s] == src[t] {
s++
t++
}
// Emit the copied bytes.
d += emitCopy(dst[d:], s-t, s-s0)
lit = s
}
// Emit any final pending literal bytes and return.
if lit != len(src) {
d += emitLiteral(dst[d:], src[lit:])
}
return dst[:d], nil
}
// MaxEncodedLen returns the maximum length of a snappy block, given its
// uncompressed length.
func MaxEncodedLen(srcLen int) int {
// Compressed data can be defined as:
// compressed := item* literal*
// item := literal* copy
//
// The trailing literal sequence has a space blowup of at most 62/60
// since a literal of length 60 needs one tag byte + one extra byte
// for length information.
//
// Item blowup is trickier to measure. Suppose the "copy" op copies
// 4 bytes of data. Because of a special check in the encoding code,
// we produce a 4-byte copy only if the offset is < 65536. Therefore
// the copy op takes 3 bytes to encode, and this type of item leads
// to at most the 62/60 blowup for representing literals.
//
// Suppose the "copy" op copies 5 bytes of data. If the offset is big
// enough, it will take 5 bytes to encode the copy op. Therefore the
// worst case here is a one-byte literal followed by a five-byte copy.
// That is, 6 bytes of input turn into 7 bytes of "compressed" data.
//
// This last factor dominates the blowup, so the final estimate is:
return 32 + srcLen + srcLen/6
}

View File

@ -0,0 +1,38 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package snappy implements the snappy block-based compression format.
// It aims for very high speeds and reasonable compression.
//
// The C++ snappy implementation is at http://code.google.com/p/snappy/
package snappy
/*
Each encoded block begins with the varint-encoded length of the decoded data,
followed by a sequence of chunks. Chunks begin and end on byte boundaries. The
first byte of each chunk is broken into its 2 least and 6 most significant bits
called l and m: l ranges in [0, 4) and m ranges in [0, 64). l is the chunk tag.
Zero means a literal tag. All other values mean a copy tag.
For literal tags:
- If m < 60, the next 1 + m bytes are literal bytes.
- Otherwise, let n be the little-endian unsigned integer denoted by the next
m - 59 bytes. The next 1 + n bytes after that are literal bytes.
For copy tags, length bytes are copied from offset bytes ago, in the style of
Lempel-Ziv compression algorithms. In particular:
- For l == 1, the offset ranges in [0, 1<<11) and the length in [4, 12).
The length is 4 + the low 3 bits of m. The high 3 bits of m form bits 8-10
of the offset. The next byte is bits 0-7 of the offset.
- For l == 2, the offset ranges in [0, 1<<16) and the length in [1, 65).
The length is 1 + m. The offset is the little-endian unsigned integer
denoted by the next 2 bytes.
- For l == 3, this tag is a legacy format that is no longer supported.
*/
const (
tagLiteral = 0x00
tagCopy1 = 0x01
tagCopy2 = 0x02
tagCopy4 = 0x03
)

View File

@ -0,0 +1,261 @@
// Copyright 2011 The Snappy-Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package snappy
import (
"bytes"
"flag"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
)
var download = flag.Bool("download", false, "If true, download any missing files before running benchmarks")
func roundtrip(b, ebuf, dbuf []byte) error {
e, err := Encode(ebuf, b)
if err != nil {
return fmt.Errorf("encoding error: %v", err)
}
d, err := Decode(dbuf, e)
if err != nil {
return fmt.Errorf("decoding error: %v", err)
}
if !bytes.Equal(b, d) {
return fmt.Errorf("roundtrip mismatch:\n\twant %v\n\tgot %v", b, d)
}
return nil
}
func TestEmpty(t *testing.T) {
if err := roundtrip(nil, nil, nil); err != nil {
t.Fatal(err)
}
}
func TestSmallCopy(t *testing.T) {
for _, ebuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} {
for _, dbuf := range [][]byte{nil, make([]byte, 20), make([]byte, 64)} {
for i := 0; i < 32; i++ {
s := "aaaa" + strings.Repeat("b", i) + "aaaabbbb"
if err := roundtrip([]byte(s), ebuf, dbuf); err != nil {
t.Errorf("len(ebuf)=%d, len(dbuf)=%d, i=%d: %v", len(ebuf), len(dbuf), i, err)
}
}
}
}
}
func TestSmallRand(t *testing.T) {
rand.Seed(27354294)
for n := 1; n < 20000; n += 23 {
b := make([]byte, n)
for i, _ := range b {
b[i] = uint8(rand.Uint32())
}
if err := roundtrip(b, nil, nil); err != nil {
t.Fatal(err)
}
}
}
func TestSmallRegular(t *testing.T) {
for n := 1; n < 20000; n += 23 {
b := make([]byte, n)
for i, _ := range b {
b[i] = uint8(i%10 + 'a')
}
if err := roundtrip(b, nil, nil); err != nil {
t.Fatal(err)
}
}
}
func benchDecode(b *testing.B, src []byte) {
encoded, err := Encode(nil, src)
if err != nil {
b.Fatal(err)
}
// Bandwidth is in amount of uncompressed data.
b.SetBytes(int64(len(src)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
Decode(src, encoded)
}
}
func benchEncode(b *testing.B, src []byte) {
// Bandwidth is in amount of uncompressed data.
b.SetBytes(int64(len(src)))
dst := make([]byte, MaxEncodedLen(len(src)))
b.ResetTimer()
for i := 0; i < b.N; i++ {
Encode(dst, src)
}
}
func readFile(b *testing.B, filename string) []byte {
src, err := ioutil.ReadFile(filename)
if err != nil {
b.Fatalf("failed reading %s: %s", filename, err)
}
if len(src) == 0 {
b.Fatalf("%s has zero length", filename)
}
return src
}
// expand returns a slice of length n containing repeated copies of src.
func expand(src []byte, n int) []byte {
dst := make([]byte, n)
for x := dst; len(x) > 0; {
i := copy(x, src)
x = x[i:]
}
return dst
}
func benchWords(b *testing.B, n int, decode bool) {
// Note: the file is OS-language dependent so the resulting values are not
// directly comparable for non-US-English OS installations.
data := expand(readFile(b, "/usr/share/dict/words"), n)
if decode {
benchDecode(b, data)
} else {
benchEncode(b, data)
}
}
func BenchmarkWordsDecode1e3(b *testing.B) { benchWords(b, 1e3, true) }
func BenchmarkWordsDecode1e4(b *testing.B) { benchWords(b, 1e4, true) }
func BenchmarkWordsDecode1e5(b *testing.B) { benchWords(b, 1e5, true) }
func BenchmarkWordsDecode1e6(b *testing.B) { benchWords(b, 1e6, true) }
func BenchmarkWordsEncode1e3(b *testing.B) { benchWords(b, 1e3, false) }
func BenchmarkWordsEncode1e4(b *testing.B) { benchWords(b, 1e4, false) }
func BenchmarkWordsEncode1e5(b *testing.B) { benchWords(b, 1e5, false) }
func BenchmarkWordsEncode1e6(b *testing.B) { benchWords(b, 1e6, false) }
// testFiles' values are copied directly from
// https://code.google.com/p/snappy/source/browse/trunk/snappy_unittest.cc.
// The label field is unused in snappy-go.
var testFiles = []struct {
label string
filename string
}{
{"html", "html"},
{"urls", "urls.10K"},
{"jpg", "house.jpg"},
{"pdf", "mapreduce-osdi-1.pdf"},
{"html4", "html_x_4"},
{"cp", "cp.html"},
{"c", "fields.c"},
{"lsp", "grammar.lsp"},
{"xls", "kennedy.xls"},
{"txt1", "alice29.txt"},
{"txt2", "asyoulik.txt"},
{"txt3", "lcet10.txt"},
{"txt4", "plrabn12.txt"},
{"bin", "ptt5"},
{"sum", "sum"},
{"man", "xargs.1"},
{"pb", "geo.protodata"},
{"gaviota", "kppkn.gtb"},
}
// The test data files are present at this canonical URL.
const baseURL = "https://snappy.googlecode.com/svn/trunk/testdata/"
func downloadTestdata(basename string) (errRet error) {
filename := filepath.Join("testdata", basename)
f, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to create %s: %s", filename, err)
}
defer f.Close()
defer func() {
if errRet != nil {
os.Remove(filename)
}
}()
resp, err := http.Get(baseURL + basename)
if err != nil {
return fmt.Errorf("failed to download %s: %s", baseURL+basename, err)
}
defer resp.Body.Close()
_, err = io.Copy(f, resp.Body)
if err != nil {
return fmt.Errorf("failed to write %s: %s", filename, err)
}
return nil
}
func benchFile(b *testing.B, n int, decode bool) {
filename := filepath.Join("testdata", testFiles[n].filename)
if stat, err := os.Stat(filename); err != nil || stat.Size() == 0 {
if !*download {
b.Fatal("test data not found; skipping benchmark without the -download flag")
}
// Download the official snappy C++ implementation reference test data
// files for benchmarking.
if err := os.Mkdir("testdata", 0777); err != nil && !os.IsExist(err) {
b.Fatalf("failed to create testdata: %s", err)
}
for _, tf := range testFiles {
if err := downloadTestdata(tf.filename); err != nil {
b.Fatalf("failed to download testdata: %s", err)
}
}
}
data := readFile(b, filename)
if decode {
benchDecode(b, data)
} else {
benchEncode(b, data)
}
}
// Naming convention is kept similar to what snappy's C++ implementation uses.
func Benchmark_UFlat0(b *testing.B) { benchFile(b, 0, true) }
func Benchmark_UFlat1(b *testing.B) { benchFile(b, 1, true) }
func Benchmark_UFlat2(b *testing.B) { benchFile(b, 2, true) }
func Benchmark_UFlat3(b *testing.B) { benchFile(b, 3, true) }
func Benchmark_UFlat4(b *testing.B) { benchFile(b, 4, true) }
func Benchmark_UFlat5(b *testing.B) { benchFile(b, 5, true) }
func Benchmark_UFlat6(b *testing.B) { benchFile(b, 6, true) }
func Benchmark_UFlat7(b *testing.B) { benchFile(b, 7, true) }
func Benchmark_UFlat8(b *testing.B) { benchFile(b, 8, true) }
func Benchmark_UFlat9(b *testing.B) { benchFile(b, 9, true) }
func Benchmark_UFlat10(b *testing.B) { benchFile(b, 10, true) }
func Benchmark_UFlat11(b *testing.B) { benchFile(b, 11, true) }
func Benchmark_UFlat12(b *testing.B) { benchFile(b, 12, true) }
func Benchmark_UFlat13(b *testing.B) { benchFile(b, 13, true) }
func Benchmark_UFlat14(b *testing.B) { benchFile(b, 14, true) }
func Benchmark_UFlat15(b *testing.B) { benchFile(b, 15, true) }
func Benchmark_UFlat16(b *testing.B) { benchFile(b, 16, true) }
func Benchmark_UFlat17(b *testing.B) { benchFile(b, 17, true) }
func Benchmark_ZFlat0(b *testing.B) { benchFile(b, 0, false) }
func Benchmark_ZFlat1(b *testing.B) { benchFile(b, 1, false) }
func Benchmark_ZFlat2(b *testing.B) { benchFile(b, 2, false) }
func Benchmark_ZFlat3(b *testing.B) { benchFile(b, 3, false) }
func Benchmark_ZFlat4(b *testing.B) { benchFile(b, 4, false) }
func Benchmark_ZFlat5(b *testing.B) { benchFile(b, 5, false) }
func Benchmark_ZFlat6(b *testing.B) { benchFile(b, 6, false) }
func Benchmark_ZFlat7(b *testing.B) { benchFile(b, 7, false) }
func Benchmark_ZFlat8(b *testing.B) { benchFile(b, 8, false) }
func Benchmark_ZFlat9(b *testing.B) { benchFile(b, 9, false) }
func Benchmark_ZFlat10(b *testing.B) { benchFile(b, 10, false) }
func Benchmark_ZFlat11(b *testing.B) { benchFile(b, 11, false) }
func Benchmark_ZFlat12(b *testing.B) { benchFile(b, 12, false) }
func Benchmark_ZFlat13(b *testing.B) { benchFile(b, 13, false) }
func Benchmark_ZFlat14(b *testing.B) { benchFile(b, 14, false) }
func Benchmark_ZFlat15(b *testing.B) { benchFile(b, 15, false) }
func Benchmark_ZFlat16(b *testing.B) { benchFile(b, 16, false) }
func Benchmark_ZFlat17(b *testing.B) { benchFile(b, 17, false) }

View File

@ -0,0 +1,5 @@
/tmp
*/**/*un~
*un~
.DS_Store
*/**/.DS_Store

View File

@ -0,0 +1,3 @@
[submodule "serp"]
path = serpent
url = https://github.com/ethereum/serpent.git

View File

@ -0,0 +1,12 @@
[serpent](https://github.com/ethereum/serpent) go bindings.
## Build instructions
```
go get -d github.com/ethereum/serpent-go
cd $GOPATH/src/github.com/ethereum/serpent-go
git submodule init
git submodule update
```
You're now ready to go :-)

View File

@ -0,0 +1,16 @@
#include "serpent/bignum.cpp"
#include "serpent/util.cpp"
#include "serpent/tokenize.cpp"
#include "serpent/parser.cpp"
#include "serpent/compiler.cpp"
#include "serpent/funcs.cpp"
#include "serpent/lllparser.cpp"
#include "serpent/rewriter.cpp"
#include "serpent/opcodes.cpp"
#include "serpent/optimize.cpp"
#include "serpent/functions.cpp"
#include "serpent/preprocess.cpp"
#include "serpent/rewriteutils.cpp"
#include "cpp/api.cpp"

View File

@ -0,0 +1,26 @@
#include <string>
#include "serpent/lllparser.h"
#include "serpent/bignum.h"
#include "serpent/util.h"
#include "serpent/tokenize.h"
#include "serpent/parser.h"
#include "serpent/compiler.h"
#include "cpp/api.h"
const char *compileGo(char *code, int *err)
{
try {
std::string c = binToHex(compile(std::string(code)));
return c.c_str();
}
catch(std::string &error) {
*err = 1;
return error.c_str();
}
catch(...) {
return "Unknown error";
}
}

View File

@ -0,0 +1,14 @@
#ifndef CPP_API_H
#define CPP_API_H
#ifdef __cplusplus
extern "C" {
#endif
const char *compileGo(char *code, int *err);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -0,0 +1,27 @@
package serpent
// #cgo CXXFLAGS: -I. -Ilangs/ -std=c++0x -Wall -fno-strict-aliasing
// #cgo LDFLAGS: -lstdc++
//
// #include "cpp/api.h"
//
import "C"
import (
"encoding/hex"
"errors"
"unsafe"
)
func Compile(str string) ([]byte, error) {
var err C.int
out := C.GoString(C.compileGo(C.CString(str), (*C.int)(unsafe.Pointer(&err))))
if err == C.int(1) {
return nil, errors.New(out)
}
bytes, _ := hex.DecodeString(out)
return bytes, nil
}

View File

@ -0,0 +1,12 @@
[._]*.s[a-w][a-z]
[._]s[a-w][a-z]
*.un~
Session.vim
.netrwhist
*~
*.o
serpent
libserpent.a
pyserpent.so
dist
*.egg-info

View File

@ -0,0 +1,5 @@
include *.cpp
include *.h
include *py
include README.md
include Makefile

View File

@ -0,0 +1,55 @@
PLATFORM_OPTS =
PYTHON = /usr/include/python2.7
CXXFLAGS = -fPIC
# -g3 -O0
BOOST_INC = /usr/include
BOOST_LIB = /usr/lib
TARGET = pyserpent
COMMON_OBJS = bignum.o util.o tokenize.o lllparser.o parser.o opcodes.o optimize.o functions.o rewriteutils.o preprocess.o rewriter.o compiler.o funcs.o
HEADERS = bignum.h util.h tokenize.h lllparser.h parser.h opcodes.h functions.h optimize.h rewriteutils.h preprocess.h rewriter.h compiler.h funcs.h
PYTHON_VERSION = 2.7
serpent : serpentc lib
lib:
ar rvs libserpent.a $(COMMON_OBJS)
g++ $(CXXFLAGS) -shared $(COMMON_OBJS) -o libserpent.so
serpentc: $(COMMON_OBJS) cmdline.o
rm -rf serpent
g++ -Wall $(COMMON_OBJS) cmdline.o -o serpent
bignum.o : bignum.cpp bignum.h
opcodes.o : opcodes.cpp opcodes.h
util.o : util.cpp util.h bignum.o
tokenize.o : tokenize.cpp tokenize.h util.o
lllparser.o : lllparser.cpp lllparser.h tokenize.o util.o
parser.o : parser.cpp parser.h tokenize.o util.o
rewriter.o : rewriter.cpp rewriter.h lllparser.o util.o rewriteutils.o preprocess.o opcodes.o functions.o
preprocessor.o: rewriteutils.o functions.o
compiler.o : compiler.cpp compiler.h util.o
funcs.o : funcs.cpp funcs.h
cmdline.o: cmdline.cpp
pyext.o: pyext.cpp
clean:
rm -f serpent *\.o libserpent.a libserpent.so
install:
cp serpent /usr/local/bin
cp libserpent.a /usr/local/lib
cp libserpent.so /usr/local/lib
rm -rf /usr/local/include/libserpent
mkdir -p /usr/local/include/libserpent
cp $(HEADERS) /usr/local/include/libserpent

View File

@ -0,0 +1,3 @@
Installation:
```make && sudo make install```

View File

@ -0,0 +1,112 @@
#include <stdio.h>
#include <iostream>
#include <vector>
#include <map>
#include "bignum.h"
//Integer to string conversion
std::string unsignedToDecimal(unsigned branch) {
if (branch < 10) return nums.substr(branch, 1);
else return unsignedToDecimal(branch / 10) + nums.substr(branch % 10,1);
}
//Add two strings representing decimal values
std::string decimalAdd(std::string a, std::string b) {
std::string o = a;
while (b.length() < a.length()) b = "0" + b;
while (o.length() < b.length()) o = "0" + o;
bool carry = false;
for (int i = o.length() - 1; i >= 0; i--) {
o[i] = o[i] + b[i] - '0';
if (carry) o[i]++;
if (o[i] > '9') {
o[i] -= 10;
carry = true;
}
else carry = false;
}
if (carry) o = "1" + o;
return o;
}
//Helper function for decimalMul
std::string decimalDigitMul(std::string a, int dig) {
if (dig == 0) return "0";
else return decimalAdd(a, decimalDigitMul(a, dig - 1));
}
//Multiply two strings representing decimal values
std::string decimalMul(std::string a, std::string b) {
std::string o = "0";
for (unsigned i = 0; i < b.length(); i++) {
std::string n = decimalDigitMul(a, b[i] - '0');
if (n != "0") {
for (unsigned j = i + 1; j < b.length(); j++) n += "0";
}
o = decimalAdd(o, n);
}
return o;
}
//Modexp
std::string decimalModExp(std::string b, std::string e, std::string m) {
if (e == "0") return "1";
else if (e == "1") return b;
else if (decimalMod(e, "2") == "0") {
std::string o = decimalModExp(b, decimalDiv(e, "2"), m);
return decimalMod(decimalMul(o, o), m);
}
else {
std::string o = decimalModExp(b, decimalDiv(e, "2"), m);
return decimalMod(decimalMul(decimalMul(o, o), b), m);
}
}
//Is a greater than b? Flag allows equality
bool decimalGt(std::string a, std::string b, bool eqAllowed) {
if (a == b) return eqAllowed;
return (a.length() > b.length()) || (a.length() >= b.length() && a > b);
}
//Subtract the two strings representing decimal values
std::string decimalSub(std::string a, std::string b) {
if (b == "0") return a;
if (b == a) return "0";
while (b.length() < a.length()) b = "0" + b;
std::string c = b;
for (unsigned i = 0; i < c.length(); i++) c[i] = '0' + ('9' - c[i]);
std::string o = decimalAdd(decimalAdd(a, c).substr(1), "1");
while (o.size() > 1 && o[0] == '0') o = o.substr(1);
return o;
}
//Divide the two strings representing decimal values
std::string decimalDiv(std::string a, std::string b) {
std::string c = b;
if (decimalGt(c, a)) return "0";
int zeroes = -1;
while (decimalGt(a, c, true)) {
zeroes += 1;
c = c + "0";
}
c = c.substr(0, c.size() - 1);
std::string quot = "0";
while (decimalGt(a, c, true)) {
a = decimalSub(a, c);
quot = decimalAdd(quot, "1");
}
for (int i = 0; i < zeroes; i++) quot += "0";
return decimalAdd(quot, decimalDiv(a, b));
}
//Modulo the two strings representing decimal values
std::string decimalMod(std::string a, std::string b) {
return decimalSub(a, decimalMul(decimalDiv(a, b), b));
}
//String to int conversion
unsigned decimalToUnsigned(std::string a) {
if (a.size() == 0) return 0;
else return (a[a.size() - 1] - '0')
+ decimalToUnsigned(a.substr(0,a.size()-1)) * 10;
}

View File

@ -0,0 +1,41 @@
#ifndef ETHSERP_BIGNUM
#define ETHSERP_BIGNUM
const std::string nums = "0123456789";
const std::string tt256 =
"115792089237316195423570985008687907853269984665640564039457584007913129639936"
;
const std::string tt256m1 =
"115792089237316195423570985008687907853269984665640564039457584007913129639935"
;
const std::string tt255 =
"57896044618658097711785492504343953926634992332820282019728792003956564819968";
const std::string tt176 =
"95780971304118053647396689196894323976171195136475136";
std::string unsignedToDecimal(unsigned branch);
std::string decimalAdd(std::string a, std::string b);
std::string decimalMul(std::string a, std::string b);
std::string decimalSub(std::string a, std::string b);
std::string decimalDiv(std::string a, std::string b);
std::string decimalMod(std::string a, std::string b);
std::string decimalModExp(std::string b, std::string e, std::string m);
bool decimalGt(std::string a, std::string b, bool eqAllowed=false);
unsigned decimalToUnsigned(std::string a);
#define utd unsignedToDecimal
#define dtu decimalToUnsigned
#endif

View File

@ -0,0 +1,132 @@
#include <stdio.h>
#include <string>
#include <iostream>
#include <vector>
#include <map>
#include "funcs.h"
int main(int argv, char** argc) {
if (argv == 1) {
std::cerr << "Must provide a command and arguments! Try parse, rewrite, compile, assemble\n";
return 0;
}
if (argv == 2 && std::string(argc[1]) == "--help" || std::string(argc[1]) == "-h" ) {
std::cout << argc[1] << "\n";
std::cout << "serpent command input\n";
std::cout << "where input -s for from stdin, a file, or interpreted as serpent code if does not exist as file.";
std::cout << "where command: \n";
std::cout << " parse: Just parses and returns s-expression code.\n";
std::cout << " rewrite: Parse, use rewrite rules print s-expressions of result.\n";
std::cout << " compile: Return resulting compiled EVM code in hex.\n";
std::cout << " assemble: Return result from step before compilation.\n";
return 0;
}
std::string flag = "";
std::string command = argc[1];
std::string input;
std::string secondInput;
if (std::string(argc[1]) == "-s") {
flag = command.substr(1);
command = argc[2];
input = "";
std::string line;
while (std::getline(std::cin, line)) {
input += line + "\n";
}
secondInput = argv == 3 ? "" : argc[3];
}
else {
if (argv == 2) {
std::cerr << "Not enough arguments for serpent cmdline\n";
throw(0);
}
input = argc[2];
secondInput = argv == 3 ? "" : argc[3];
}
bool haveSec = secondInput.length() > 0;
if (command == "parse" || command == "parse_serpent") {
std::cout << printAST(parseSerpent(input), haveSec) << "\n";
}
else if (command == "rewrite") {
std::cout << printAST(rewrite(parseLLL(input, true)), haveSec) << "\n";
}
else if (command == "compile_to_lll") {
std::cout << printAST(compileToLLL(input), haveSec) << "\n";
}
else if (command == "rewrite_chunk") {
std::cout << printAST(rewriteChunk(parseLLL(input, true)), haveSec) << "\n";
}
else if (command == "compile_chunk_to_lll") {
std::cout << printAST(compileChunkToLLL(input), haveSec) << "\n";
}
else if (command == "build_fragtree") {
std::cout << printAST(buildFragmentTree(parseLLL(input, true))) << "\n";
}
else if (command == "compile_lll") {
std::cout << binToHex(compileLLL(parseLLL(input, true))) << "\n";
}
else if (command == "dereference") {
std::cout << printAST(dereference(parseLLL(input, true)), haveSec) <<"\n";
}
else if (command == "pretty_assemble") {
std::cout << printTokens(prettyAssemble(parseLLL(input, true))) <<"\n";
}
else if (command == "pretty_compile_lll") {
std::cout << printTokens(prettyCompileLLL(parseLLL(input, true))) << "\n";
}
else if (command == "pretty_compile") {
std::cout << printTokens(prettyCompile(input)) << "\n";
}
else if (command == "pretty_compile_chunk") {
std::cout << printTokens(prettyCompileChunk(input)) << "\n";
}
else if (command == "assemble") {
std::cout << assemble(parseLLL(input, true)) << "\n";
}
else if (command == "serialize") {
std::cout << binToHex(serialize(tokenize(input, Metadata(), false))) << "\n";
}
else if (command == "flatten") {
std::cout << printTokens(flatten(parseLLL(input, true))) << "\n";
}
else if (command == "deserialize") {
std::cout << printTokens(deserialize(hexToBin(input))) << "\n";
}
else if (command == "compile") {
std::cout << binToHex(compile(input)) << "\n";
}
else if (command == "compile_chunk") {
std::cout << binToHex(compileChunk(input)) << "\n";
}
else if (command == "encode_datalist") {
std::vector<Node> tokens = tokenize(input);
std::vector<std::string> o;
for (int i = 0; i < (int)tokens.size(); i++) {
o.push_back(tokens[i].val);
}
std::cout << binToHex(encodeDatalist(o)) << "\n";
}
else if (command == "decode_datalist") {
std::vector<std::string> o = decodeDatalist(hexToBin(input));
std::vector<Node> tokens;
for (int i = 0; i < (int)o.size(); i++)
tokens.push_back(token(o[i]));
std::cout << printTokens(tokens) << "\n";
}
else if (command == "tokenize") {
std::cout << printTokens(tokenize(input));
}
else if (command == "biject") {
if (argv == 3)
std::cerr << "Not enough arguments for biject\n";
int pos = decimalToUnsigned(secondInput);
std::vector<Node> n = prettyCompile(input);
if (pos >= (int)n.size())
std::cerr << "Code position too high\n";
Metadata m = n[pos].metadata;
std::cout << "Opcode: " << n[pos].val << ", file: " << m.file <<
", line: " << m.ln << ", char: " << m.ch << "\n";
}
}

View File

@ -0,0 +1,554 @@
#include <stdio.h>
#include <iostream>
#include <vector>
#include <map>
#include "util.h"
#include "bignum.h"
#include "opcodes.h"
struct programAux {
std::map<std::string, std::string> vars;
int nextVarMem;
bool allocUsed;
bool calldataUsed;
int step;
int labelLength;
};
struct programVerticalAux {
int height;
std::string innerScopeName;
std::map<std::string, int> dupvars;
std::map<std::string, int> funvars;
std::vector<mss> scopes;
};
struct programData {
programAux aux;
Node code;
int outs;
};
programAux Aux() {
programAux o;
o.allocUsed = false;
o.calldataUsed = false;
o.step = 0;
o.nextVarMem = 32;
return o;
}
programVerticalAux verticalAux() {
programVerticalAux o;
o.height = 0;
o.dupvars = std::map<std::string, int>();
o.funvars = std::map<std::string, int>();
o.scopes = std::vector<mss>();
return o;
}
programData pd(programAux aux = Aux(), Node code=token("_"), int outs=0) {
programData o;
o.aux = aux;
o.code = code;
o.outs = outs;
return o;
}
Node multiToken(Node nodes[], int len, Metadata met) {
std::vector<Node> out;
for (int i = 0; i < len; i++) {
out.push_back(nodes[i]);
}
return astnode("_", out, met);
}
Node finalize(programData c);
Node popwrap(Node node) {
Node nodelist[] = {
node,
token("POP", node.metadata)
};
return multiToken(nodelist, 2, node.metadata);
}
// Grabs variables
mss getVariables(Node node, mss cur=mss()) {
Metadata m = node.metadata;
// Tokens don't contain any variables
if (node.type == TOKEN)
return cur;
// Don't descend into call fragments
else if (node.val == "lll")
return getVariables(node.args[1], cur);
// At global scope get/set/ref also declare
else if (node.val == "get" || node.val == "set" || node.val == "ref") {
if (node.args[0].type != TOKEN)
err("Variable name must be simple token,"
" not complex expression!", m);
if (!cur.count(node.args[0].val)) {
cur[node.args[0].val] = utd(cur.size() * 32 + 32);
//std::cerr << node.args[0].val << " " << cur[node.args[0].val] << "\n";
}
}
// Recursively process children
for (unsigned i = 0; i < node.args.size(); i++) {
cur = getVariables(node.args[i], cur);
}
return cur;
}
// Turns LLL tree into tree of code fragments
programData opcodeify(Node node,
programAux aux=Aux(),
programVerticalAux vaux=verticalAux()) {
std::string symb = "_"+mkUniqueToken();
Metadata m = node.metadata;
// Get variables
if (!aux.vars.size()) {
aux.vars = getVariables(node);
aux.nextVarMem = aux.vars.size() * 32 + 32;
}
// Numbers
if (node.type == TOKEN) {
return pd(aux, nodeToNumeric(node), 1);
}
else if (node.val == "ref" || node.val == "get" || node.val == "set") {
std::string varname = node.args[0].val;
// Determine reference to variable
Node varNode = tkn(aux.vars[varname], m);
//std::cerr << varname << " " << printSimple(varNode) << "\n";
// Set variable
if (node.val == "set") {
programData sub = opcodeify(node.args[1], aux, vaux);
if (!sub.outs)
err("Value to set variable must have nonzero arity!", m);
// What if we are setting a stack variable?
if (vaux.dupvars.count(node.args[0].val)) {
int h = vaux.height - vaux.dupvars[node.args[0].val];
if (h > 16) err("Too deep for stack variable (max 16)", m);
Node nodelist[] = {
sub.code,
token("SWAP"+unsignedToDecimal(h), m),
token("POP", m)
};
return pd(sub.aux, multiToken(nodelist, 3, m), 0);
}
// Setting a memory variable
else {
Node nodelist[] = {
sub.code,
varNode,
token("MSTORE", m),
};
return pd(sub.aux, multiToken(nodelist, 3, m), 0);
}
}
// Get variable
else if (node.val == "get") {
// Getting a stack variable
if (vaux.dupvars.count(node.args[0].val)) {
int h = vaux.height - vaux.dupvars[node.args[0].val];
if (h > 16) err("Too deep for stack variable (max 16)", m);
return pd(aux, token("DUP"+unsignedToDecimal(h)), 1);
}
// Getting a memory variable
else {
Node nodelist[] =
{ varNode, token("MLOAD", m) };
return pd(aux, multiToken(nodelist, 2, m), 1);
}
}
// Refer variable
else if (node.val == "ref") {
if (vaux.dupvars.count(node.args[0].val))
err("Cannot ref stack variable!", m);
return pd(aux, varNode, 1);
}
}
// Comments do nothing
else if (node.val == "comment") {
Node nodelist[] = { };
return pd(aux, multiToken(nodelist, 0, m), 0);
}
// Custom operation sequence
// eg. (ops bytez id msize swap1 msize add 0 swap1 mstore) == alloc
if (node.val == "ops") {
std::vector<Node> subs2;
int depth = 0;
for (unsigned i = 0; i < node.args.size(); i++) {
std::string op = upperCase(node.args[i].val);
if (node.args[i].type == ASTNODE || opinputs(op) == -1) {
programVerticalAux vaux2 = vaux;
vaux2.height = vaux.height - i - 1 + node.args.size();
programData sub = opcodeify(node.args[i], aux, vaux2);
aux = sub.aux;
depth += sub.outs;
subs2.push_back(sub.code);
}
else {
subs2.push_back(token(op, m));
depth += opoutputs(op) - opinputs(op);
}
}
if (depth < 0 || depth > 1) err("Stack depth mismatch", m);
return pd(aux, astnode("_", subs2, m), 0);
}
// Code blocks
if (node.val == "lll" && node.args.size() == 2) {
if (node.args[1].val != "0") aux.allocUsed = true;
std::vector<Node> o;
o.push_back(finalize(opcodeify(node.args[0])));
programData sub = opcodeify(node.args[1], aux, vaux);
Node code = astnode("____CODE", o, m);
Node nodelist[] = {
token("$begincode"+symb+".endcode"+symb, m), token("DUP1", m),
token("$begincode"+symb, m), sub.code, token("CODECOPY", m),
token("$endcode"+symb, m), token("JUMP", m),
token("~begincode"+symb, m), code,
token("~endcode"+symb, m), token("JUMPDEST", m)
};
return pd(sub.aux, multiToken(nodelist, 11, m), 1);
}
// Stack variables
if (node.val == "with") {
programData initial = opcodeify(node.args[1], aux, vaux);
programVerticalAux vaux2 = vaux;
vaux2.dupvars[node.args[0].val] = vaux.height;
vaux2.height += 1;
if (!initial.outs)
err("Initial variable value must have nonzero arity!", m);
programData sub = opcodeify(node.args[2], initial.aux, vaux2);
Node nodelist[] = {
initial.code,
sub.code
};
programData o = pd(sub.aux, multiToken(nodelist, 2, m), sub.outs);
if (sub.outs)
o.code.args.push_back(token("SWAP1", m));
o.code.args.push_back(token("POP", m));
return o;
}
// Seq of multiple statements
if (node.val == "seq") {
std::vector<Node> children;
int lastOut = 0;
for (unsigned i = 0; i < node.args.size(); i++) {
programData sub = opcodeify(node.args[i], aux, vaux);
aux = sub.aux;
if (sub.outs == 1) {
if (i < node.args.size() - 1) sub.code = popwrap(sub.code);
else lastOut = 1;
}
children.push_back(sub.code);
}
return pd(aux, astnode("_", children, m), lastOut);
}
// 2-part conditional (if gets rewritten to unless in rewrites)
else if (node.val == "unless" && node.args.size() == 2) {
programData cond = opcodeify(node.args[0], aux, vaux);
programData action = opcodeify(node.args[1], cond.aux, vaux);
aux = action.aux;
if (!cond.outs) err("Condition of if/unless statement has arity 0", m);
if (action.outs) action.code = popwrap(action.code);
Node nodelist[] = {
cond.code,
token("$endif"+symb, m), token("JUMPI", m),
action.code,
token("~endif"+symb, m), token("JUMPDEST", m)
};
return pd(aux, multiToken(nodelist, 6, m), 0);
}
// 3-part conditional
else if (node.val == "if" && node.args.size() == 3) {
programData ifd = opcodeify(node.args[0], aux, vaux);
programData thend = opcodeify(node.args[1], ifd.aux, vaux);
programData elsed = opcodeify(node.args[2], thend.aux, vaux);
aux = elsed.aux;
if (!ifd.outs)
err("Condition of if/unless statement has arity 0", m);
// Handle cases where one conditional outputs something
// and the other does not
int outs = (thend.outs && elsed.outs) ? 1 : 0;
if (thend.outs > outs) thend.code = popwrap(thend.code);
if (elsed.outs > outs) elsed.code = popwrap(elsed.code);
Node nodelist[] = {
ifd.code,
token("ISZERO", m),
token("$else"+symb, m), token("JUMPI", m),
thend.code,
token("$endif"+symb, m), token("JUMP", m),
token("~else"+symb, m), token("JUMPDEST", m),
elsed.code,
token("~endif"+symb, m), token("JUMPDEST", m)
};
return pd(aux, multiToken(nodelist, 12, m), outs);
}
// While (rewritten to this in rewrites)
else if (node.val == "until") {
programData cond = opcodeify(node.args[0], aux, vaux);
programData action = opcodeify(node.args[1], cond.aux, vaux);
aux = action.aux;
if (!cond.outs)
err("Condition of while/until loop has arity 0", m);
if (action.outs) action.code = popwrap(action.code);
Node nodelist[] = {
token("~beg"+symb, m), token("JUMPDEST", m),
cond.code,
token("$end"+symb, m), token("JUMPI", m),
action.code,
token("$beg"+symb, m), token("JUMP", m),
token("~end"+symb, m), token("JUMPDEST", m),
};
return pd(aux, multiToken(nodelist, 10, m));
}
// Memory allocations
else if (node.val == "alloc") {
programData bytez = opcodeify(node.args[0], aux, vaux);
aux = bytez.aux;
if (!bytez.outs)
err("Alloc input has arity 0", m);
aux.allocUsed = true;
Node nodelist[] = {
bytez.code,
token("MSIZE", m), token("SWAP1", m), token("MSIZE", m),
token("ADD", m),
token("0", m), token("SWAP1", m), token("MSTORE", m)
};
return pd(aux, multiToken(nodelist, 8, m), 1);
}
// All other functions/operators
else {
std::vector<Node> subs2;
int depth = opinputs(upperCase(node.val));
if (depth == -1)
err("Not a function or opcode: "+node.val, m);
if ((int)node.args.size() != depth)
err("Invalid arity for "+node.val, m);
for (int i = node.args.size() - 1; i >= 0; i--) {
programVerticalAux vaux2 = vaux;
vaux2.height = vaux.height - i - 1 + node.args.size();
programData sub = opcodeify(node.args[i], aux, vaux2);
aux = sub.aux;
if (!sub.outs)
err("Input "+unsignedToDecimal(i)+" has arity 0", sub.code.metadata);
subs2.push_back(sub.code);
}
subs2.push_back(token(upperCase(node.val), m));
int outdepth = opoutputs(upperCase(node.val));
return pd(aux, astnode("_", subs2, m), outdepth);
}
}
// Adds necessary wrappers to a program
Node finalize(programData c) {
std::vector<Node> bottom;
Metadata m = c.code.metadata;
// If we are using both alloc and variables, we need to pre-zfill
// some memory
if ((c.aux.allocUsed || c.aux.calldataUsed) && c.aux.vars.size() > 0) {
Node nodelist[] = {
token("0", m),
token(unsignedToDecimal(c.aux.nextVarMem - 1)),
token("MSTORE8", m)
};
bottom.push_back(multiToken(nodelist, 3, m));
}
// The actual code
bottom.push_back(c.code);
return astnode("_", bottom, m);
}
//LLL -> code fragment tree
Node buildFragmentTree(Node node) {
return finalize(opcodeify(node));
}
// Builds a dictionary mapping labels to variable names
programAux buildDict(Node program, programAux aux, int labelLength) {
Metadata m = program.metadata;
// Token
if (program.type == TOKEN) {
if (isNumberLike(program)) {
aux.step += 1 + toByteArr(program.val, m).size();
}
else if (program.val[0] == '~') {
aux.vars[program.val.substr(1)] = unsignedToDecimal(aux.step);
}
else if (program.val[0] == '$') {
aux.step += labelLength + 1;
}
else aux.step += 1;
}
// A sub-program (ie. LLL)
else if (program.val == "____CODE") {
programAux auks = Aux();
for (unsigned i = 0; i < program.args.size(); i++) {
auks = buildDict(program.args[i], auks, labelLength);
}
for (std::map<std::string,std::string>::iterator it=auks.vars.begin();
it != auks.vars.end();
it++) {
aux.vars[(*it).first] = (*it).second;
}
aux.step += auks.step;
}
// Normal sub-block
else {
for (unsigned i = 0; i < program.args.size(); i++) {
aux = buildDict(program.args[i], aux, labelLength);
}
}
return aux;
}
// Applies that dictionary
Node substDict(Node program, programAux aux, int labelLength) {
Metadata m = program.metadata;
std::vector<Node> out;
std::vector<Node> inner;
if (program.type == TOKEN) {
if (program.val[0] == '$') {
std::string tokStr = "PUSH"+unsignedToDecimal(labelLength);
out.push_back(token(tokStr, m));
int dotLoc = program.val.find('.');
if (dotLoc == -1) {
std::string val = aux.vars[program.val.substr(1)];
inner = toByteArr(val, m, labelLength);
}
else {
std::string start = aux.vars[program.val.substr(1, dotLoc-1)],
end = aux.vars[program.val.substr(dotLoc + 1)],
dist = decimalSub(end, start);
inner = toByteArr(dist, m, labelLength);
}
out.push_back(astnode("_", inner, m));
}
else if (program.val[0] == '~') { }
else if (isNumberLike(program)) {
inner = toByteArr(program.val, m);
out.push_back(token("PUSH"+unsignedToDecimal(inner.size())));
out.push_back(astnode("_", inner, m));
}
else return program;
}
else {
for (unsigned i = 0; i < program.args.size(); i++) {
Node n = substDict(program.args[i], aux, labelLength);
if (n.type == TOKEN || n.args.size()) out.push_back(n);
}
}
return astnode("_", out, m);
}
// Compiled fragtree -> compiled fragtree without labels
Node dereference(Node program) {
int sz = treeSize(program) * 4;
int labelLength = 1;
while (sz >= 256) { labelLength += 1; sz /= 256; }
programAux aux = buildDict(program, Aux(), labelLength);
return substDict(program, aux, labelLength);
}
// Dereferenced fragtree -> opcodes
std::vector<Node> flatten(Node derefed) {
std::vector<Node> o;
if (derefed.type == TOKEN) {
o.push_back(derefed);
}
else {
for (unsigned i = 0; i < derefed.args.size(); i++) {
std::vector<Node> oprime = flatten(derefed.args[i]);
for (unsigned j = 0; j < oprime.size(); j++) o.push_back(oprime[j]);
}
}
return o;
}
// Opcodes -> bin
std::string serialize(std::vector<Node> codons) {
std::string o;
for (unsigned i = 0; i < codons.size(); i++) {
int v;
if (isNumberLike(codons[i])) {
v = decimalToUnsigned(codons[i].val);
}
else if (codons[i].val.substr(0,4) == "PUSH") {
v = 95 + decimalToUnsigned(codons[i].val.substr(4));
}
else {
v = opcode(codons[i].val);
}
o += (char)v;
}
return o;
}
// Bin -> opcodes
std::vector<Node> deserialize(std::string ser) {
std::vector<Node> o;
int backCount = 0;
for (unsigned i = 0; i < ser.length(); i++) {
unsigned char v = (unsigned char)ser[i];
std::string oper = op((int)v);
if (oper != "" && backCount <= 0) o.push_back(token(oper));
else if (v >= 96 && v < 128 && backCount <= 0) {
o.push_back(token("PUSH"+unsignedToDecimal(v - 95)));
}
else o.push_back(token(unsignedToDecimal(v)));
if (v >= 96 && v < 128 && backCount <= 0) {
backCount = v - 95;
}
else backCount--;
}
return o;
}
// Fragtree -> bin
std::string assemble(Node fragTree) {
return serialize(flatten(dereference(fragTree)));
}
// Fragtree -> tokens
std::vector<Node> prettyAssemble(Node fragTree) {
return flatten(dereference(fragTree));
}
// LLL -> bin
std::string compileLLL(Node program) {
return assemble(buildFragmentTree(program));
}
// LLL -> tokens
std::vector<Node> prettyCompileLLL(Node program) {
return prettyAssemble(buildFragmentTree(program));
}
// Converts a list of integer values to binary transaction data
std::string encodeDatalist(std::vector<std::string> vals) {
std::string o;
for (unsigned i = 0; i < vals.size(); i++) {
std::vector<Node> n = toByteArr(strToNumeric(vals[i]), Metadata(), 32);
for (unsigned j = 0; j < n.size(); j++) {
int v = decimalToUnsigned(n[j].val);
o += (char)v;
}
}
return o;
}
// Converts binary transaction data into a list of integer values
std::vector<std::string> decodeDatalist(std::string ser) {
std::vector<std::string> out;
for (unsigned i = 0; i < ser.length(); i+= 32) {
std::string o = "0";
for (unsigned j = i; j < i + 32; j++) {
int vj = (int)(unsigned char)ser[j];
o = decimalAdd(decimalMul(o, "256"), unsignedToDecimal(vj));
}
out.push_back(o);
}
return out;
}

View File

@ -0,0 +1,43 @@
#ifndef ETHSERP_COMPILER
#define ETHSERP_COMPILER
#include <stdio.h>
#include <iostream>
#include <vector>
#include <map>
#include "util.h"
// Compiled fragtree -> compiled fragtree without labels
Node dereference(Node program);
// LLL -> fragtree
Node buildFragmentTree(Node program);
// Dereferenced fragtree -> opcodes
std::vector<Node> flatten(Node derefed);
// opcodes -> bin
std::string serialize(std::vector<Node> codons);
// Fragtree -> bin
std::string assemble(Node fragTree);
// Fragtree -> opcodes
std::vector<Node> prettyAssemble(Node fragTree);
// LLL -> bin
std::string compileLLL(Node program);
// LLL -> opcodes
std::vector<Node> prettyCompileLLL(Node program);
// bin -> opcodes
std::vector<Node> deserialize(std::string ser);
// Converts a list of integer values to binary transaction data
std::string encodeDatalist(std::vector<std::string> vals);
// Converts binary transaction data into a list of integer values
std::vector<std::string> decodeDatalist(std::string ser);
#endif

View File

@ -0,0 +1,11 @@
#include <libserpent/funcs.h>
#include <libserpent/bignum.h>
#include <iostream>
using namespace std;
int main() {
cout << printAST(compileToLLL(get_file_contents("examples/namecoin.se"))) << "\n";
cout << decimalSub("10234", "10234") << "\n";
cout << decimalSub("10234", "10233") << "\n";
}

View File

@ -0,0 +1,11 @@
x = msg.data[0]
steps = 0
while x > 1:
steps += 1
if (x % 2) == 0:
x /= 2
else:
x = 3 * x + 1
return(steps)

View File

@ -0,0 +1,274 @@
# Ethereum forks Counterparty in 340 lines of serpent
# Not yet tested
# assets[i] = a registered asset, assets[i].holders[j] = former or current i-holder
data assets[2^50](creator, name, calldate, callprice, dividend_paid, holders[2^50], holdersCount)
data nextAssetId
# holdersMap: holdersMap[addr][asset] = 1 if addr holds asset
data holdersMap[2^160][2^50]
# balances[x][y] = how much of y x holds
data balances[2^160][2^50]
# orders[a][b] = heap of indices to (c, d, e)
# = c offers to sell d units of a at a price of e units of b per 10^18 units
# of a
data orderbooks[2^50][2^50]
# store of general order data
data orders[2^50](seller, asset_sold, quantity, price)
data ordersCount
# data feeds
data feeds[2^50](owner, value)
data feedCount
# heap
data heap
extern heap: [register, push, pop, top, size]
data cfds[2^50](maker, acceptor, feed, asset, strike, leverage, min, max, maturity)
data cfdCount
data bets[2^50](maker, acceptor, feed, asset, makerstake, acceptorstake, eqtest, maturity)
data betCount
def init():
heap = create('heap.se')
# Add units (internal method)
def add(to, asset, value):
assert msg.sender == self
self.balances[to][asset] += value
# Add the holder to the holders list
if not self.holdersMap[to][asset]:
self.holdersMap[to][asset] = 1
c = self.assets[asset].holdersCount
self.assets[asset].holders[c] = to
self.assets[asset].holdersCount = c + 1
# Register a new asset
def register_asset(q, name, calldate, callprice):
newid = self.nextAssetId
self.assets[newid].creator = msg.sender
self.assets[newid].name = name
self.assets[newid].calldate = calldate
self.assets[newid].callprice = callprice
self.assets[newid].holders[0] = msg.sender
self.assets[newid].holdersCount = 1
self.balances[msg.sender][newid] = q
self.holdersMap[msg.sender][newid] = 1
# Send
def send(to, asset, value):
fromval = self.balances[msg.sender][asset]
if fromval >= value:
self.balances[msg.sender][asset] -= value
self.add(to, asset, value)
# Order
def mkorder(selling, buying, quantity, price):
# Make sure you have enough to pay for the order
assert self.balances[msg.sender][selling] >= quantity:
# Try to match existing orders
o = orderbooks[buying][selling]
if not o:
o = self.heap.register()
orderbooks[selling][buying] = o
sz = self.heap.size(o)
invprice = 10^36 / price
while quantity > 0 and sz > 0:
orderid = self.heap.pop()
p = self.orders[orderid].price
if p > invprice:
sz = 0
else:
q = self.orders[orderid].quantity
oq = min(q, quantity)
b = self.orders[orderid].seller
self.balances[msg.sender][selling] -= oq * p / 10^18
self.add(msg.sender, buying, oq)
self.add(b, selling, oq * p / 10^18)
self.orders[orderid].quantity = q - oq
if oq == q:
self.orders[orderid].seller = 0
self.orders[orderid].price = 0
self.orders[orderid].asset_sold = 0
quantity -= oq
sz -= 1
assert quantity > 0
# Make the order
c = self.ordersCount
self.orders[c].seller = msg.sender
self.orders[c].asset_sold = selling
self.orders[c].quantity = quantity
self.orders[c].price = price
self.ordersCount += 1
# Add it to the heap
o = orderbooks[selling][buying]
if not o:
o = self.heap.register()
orderbooks[selling][buying] = o
self.balances[msg.sender][selling] -= quantity
self.heap.push(o, price, c)
return(c)
def cancel_order(id):
if self.orders[id].seller == msg.sender:
self.orders[id].seller = 0
self.orders[id].price = 0
self.balances[msg.sender][self.orders[id].asset_sold] += self.orders[id].quantity
self.orders[id].quantity = 0
self.orders[id].asset_sold = 0
def register_feed():
c = self.feedCount
self.feeds[c].owner = msg.sender
self.feedCount = c + 1
return(c)
def set_feed(id, v):
if self.feeds[id].owner == msg.sender:
self.feeds[id].value = v
def mk_cfd_offer(feed, asset, strike, leverage, min, max, maturity):
b = self.balances[msg.sender][asset]
req = max((strike - min) * leverage, (strike - max) * leverage)
assert b >= req
self.balances[msg.sender][asset] = b - req
c = self.cfdCount
self.cfds[c].maker = msg.sender
self.cfds[c].feed = feed
self.cfds[c].asset = asset
self.cfds[c].strike = strike
self.cfds[c].leverage = leverage
self.cfds[c].min = min
self.cfds[c].max = max
self.cfds[c].maturity = maturity
self.cfdCount = c + 1
return(c)
def accept_cfd_offer(c):
assert not self.cfds[c].acceptor and self.cfds[c].maker
asset = self.cfds[c].asset
strike = self.cfds[c].strike
min = self.cfds[c].min
max = self.cfds[c].max
leverage = self.cfds[c].leverage
b = self.balances[msg.sender][asset]
req = max((min - strike) * leverage, (max - strike) * leverage)
assert b >= req
self.balances[msg.sender][asset] = b - req
self.cfds[c].acceptor = msg.sender
self.cfds[c].maturity += block.timestamp
def claim_cfd_offer(c):
asset = self.cfds[c].asset
strike = self.cfds[c].strike
min = self.cfds[c].min
max = self.cfds[c].max
leverage = self.cfds[c].leverage
v = self.feeds[self.cfds[c].feed].value
assert v <= min or v >= max or block.timestamp >= self.cfds[c].maturity
maker_req = max((strike - min) * leverage, (strike - max) * leverage)
acceptor_req = max((min - strike) * leverage, (max - strike) * leverage)
paydelta = (strike - v) * leverage
self.add(self.cfds[c].maker, asset, maker_req + paydelta)
self.add(self.cfds[c].acceptor, asset, acceptor_req - paydelta)
self.cfds[c].maker = 0
self.cfds[c].acceptor = 0
self.cfds[c].feed = 0
self.cfds[c].asset = 0
self.cfds[c].strike = 0
self.cfds[c].leverage = 0
self.cfds[c].min = 0
self.cfds[c].max = 0
self.cfds[c].maturity = 0
def withdraw_cfd_offer(c):
if self.cfds[c].maker == msg.sender and not self.cfds[c].acceptor:
asset = self.cfds[c].asset
strike = self.cfds[c].strike
min = self.cfds[c].min
max = self.cfds[c].max
leverage = self.cfds[c].leverage
maker_req = max((strike - min) * leverage, (strike - max) * leverage)
self.balances[self.cfds[c].maker][asset] += maker_req
self.cfds[c].maker = 0
self.cfds[c].acceptor = 0
self.cfds[c].feed = 0
self.cfds[c].asset = 0
self.cfds[c].strike = 0
self.cfds[c].leverage = 0
self.cfds[c].min = 0
self.cfds[c].max = 0
self.cfds[c].maturity = 0
def mk_bet_offer(feed, asset, makerstake, acceptorstake, eqtest, maturity):
assert self.balances[msg.sender][asset] >= makerstake
c = self.betCount
self.bets[c].maker = msg.sender
self.bets[c].feed = feed
self.bets[c].asset = asset
self.bets[c].makerstake = makerstake
self.bets[c].acceptorstake = acceptorstake
self.bets[c].eqtest = eqtest
self.bets[c].maturity = maturity
self.balances[msg.sender][asset] -= makerstake
self.betCount = c + 1
return(c)
def accept_bet_offer(c):
assert self.bets[c].maker and not self.bets[c].acceptor
asset = self.bets[c].asset
acceptorstake = self.bets[c].acceptorstake
assert self.balances[msg.sender][asset] >= acceptorstake
self.balances[msg.sender][asset] -= acceptorstake
self.bets[c].acceptor = msg.sender
def claim_bet_offer(c):
assert block.timestamp >= self.bets[c].maturity
v = self.feeds[self.bets[c].feed].value
totalstake = self.bets[c].makerstake + self.bets[c].acceptorstake
if v == self.bets[c].eqtest:
self.add(self.bets[c].maker, self.bets[c].asset, totalstake)
else:
self.add(self.bets[c].acceptor, self.bets[c].asset, totalstake)
self.bets[c].maker = 0
self.bets[c].feed = 0
self.bets[c].asset = 0
self.bets[c].makerstake = 0
self.bets[c].acceptorstake = 0
self.bets[c].eqtest = 0
self.bets[c].maturity = 0
def cancel_bet(c):
assert not self.bets[c].acceptor and msg.sender == self.bets[c].maker
self.balances[msg.sender][self.bets[c].asset] += self.bets[c].makerstake
self.bets[c].maker = 0
self.bets[c].feed = 0
self.bets[c].asset = 0
self.bets[c].makerstake = 0
self.bets[c].acceptorstake = 0
self.bets[c].eqtest = 0
self.bets[c].maturity = 0
def dividend(holder_asset, divvying_asset, ratio):
i = 0
sz = self.assets[holder_asset].holdersCount
t = 0
holders = array(sz)
payments = array(sz)
while i < sz:
holders[i] = self.assets[holder_asset].holders[i]
payments[i] = self.balances[holders[i]][holder_asset] * ratio / 10^18
t += payments[i]
i += 1
if self.balances[msg.sender][divvying_asset] >= t:
i = 0
while i < sz:
self.add(holders[i], divvying_asset, payments[i])
i += 1
self.balances[msg.sender][divvying_asset] -= t

View File

@ -0,0 +1,69 @@
data heaps[2^50](owner, size, nodes[2^50](key, value))
data heapIndex
def register():
i = self.heapIndex
self.heaps[i].owner = msg.sender
self.heapIndex = i + 1
return(i)
def push(heap, key, value):
assert msg.sender == self.heaps[heap].owner
sz = self.heaps[heap].size
self.heaps[heap].nodes[sz].key = key
self.heaps[heap].nodes[sz].value = value
k = sz + 1
while k > 1:
bottom = self.heaps[heap].nodes[k].key
top = self.heaps[heap].nodes[k/2].key
if bottom < top:
tvalue = self.heaps[heap].nodes[k/2].value
bvalue = self.heaps[heap].nodes[k].value
self.heaps[heap].nodes[k].key = top
self.heaps[heap].nodes[k].value = tvalue
self.heaps[heap].nodes[k/2].key = bottom
self.heaps[heap].nodes[k/2].value = bvalue
k /= 2
else:
k = 0
self.heaps[heap].size = sz + 1
def pop(heap):
sz = self.heaps[heap].size
assert sz
prevtop = self.heaps[heap].nodes[1].value
self.heaps[heap].nodes[1].key = self.heaps[heap].nodes[sz].key
self.heaps[heap].nodes[1].value = self.heaps[heap].nodes[sz].value
self.heaps[heap].nodes[sz].key = 0
self.heaps[heap].nodes[sz].value = 0
top = self.heaps[heap].nodes[1].key
k = 1
while k * 2 < sz:
bottom1 = self.heaps[heap].nodes[k * 2].key
bottom2 = self.heaps[heap].nodes[k * 2 + 1].key
if bottom1 < top and (bottom1 < bottom2 or k * 2 + 1 >= sz):
tvalue = self.heaps[heap].nodes[1].value
bvalue = self.heaps[heap].nodes[k * 2].value
self.heaps[heap].nodes[k].key = bottom1
self.heaps[heap].nodes[k].value = bvalue
self.heaps[heap].nodes[k * 2].key = top
self.heaps[heap].nodes[k * 2].value = tvalue
k = k * 2
elif bottom2 < top and bottom2 < bottom1 and k * 2 + 1 < sz:
tvalue = self.heaps[heap].nodes[1].value
bvalue = self.heaps[heap].nodes[k * 2 + 1].value
self.heaps[heap].nodes[k].key = bottom2
self.heaps[heap].nodes[k].value = bvalue
self.heaps[heap].nodes[k * 2 + 1].key = top
self.heaps[heap].nodes[k * 2 + 1].value = tvalue
k = k * 2 + 1
else:
k = sz
self.heaps[heap].size = sz - 1
return(prevtop)
def top(heap):
return(self.heaps[heap].nodes[1].value)
def size(heap):
return(self.heaps[heap].size)

View File

@ -0,0 +1,53 @@
data campaigns[2^80](recipient, goal, deadline, contrib_total, contrib_count, contribs[2^50](sender, value))
def create_campaign(id, recipient, goal, timelimit):
if self.campaigns[id].recipient:
return(0)
self.campaigns[id].recipient = recipient
self.campaigns[id].goal = goal
self.campaigns[id].deadline = block.timestamp + timelimit
def contribute(id):
# Update contribution total
total_contributed = self.campaigns[id].contrib_total + msg.value
self.campaigns[id].contrib_total = total_contributed
# Record new contribution
sub_index = self.campaigns[id].contrib_count
self.campaigns[id].contribs[sub_index].sender = msg.sender
self.campaigns[id].contribs[sub_index].value = msg.value
self.campaigns[id].contrib_count = sub_index + 1
# Enough funding?
if total_contributed >= self.campaigns[id].goal:
send(self.campaigns[id].recipient, total_contributed)
self.clear(id)
return(1)
# Expired?
if block.timestamp > self.campaigns[id].deadline:
i = 0
c = self.campaigns[id].contrib_count
while i < c:
send(self.campaigns[id].contribs[i].sender, self.campaigns[id].contribs[i].value)
i += 1
self.clear(id)
return(2)
def progress_report(id):
return(self.campaigns[id].contrib_total)
# Clearing function for internal use
def clear(id):
if self == msg.sender:
self.campaigns[id].recipient = 0
self.campaigns[id].goal = 0
self.campaigns[id].deadline = 0
c = self.campaigns[id].contrib_count
self.campaigns[id].contrib_count = 0
self.campaigns[id].contrib_total = 0
i = 0
while i < c:
self.campaigns[id].contribs[i].sender = 0
self.campaigns[id].contribs[i].value = 0
i += 1

View File

@ -0,0 +1,136 @@
# 0: current epoch
# 1: number of proposals
# 2: master currency
# 3: last winning market
# 4: last txid
# 5: long-term ema currency units purchased
# 6: last block when currency units purchased
# 7: ether allocated to last round
# 8: last block when currency units claimed
# 9: ether allocated to current round
# 1000+: [proposal address, market ID, totprice, totvolume]
init:
# We technically have two levels of epoch here. We have
# one epoch of 1000, to synchronize with the 1000 epoch
# of the market, and then 100 of those epochs make a
# meta-epoch (I'll nominate the term "seculum") over
# which the futarchy protocol will take place
contract.storage[0] = block.number / 1000
# The master currency of the futarchy. The futarchy will
# assign currency units to whoever the prediction market
# thinks will best increase the currency's value
master_currency = create('subcurrency.se')
contract.storage[2] = master_currency
code:
curepoch = block.number / 1000
prevepoch = contract.storage[0]
if curepoch > prevepoch:
if (curepoch % 100) > 50:
# Collect price data
# We take an average over 50 subepochs to determine
# the price of each asset, weighting by volume to
# prevent abuse
contract.storage[0] = curepoch
i = 0
numprop = contract.storage[1]
while i < numprop:
market = contract.storage[1001 + i * 4]
price = call(market, 2)
volume = call(market, 3)
contract.storage[1002 + i * 4] += price
contract.storage[1003 + i * 4] += volume * price
i += 1
if (curepoch / 100) > (prevepoch / 100):
# If we are entering a new seculum, we determine the
# market with the highest total average price
best = 0
bestmarket = 0
besti = 0
i = 0
while i < numprop:
curtotprice = contract.storage[1002 + i * 4]
curvolume = contract.storage[1002 + i * 4]
curavgprice = curtotprice / curvolume
if curavgprice > best:
best = curavgprice
besti = i
bestmarket = contract.storage[1003 + i * 4]
i += 1
# Reset the number of proposals to 0
contract.storage[1] = 0
# Reward the highest proposal
call(contract.storage[2], [best, 10^9, 0], 3)
# Record the winning market so we can later appropriately
# compensate the participants
contract.storage[2] = bestmarket
# The amount of ether allocated to the last round
contract.storage[7] = contract.storage[9]
# The amount of ether allocated to the next round
contract.storage[9] = contract.balance / 2
# Make a proposal [0, address]
if msg.data[0] == 0 and curepoch % 100 < 50:
pid = contract.storage[1]
market = create('market.se')
c1 = create('subcurrency.se')
c2 = create('subcurrency.se')
call(market, [c1, c2], 2)
contract.storage[1000 + pid * 4] = msg.data[1]
contract.storage[1001 + pid * 4] = market
contract.storage[1] += 1
# Claim ether [1, address]
# One unit of the first currency in the last round's winning
# market entitles you to a quantity of ether that was decided
# at the start of that epoch
elif msg.data[0] == 1:
first_subcurrency = call(contract.storage[2], 3)
# We ask the first subcurrency contract what the last transaction was. The
# way to make a claim is to send the amount of first currency units that
# you wish to claim with, and then immediately call this contract. For security
# it makes sense to set up a tx which sends both messages in sequence atomically
data = call(first_subcurrency, [], 0, 4)
from = data[0]
to = data[1]
value = data[2]
txid = data[3]
if txid > contract.storage[4] and to == contract.address:
send(to, contract.storage[7] * value / 10^9)
contract.storage[4] = txid
# Claim second currency [2, address]
# One unit of the second currency in the last round's winning
# market entitles you to one unit of the futarchy's master
# currency
elif msg.data[0] == 2:
second_subcurrency = call(contract.storage[2], 3)
data = call(first_subcurrency, [], 0, 4)
from = data[0]
to = data[1]
value = data[2]
txid = data[3]
if txid > contract.storage[4] and to == contract.address:
call(contract.storage[2], [to, value], 2)
contract.storage[4] = txid
# Purchase currency for ether (target releasing 10^9 units per seculum)
# Price starts off 1 eth for 10^9 units but increases hyperbolically to
# limit issuance
elif msg.data[0] == 3:
pre_ema = contract.storage[5]
post_ema = pre_ema + msg.value
pre_reserve = 10^18 / (10^9 + pre_ema / 10^9)
post_reserve = 10^18 / (10^9 + post_ema / 10^9)
call(contract.storage[2], [msg.sender, pre_reserve - post_reserve], 2)
last_sold = contract.storage[6]
contract.storage[5] = pre_ema * (100000 + last_sold - block.number) + msg.value
contract.storage[6] = block.number
# Claim all currencies as the ether miner of the current block
elif msg.data[0] == 2 and msg.sender == block.coinbase and block.number > contract.storage[8]:
i = 0
numproposals = contract.storage[1]
while i < numproposals:
market = contract.storage[1001 + i * 3]
fc = call(market, 4)
sc = call(market, 5)
call(fc, [msg.sender, 1000], 2)
call(sc, [msg.sender, 1000], 2)
i += 1
contract.storage[8] = block.number

View File

@ -0,0 +1,55 @@
# 0: size
# 1-n: elements
init:
contract.storage[1000] = msg.sender
code:
# Only owner of the heap is allowed to modify it
if contract.storage[1000] != msg.sender:
stop
# push
if msg.data[0] == 0:
sz = contract.storage[0]
contract.storage[sz + 1] = msg.data[1]
k = sz + 1
while k > 1:
bottom = contract.storage[k]
top = contract.storage[k/2]
if bottom < top:
contract.storage[k] = top
contract.storage[k/2] = bottom
k /= 2
else:
k = 0
contract.storage[0] = sz + 1
# pop
elif msg.data[0] == 1:
sz = contract.storage[0]
if !sz:
return(0)
prevtop = contract.storage[1]
contract.storage[1] = contract.storage[sz]
contract.storage[sz] = 0
top = contract.storage[1]
k = 1
while k * 2 < sz:
bottom1 = contract.storage[k * 2]
bottom2 = contract.storage[k * 2 + 1]
if bottom1 < top and (bottom1 < bottom2 or k * 2 + 1 >= sz):
contract.storage[k] = bottom1
contract.storage[k * 2] = top
k = k * 2
elif bottom2 < top and bottom2 < bottom1 and k * 2 + 1 < sz:
contract.storage[k] = bottom2
contract.storage[k * 2 + 1] = top
k = k * 2 + 1
else:
k = sz
contract.storage[0] = sz - 1
return(prevtop)
# top
elif msg.data[0] == 2:
return(contract.storage[1])
# size
elif msg.data[0] == 3:
return(contract.storage[0])

View File

@ -0,0 +1,117 @@
# Creates a decentralized market between any two subcurrencies
# Here, the first subcurrency is the base asset and the second
# subcurrency is the asset priced against the base asset. Hence,
# "buying" refers to trading the first for the second, and
# "selling" refers to trading the second for the first
# storage 0: buy orders
# storage 1: sell orders
# storage 1000: first subcurrency
# storage 1001: last first subcurrency txid
# storage 2000: second subcurrency
# storage 2001: last second subcurrency txid
# storage 3000: current epoch
# storage 4000: price
# storage 4001: volume
init:
# Heap for buy orders
contract.storage[0] = create('heap.se')
# Heap for sell orders
contract.storage[1] = create('heap.se')
code:
# Initialize with [ first_subcurrency, second_subcurrency ]
if !contract.storage[1000]:
contract.storage[1000] = msg.data[0] # First subcurrency
contract.storage[1001] = -1
contract.storage[2000] = msg.data[1] # Second subcurrency
contract.storage[2001] = -1
contract.storage[3000] = block.number / 1000
stop
first_subcurrency = contract.storage[1000]
second_subcurrency = contract.storage[2000]
buy_heap = contract.storage[0]
sell_heap = contract.storage[1]
# This contract operates in "epochs" of 100 blocks
# At the end of each epoch, we process all orders
# simultaneously, independent of order. This algorithm
# prevents front-running, and generates a profit from
# the spread. The profit is permanently kept in the
# market (ie. destroyed), making both subcurrencies
# more valuable
# Epoch transition code
if contract.storage[3000] < block.number / 100:
done = 0
volume = 0
while !done:
# Grab the top buy and sell order from each heap
topbuy = call(buy_heap, 1)
topsell = call(sell_heap, 1)
# An order is recorded in the heap as:
# Buys: (2^48 - 1 - price) * 2^208 + units of first currency * 2^160 + from
# Sells: price * 2^208 + units of second currency * 2^160 + from
buyprice = -(topbuy / 2^208)
buyfcvalue = (topbuy / 2^160) % 2^48
buyer = topbuy % 2^160
sellprice = topsell / 2^208
sellscvalue = (topsell / 2^160) % 2^48
seller = topsell % 2^160
# Heap empty, or no more matching orders
if not topbuy or not topsell or buyprice < sellprice:
done = 1
else:
# Add to volume counter
volume += buyfcvalue
# Calculate how much of the second currency the buyer gets, and
# how much of the first currency the seller gets
sellfcvalue = sellscvalue / buyprice
buyscvalue = buyfcvalue * sellprice
# Send the currency units along
call(second_subcurrency, [buyer, buyscvalue], 2)
call(first_subcurrency, [seller, sellfcvalue], 2)
if volume:
contract.storage[4000] = (buyprice + sellprice) / 2
contract.storage[4001] = volume
contract.storage[3000] = block.number / 100
# Make buy order [0, price]
if msg.data[0] == 0:
# We ask the first subcurrency contract what the last transaction was. The
# way to make a buy order is to send the amount of first currency units that
# you wish to buy with, and then immediately call this contract. For security
# it makes sense to set up a tx which sends both messages in sequence atomically
data = call(first_subcurrency, [], 0, 4)
from = data[0]
to = data[1]
value = data[2]
txid = data[3]
price = msg.data[1]
if txid > contract.storage[1001] and to == contract.address:
contract.storage[1001] = txid
# Adds the order to the heap
call(buy_heap, [0, -price * 2^208 + (value % 2^48) * 2^160 + from], 2)
# Make sell order [1, price]
elif msg.data[0] == 1:
# Same mechanics as buying
data = call(second_subcurrency, [], 0, 4)
from = data[0]
to = data[1]
value = data[2]
txid = data[3]
price = msg.data[1]
if txid > contract.storage[2001] and to == contract.address:
contract.storage[2001] = txid
call(sell_heap, [0, price * 2^208 + (value % 2^48) * 2^160 + from], 2)
# Ask for price
elif msg.data[0] == 2:
return(contract.storage[4000])
# Ask for volume
elif msg.data[0] == 3:
return(contract.storage[1000])
# Ask for first currency
elif msg.data[0] == 4:
return(contract.storage[2000])
# Ask for second currency
elif msg.data[0] == 5:
return(contract.storage[4001])

View File

@ -0,0 +1,35 @@
# Initialization
# Admin can issue and delete at will
init:
contract.storage[0] = msg.sender
code:
# If a message with one item is sent, that's a balance query
if msg.datasize == 1:
addr = msg.data[0]
return(contract.storage[addr])
# If a message with two items [to, value] are sent, that's a transfer request
elif msg.datasize == 2:
from = msg.sender
fromvalue = contract.storage[from]
to = msg.data[0]
value = msg.data[1]
if fromvalue >= value and value > 0 and to > 4:
contract.storage[from] = fromvalue - value
contract.storage[to] += value
contract.storage[2] = from
contract.storage[3] = to
contract.storage[4] = value
contract.storage[5] += 1
return(1)
return(0)
elif msg.datasize == 3 and msg.sender == contract.storage[0]:
# Admin can issue at will by sending a [to, value, 0] message
if msg.data[2] == 0:
contract.storage[msg.data[0]] += msg.data[1]
# Change admin [ newadmin, 0, 1 ]
# Set admin to 0 to disable administration
elif msg.data[2] == 1:
contract.storage[0] = msg.data[0]
# Fetch last transaction
else:
return([contract.storage[2], contract.storage[3], contract.storage[4], contract.storage[5]], 4)

View File

@ -0,0 +1,39 @@
from __future__ import print_function
import pyethereum
t = pyethereum.tester
s = t.state()
# Create currencies
c1 = s.contract('subcurrency.se')
print("First currency: %s" % c1)
c2 = s.contract('subcurrency.se')
print("First currency: %s" % c2)
# Allocate units
s.send(t.k0, c1, 0, [t.a0, 1000, 0])
s.send(t.k0, c1, 0, [t.a1, 1000, 0])
s.send(t.k0, c2, 0, [t.a2, 1000000, 0])
s.send(t.k0, c2, 0, [t.a3, 1000000, 0])
print("Allocated units")
# Market
m = s.contract('market.se')
s.send(t.k0, m, 0, [c1, c2])
# Place orders
s.send(t.k0, c1, 0, [m, 1000])
s.send(t.k0, m, 0, [0, 1200])
s.send(t.k1, c1, 0, [m, 1000])
s.send(t.k1, m, 0, [0, 1400])
s.send(t.k2, c2, 0, [m, 1000000])
s.send(t.k2, m, 0, [1, 800])
s.send(t.k3, c2, 0, [m, 1000000])
s.send(t.k3, m, 0, [1, 600])
print("Orders placed")
# Next epoch and ping
s.mine(100)
print("Mined 100")
s.send(t.k0, m, 0, [])
print("Updating")
# Check
assert s.send(t.k0, c2, 0, [t.a0]) == [800000]
assert s.send(t.k0, c2, 0, [t.a1]) == [600000]
assert s.send(t.k0, c1, 0, [t.a2]) == [833]
assert s.send(t.k0, c1, 0, [t.a3]) == [714]
print("Balance checks passed")

View File

@ -0,0 +1,12 @@
# Database updateable only by the original creator
data creator
def init():
self.creator = msg.sender
def update(k, v):
if msg.sender == self.creator:
self.storage[k] = v
def query(k):
return(self.storage[k])

View File

@ -0,0 +1,40 @@
# So I looked up on Wikipedia what Jacobian form actually is, and noticed that it's
# actually a rather different and more clever construction than the naive version
# that I created. It may possible to achieve a further 20-50% savings by applying
# that version.
extern all: [call]
data JORDANMUL
data JORDANADD
data EXP
def init():
self.JORDANMUL = create('jacobian_mul.se')
self.JORDANADD = create('jacobian_add.se')
self.EXP = create('modexp.se')
def call(h, v, r, s):
N = -432420386565659656852420866394968145599
P = -4294968273
h = mod(h, N)
r = mod(r, P)
s = mod(s, N)
Gx = 55066263022277343669578718895168534326250603453777594175500187360389116729240
Gy = 32670510020758816978083085130507043184471273380659243275938904335757337482424
x = r
xcubed = mulmod(mulmod(x, x, P), x, P)
beta = self.EXP.call(addmod(xcubed, 7, P), div(P + 1, 4), P)
# Static-gascost ghetto conditional
y_is_positive = mod(v, 2) xor mod(beta, 2)
y = beta * y_is_positive + (P - beta) * (1 - y_is_positive)
GZ = self.JORDANMUL.call(Gx, 1, Gy, 1, N - h, outsz=4)
XY = self.JORDANMUL.call(x, 1, y, 1, s, outsz=4)
COMB = self.JORDANADD.call(GZ[0], GZ[1], GZ[2], GZ[3], XY[0], XY[1], XY[2], XY[3], 1, outsz=5)
COMB[4] = self.EXP.call(r, N - 2, N)
Q = self.JORDANMUL.call(data=COMB, datasz=5, outsz=4)
ox = mulmod(Q[0], self.EXP.call(Q[1], P - 2, P), P)
oy = mulmod(Q[2], self.EXP.call(Q[3], P - 2, P), P)
return([ox, oy], 2)

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,32 @@
extern all: [call]
data DOUBLE
def init():
self.DOUBLE = create('jacobian_double.se')
def call(axn, axd, ayn, ayd, bxn, bxd, byn, byd):
if !axn and !ayn:
o = [bxn, bxd, byn, byd]
if !bxn and !byn:
o = [axn, axd, ayn, ayd]
if o:
return(o, 4)
with P = -4294968273:
if addmod(mulmod(axn, bxd, P), P - mulmod(axd, bxn, P), P) == 0:
if addmod(mulmod(ayn, byd, P), P - mulmod(ayd, byn, P), P) == 0:
return(self.DOUBLE.call(axn, axd, ayn, ayd, outsz=4), 4)
else:
return([0, 1, 0, 1], 4)
with mn = mulmod(addmod(mulmod(byn, ayd, P), P - mulmod(ayn, byd, P), P), mulmod(bxd, axd, P), P):
with md = mulmod(mulmod(byd, ayd, P), addmod(mulmod(bxn, axd, P), P - mulmod(axn, bxd, P), P), P):
with msqn = mulmod(mn, mn, P):
with msqd = mulmod(md, md, P):
with msqman = addmod(mulmod(msqn, axd, P), P - mulmod(msqd, axn, P), P):
with msqmad = mulmod(msqd, axd, P):
with xn = addmod(mulmod(msqman, bxd, P), P - mulmod(msqmad, bxn, P), P):
with xd = mulmod(msqmad, bxd, P):
with mamxn = mulmod(mn, addmod(mulmod(axn, xd, P), P - mulmod(xn, axd, P), P), P):
with mamxd = mulmod(md, mulmod(axd, xd, P), P):
with yn = addmod(mulmod(mamxn, ayd, P), P - mulmod(mamxd, ayn, P), P):
with yd = mulmod(mamxd, ayd, P):
return([xn, xd, yn, yd], 4)

View File

@ -0,0 +1,16 @@
def call(axn, axd, ayn, ayd):
if !axn and !ayn:
return([0, 1, 0, 1], 4)
with P = -4294968273:
# No need to add (A, 1) because A = 0 for bitcoin
with mn = mulmod(mulmod(mulmod(axn, axn, P), 3, P), ayd, P):
with md = mulmod(mulmod(axd, axd, P), mulmod(ayn, 2, P), P):
with msqn = mulmod(mn, mn, P):
with msqd = mulmod(md, md, P):
with xn = addmod(mulmod(msqn, axd, P), P - mulmod(msqd, mulmod(axn, 2, P), P), P):
with xd = mulmod(msqd, axd, P):
with mamxn = mulmod(addmod(mulmod(axn, xd, P), P - mulmod(axd, xn, P), P), mn, P):
with mamxd = mulmod(mulmod(axd, xd, P), md, P):
with yn = addmod(mulmod(mamxn, ayd, P), P - mulmod(mamxd, ayn, P), P):
with yd = mulmod(mamxd, ayd, P):
return([xn, xd, yn, yd], 4)

View File

@ -0,0 +1,37 @@
# Expected gas cost
#
# def expect(n, point_at_infinity=False):
# n = n % (2**256 - 432420386565659656852420866394968145599)
# if point_at_infinity:
# return 79
# if n == 0:
# return 34479
# L = int(1 + math.log(n) / math.log(2))
# H = len([x for x in b.encode(n, 2) if x == '1'])
# return 34221 + 94 * L + 343 * H
data DOUBLE
data ADD
def init():
self.DOUBLE = create('jacobian_double.se')
self.ADD = create('jacobian_add.se')
def call(axn, axd, ayn, ayd, n):
n = mod(n, -432420386565659656852420866394968145599)
if !axn * !ayn + !n: # Constant-gas version of !axn and !ayn or !n
return([0, 1, 0, 1], 4)
with o = [0, 0, 1, 0, 1, 0, 0, 0, 0]:
with b = 2 ^ 255:
while gt(b, 0):
if n & b:
~call(20000, self.DOUBLE, 0, o + 31, 129, o + 32, 128)
o[5] = axn
o[6] = axd
o[7] = ayn
o[8] = ayd
~call(20000, self.ADD, 0, o + 31, 257, o + 32, 128)
else:
~call(20000, self.DOUBLE, 0, o + 31, 129, o + 32, 128)
b = div(b, 2)
return(o + 32, 4)

View File

@ -0,0 +1,11 @@
def call(b, e, m):
with o = 1:
with bit = 2 ^ 255:
while gt(bit, 0):
# A touch of loop unrolling for 20% efficiency gain
o = mulmod(mulmod(o, o, m), b ^ !(!(e & bit)), m)
o = mulmod(mulmod(o, o, m), b ^ !(!(e & div(bit, 2))), m)
o = mulmod(mulmod(o, o, m), b ^ !(!(e & div(bit, 4))), m)
o = mulmod(mulmod(o, o, m), b ^ !(!(e & div(bit, 8))), m)
bit = div(bit, 16)
return(o)

View File

@ -0,0 +1,78 @@
import bitcoin as b
import math
import sys
def signed(o):
return map(lambda x: x - 2**256 if x >= 2**255 else x, o)
def hamming_weight(n):
return len([x for x in b.encode(n, 2) if x == '1'])
def binary_length(n):
return len(b.encode(n, 2))
def jacobian_mul_substitute(A, B, C, D, N):
if A == 0 and C == 0 or (N % b.N) == 0:
return {"gas": 86, "output": [0, 1, 0, 1]}
else:
output = b.jordan_multiply(((A, B), (C, D)), N)
return {
"gas": 35262 + 95 * binary_length(N % b.N) + 355 * hamming_weight(N % b.N),
"output": signed(list(output[0]) + list(output[1]))
}
def jacobian_add_substitute(A, B, C, D, E, F, G, H):
if A == 0 or E == 0:
gas = 149
elif (A * F - B * E) % b.P == 0:
if (C * H - D * G) % b.P == 0:
gas = 442
else:
gas = 177
else:
gas = 301
output = b.jordan_add(((A, B), (C, D)), ((E, F), (G, H)))
return {
"gas": gas,
"output": signed(list(output[0]) + list(output[1]))
}
def modexp_substitute(base, exp, mod):
return {
"gas": 5150,
"output": signed([pow(base, exp, mod) if mod > 0 else 0])
}
def ecrecover_substitute(z, v, r, s):
P, A, B, N, Gx, Gy = b.P, b.A, b.B, b.N, b.Gx, b.Gy
x = r
beta = pow(x*x*x+A*x+B, (P + 1) / 4, P)
BETA_PREMIUM = modexp_substitute(x, (P + 1) / 4, P)["gas"]
y = beta if v % 2 ^ beta % 2 else (P - beta)
Gz = b.jordan_multiply(((Gx, 1), (Gy, 1)), (N - z) % N)
GZ_PREMIUM = jacobian_mul_substitute(Gx, 1, Gy, 1, (N - z) % N)["gas"]
XY = b.jordan_multiply(((x, 1), (y, 1)), s)
XY_PREMIUM = jacobian_mul_substitute(x, 1, y, 1, s % N)["gas"]
Qr = b.jordan_add(Gz, XY)
QR_PREMIUM = jacobian_add_substitute(Gz[0][0], Gz[0][1], Gz[1][0], Gz[1][1],
XY[0][0], XY[0][1], XY[1][0], XY[1][1]
)["gas"]
Q = b.jordan_multiply(Qr, pow(r, N - 2, N))
Q_PREMIUM = jacobian_mul_substitute(Qr[0][0], Qr[0][1], Qr[1][0], Qr[1][1],
pow(r, N - 2, N))["gas"]
R_PREMIUM = modexp_substitute(r, N - 2, N)["gas"]
OX_PREMIUM = modexp_substitute(Q[0][1], P - 2, P)["gas"]
OY_PREMIUM = modexp_substitute(Q[1][1], P - 2, P)["gas"]
Q = b.from_jordan(Q)
return {
"gas": 991 + BETA_PREMIUM + GZ_PREMIUM + XY_PREMIUM + QR_PREMIUM +
Q_PREMIUM + R_PREMIUM + OX_PREMIUM + OY_PREMIUM,
"output": signed(Q)
}

View File

@ -0,0 +1,129 @@
import bitcoin as b
import random
import sys
import math
from pyethereum import tester as t
import substitutes
import time
vals = [random.randrange(2**256) for i in range(12)]
test_points = [list(p[0]) + list(p[1]) for p in
[b.jordan_multiply(((b.Gx, 1), (b.Gy, 1)), r) for r in vals]]
G = [b.Gx, 1, b.Gy, 1]
Z = [0, 1, 0, 1]
def neg_point(p):
return [p[0], b.P - p[1], p[2], b.P - p[3]]
s = t.state()
s.block.gas_limit = 10000000
t.gas_limit = 1000000
c = s.contract('modexp.se')
print "Starting modexp tests"
for i in range(0, len(vals) - 2, 3):
o1 = substitutes.modexp_substitute(vals[i], vals[i+1], vals[i+2])
o2 = s.profile(t.k0, c, 0, funid=0, abi=vals[i:i+3])
#assert o1["gas"] == o2["gas"], (o1, o2)
assert o1["output"] == o2["output"], (o1, o2)
c = s.contract('jacobian_add.se')
print "Starting addition tests"
for i in range(2):
P = test_points[i * 2]
Q = test_points[i * 2 + 1]
NP = neg_point(P)
o1 = substitutes.jacobian_add_substitute(*(P + Q))
o2 = s.profile(t.k0, c, 0, funid=0, abi=P + Q)
#assert o1["gas"] == o2["gas"], (o1, o2)
assert o1["output"] == o2["output"], (o1, o2)
o1 = substitutes.jacobian_add_substitute(*(P + NP))
o2 = s.profile(t.k0, c, 0, funid=0, abi=P + NP)
#assert o1["gas"] == o2["gas"], (o1, o2)
assert o1["output"] == o2["output"], (o1, o2)
o1 = substitutes.jacobian_add_substitute(*(P + P))
o2 = s.profile(t.k0, c, 0, funid=0, abi=P + P)
#assert o1["gas"] == o2["gas"], (o1, o2)
assert o1["output"] == o2["output"], (o1, o2)
o1 = substitutes.jacobian_add_substitute(*(P + Z))
o2 = s.profile(t.k0, c, 0, funid=0, abi=P + Z)
#assert o1["gas"] == o2["gas"], (o1, o2)
assert o1["output"] == o2["output"], (o1, o2)
o1 = substitutes.jacobian_add_substitute(*(Z + P))
o2 = s.profile(t.k0, c, 0, funid=0, abi=Z + P)
#assert o1["gas"] == o2["gas"], (o1, o2)
assert o1["output"] == o2["output"], (o1, o2)
c = s.contract('jacobian_mul.se')
print "Starting multiplication tests"
mul_tests = [
Z + [0],
Z + [vals[0]],
test_points[0] + [0],
test_points[1] + [b.N],
test_points[2] + [1],
test_points[2] + [2],
test_points[2] + [3],
test_points[2] + [4],
test_points[3] + [5],
test_points[3] + [6],
test_points[4] + [7],
test_points[4] + [2**254],
test_points[4] + [vals[1]],
test_points[4] + [vals[2]],
test_points[4] + [vals[3]],
test_points[5] + [2**256 - 1],
]
for i, test in enumerate(mul_tests):
print 'trying mul_test %i' % i, test
o1 = substitutes.jacobian_mul_substitute(*test)
o2 = s.profile(t.k0, c, 0, funid=0, abi=test)
# assert o1["gas"] == o2["gas"], (o1, o2, test)
assert o1["output"] == o2["output"], (o1, o2, test)
c = s.contract('ecrecover.se')
print "Starting ecrecover tests"
for i in range(5):
print 'trying ecrecover_test', vals[i*2], vals[i*2+1]
k = vals[i*2]
h = vals[i*2+1]
V, R, S = b.ecdsa_raw_sign(b.encode(h, 256, 32), k)
aa = time.time()
o1 = substitutes.ecrecover_substitute(h, V, R, S)
print 'sub', time.time() - aa
a = time.time()
o2 = s.profile(t.k0, c, 0, funid=0, abi=[h, V, R, S])
print time.time() - a
# assert o1["gas"] == o2["gas"], (o1, o2, h, V, R, S)
assert o1["output"] == o2["output"], (o1, o2, h, V, R, S)
# Explicit tests
data = [[
0xf007a9c78a4b2213220adaaf50c89a49d533fbefe09d52bbf9b0da55b0b90b60,
0x1b,
0x5228fc9e2fabfe470c32f459f4dc17ef6a0a81026e57e4d61abc3bc268fc92b5,
0x697d4221cd7bc5943b482173de95d3114b9f54c5f37cc7f02c6910c6dd8bd107
]]
for datum in data:
o1 = substitutes.ecrecover_substitute(*datum)
o2 = s.profile(t.k0, c, 0, funid=0, abi=datum)
#assert o1["gas"] == o2["gas"], (o1, o2, datum)
assert o1["output"] == o2["output"], (o1, o2, datum)

View File

@ -0,0 +1,45 @@
if msg.data[0] == 0:
new_id = contract.storage[-1]
# store [from, to, value, maxvalue, timeout] in contract storage
contract.storage[new_id] = msg.sender
contract.storage[new_id + 1] = msg.data[1]
contract.storage[new_id + 2] = 0
contract.storage[new_id + 3] = msg.value
contract.storage[new_id + 4] = 2^254
# increment next id
contract.storage[-1] = new_id + 10
# return id of this channel
return(new_id)
# Increase payment on channel: [1, id, value, v, r, s]
elif msg.data[0] == 1:
# Ecrecover native extension; will be a different address in testnet and live
ecrecover = 0x46a8d0b21b1336d83b06829f568d7450df36883f
# Message data parameters
id = msg.data[1] % 2^160
value = msg.data[2]
# Determine sender from signature
h = sha3([id, value], 2)
sender = call(ecrecover, [h, msg.data[3], msg.data[4], msg.data[5]], 4)
# Check sender matches and new value is greater than old
if sender == contract.storage[id]:
if value > contract.storage[id + 2] and value <= contract.storage[id + 3]:
# Update channel, increasing value and setting timeout
contract.storage[id + 2] = value
contract.storage[id + 4] = block.number + 1000
# Cash out channel: [2, id]
elif msg.data[0] == 2:
id = msg.data[1] % 2^160
# Check if timeout has run out
if block.number >= contract.storage[id + 3]:
# Send funds
send(contract.storage[id + 1], contract.storage[id + 2])
# Send refund
send(contract.storage[id], contract.storage[id + 3] - contract.storage[id + 2])
# Clear storage
contract.storage[id] = 0
contract.storage[id + 1] = 0
contract.storage[id + 2] = 0
contract.storage[id + 3] = 0
contract.storage[id + 4] = 0

View File

@ -0,0 +1,19 @@
# An implementation of a contract for storing a key/value binding
init:
# Set owner
contract.storage[0] = msg.sender
code:
# Check ownership
if msg.sender == contract.storage[0]:
# Get: returns (found, val)
if msg.data[0] == 0:
s = sha3(msg.data[1])
return([contract.storage[s], contract.storage[s+1]], 2)
# Set: sets map[k] = v
elif msg.data[0] == 1:
s = sha3(msg.data[1])
contract.storage[s] = 1
contract.storage[s + 1] = msg.data[2]
# Suicide
elif msg.data[2] == 1:
suicide(0)

View File

@ -0,0 +1,14 @@
init:
contract.storage[0] = msg.sender
code:
if msg.sender != contract.storage[0]:
stop
i = 0
while i < ~calldatasize():
to = ~calldataload(i)
value = ~calldataload(i+20) / 256^12
datasize = ~calldataload(i+32) / 256^30
data = alloc(datasize)
~calldatacopy(data, i+34, datasize)
~call(tx.gas - 25, to, value, data, datasize, 0, 0)
i += 34 + datasize

View File

@ -0,0 +1,166 @@
# Exists in state:
# (i) last committed block
# (ii) chain of uncommitted blocks (linear only)
# (iii) transactions, each tx with an associated block number
#
# Uncommitted block =
# [ numtxs, numkvs, tx1 (N words), tx2 (N words) ..., [k1, v1], [k2, v2], [k3, v3] ... ]
#
# Block checking process
#
# Suppose last committed state is m
# Last uncommitted state is n
# Contested block is b
#
# 1. Temporarily apply all state transitions from
# m to b
# 2. Run code, get list of changes
# 3. Check is list of changes matches deltas
# * if yes, do nothing
# * if no, set last uncommitted state to pre-b
#
# Storage variables:
#
# Last committed block: 0
# Last uncommitted block: 1
# Contract holding code: 2
# Uncommitted map: 3
# Transaction length (parameter): 4
# Block b: 2^160 + b * 2^40:
# + 1: submission blknum
# + 2: submitter
# + 3: data in uncommitted block format above
# Last committed storage:
# sha3(k): index k
# Initialize: [0, c, txlength], set address of the code-holding contract and the transaction
# length
if not contract.storage[2]:
contract.storage[2] = msg.data[1]
contract.storage[4] = msg.data[2]
stop
# Sequentially commit all uncommitted blocks that are more than 1000 mainchain-blocks old
last_committed_block = contract.storage[0]
last_uncommitted_block = contract.storage[1]
lcb_storage_index = 2^160 + last_committed_block * 2^40
while contract.storage[lcb_storage_index + 1] < block.number - 1000 and last_committed_block < last_uncommitted_block:
kvpairs = contract.storage[lcb_storage_index]
i = 0
while i < kvpairs:
k = contract.storage[lcb_storage_index + 3 + i * 2]
v = contract.storage[lcb_storage_index + 4 + i * 2]
contract.storage[sha3(k)] = v
i += 1
last_committed_block += 1
lcb_storage_index += 2^40
contract.storage[0] = last_committed_block
# Propose block: [ 0, block number, data in block format above ... ]
if msg.data[0] == 0:
blknumber = msg.data[1]
# Block number must be correct
if blknumber != contract.storage[1]:
stop
# Deposit requirement
if msg.value < 10^19:
stop
# Store the proposal in storage as
# [ 0, main-chain block number, sender, block data...]
start_index = 2^160 + blknumber * 2^40
numkvs = (msg.datasize - 2) / 2
contract.storage[start_index + 1] = block.number
1ontract.storage[start_index + 2] = msg.sender
i = 0
while i < msg.datasize - 2:
contract.storage[start_index + 3 + i] = msg.data[2 + i]
i += 1
contract.storage[1] = blknumber + 1
# Challenge block: [ 1, b ]
elif msg.data[0] == 1:
blknumber = msg.data[1]
txwidth = contract.storage[4]
last_uncommitted_block = contract.storage[1]
last_committed_block = contract.storage[0]
# Cannot challenge nonexistent or committed blocks
if blknumber <= last_uncommitted_block or blknumber > last_committed_block:
stop
# Create a contract to serve as a map that maintains keys and values
# temporarily
tempstore = create('map.se')
contract.storage[3] = tempstore
# Unquestioningly apply the state transitions from the last committed block
# up to b
b = last_committed_block
cur_storage_index = 2^160 + last_committed_block * 2^40
while b < blknumber:
numtxs = contract.storage[cur_storage_index + 3]
numkvs = contract.storage[cur_storage_index + 4]
kv0index = cur_storage_index + 5 + numtxs * txwidth
i = 0
while i < numkvs:
k = contract.storage[kv0index + i * 2]
v = contract.storage[kx0index + i * 2 + 1]
call(tempstore, [1, k, v], 3)
i += 1
b += 1
cur_storage_index += 2^40
# Run the actual code, and see what state transitions it outputs
# The way that the code is expected to work is to:
#
# (1) take as input the list of transactions (the contract should
# use msg.datasize to determine how many txs there are, and it should
# be aware of the value of txwidth)
# (2) call this contract with [2, k] to read current state data
# (3) call this contract with [3, k, v] to write current state data
# (4) return as output a list of all state transitions that it made
# in the form [kvcount, k1, v1, k2, v2 ... ]
#
# The reason for separating (2) from (3) is that sometimes the state
# transition may end up changing a given key many times, and we don't
# need to inefficiently store that in storage
numkvs = contract.storage[cur_storage_index + 3]
numtxs = contract.storage[cur_storage_index + 4]
# Populate input array
inpwidth = numtxs * txwidth
inp = array(inpwidth)
i = 0
while i < inpwidth:
inp[i] = contract.storage[cur_storage_index + 5 + i]
i += 1
out = call(contract.storage[2], inp, inpwidth, numkvs * 2 + 1)
# Check that the number of state transitions is the same
if out[0] != kvcount:
send(msg.sender, 10^19)
contract.storage[0] = last_committed_block
stop
kv0index = cur_storage_index + 5 + numtxs * txwidth
i = 0
while i < kvcount:
# Check that each individual state transition matches
k = contract.storage[kv0index + i * 2 + 1]
v = contract.storage[kv0index + i * 2 + 2]
if k != out[i * 2 + 1] or v != out[i * 2 + 2]:
send(msg.sender, 10^19)
contract.storage[0] = last_committed_block
stop
i += 1
# Suicide tempstore
call(tempstore, 2)
# Read data [2, k]
elif msg.data[0] == 2:
tempstore = contract.storage[3]
o = call(tempstore, [0, msg.data[1]], 2, 2)
if o[0]:
return(o[1])
else:
return contract.storage[sha3(msg.data[1])]
# Write data [3, k, v]
elif msg.data[0] == 3:
tempstore = contract.storage[3]
call(tempstore, [1, msg.data[1], msg.data[2]], 3, 2)

View File

@ -0,0 +1,31 @@
type f: [a, b, c, d, e]
macro f($a) + f($b):
f(add($a, $b))
macro f($a) - f($b):
f(sub($a, $b))
macro f($a) * f($b):
f(mul($a, $b) / 10000)
macro f($a) / f($b):
f(sdiv($a * 10000, $b))
macro f($a) % f($b):
f(smod($a, $b))
macro f($v) = f($w):
$v = $w
macro unfify(f($a)):
$a / 10000
macro fify($a):
f($a * 10000)
a = fify(5)
b = fify(2)
c = a / b
e = c + (a / b)
return(unfify(e))

View File

@ -0,0 +1,116 @@
macro smin($a, $b):
with $1 = $a:
with $2 = $b:
if(slt($1, $2), $1, $2)
macro smax($a, $b):
with $1 = $a:
with $2 = $b:
if(slt($1, $2), $2, $1)
def omul(x, y):
o = expose(mklong(x) * mklong(y))
return(slice(o, 1), o[0]+1)
def oadd(x, y):
o = expose(mklong(x) + mklong(y))
return(slice(o, 1), o[0]+1)
def osub(x, y):
o = expose(mklong(x) - mklong(y))
return(slice(o, 1), o[0]+1)
def odiv(x, y):
o = expose(mklong(x) / mklong(y))
return(slice(o, 1), o[0]+1)
def comb(a:a, b:a, sign):
sz = smax(a[0], b[0])
msz = smin(a[0], b[0])
c = array(sz + 2)
c[0] = sz
i = 0
carry = 0
while i < msz:
m = a[i + 1] + sign * b[i + 1] + carry
c[i + 1] = mod(m + 2^127, 2^128) - 2^127
carry = (div(m + 2^127, 2^128) + 2^127) % 2^128 - 2^127
i += 1
u = if(a[0] > msz, a, b)
s = if(a[0] > msz, 1, sign)
while i < sz:
m = s * u[i + 1] + carry
c[i + 1] = mod(m + 2^127, 2^128) - 2^127
carry = (div(m + 2^127, 2^128) + 2^127) % 2^128 - 2^127
i += 1
if carry:
c[0] += 1
c[sz + 1] = carry
return(c, c[0]+1)
def mul(a:a, b:a):
c = array(a[0] + b[0] + 2)
c[0] = a[0] + b[0]
i = 0
while i < a[0]:
j = 0
carry = 0
while j < b[0]:
m = c[i + j + 1] + a[i + 1] * b[j + 1] + carry
c[i + j + 1] = mod(m + 2^127, 2^128) - 2^127
carry = (div(m + 2^127, 2^128) + 2^127) % 2^128 - 2^127
j += 1
if carry:
c[0] = a[0] + b[0] + 1
c[i + j + 1] += carry
i += 1
return(c, c[0]+1)
macro long($a) + long($b):
long(self.comb($a:$a[0]+1, $b:$b[0]+1, 1, outsz=$a[0]+$b[0]+2))
macro long($a) - long($b):
long(self.comb($a:$a[0]+1, $b:$b[0]+1, -1, outsz=$a[0]+$b[0]+2))
macro long($a) * long($b):
long(self.mul($a:$a[0]+1, $b:$b[0]+1, outsz=$a[0]+$b[0]+2))
macro long($a) / long($b):
long(self.div($a:$a[0]+1, $b:$b[0]+1, outsz=$a[0]+$b[0]+2))
macro mulexpand(long($a), $k, $m):
long:
with $c = array($a[0]+k+2):
$c[0] = $a[0]+$k
with i = 0:
while i < $a[0]:
v = $a[i+1] * $m + $c[i+$k+1]
$c[i+$k+1] = mod(v + 2^127, 2^128) - 2^127
$c[i+$k+2] = div(v + 2^127, 2^128)
i += 1
$c
def div(a:a, b:a):
asz = a[0]
bsz = b[0]
while b[bsz] == 0 and bsz > 0:
bsz -= 1
c = array(asz+2)
c[0] = asz+1
while 1:
while a[asz] == 0 and asz > 0:
asz -= 1
if asz < bsz:
return(c, c[0]+1)
sub = expose(mulexpand(long(b), asz - bsz, a[asz] / b[bsz]))
c[asz - bsz+1] = a[asz] / b[bsz]
a = expose(long(a) - long(sub))
a[asz-1] += 2^128 * a[asz]
a[asz] = 0
macro mklong($i):
long([2, mod($i + 2^127, 2^128) - 2^127, div($i + 2^127, 2^128)])
macro expose(long($i)):
$i

View File

@ -0,0 +1,2 @@
def double(v):
return(v*2)

View File

@ -0,0 +1,187 @@
# mutuala - subcurrency
# We want to issue a currency that reduces in value as you store it through negative interest.
# That negative interest would be stored in a commons account. It's like the p2p version of a
# capital tax
# the same things goes for transactions - you pay as you use the currency. However, the more
# you pay, the more you get to say about what the tax is used for
# each participant can propose a recipient for a payout to be made out of the commons account,
# others can vote on it by awarding it tax_credits.
# TODO should proposal have expiration timestamp?, after which the tax_credits are refunded
# TODO multiple proposals can take more credits that available in the Commons, how to handle this
# TODO how to handle lost accounts, after which no longer possible to get 2/3 majority
shared:
COMMONS = 42
ADMIN = 666
CAPITAL_TAX_PER_DAY = 7305 # 5% per year
PAYMENT_TAX = 20 # 5%
ACCOUNT_LIST_OFFSET = 2^160
ACCOUNT_MAP_OFFSET = 2^161
PROPOSAL_LIST_OFFSET = 2^162
PROPOSAL_MAP_OFFSET = 2^163
init:
contract.storage[ADMIN] = msg.sender
contract.storage[ACCOUNT_LIST_OFFSET - 1] = 1
contract.storage[ACCOUNT_LIST_OFFSET] = msg.sender
contract.storage[ACCOUNT_MAP_OFFSET + msg.sender] = 10^12
contract.storage[ACCOUNT_MAP_OFFSET + msg.sender + 1] = block.timestamp
# contract.storage[COMMONS] = balance commons
# contract.storage[ACCOUNT_LIST_OFFSET - 1] = number of accounts
# contract.storage[ACCOUNT_LIST_OFFSET + n] = account n
# contract.storage[PROPOSAL_LIST_OFFSET - 1] contains the number of proposals
# contract.storage[PROPOSAL_LIST_OFFSET + n] = proposal n
# per account:
# contract.storage[ACCOUNT_MAP_OFFSET + account] = balance
# contract.storage[ACCOUNT_MAP_OFFSET + account+1] = timestamp_last_transaction
# contract.storage[ACCOUNT_MAP_OFFSET + account+2] = tax_credits
# per proposal:
# contract.storage[PROPOSAL_MAP_OFFSET + proposal_id] = recipient
# contract.storage[PROPOSAL_MAP_OFFSET + proposal_id+1] = amount
# contract.storage[PROPOSAL_MAP_OFFSET + proposal_id+2] = total vote credits
code:
if msg.data[0] == "suicide" and msg.sender == contract.storage[ADMIN]:
suicide(msg.sender)
elif msg.data[0] == "balance":
addr = msg.data[1]
return(contract.storage[ACCOUNT_MAP_OFFSET + addr])
elif msg.data[0] == "pay":
from = msg.sender
fromvalue = contract.storage[ACCOUNT_MAP_OFFSET + from]
to = msg.data[1]
if to == 0 or to >= 2^160:
return([0, "invalid address"], 2)
value = msg.data[2]
tax = value / PAYMENT_TAX
if fromvalue >= value + tax:
contract.storage[ACCOUNT_MAP_OFFSET + from] = fromvalue - (value + tax)
contract.storage[ACCOUNT_MAP_OFFSET + to] += value
# tax
contract.storage[COMMONS] += tax
contract.storage[ACCOUNT_MAP_OFFSET + from + 2] += tax
# check timestamp field to see if target account exists
if contract.storage[ACCOUNT_MAP_OFFSET + to + 1] == 0:
# register new account
nr_accounts = contract.storage[ACCOUNT_LIST_OFFSET - 1]
contract.storage[ACCOUNT_LIST_OFFSET + nr_accounts] = to
contract.storage[ACCOUNT_LIST_OFFSET - 1] += 1
contract.storage[ACCOUNT_MAP_OFFSET + to + 1] = block.timestamp
return(1)
else:
return([0, "insufficient balance"], 2)
elif msg.data[0] == "hash":
proposal_id = sha3(msg.data[1])
return(proposal_id)
elif msg.data[0] == "propose":
from = msg.sender
# check if sender has an account and has tax credits
if contract.storage[ACCOUNT_MAP_OFFSET + from + 2] == 0:
return([0, "sender has no tax credits"], 2)
proposal_id = sha3(msg.data[1])
# check if proposal doesn't already exist
if contract.storage[PROPOSAL_MAP_OFFSET + proposal_id]:
return([0, "proposal already exists"])
to = msg.data[2]
# check if recipient is a valid address and has an account (with timestamp)
if to == 0 or to >= 2^160:
return([0, "invalid address"], 2)
if contract.storage[ACCOUNT_MAP_OFFSET + to + 1] == 0:
return([0, "invalid to account"], 2)
value = msg.data[3]
# check if there is enough money in the commons account
if value > contract.storage[COMMONS]:
return([0, "not enough credits in commons"], 2)
# record proposal in list
nr_proposals = contract.storage[PROPOSAL_LIST_OFFSET - 1]
contract.storage[PROPOSAL_LIST_OFFSET + nr_proposals] = proposal_id
contract.storage[PROPOSAL_LIST_OFFSET - 1] += 1
# record proposal in map
contract.storage[PROPOSAL_MAP_OFFSET + proposal_id] = to
contract.storage[PROPOSAL_MAP_OFFSET + proposal_id + 1] = value
return(proposal_id)
elif msg.data[0] == "vote":
from = msg.sender
proposal_id = sha3(msg.data[1])
value = msg.data[2]
# check if sender has an account and has tax credits
if value < contract.storage[ACCOUNT_MAP_OFFSET + from + 2]:
return([0, "sender doesn't have enough tax credits"], 2)
# check if proposal exist
if contract.storage[PROPOSAL_MAP_OFFSET + proposal_id] == 0:
return([0, "proposal doesn't exist"], 2)
# increase votes
contract.storage[PROPOSAL_MAP_OFFSET + proposal_id + 2] += value
# withdraw tax credits
contract.storage[ACCOUNT_MAP_OFFSET + from + 2] -= value
# did we reach 2/3 threshold?
if contract.storage[PROPOSAL_MAP_OFFSET + proposal_id + 2] >= contract.storage[COMMONS] * 2 / 3:
# got majority
to = contract.storage[PROPOSAL_MAP_OFFSET + proposal_id]
amount = contract.storage[PROPOSAL_MAP_OFFSET + proposal_id + 1]
# adjust balances
contract.storage[ACCOUNT_MAP_OFFSET + to] += amount
contract.storage[COMMONS] -= amount
# reset proposal
contract.storage[PROPOSAL_MAP_OFFSET + proposal_id] = 0
contract.storage[PROPOSAL_MAP_OFFSET + proposal_id + 1] = 0
contract.storage[PROPOSAL_MAP_OFFSET + proposal_id + 2] = 0
return(1)
return(proposal_id)
elif msg.data[0] == "tick":
nr_accounts = contract.storage[ACCOUNT_LIST_OFFSET - 1]
account_idx = 0
tax_paid = 0
# process all accounts and see if they have to pay their daily capital tax
while account_idx < nr_accounts:
cur_account = contract.storage[ACCOUNT_LIST_OFFSET + account_idx]
last_timestamp = contract.storage[ACCOUNT_MAP_OFFSET + cur_account + 1]
time_diff = block.timestamp - last_timestamp
if time_diff >= 86400:
tax_days = time_diff / 86400
balance = contract.storage[ACCOUNT_MAP_OFFSET + cur_account]
tax = tax_days * (balance / CAPITAL_TAX_PER_DAY)
if tax > 0:
# charge capital tax, but give tax credits in return
contract.storage[ACCOUNT_MAP_OFFSET + cur_account] -= tax
contract.storage[ACCOUNT_MAP_OFFSET + cur_account + 1] += tax_days * 86400
contract.storage[ACCOUNT_MAP_OFFSET + cur_account + 2] += tax
contract.storage[COMMONS] += tax
tax_paid += 1
account_idx += 1
return(tax_paid) # how many accounts did we charge tax on
else:
return([0, "unknown command"], 2)

View File

@ -0,0 +1,7 @@
def register(k, v):
if !self.storage[k]: # Is the key not yet taken?
# Then take it!
self.storage[k] = v
return(1)
else:
return(0) // Otherwise do nothing

View File

@ -0,0 +1,43 @@
macro padd($x, psuc($y)):
psuc(padd($x, $y))
macro padd($x, z()):
$x
macro dec(psuc($x)):
dec($x) + 1
macro dec(z()):
0
macro pmul($x, z()):
z()
macro pmul($x, psuc($y)):
padd(pmul($x, $y), $x)
macro pexp($x, z()):
one()
macro pexp($x, psuc($y)):
pmul($x, pexp($x, $y))
macro fac(z()):
one()
macro fac(psuc($x)):
pmul(psuc($x), fac($x))
macro one():
psuc(z())
macro two():
psuc(psuc(z()))
macro three():
psuc(psuc(psuc(z())))
macro five():
padd(three(), two())
return([dec(pmul(three(), pmul(three(), three()))), dec(fac(five()))], 2)

View File

@ -0,0 +1,4 @@
extern mul2: [double]
x = create("mul2.se")
return(x.double(5))

View File

@ -0,0 +1,33 @@
def kall():
argcount = ~calldatasize() / 32
if argcount == 1:
return(~calldataload(1))
args = array(argcount)
~calldatacopy(args, 1, argcount * 32)
low = array(argcount)
lsz = 0
high = array(argcount)
hsz = 0
i = 1
while i < argcount:
if args[i] < args[0]:
low[lsz] = args[i]
lsz += 1
else:
high[hsz] = args[i]
hsz += 1
i += 1
low = self.kall(data=low, datasz=lsz, outsz=lsz)
high = self.kall(data=high, datasz=hsz, outsz=hsz)
o = array(argcount)
i = 0
while i < lsz:
o[i] = low[i]
i += 1
o[lsz] = args[0]
j = 0
while j < hsz:
o[lsz + 1 + j] = high[j]
j += 1
return(o, argcount)

View File

@ -0,0 +1,46 @@
# Quicksort pairs
# eg. input of the form [ 30, 1, 90, 2, 70, 3, 50, 4]
# outputs [ 30, 1, 50, 4, 70, 3, 90, 2 ]
#
# Note: this can be used as a generalized sorting algorithm:
# map every object to [ key, ref ] where `ref` is the index
# in memory to all of the properties and `key` is the key to
# sort by
def kall():
argcount = ~calldatasize() / 64
if argcount == 1:
return([~calldataload(1), ~calldataload(33)], 2)
args = array(argcount * 2)
~calldatacopy(args, 1, argcount * 64)
low = array(argcount * 2)
lsz = 0
high = array(argcount * 2)
hsz = 0
i = 2
while i < argcount * 2:
if args[i] < args[0]:
low[lsz] = args[i]
low[lsz + 1] = args[i + 1]
lsz += 2
else:
high[hsz] = args[i]
high[hsz + 1] = args[i + 1]
hsz += 2
i = i + 2
low = self.kall(data=low, datasz=lsz, outsz=lsz)
high = self.kall(data=high, datasz=hsz, outsz=hsz)
o = array(argcount * 2)
i = 0
while i < lsz:
o[i] = low[i]
i += 1
o[lsz] = args[0]
o[lsz + 1] = args[1]
j = 0
while j < hsz:
o[lsz + 2 + j] = high[j]
j += 1
return(o, argcount * 2)

View File

@ -0,0 +1,94 @@
# SchellingCoin implementation
#
# Epoch length: 100 blocks
# Target savings depletion rate: 0.1% per epoch
data epoch
data hashes_submitted
data output
data quicksort_pairs
data accounts[2^160]
data submissions[2^80](hash, deposit, address, value)
extern any: [call]
def init():
self.epoch = block.number / 100
self.quicksort_pairs = create('quicksort_pairs.se')
def any():
if block.number / 100 > epoch:
# Sort all values submitted
N = self.hashes_submitted
o = array(N * 2)
i = 0
j = 0
while i < N:
v = self.submissions[i].value
if v:
o[j] = v
o[j + 1] = i
j += 2
i += 1
values = self.quicksort_pairs.call(data=o, datasz=j, outsz=j)
# Calculate total deposit, refund non-submitters and
# cleanup
deposits = array(j / 2)
addresses = array(j / 2)
i = 0
total_deposit = 0
while i < j / 2:
base_index = HASHES + values[i * 2 + 1] * 3
deposits[i] = self.submissions[i].deposit
addresses[i] = self.submissions[i].address
if self.submissions[values[i * 2 + 1]].value:
total_deposit += deposits[i]
else:
send(addresses[i], deposits[i] * 999 / 1000)
i += 1
inverse_profit_ratio = total_deposit / (contract.balance / 1000) + 1
# Reward everyone
i = 0
running_deposit_sum = 0
halfway_passed = 0
while i < j / 2:
new_deposit_sum = running_deposit_sum + deposits[i]
if new_deposit_sum > total_deposit / 4 and running_deposit_sum < total_deposit * 3 / 4:
send(addresses[i], deposits[i] + deposits[i] / inverse_profit_ratio * 2)
else:
send(addresses[i], deposits[i] - deposits[i] / inverse_profit_ratio)
if not halfway_passed and new_deposit_sum > total_deposit / 2:
self.output = self.submissions[i].value
halfway_passed = 1
self.submissions[i].value = 0
running_deposit_sum = new_deposit_sum
i += 1
self.epoch = block.number / 100
self.hashes_submitted = 0
def submit_hash(h):
if block.number % 100 < 50:
cur = self.hashes_submitted
pos = HASHES + cur * 3
self.submissions[cur].hash = h
self.submissions[cur].deposit = msg.value
self.submissions[cur].address = msg.sender
self.hashes_submitted = cur + 1
return(cur)
def submit_value(index, v):
if sha3([msg.sender, v], 2) == self.submissions[index].hash:
self.submissions[index].value = v
return(1)
def request_balance():
return(contract.balance)
def request_output():
return(self.output)

View File

@ -0,0 +1,171 @@
# Hedged zero-supply dollar implementation
# Uses SchellingCoin as price-determining backend
#
# Stored variables:
#
# 0: Schelling coin contract
# 1: Last epoch
# 2: Genesis block of contract
# 3: USD exposure
# 4: ETH exposure
# 5: Cached price
# 6: Last interest rate
# 2^160 + k: interest rate accumulator at k epochs
# 2^161 + ADDR * 3: eth-balance of a particular address
# 2^161 + ADDR * 3 + 1: usd-balance of a particular address
# 2^161 + ADDR * 3 + 1: last accessed epoch of a particular address
#
# Transaction types:
#
# [1, to, val]: send ETH
# [2, to, val]: send USD
# [3, wei_amount]: convert ETH to USD
# [4, usd_amount]: converts USD to ETH
# [5]: deposit
# [6, amount]: withdraw
# [7]: my balance query
# [7, acct]: balance query for any acct
# [8]: global state query
# [9]: liquidation test any account
#
# The purpose of the contract is to serve as a sort of cryptographic
# bank account where users can store both ETH and USD. ETH must be
# stored in zero or positive quantities, but USD balances can be
# positive or negative. If the USD balance is negative, the invariant
# usdbal * 10 >= ethbal * 9 must be satisfied; if any account falls
# below this value, then that account's balances are zeroed. Note
# that there is a 2% bounty to ping the app if an account does go
# below zero; one weakness is that if no one does ping then it is
# quite possible for accounts to go negative-net-worth, then zero
# themselves out, draining the reserves of the "bank" and potentially
# bankrupting it. A 0.1% fee on ETH <-> USD trade is charged to
# minimize this risk. Additionally, the bank itself will inevitably
# end up with positive or negative USD exposure; to mitigate this,
# it automatically updates interest rates on USD to keep exposure
# near zero.
data schelling_coin
data last_epoch
data starting_block
data usd_exposure
data eth_exposure
data price
data last_interest_rate
data interest_rate_accum[2^50]
data accounts[2^160](eth, usd, last_epoch)
extern sc: [submit_hash, submit_value, request_balance, request_output]
def init():
self.schelling_coin = create('schellingcoin.se')
self.price = self.schelling_coin.request_output()
self.interest_rate_accum[0] = 10^18
self.starting_block = block.number
def any():
sender = msg.sender
epoch = (block.number - self.starting_block) / 100
last_epoch = self.last_epoch
usdprice = self.price
# Update contract epochs
if epoch > last_epoch:
delta = epoch - last_epoch
last_interest_rate = self.last_interest_rate
usd_exposure - self.usd_exposure
last_accum = self.interest_rate_accum[last_epoch]
if usd_exposure < 0:
self.last_interest_rate = last_interest_rate - 10000 * delta
elif usd_exposure > 0:
self.last_interest_rate = last_interest_rate + 10000 * delta
self.interest_rate_accum[epoch] = last_accum + last_accum * last_interest_rate * delta / 10^9
# Proceeds go to support the SchellingCoin feeding it price data, ultimately providing the depositors
# of the SchellingCoin an interest rate
bal = max(self.balance - self.eth_exposure, 0) / 10000
usdprice = self.schelling_coin.request_output()
self.price = usdprice
self.last_epoch = epoch
ethbal = self.accounts[msg.sender].eth
usdbal = self.accounts[msg.sender].usd
# Apply interest rates to sender and liquidation-test self
if msg.sender != self:
self.ping(self)
def send_eth(to, value):
if value > 0 and value <= ethbal and usdbal * usdprice * 2 + (ethbal - value) >= 0:
self.accounts[msg.sender].eth = ethbal - value
self.ping(to)
self.accounts[to].eth += value
return(1)
def send_usd(to, value):
if value > 0 and value <= usdbal and (usdbal - value) * usdprice * 2 + ethbal >= 0:
self.accounts[msg.sender].usd = usdbal - value
self.ping(to)
self.accounts[to].usd += value
return(1)
def convert_to_eth(usdvalue):
ethplus = usdvalue * usdprice * 999 / 1000
if usdvalue > 0 and (usdbal - usdvalue) * usdprice * 2 + (ethbal + ethplus) >= 0:
self.accounts[msg.sender].eth = ethbal + ethplus
self.accounts[msg.sender].usd = usdbal - usdvalue
self.eth_exposure += ethplus
self.usd_exposure -= usdvalue
return([ethbal + ethplus, usdbal - usdvalue], 2)
def convert_to_usd(ethvalue):
usdplus = ethvalue / usdprice * 999 / 1000
if ethvalue > 0 and (usdbal + usdplus) * usdprice * 2 + (ethbal - ethvalue) >= 0:
self.accounts[msg.sender].eth = ethbal - ethvalue
self.accounts[msg.sender].usd = usdbal + usdplus
self.eth_exposure -= ethvalue
self.usd_exposure += usdplus
return([ethbal - ethvalue, usdbal + usdplus], 2)
def deposit():
self.accounts[msg.sender].eth = ethbal + msg.value
self.eth_exposure += msg.value
return(ethbal + msg.value)
def withdraw(value):
if value > 0 and value <= ethbal and usdbal * usdprice * 2 + (ethbal - value) >= 0:
self.accounts[msg.sender].eth -= value
self.eth_exposure -= value
return(ethbal - value)
def balance(acct):
self.ping(acct)
return([self.accounts[acct].eth, self.accounts[acct].usd], 2)
def global_state_query(acct):
interest = self.last_interest_rate
usd_exposure = self.usd_exposure
eth_exposure = self.eth_exposure
eth_balance = self.balance
return([epoch, usdprice, interest, usd_exposure, eth_exposure, eth_balance], 6)
def ping(acct):
account_last_epoch = self.accounts[acct].last_epoch
if account_last_epoch != epoch:
cur_usd_balance = self.accounts[acct].usd
new_usd_balance = cur_usd_balance * self.interest_rate_accum[epoch] / self.interest_rate_accum[account_last_epoch]
self.accounts[acct].usd = new_usd_balance
self.accounts[acct].last_epoch = epoch
self.usd_exposure += new_usd_balance - cur_usd_balance
ethbal = self.accounts[acct].eth
if new_usd_balance * usdval * 10 + ethbal * 9 < 0:
self.accounts[acct].eth = 0
self.accounts[acct].usd = 0
self.accounts[msg.sender].eth += ethbal / 50
self.eth_exposure += -ethbal + ethbal / 50
self.usd_exposure += new_usd_balance
return(1)
return(0)

View File

@ -0,0 +1 @@
return(sha3([msg.sender, msg.data[0]], 2))

View File

@ -0,0 +1,3 @@
def register(k, v):
if !self.storage[k]:
self.storage[k] = v

View File

@ -0,0 +1,11 @@
def init():
self.storage[msg.sender] = 1000000
def balance_query(k):
return(self.storage[addr])
def send(to, value):
fromvalue = self.storage[msg.sender]
if fromvalue >= value:
self.storage[from] = fromvalue - value
self.storage[to] += value

View File

@ -0,0 +1,35 @@
#include <stdio.h>
#include <iostream>
#include <vector>
#include "funcs.h"
#include "bignum.h"
#include "util.h"
#include "parser.h"
#include "lllparser.h"
#include "compiler.h"
#include "rewriter.h"
#include "tokenize.h"
Node compileToLLL(std::string input) {
return rewrite(parseSerpent(input));
}
Node compileChunkToLLL(std::string input) {
return rewriteChunk(parseSerpent(input));
}
std::string compile(std::string input) {
return compileLLL(compileToLLL(input));
}
std::vector<Node> prettyCompile(std::string input) {
return prettyCompileLLL(compileToLLL(input));
}
std::string compileChunk(std::string input) {
return compileLLL(compileChunkToLLL(input));
}
std::vector<Node> prettyCompileChunk(std::string input) {
return prettyCompileLLL(compileChunkToLLL(input));
}

View File

@ -0,0 +1,35 @@
#include <stdio.h>
#include <iostream>
#include <vector>
#include "bignum.h"
#include "util.h"
#include "parser.h"
#include "lllparser.h"
#include "compiler.h"
#include "rewriter.h"
#include "tokenize.h"
// Function listing:
//
// parseSerpent (serpent -> AST) std::string -> Node
// parseLLL (LLL -> AST) std::string -> Node
// rewrite (apply rewrite rules) Node -> Node
// compileToLLL (serpent -> LLL) std::string -> Node
// compileLLL (LLL -> EVMhex) Node -> std::string
// prettyCompileLLL (LLL -> EVMasm) Node -> std::vector<Node>
// prettyCompile (serpent -> EVMasm) std::string -> std::vector>Node>
// compile (serpent -> EVMhex) std::string -> std::string
// get_file_contents (filename -> file) std::string -> std::string
// exists (does file exist?) std::string -> bool
Node compileToLLL(std::string input);
Node compileChunkToLLL(std::string input);
std::string compile(std::string input);
std::vector<Node> prettyCompile(std::string input);
std::string compileChunk(std::string input);
std::vector<Node> prettyCompileChunk(std::string input);

View File

@ -0,0 +1,203 @@
#include <stdio.h>
#include <iostream>
#include <vector>
#include <map>
#include "util.h"
#include "lllparser.h"
#include "bignum.h"
#include "optimize.h"
#include "rewriteutils.h"
#include "preprocess.h"
#include "functions.h"
std::string getSignature(std::vector<Node> args) {
std::string o;
for (unsigned i = 0; i < args.size(); i++) {
if (args[i].val == ":" && args[i].args[1].val == "s")
o += "s";
else if (args[i].val == ":" && args[i].args[1].val == "a")
o += "a";
else
o += "i";
}
return o;
}
// Convert a list of arguments into a node containing a
// < datastart, datasz > pair
Node packArguments(std::vector<Node> args, std::string sig,
int funId, Metadata m) {
// Plain old 32 byte arguments
std::vector<Node> nargs;
// Variable-sized arguments
std::vector<Node> vargs;
// Variable sizes
std::vector<Node> sizes;
// Is a variable an array?
std::vector<bool> isArray;
// Fill up above three argument lists
int argCount = 0;
for (unsigned i = 0; i < args.size(); i++) {
Metadata m = args[i].metadata;
if (args[i].val == "=") {
// do nothing
}
else {
// Determine the correct argument type
char argType;
if (sig.size() > 0) {
if (argCount >= (signed)sig.size())
err("Too many args", m);
argType = sig[argCount];
}
else argType = 'i';
// Integer (also usable for short strings)
if (argType == 'i') {
if (args[i].val == ":")
err("Function asks for int, provided string or array", m);
nargs.push_back(args[i]);
}
// Long string
else if (argType == 's') {
if (args[i].val != ":")
err("Must specify string length", m);
vargs.push_back(args[i].args[0]);
sizes.push_back(args[i].args[1]);
isArray.push_back(false);
}
// Array
else if (argType == 'a') {
if (args[i].val != ":")
err("Must specify array length", m);
vargs.push_back(args[i].args[0]);
sizes.push_back(args[i].args[1]);
isArray.push_back(true);
}
else err("Invalid arg type in signature", m);
argCount++;
}
}
int static_arg_size = 1 + (vargs.size() + nargs.size()) * 32;
// Start off by saving the size variables and calculating the total
msn kwargs;
kwargs["funid"] = tkn(utd(funId), m);
std::string pattern =
"(with _sztot "+utd(static_arg_size)+" "
" (with _sizes (alloc "+utd(sizes.size() * 32)+") "
" (seq ";
for (unsigned i = 0; i < sizes.size(); i++) {
std::string sizeIncrement =
isArray[i] ? "(mul 32 _x)" : "_x";
pattern +=
"(with _x $sz"+utd(i)+"(seq "
" (mstore (add _sizes "+utd(i * 32)+") _x) "
" (set _sztot (add _sztot "+sizeIncrement+" )))) ";
kwargs["sz"+utd(i)] = sizes[i];
}
// Allocate memory, and set first data byte
pattern +=
"(with _datastart (alloc (add _sztot 32)) (seq "
" (mstore8 _datastart $funid) ";
// Copy over size variables
for (unsigned i = 0; i < sizes.size(); i++) {
int v = 1 + i * 32;
pattern +=
" (mstore "
" (add _datastart "+utd(v)+") "
" (mload (add _sizes "+utd(v-1)+"))) ";
}
// Store normal arguments
for (unsigned i = 0; i < nargs.size(); i++) {
int v = 1 + (i + sizes.size()) * 32;
pattern +=
" (mstore (add _datastart "+utd(v)+") $"+utd(i)+") ";
kwargs[utd(i)] = nargs[i];
}
// Loop through variable-sized arguments, store them
pattern +=
" (with _pos (add _datastart "+utd(static_arg_size)+") (seq";
for (unsigned i = 0; i < vargs.size(); i++) {
std::string copySize =
isArray[i] ? "(mul 32 (mload (add _sizes "+utd(i * 32)+")))"
: "(mload (add _sizes "+utd(i * 32)+"))";
pattern +=
" (unsafe_mcopy _pos $vl"+utd(i)+" "+copySize+") "
" (set _pos (add _pos "+copySize+")) ";
kwargs["vl"+utd(i)] = vargs[i];
}
// Return a 2-item array containing the start and size
pattern += " (array_lit _datastart _sztot))))))))";
std::string prefix = "_temp_"+mkUniqueToken();
// Fill in pattern, return triple
return subst(parseLLL(pattern), kwargs, prefix, m);
}
// Create a node for argument unpacking
Node unpackArguments(std::vector<Node> vars, Metadata m) {
std::vector<std::string> varNames;
std::vector<std::string> longVarNames;
std::vector<bool> longVarIsArray;
// Fill in variable and long variable names, as well as which
// long variables are arrays and which are strings
for (unsigned i = 0; i < vars.size(); i++) {
if (vars[i].val == ":") {
if (vars[i].args.size() != 2)
err("Malformed def!", m);
longVarNames.push_back(vars[i].args[0].val);
std::string tag = vars[i].args[1].val;
if (tag == "s")
longVarIsArray.push_back(false);
else if (tag == "a")
longVarIsArray.push_back(true);
else
err("Function value can only be string or array", m);
}
else {
varNames.push_back(vars[i].val);
}
}
std::vector<Node> sub;
if (!varNames.size() && !longVarNames.size()) {
// do nothing if we have no arguments
}
else {
std::vector<Node> varNodes;
for (unsigned i = 0; i < longVarNames.size(); i++)
varNodes.push_back(token(longVarNames[i], m));
for (unsigned i = 0; i < varNames.size(); i++)
varNodes.push_back(token(varNames[i], m));
// Copy over variable lengths and short variables
for (unsigned i = 0; i < varNodes.size(); i++) {
int pos = 1 + i * 32;
std::string prefix = (i < longVarNames.size()) ? "_len_" : "";
sub.push_back(asn("untyped", asn("set",
token(prefix+varNodes[i].val, m),
asn("calldataload", tkn(utd(pos), m), m),
m)));
}
// Copy over long variables
if (longVarNames.size() > 0) {
std::vector<Node> sub2;
int pos = varNodes.size() * 32 + 1;
Node tot = tkn("_tot", m);
for (unsigned i = 0; i < longVarNames.size(); i++) {
Node var = tkn(longVarNames[i], m);
Node varlen = longVarIsArray[i]
? asn("mul", tkn("32", m), tkn("_len_"+longVarNames[i], m))
: tkn("_len_"+longVarNames[i], m);
sub2.push_back(asn("untyped",
asn("set", var, asn("alloc", varlen))));
sub2.push_back(asn("calldatacopy", var, tot, varlen));
sub2.push_back(asn("set", tot, asn("add", tot, varlen)));
}
std::string prefix = "_temp_"+mkUniqueToken();
sub.push_back(subst(
astnode("with", tot, tkn(utd(pos), m), asn("seq", sub2)),
msn(),
prefix,
m));
}
}
return asn("seq", sub, m);
}

View File

@ -0,0 +1,39 @@
#ifndef ETHSERP_FUNCTIONS
#define ETHSERP_FUNCTIONS
#include <stdio.h>
#include <iostream>
#include <vector>
#include <map>
#include "util.h"
#include "lllparser.h"
#include "bignum.h"
#include "optimize.h"
#include "rewriteutils.h"
#include "preprocess.h"
class argPack {
public:
argPack(Node a, Node b, Node c) {
pre = a;
datastart = b;
datasz = c;
}
Node pre;
Node datastart;
Node datasz;
};
// Get a signature from a function
std::string getSignature(std::vector<Node> args);
// Convert a list of arguments into a <pre, mstart, msize> node
// triple, given the signature of a function
Node packArguments(std::vector<Node> args, std::string sig,
int funId, Metadata m);
// Create a node for argument unpacking
Node unpackArguments(std::vector<Node> vars, Metadata m);
#endif

Some files were not shown because too many files have changed in this diff Show More