support unix domain websockets
This commit is contained in:
parent
74130008f7
commit
1410693eae
|
@ -16,24 +16,18 @@ import (
|
|||
|
||||
// Set the net.Dial manually so we can do http over tcp or unix.
|
||||
// Get/Post require a dummyDomain but it's over written by the Transport
|
||||
var dummyDomain = "http://dummyDomain/"
|
||||
var dummyDomain = "http://dummyDomain"
|
||||
|
||||
func dialFunc(sockType, remote string) func(string, string) (net.Conn, error) {
|
||||
func dialer(remote string) func(string, string) (net.Conn, error) {
|
||||
return func(proto, addr string) (conn net.Conn, err error) {
|
||||
return net.Dial(sockType, remote)
|
||||
return net.Dial(rpctypes.SocketType(remote), remote)
|
||||
}
|
||||
}
|
||||
|
||||
// remote is IP:PORT or /path/to/socket
|
||||
func socketTransport(remote string) *http.Transport {
|
||||
if rpctypes.SocketType(remote) == "unix" {
|
||||
return &http.Transport{
|
||||
Dial: dialFunc("unix", remote),
|
||||
}
|
||||
} else {
|
||||
return &http.Transport{
|
||||
Dial: dialFunc("tcp", remote),
|
||||
}
|
||||
Dial: dialer(remote),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -105,7 +99,7 @@ func (c *ClientURI) call(method string, params map[string]interface{}, result in
|
|||
return nil, err
|
||||
}
|
||||
log.Info(Fmt("URI request to %v (%v): %v", c.remote, method, values))
|
||||
resp, err := c.client.PostForm(dummyDomain+method, values)
|
||||
resp, err := c.client.PostForm(dummyDomain+"/"+method, values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -19,16 +19,18 @@ const (
|
|||
|
||||
type WSClient struct {
|
||||
QuitService
|
||||
Address string
|
||||
Address string // IP:PORT or /path/to/socket
|
||||
Endpoint string // /websocket/url/endpoint
|
||||
*websocket.Conn
|
||||
ResultsCh chan json.RawMessage // closes upon WSClient.Stop()
|
||||
ErrorsCh chan error // closes upon WSClient.Stop()
|
||||
}
|
||||
|
||||
// create a new connection
|
||||
func NewWSClient(addr string) *WSClient {
|
||||
func NewWSClient(addr, endpoint string) *WSClient {
|
||||
wsClient := &WSClient{
|
||||
Address: addr,
|
||||
Endpoint: endpoint,
|
||||
Conn: nil,
|
||||
ResultsCh: make(chan json.RawMessage, wsResultsChannelCapacity),
|
||||
ErrorsCh: make(chan error, wsErrorsChannelCapacity),
|
||||
|
@ -38,7 +40,7 @@ func NewWSClient(addr string) *WSClient {
|
|||
}
|
||||
|
||||
func (wsc *WSClient) String() string {
|
||||
return wsc.Address
|
||||
return wsc.Address + ", " + wsc.Endpoint
|
||||
}
|
||||
|
||||
func (wsc *WSClient) OnStart() error {
|
||||
|
@ -52,10 +54,14 @@ func (wsc *WSClient) OnStart() error {
|
|||
}
|
||||
|
||||
func (wsc *WSClient) dial() error {
|
||||
|
||||
// Dial
|
||||
dialer := websocket.DefaultDialer
|
||||
dialer := &websocket.Dialer{
|
||||
NetDial: dialer(wsc.Address),
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
||||
rHeader := http.Header{}
|
||||
con, _, err := dialer.Dial(wsc.Address, rHeader)
|
||||
con, _, err := dialer.Dial("ws://"+dummyDomain+wsc.Endpoint, rHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
52
rpc_test.go
52
rpc_test.go
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/tendermint/go-rpc/client"
|
||||
"github.com/tendermint/go-rpc/server"
|
||||
"github.com/tendermint/go-rpc/types"
|
||||
"github.com/tendermint/go-wire"
|
||||
)
|
||||
|
||||
|
@ -14,6 +15,8 @@ import (
|
|||
var (
|
||||
tcpAddr = "0.0.0.0:46657"
|
||||
unixAddr = "/tmp/go-rpc.sock" // NOTE: must remove file for test to run again
|
||||
|
||||
websocketEndpoint = "/websocket/endpoint"
|
||||
)
|
||||
|
||||
// Define a type for results and register concrete versions
|
||||
|
@ -42,6 +45,8 @@ func StatusResult(v string) (Result, error) {
|
|||
func init() {
|
||||
mux := http.NewServeMux()
|
||||
rpcserver.RegisterRPCFuncs(mux, Routes)
|
||||
wm := rpcserver.NewWebsocketManager(Routes, nil)
|
||||
mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler)
|
||||
go func() {
|
||||
_, err := rpcserver.StartHTTPServer(tcpAddr, mux)
|
||||
if err != nil {
|
||||
|
@ -51,6 +56,8 @@ func init() {
|
|||
|
||||
mux = http.NewServeMux()
|
||||
rpcserver.RegisterRPCFuncs(mux, Routes)
|
||||
wm = rpcserver.NewWebsocketManager(Routes, nil)
|
||||
mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler)
|
||||
go func() {
|
||||
_, err := rpcserver.StartHTTPServer(unixAddr, mux)
|
||||
if err != nil {
|
||||
|
@ -93,6 +100,33 @@ func testJSONRPC(t *testing.T, cl *rpcclient.ClientJSONRPC) {
|
|||
}
|
||||
}
|
||||
|
||||
func testWS(t *testing.T, cl *rpcclient.WSClient) {
|
||||
val := "acbd"
|
||||
params := []interface{}{val}
|
||||
err := cl.WriteJSON(rpctypes.RPCRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: "",
|
||||
Method: "status",
|
||||
Params: params,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
msg := <-cl.ResultsCh
|
||||
result := new(Result)
|
||||
wire.ReadJSONPtr(result, msg, &err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := (*result).(*ResultStatus).Value
|
||||
if got != val {
|
||||
t.Fatalf("Got: %v .... Expected: %v \n", got, val)
|
||||
}
|
||||
}
|
||||
|
||||
//-------------
|
||||
|
||||
func TestURI_TCP(t *testing.T) {
|
||||
cl := rpcclient.NewClientURI(tcpAddr)
|
||||
testURI(t, cl)
|
||||
|
@ -112,3 +146,21 @@ func TestJSONRPC_UNIX(t *testing.T) {
|
|||
cl := rpcclient.NewClientJSONRPC(unixAddr)
|
||||
testJSONRPC(t, cl)
|
||||
}
|
||||
|
||||
func TestWS_TCP(t *testing.T) {
|
||||
cl := rpcclient.NewWSClient(tcpAddr, websocketEndpoint)
|
||||
_, err := cl.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testWS(t, cl)
|
||||
}
|
||||
|
||||
func TestWS_UNIX(t *testing.T) {
|
||||
cl := rpcclient.NewWSClient(unixAddr, websocketEndpoint)
|
||||
_, err := cl.Start()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testWS(t, cl)
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ type WSRPCContext struct {
|
|||
// If tcp, must specify the port; `0.0.0.0` will return incorrectly as "unix" since there's no port
|
||||
func SocketType(listenAddr string) string {
|
||||
socketType := "unix"
|
||||
if len(strings.Split(listenAddr, ":")) == 2 {
|
||||
if len(strings.Split(listenAddr, ":")) >= 2 {
|
||||
socketType = "tcp"
|
||||
}
|
||||
return socketType
|
||||
|
|
Loading…
Reference in New Issue