support unix domain websockets
This commit is contained in:
parent
74130008f7
commit
1410693eae
|
@ -16,24 +16,18 @@ import (
|
||||||
|
|
||||||
// Set the net.Dial manually so we can do http over tcp or unix.
|
// Set the net.Dial manually so we can do http over tcp or unix.
|
||||||
// Get/Post require a dummyDomain but it's over written by the Transport
|
// Get/Post require a dummyDomain but it's over written by the Transport
|
||||||
var dummyDomain = "http://dummyDomain/"
|
var dummyDomain = "http://dummyDomain"
|
||||||
|
|
||||||
func dialFunc(sockType, remote string) func(string, string) (net.Conn, error) {
|
func dialer(remote string) func(string, string) (net.Conn, error) {
|
||||||
return func(proto, addr string) (conn net.Conn, err error) {
|
return func(proto, addr string) (conn net.Conn, err error) {
|
||||||
return net.Dial(sockType, remote)
|
return net.Dial(rpctypes.SocketType(remote), remote)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remote is IP:PORT or /path/to/socket
|
// remote is IP:PORT or /path/to/socket
|
||||||
func socketTransport(remote string) *http.Transport {
|
func socketTransport(remote string) *http.Transport {
|
||||||
if rpctypes.SocketType(remote) == "unix" {
|
|
||||||
return &http.Transport{
|
return &http.Transport{
|
||||||
Dial: dialFunc("unix", remote),
|
Dial: dialer(remote),
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return &http.Transport{
|
|
||||||
Dial: dialFunc("tcp", remote),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -105,7 +99,7 @@ func (c *ClientURI) call(method string, params map[string]interface{}, result in
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
log.Info(Fmt("URI request to %v (%v): %v", c.remote, method, values))
|
log.Info(Fmt("URI request to %v (%v): %v", c.remote, method, values))
|
||||||
resp, err := c.client.PostForm(dummyDomain+method, values)
|
resp, err := c.client.PostForm(dummyDomain+"/"+method, values)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,16 +19,18 @@ const (
|
||||||
|
|
||||||
type WSClient struct {
|
type WSClient struct {
|
||||||
QuitService
|
QuitService
|
||||||
Address string
|
Address string // IP:PORT or /path/to/socket
|
||||||
|
Endpoint string // /websocket/url/endpoint
|
||||||
*websocket.Conn
|
*websocket.Conn
|
||||||
ResultsCh chan json.RawMessage // closes upon WSClient.Stop()
|
ResultsCh chan json.RawMessage // closes upon WSClient.Stop()
|
||||||
ErrorsCh chan error // closes upon WSClient.Stop()
|
ErrorsCh chan error // closes upon WSClient.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a new connection
|
// create a new connection
|
||||||
func NewWSClient(addr string) *WSClient {
|
func NewWSClient(addr, endpoint string) *WSClient {
|
||||||
wsClient := &WSClient{
|
wsClient := &WSClient{
|
||||||
Address: addr,
|
Address: addr,
|
||||||
|
Endpoint: endpoint,
|
||||||
Conn: nil,
|
Conn: nil,
|
||||||
ResultsCh: make(chan json.RawMessage, wsResultsChannelCapacity),
|
ResultsCh: make(chan json.RawMessage, wsResultsChannelCapacity),
|
||||||
ErrorsCh: make(chan error, wsErrorsChannelCapacity),
|
ErrorsCh: make(chan error, wsErrorsChannelCapacity),
|
||||||
|
@ -38,7 +40,7 @@ func NewWSClient(addr string) *WSClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wsc *WSClient) String() string {
|
func (wsc *WSClient) String() string {
|
||||||
return wsc.Address
|
return wsc.Address + ", " + wsc.Endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wsc *WSClient) OnStart() error {
|
func (wsc *WSClient) OnStart() error {
|
||||||
|
@ -52,10 +54,14 @@ func (wsc *WSClient) OnStart() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (wsc *WSClient) dial() error {
|
func (wsc *WSClient) dial() error {
|
||||||
|
|
||||||
// Dial
|
// Dial
|
||||||
dialer := websocket.DefaultDialer
|
dialer := &websocket.Dialer{
|
||||||
|
NetDial: dialer(wsc.Address),
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
}
|
||||||
rHeader := http.Header{}
|
rHeader := http.Header{}
|
||||||
con, _, err := dialer.Dial(wsc.Address, rHeader)
|
con, _, err := dialer.Dial("ws://"+dummyDomain+wsc.Endpoint, rHeader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
52
rpc_test.go
52
rpc_test.go
|
@ -7,6 +7,7 @@ import (
|
||||||
|
|
||||||
"github.com/tendermint/go-rpc/client"
|
"github.com/tendermint/go-rpc/client"
|
||||||
"github.com/tendermint/go-rpc/server"
|
"github.com/tendermint/go-rpc/server"
|
||||||
|
"github.com/tendermint/go-rpc/types"
|
||||||
"github.com/tendermint/go-wire"
|
"github.com/tendermint/go-wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,6 +15,8 @@ import (
|
||||||
var (
|
var (
|
||||||
tcpAddr = "0.0.0.0:46657"
|
tcpAddr = "0.0.0.0:46657"
|
||||||
unixAddr = "/tmp/go-rpc.sock" // NOTE: must remove file for test to run again
|
unixAddr = "/tmp/go-rpc.sock" // NOTE: must remove file for test to run again
|
||||||
|
|
||||||
|
websocketEndpoint = "/websocket/endpoint"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Define a type for results and register concrete versions
|
// Define a type for results and register concrete versions
|
||||||
|
@ -42,6 +45,8 @@ func StatusResult(v string) (Result, error) {
|
||||||
func init() {
|
func init() {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
rpcserver.RegisterRPCFuncs(mux, Routes)
|
rpcserver.RegisterRPCFuncs(mux, Routes)
|
||||||
|
wm := rpcserver.NewWebsocketManager(Routes, nil)
|
||||||
|
mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler)
|
||||||
go func() {
|
go func() {
|
||||||
_, err := rpcserver.StartHTTPServer(tcpAddr, mux)
|
_, err := rpcserver.StartHTTPServer(tcpAddr, mux)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -51,6 +56,8 @@ func init() {
|
||||||
|
|
||||||
mux = http.NewServeMux()
|
mux = http.NewServeMux()
|
||||||
rpcserver.RegisterRPCFuncs(mux, Routes)
|
rpcserver.RegisterRPCFuncs(mux, Routes)
|
||||||
|
wm = rpcserver.NewWebsocketManager(Routes, nil)
|
||||||
|
mux.HandleFunc(websocketEndpoint, wm.WebsocketHandler)
|
||||||
go func() {
|
go func() {
|
||||||
_, err := rpcserver.StartHTTPServer(unixAddr, mux)
|
_, err := rpcserver.StartHTTPServer(unixAddr, mux)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -93,6 +100,33 @@ func testJSONRPC(t *testing.T, cl *rpcclient.ClientJSONRPC) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testWS(t *testing.T, cl *rpcclient.WSClient) {
|
||||||
|
val := "acbd"
|
||||||
|
params := []interface{}{val}
|
||||||
|
err := cl.WriteJSON(rpctypes.RPCRequest{
|
||||||
|
JSONRPC: "2.0",
|
||||||
|
ID: "",
|
||||||
|
Method: "status",
|
||||||
|
Params: params,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := <-cl.ResultsCh
|
||||||
|
result := new(Result)
|
||||||
|
wire.ReadJSONPtr(result, msg, &err)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
got := (*result).(*ResultStatus).Value
|
||||||
|
if got != val {
|
||||||
|
t.Fatalf("Got: %v .... Expected: %v \n", got, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//-------------
|
||||||
|
|
||||||
func TestURI_TCP(t *testing.T) {
|
func TestURI_TCP(t *testing.T) {
|
||||||
cl := rpcclient.NewClientURI(tcpAddr)
|
cl := rpcclient.NewClientURI(tcpAddr)
|
||||||
testURI(t, cl)
|
testURI(t, cl)
|
||||||
|
@ -112,3 +146,21 @@ func TestJSONRPC_UNIX(t *testing.T) {
|
||||||
cl := rpcclient.NewClientJSONRPC(unixAddr)
|
cl := rpcclient.NewClientJSONRPC(unixAddr)
|
||||||
testJSONRPC(t, cl)
|
testJSONRPC(t, cl)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWS_TCP(t *testing.T) {
|
||||||
|
cl := rpcclient.NewWSClient(tcpAddr, websocketEndpoint)
|
||||||
|
_, err := cl.Start()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
testWS(t, cl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWS_UNIX(t *testing.T) {
|
||||||
|
cl := rpcclient.NewWSClient(unixAddr, websocketEndpoint)
|
||||||
|
_, err := cl.Start()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
testWS(t, cl)
|
||||||
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ type WSRPCContext struct {
|
||||||
// If tcp, must specify the port; `0.0.0.0` will return incorrectly as "unix" since there's no port
|
// If tcp, must specify the port; `0.0.0.0` will return incorrectly as "unix" since there's no port
|
||||||
func SocketType(listenAddr string) string {
|
func SocketType(listenAddr string) string {
|
||||||
socketType := "unix"
|
socketType := "unix"
|
||||||
if len(strings.Split(listenAddr, ":")) == 2 {
|
if len(strings.Split(listenAddr, ":")) >= 2 {
|
||||||
socketType = "tcp"
|
socketType = "tcp"
|
||||||
}
|
}
|
||||||
return socketType
|
return socketType
|
||||||
|
|
Loading…
Reference in New Issue