diff --git a/client/client_test.go b/client/client_test.go index 41576c6..ca7c980 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -3,6 +3,7 @@ package client import ( "encoding/json" "fmt" + "io" "io/ioutil" "net/http" "net/http/httptest" @@ -71,3 +72,137 @@ func TestClientAuthenticates(t *testing.T) { assert.Nil(t, err) assert.Equal(t, token, "somerandomtoken") } + +func Test_portainerClientImp_do(t *testing.T) { + type fields struct { + user string + password string + token string + userAgent string + doNotUseToken bool + beforeRequestHooks []func(req *http.Request) (err error) + afterResponseHooks []func(resp *http.Response) (err error) + server *httptest.Server + beforeFunctionCall func(t *testing.T, tt *fields) + } + type args struct { + uri string + method string + request io.Reader + requestType string + headers http.Header + } + tests := []struct { + name string + fields fields + args args + wantRespCheck func(resp *http.Response) bool + wantErr bool + }{ + { + name: "error on bad URI", + fields: fields{ + server: httptest.NewUnstartedServer(nil), + }, + args: args{ + uri: string(0x7f), + }, + wantErr: true, + }, + { + name: "error on bad method", + fields: fields{ + server: httptest.NewUnstartedServer(nil), + }, + args: args{ + method: "WOLOLO?", + }, + wantErr: true, + }, + { + name: "extra headers are added", + fields: fields{ + token: "somerandomtoken", + server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + assert.Equal(t, req.Header.Get("Some-Header"), "value") + })), + }, + args: args{ + headers: http.Header{ + "Some-Header": []string{ + "value", + }, + }, + }, + }, + { + name: "Authorization header is added when doNotUseToken is false", + fields: fields{ + token: "token", + server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + assert.NotEmpty(t, req.Header.Get("Authorization")) + })), + }, + }, + { + name: "Authorization header is not added when doNotUseToken is true", + fields: fields{ + doNotUseToken: true, + server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + assert.Equal(t, req.Header.Get("Authorization"), "") + })), + }, + }, + { + name: "returns error on http error", + fields: fields{ + token: "token", + server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {})), + beforeFunctionCall: func(t *testing.T, tt *fields) { + tt.server.Close() + }, + }, + wantErr: true, + }, + { + name: "returns error on response error", + fields: fields{ + token: "token", + server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })), + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.fields.server.Start() + defer tt.fields.server.Close() + + apiURL, _ := url.Parse(tt.fields.server.URL + "/api/") + + n := &portainerClientImp{ + httpClient: tt.fields.server.Client(), + url: apiURL, + user: tt.fields.user, + password: tt.fields.password, + token: tt.fields.token, + userAgent: tt.fields.userAgent, + doNotUseToken: tt.fields.doNotUseToken, + beforeRequestHooks: tt.fields.beforeRequestHooks, + afterResponseHooks: tt.fields.afterResponseHooks, + } + + if tt.fields.beforeFunctionCall != nil { + tt.fields.beforeFunctionCall(t, &tt.fields) + } + gotResp, err := n.do(tt.args.uri, tt.args.method, tt.args.request, tt.args.requestType, tt.args.headers) + + assert.Equal(t, tt.wantErr, err != nil) + if tt.wantRespCheck != nil { + assert.True(t, tt.wantRespCheck(gotResp)) + } + }) + } +}