// Copyright 2015 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . package rpc import ( "bytes" "context" "crypto/tls" "encoding/base64" "encoding/json" "fmt" "net" "net/http" "net/url" "os" "strings" "time" mapset "github.com/deckarep/golang-set" "github.com/ethereum/go-ethereum/log" "golang.org/x/net/websocket" ) // websocketJSONCodec is a custom JSON codec with payload size enforcement and // special number parsing. var websocketJSONCodec = websocket.Codec{ // Marshal is the stock JSON marshaller used by the websocket library too. Marshal: func(v interface{}) ([]byte, byte, error) { msg, err := json.Marshal(v) return msg, websocket.TextFrame, err }, // Unmarshal is a specialized unmarshaller to properly convert numbers. Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { dec := json.NewDecoder(bytes.NewReader(msg)) dec.UseNumber() return dec.Decode(v) }, } // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. // // allowedOrigins should be a comma-separated list of allowed origin URLs. // To allow connections with any origin, pass "*". func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { return websocket.Server{ Handshake: wsHandshakeValidator(allowedOrigins), Handler: func(conn *websocket.Conn) { // Create a custom encode/decode pair to enforce payload size and number encoding conn.MaxPayloadBytes = maxRequestContentLength encoder := func(v interface{}) error { return websocketJSONCodec.Send(conn, v) } decoder := func(v interface{}) error { return websocketJSONCodec.Receive(conn, v) } srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) }, } } // NewWSServer creates a new websocket RPC server around an API provider. // // Deprecated: use Server.WebsocketHandler func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} } // wsHandshakeValidator returns a handler that verifies the origin during the // websocket upgrade process. When a '*' is specified as an allowed origins all // connections are accepted. func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { origins := mapset.NewSet() allowAllOrigins := false for _, origin := range allowedOrigins { if origin == "*" { allowAllOrigins = true } if origin != "" { origins.Add(strings.ToLower(origin)) } } // allow localhost if no allowedOrigins are specified. if len(origins.ToSlice()) == 0 { origins.Add("http://localhost") if hostname, err := os.Hostname(); err == nil { origins.Add("http://" + strings.ToLower(hostname)) } } log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice())) f := func(cfg *websocket.Config, req *http.Request) error { origin := strings.ToLower(req.Header.Get("Origin")) if allowAllOrigins || origins.Contains(origin) { return nil } log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) return fmt.Errorf("origin %s not allowed", origin) } return f } func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { if origin == "" { var err error if origin, err = os.Hostname(); err != nil { return nil, err } if strings.HasPrefix(endpoint, "wss") { origin = "https://" + strings.ToLower(origin) } else { origin = "http://" + strings.ToLower(origin) } } config, err := websocket.NewConfig(endpoint, origin) if err != nil { return nil, err } if config.Location.User != nil { b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String())) config.Header.Add("Authorization", "Basic "+b64auth) config.Location.User = nil } return config, nil } // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint. // // The context is used for the initial connection establishment. It does not // affect subsequent interactions with the client. func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { config, err := wsGetConfig(endpoint, origin) if err != nil { return nil, err } return newClient(ctx, func(ctx context.Context) (net.Conn, error) { return wsDialContext(ctx, config) }) } func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { var conn net.Conn var err error switch config.Location.Scheme { case "ws": conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) case "wss": dialer := contextDialer(ctx) conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) default: err = websocket.ErrBadScheme } if err != nil { return nil, err } ws, err := websocket.NewClient(config, conn) if err != nil { conn.Close() return nil, err } return ws, err } var wsPortMap = map[string]string{"ws": "80", "wss": "443"} func wsDialAddress(location *url.URL) string { if _, ok := wsPortMap[location.Scheme]; ok { if _, _, err := net.SplitHostPort(location.Host); err != nil { return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) } } return location.Host } func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} return d.DialContext(ctx, network, addr) } func contextDialer(ctx context.Context) *net.Dialer { dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} if deadline, ok := ctx.Deadline(); ok { dialer.Deadline = deadline } else { dialer.Deadline = time.Now().Add(defaultDialTimeout) } return dialer }