diff --git a/client/client.go b/client/client.go index f0f81ff..1ee628d 100644 --- a/client/client.go +++ b/client/client.go @@ -78,7 +78,6 @@ type portainerClientImp struct { password string token string userAgent string - doNotUseToken bool beforeRequestHooks []func(req *http.Request) (err error) afterResponseHooks []func(resp *http.Response) (err error) } @@ -125,16 +124,6 @@ func (n *portainerClientImp) do(uri, method string, request io.Reader, requestTy req.Header.Set("User-Agent", n.userAgent) } - if !n.doNotUseToken { - if n.token == "" { - n.token, err = n.Authenticate() - if err != nil { - return - } - } - req.Header.Set("Authorization", "Bearer "+n.token) - } - // Run all "before request" hooks for i := 0; i < len(n.beforeRequestHooks); i++ { err = n.beforeRequestHooks[i](req) @@ -165,7 +154,7 @@ func (n *portainerClientImp) do(uri, method string, request io.Reader, requestTy } // Do a JSON http request -func (n *portainerClientImp) doJSON(uri, method string, request interface{}, response interface{}) error { +func (n *portainerClientImp) doJSON(uri, method string, headers http.Header, request interface{}, response interface{}) error { var body io.Reader if request != nil { @@ -176,7 +165,7 @@ func (n *portainerClientImp) doJSON(uri, method string, request interface{}, res body = bytes.NewReader(reqBodyBytes) } - resp, err := n.do(uri, method, body, "application/json", nil) + resp, err := n.do(uri, method, body, "application/json", headers) if err != nil { return err } @@ -192,6 +181,20 @@ func (n *portainerClientImp) doJSON(uri, method string, request interface{}, res return nil } +// Do a JSON http request with an auth token +func (n *portainerClientImp) doJSONWithToken(uri, method string, headers http.Header, request interface{}, response interface{}) (err error) { + // Ensure there is an auth token + if n.token == "" { + n.token, err = n.Authenticate() + if err != nil { + return + } + } + headers.Set("Authorization", "Bearer "+n.token) + + return n.doJSON(uri, method, headers, request, response) +} + func (n *portainerClientImp) BeforeRequest(hook func(req *http.Request) (err error)) { n.beforeRequestHooks = append(n.beforeRequestHooks, hook) } @@ -208,28 +211,23 @@ func (n *portainerClientImp) Authenticate() (token string, err error) { respBody := AuthenticateUserResponse{} - previousDoNotUseTokenValue := n.doNotUseToken - n.doNotUseToken = true - - err = n.doJSON("auth", http.MethodPost, &reqBody, &respBody) + err = n.doJSON("auth", http.MethodPost, http.Header{}, &reqBody, &respBody) if err != nil { return } - n.doNotUseToken = previousDoNotUseTokenValue - token = respBody.Jwt return } func (n *portainerClientImp) GetEndpoints() (endpoints []portainer.Endpoint, err error) { - err = n.doJSON("endpoints", http.MethodGet, nil, &endpoints) + err = n.doJSONWithToken("endpoints", http.MethodGet, http.Header{}, nil, &endpoints) return } func (n *portainerClientImp) GetEndpointGroups() (endpointGroups []portainer.EndpointGroup, err error) { - err = n.doJSON("endpoint_groups", http.MethodGet, nil, &endpointGroups) + err = n.doJSONWithToken("endpoint_groups", http.MethodGet, http.Header{}, nil, &endpointGroups) return } @@ -242,7 +240,7 @@ func (n *portainerClientImp) GetStacks(swarmID string, endpointID portainer.Endp filterJSONBytes, _ := json.Marshal(filter) filterJSONString := string(filterJSONBytes) - err = n.doJSON(fmt.Sprintf("stacks?filters=%s", filterJSONString), http.MethodGet, nil, &stacks) + err = n.doJSONWithToken(fmt.Sprintf("stacks?filters=%s", filterJSONString), http.MethodGet, http.Header{}, nil, &stacks) return } @@ -254,7 +252,7 @@ func (n *portainerClientImp) CreateSwarmStack(stackName string, environmentVaria StackFileContent: stackFileContent, } - err = n.doJSON(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 1, "string", endpointID), http.MethodPost, &reqBody, &stack) + err = n.doJSONWithToken(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 1, "string", endpointID), http.MethodPost, http.Header{}, &reqBody, &stack) return } @@ -265,7 +263,7 @@ func (n *portainerClientImp) CreateComposeStack(stackName string, environmentVar StackFileContent: stackFileContent, } - err = n.doJSON(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 2, "string", endpointID), http.MethodPost, &reqBody, &stack) + err = n.doJSONWithToken(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 2, "string", endpointID), http.MethodPost, http.Header{}, &reqBody, &stack) return } @@ -276,19 +274,19 @@ func (n *portainerClientImp) UpdateStack(stack portainer.Stack, environmentVaria Prune: prune, } - err = n.doJSON(fmt.Sprintf("stacks/%v?endpointId=%v", stack.ID, endpointID), http.MethodPut, &reqBody, nil) + err = n.doJSONWithToken(fmt.Sprintf("stacks/%v?endpointId=%v", stack.ID, endpointID), http.MethodPut, http.Header{}, &reqBody, nil) return } func (n *portainerClientImp) DeleteStack(stackID portainer.StackID) (err error) { - err = n.doJSON(fmt.Sprintf("stacks/%d", stackID), http.MethodDelete, nil, nil) + err = n.doJSONWithToken(fmt.Sprintf("stacks/%d", stackID), http.MethodDelete, http.Header{}, nil, nil) return } func (n *portainerClientImp) GetStackFileContent(stackID portainer.StackID) (content string, err error) { var respBody StackFileInspectResponse - err = n.doJSON(fmt.Sprintf("stacks/%v/file", stackID), http.MethodGet, nil, &respBody) + err = n.doJSONWithToken(fmt.Sprintf("stacks/%v/file", stackID), http.MethodGet, http.Header{}, nil, &respBody) if err != nil { return } @@ -299,12 +297,12 @@ func (n *portainerClientImp) GetStackFileContent(stackID portainer.StackID) (con } func (n *portainerClientImp) GetEndpointDockerInfo(endpointID portainer.EndpointID) (info map[string]interface{}, err error) { - err = n.doJSON(fmt.Sprintf("endpoints/%v/docker/info", endpointID), http.MethodGet, nil, &info) + err = n.doJSONWithToken(fmt.Sprintf("endpoints/%v/docker/info", endpointID), http.MethodGet, http.Header{}, nil, &info) return } func (n *portainerClientImp) GetStatus() (status portainer.Status, err error) { - err = n.doJSON("status", http.MethodGet, nil, &status) + err = n.doJSONWithToken("status", http.MethodGet, http.Header{}, nil, &status) return } diff --git a/client/client_test.go b/client/client_test.go index ca7c980..7ed88d0 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -79,7 +79,6 @@ func Test_portainerClientImp_do(t *testing.T) { password string token string userAgent string - doNotUseToken bool beforeRequestHooks []func(req *http.Request) (err error) afterResponseHooks []func(resp *http.Response) (err error) server *httptest.Server @@ -122,7 +121,6 @@ func Test_portainerClientImp_do(t *testing.T) { { name: "extra headers are added", fields: fields{ - token: "somerandomtoken", server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { assert.Equal(t, req.Header.Get("Some-Header"), "value") })), @@ -135,24 +133,6 @@ func Test_portainerClientImp_do(t *testing.T) { }, }, }, - { - name: "Authorization header is added when doNotUseToken is false", - fields: fields{ - token: "token", - server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.NotEmpty(t, req.Header.Get("Authorization")) - })), - }, - }, - { - name: "Authorization header is not added when doNotUseToken is true", - fields: fields{ - doNotUseToken: true, - server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - assert.Equal(t, req.Header.Get("Authorization"), "") - })), - }, - }, { name: "returns error on http error", fields: fields{ @@ -167,7 +147,6 @@ func Test_portainerClientImp_do(t *testing.T) { { name: "returns error on response error", fields: fields{ - token: "token", server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { w.WriteHeader(http.StatusInternalServerError) })), @@ -189,7 +168,6 @@ func Test_portainerClientImp_do(t *testing.T) { password: tt.fields.password, token: tt.fields.token, userAgent: tt.fields.userAgent, - doNotUseToken: tt.fields.doNotUseToken, beforeRequestHooks: tt.fields.beforeRequestHooks, afterResponseHooks: tt.fields.afterResponseHooks, }