diff --git a/accounts/vault/vault_backend.go b/accounts/vault/vault_backend.go index ef5210def..f596238fc 100644 --- a/accounts/vault/vault_backend.go +++ b/accounts/vault/vault_backend.go @@ -10,15 +10,17 @@ import ( type VaultBackend struct { wallets []accounts.Wallet updateScope event.SubscriptionScope - updateFeed event.Feed + updateFeed *event.Feed // Other backend impls require mutexes for safety as their wallets can change at any time (e.g. if a file/usb is added/removed). vaultWallets can only be created at startup so there is no danger of concurrent reads and writes. } func NewHashicorpBackend(walletConfigs []hashicorpWalletConfig) VaultBackend { wallets := []accounts.Wallet{} + var updateFeed event.Feed + for _, conf := range walletConfigs { - w, err := newHashicorpWallet(conf) + w, err := newHashicorpWallet(conf, &updateFeed) if err != nil { log.Error("unable to create Hashicorp wallet from config", "err", err) continue @@ -30,6 +32,7 @@ func NewHashicorpBackend(walletConfigs []hashicorpWalletConfig) VaultBackend { return VaultBackend{ wallets: wallets, + updateFeed: &updateFeed, } } diff --git a/accounts/vault/vault_backend_test.go b/accounts/vault/vault_backend_test.go index 080245d82..8d7cdc69b 100644 --- a/accounts/vault/vault_backend_test.go +++ b/accounts/vault/vault_backend_test.go @@ -2,12 +2,13 @@ package vault import ( "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/event" "reflect" "strings" "testing" ) -func TestNewHashicorpBackend_CreatesWalletsFromConfig(t *testing.T) { +func TestNewHashicorpBackend_CreatesWalletsWithUrlsFromConfig(t *testing.T) { makeConfs := func (url string, urls... string) []hashicorpWalletConfig { var confs []hashicorpWalletConfig @@ -20,47 +21,55 @@ func TestNewHashicorpBackend_CreatesWalletsFromConfig(t *testing.T) { return confs } - // makeWlts crudely splits the urls to get them as accounts.URLs so as to not use the same parsing method as in the production code. This is fine for simple urls but may not be suitable for tests that require more complex urls. - makeWlts := func(url string, urls... string) []accounts.Wallet { - var wlts []accounts.Wallet + makeUrls := func(strUrl string, strUrls... string) []accounts.URL { + var urls []accounts.URL - s := strings.Split(url, "://") + s := strings.Split(strUrl, "://") scheme, path := s[0], s[1] - wlts = append(wlts, vaultWallet{url: accounts.URL{Scheme: scheme, Path: path}}) + urls = append(urls, accounts.URL{Scheme: scheme, Path: path}) - for _, u := range urls { + for _, u := range strUrls { s := strings.Split(u, "://") scheme, path := s[0], s[1] - wlts = append(wlts, vaultWallet{url: accounts.URL{Scheme: scheme, Path: path}}) + urls = append(urls, accounts.URL{Scheme: scheme, Path: path}) } - return wlts + return urls } tests := map[string]struct{ in []hashicorpWalletConfig - want []accounts.Wallet + wantUrls []accounts.URL }{ - "no config": {in: []hashicorpWalletConfig{}, want: []accounts.Wallet{}}, - "single": {in: makeConfs("http://url:1"), want: makeWlts("http://url:1")}, - "multiple": {in: makeConfs("http://url:1", "http://url:2"), want: makeWlts("http://url:1", "http://url:2")}, + "no config": {in: []hashicorpWalletConfig{}, wantUrls: []accounts.URL(nil)}, + "single": {in: makeConfs("http://url:1"), wantUrls: makeUrls("http://url:1")}, + "multiple": {in: makeConfs("http://url:1", "http://url:2"), wantUrls: makeUrls("http://url:1", "http://url:2")}, "orders by url": { in: makeConfs("https://url:1", "https://a:9", "http://url:2", "http://url:1"), - want: makeWlts("http://url:1", "http://url:2", "https://a:9", "https://url:1")}, + wantUrls: makeUrls("http://url:1", "http://url:2", "https://a:9", "https://url:1")}, } for name, tt := range tests { t.Run(name, func(t *testing.T) { b := NewHashicorpBackend(tt.in) - if !reflect.DeepEqual(tt.want, b.wallets) { - t.Fatalf("\nwant: %v, \ngot : %v", tt.want, b.wallets) + if len(tt.wantUrls) != len(b.wallets) { + t.Fatalf("wallets created with incorrect urls or incorrectly ordered by url: want: %v, got: %v", len(tt.wantUrls), len(b.wallets)) + } + + var gotUrls []accounts.URL + + for _, wlt := range b.wallets { + gotUrls = append(gotUrls, wlt.URL()) + } + + if !reflect.DeepEqual(tt.wantUrls, gotUrls) { + t.Fatalf("incorrect wallets created/wallets incorrectly ordered\nwant: %v\ngot : %v", tt.wantUrls, gotUrls) } }) } - } func TestVaultBackend_Wallets_ReturnsWallets(t *testing.T) { @@ -103,9 +112,9 @@ func TestVaultBackend_Wallets_ReturnsCopy(t *testing.T) { } func TestVaultBackend_Subscribe_SubscriberReceivesEventsAddedToFeed(t *testing.T) { - subscriber := make(chan accounts.WalletEvent, 1) - b := VaultBackend{} + b := VaultBackend{updateFeed: &event.Feed{}} + subscriber := make(chan accounts.WalletEvent, 1) b.Subscribe(subscriber) if b.updateScope.Count() != 1 { @@ -113,8 +122,8 @@ func TestVaultBackend_Subscribe_SubscriberReceivesEventsAddedToFeed(t *testing.T } // mock an event - event := accounts.WalletEvent{Wallet: vaultWallet{}, Kind: accounts.WalletOpened} - b.updateFeed.Send(event) + e := accounts.WalletEvent{Wallet: vaultWallet{}, Kind: accounts.WalletOpened} + b.updateFeed.Send(e) if len(subscriber) != 1 { t.Fatal("event not added to subscriber") @@ -122,7 +131,40 @@ func TestVaultBackend_Subscribe_SubscriberReceivesEventsAddedToFeed(t *testing.T got := <-subscriber - if !reflect.DeepEqual(event, got) { - t.Fatalf("want: %v, got: %v", event, got) + if !reflect.DeepEqual(e, got) { + t.Fatalf("want: %v, got: %v", e, got) } } + +func TestVaultBackend_Subscribe_SubscriberReceivesEventsAddedToFeedByHashicorpWallet(t *testing.T) { + conf := hashicorpWalletConfig{Client: hashicorpClientConfig{Url: "http://url:1"}} + b := NewHashicorpBackend([]hashicorpWalletConfig{conf}) + + if len(b.wallets) != 1 { + t.Fatalf("incorrect number of wallets: want: %v, got: %v", 1, len(b.wallets)) + } + + w := b.wallets[0].(vaultWallet) + + subscriber := make(chan accounts.WalletEvent, 1) + b.Subscribe(subscriber) + + if b.updateScope.Count() != 1 { + t.Fatalf("incorrect number of subscribers for backend: want: %v, got: %v", 1, b.updateScope.Count()) + } + + // mock an event + e := accounts.WalletEvent{Wallet: vaultWallet{}, Kind: accounts.WalletOpened} + w.updateFeed.Send(e) + + if len(subscriber) != 1 { + t.Fatal("event not added to subscriber") + } + + got := <-subscriber + + if !reflect.DeepEqual(e, got) { + t.Fatalf("want: %v, got: %v", e, got) + } + +} diff --git a/accounts/vault/vault_wallet.go b/accounts/vault/vault_wallet.go index 98cbf2b35..be5b114db 100644 --- a/accounts/vault/vault_wallet.go +++ b/accounts/vault/vault_wallet.go @@ -6,6 +6,7 @@ import ( "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/event" "github.com/hashicorp/vault/api" "math/big" "os" @@ -14,6 +15,7 @@ import ( type vaultWallet struct { url accounts.URL vault vaultService + updateFeed *event.Feed } // vault related behaviour that will be specific to each vault type @@ -22,7 +24,7 @@ type vaultService interface { open() error } -func newHashicorpWallet(config hashicorpWalletConfig) (vaultWallet, error) { +func newHashicorpWallet(config hashicorpWalletConfig, updateFeed *event.Feed) (vaultWallet, error) { var url accounts.URL //to parse a string url as an accounts.URL it must first be in json format @@ -32,7 +34,7 @@ func newHashicorpWallet(config hashicorpWalletConfig) (vaultWallet, error) { return vaultWallet{}, err } - return vaultWallet{url: url, vault: &hashicorpService{config: config.Client}}, nil + return vaultWallet{url: url, vault: &hashicorpService{config: config.Client}, updateFeed: updateFeed}, nil } func (w vaultWallet) URL() accounts.URL { @@ -45,7 +47,11 @@ func (w vaultWallet) Status() (string, error) { } func (w vaultWallet) Open(passphrase string) error { - return w.vault.open() + if err := w.vault.open(); err != nil { + return err + } + w.updateFeed.Send(accounts.WalletEvent{Wallet: w, Kind: accounts.WalletOpened}) + return nil } func (w vaultWallet) Close() error { diff --git a/accounts/vault/vault_wallet_test.go b/accounts/vault/vault_wallet_test.go index 8c7273c18..5294978e8 100644 --- a/accounts/vault/vault_wallet_test.go +++ b/accounts/vault/vault_wallet_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "github.com/ethereum/go-ethereum/accounts" + "github.com/ethereum/go-ethereum/event" "github.com/hashicorp/vault/api" "io/ioutil" "net/http" @@ -150,7 +151,7 @@ func TestVaultWallet_Open_Hashicorp_ReturnsErrIfAlreadyOpen(t *testing.T) { } } -func TestVaultWallet_Open_Hashicorp_CreatesClientFromConfig(t *testing.T) { +func TestVaultWallet_Open_Hashicorp_CreatesClientUsingConfig(t *testing.T) { if err := os.Setenv(api.EnvVaultToken, "mytoken"); err != nil { t.Fatal(err) } @@ -172,7 +173,7 @@ func TestVaultWallet_Open_Hashicorp_CreatesClientFromConfig(t *testing.T) { Url: vaultServer.URL, } - w := vaultWallet{vault: &hashicorpService{config: config}} + w := vaultWallet{vault: &hashicorpService{config: config}, updateFeed: &event.Feed{}} if err := w.Open(""); err != nil { t.Fatalf("error: %v", err) @@ -206,7 +207,7 @@ func TestVaultWallet_Open_Hashicorp_CreatesClientFromConfig(t *testing.T) { } } -func TestVaultWallet_Open_Hashicorp_CreatesTLSClientFromConfig(t *testing.T) { +func TestVaultWallet_Open_Hashicorp_CreatesTLSClientUsingConfig(t *testing.T) { if err := os.Setenv(api.EnvVaultToken, "mytoken"); err != nil { t.Fatal(err) } @@ -266,15 +267,12 @@ func TestVaultWallet_Open_Hashicorp_CreatesTLSClientFromConfig(t *testing.T) { // create wallet with config and open config := hashicorpClientConfig{ Url: vaultServer.URL, - //Approle: "myapprole", CaCert: "testdata/caRoot.pem", ClientCert: "testdata/quorum-client-chain.pem", ClientKey: "testdata/quorum-client.key", - //EnvVarPrefix: "prefix", - //UseSecretCache: false, } - w := vaultWallet{vault: &hashicorpService{config: config}} + w := vaultWallet{vault: &hashicorpService{config: config}, updateFeed: &event.Feed{}} if err := w.Open(""); err != nil { t.Fatalf("error: %v", err) @@ -303,7 +301,7 @@ func TestVaultWallet_Open_Hashicorp_CreatesTLSClientFromConfig(t *testing.T) { } } -func TestVaultWallet_Open_Hashicorp_CreatesAuthenticatedClient(t *testing.T) { +func TestVaultWallet_Open_Hashicorp_ClientAuthenticatesUsingEnvVars(t *testing.T) { const ( myToken = "myToken" myRoleId = "myRoleId" @@ -381,7 +379,7 @@ func TestVaultWallet_Open_Hashicorp_CreatesAuthenticatedClient(t *testing.T) { Approle: tt.approle, } - w := vaultWallet{vault: &hashicorpService{config: config}} + w := vaultWallet{vault: &hashicorpService{config: config}, updateFeed: &event.Feed{}} if err := w.Open(""); err != nil { t.Fatalf("error: %v", err) @@ -407,7 +405,7 @@ func TestVaultWallet_Open_Hashicorp_CreatesAuthenticatedClient(t *testing.T) { } } -func TestVaultWallet_Open_Hashicorp_ErrCreatingAuthenticatedClient(t *testing.T) { +func TestVaultWallet_Open_Hashicorp_ErrAuthenticatingClient(t *testing.T) { const ( myToken = "myToken" myRoleId = "myRoleId" @@ -464,3 +462,44 @@ func TestVaultWallet_Open_Hashicorp_ErrCreatingAuthenticatedClient(t *testing.T) } } +// Note: This is an integration test, as such the scope of the test is large, covering the VaultBackend, VaultWallet and HashicorpService +func TestVaultWallet_Open_Hashicorp_SendsEventToBackendSubscribers(t *testing.T) { + if err := os.Setenv(api.EnvVaultToken, "mytoken"); err != nil { + t.Fatal(err) + } + + walletConfig := hashicorpWalletConfig{ + Client: hashicorpClientConfig{ + Url: "http://url:1", + }, + } + + b := NewHashicorpBackend([]hashicorpWalletConfig{walletConfig}) + + if len(b.wallets) != 1 { + t.Fatalf("NewHashicorpBackend: incorrect number of wallets created: want 1, got: %v", len(b.wallets)) + } + + subscriber := make(chan accounts.WalletEvent, 1) + b.Subscribe(subscriber) + + if b.updateScope.Count() != 1 { + t.Fatalf("incorrect number of subscribers for backend: want: %v, got: %v", 1, b.updateScope.Count()) + } + + if err := b.wallets[0].Open(""); err != nil { + t.Fatalf("error: %v", err) + } + + if len(subscriber) != 1 { + t.Fatal("event not added to subscriber") + } + + got := <-subscriber + + want := accounts.WalletEvent{Wallet: b.wallets[0], Kind: accounts.WalletOpened} + + if !reflect.DeepEqual(want, got) { + t.Fatalf("want: %v, got: %v", want, got) + } +}