Refactor auth token setting in Portainer client

This commit is contained in:
Juan Carlos Mejías Rodríguez 2019-08-24 09:01:27 -04:00
parent 2cc0faa625
commit 14043a8bf8
2 changed files with 27 additions and 51 deletions

View File

@ -78,7 +78,6 @@ type portainerClientImp struct {
password string
token string
userAgent string
doNotUseToken bool
beforeRequestHooks []func(req *http.Request) (err error)
afterResponseHooks []func(resp *http.Response) (err error)
}
@ -125,16 +124,6 @@ func (n *portainerClientImp) do(uri, method string, request io.Reader, requestTy
req.Header.Set("User-Agent", n.userAgent)
}
if !n.doNotUseToken {
if n.token == "" {
n.token, err = n.Authenticate()
if err != nil {
return
}
}
req.Header.Set("Authorization", "Bearer "+n.token)
}
// Run all "before request" hooks
for i := 0; i < len(n.beforeRequestHooks); i++ {
err = n.beforeRequestHooks[i](req)
@ -165,7 +154,7 @@ func (n *portainerClientImp) do(uri, method string, request io.Reader, requestTy
}
// Do a JSON http request
func (n *portainerClientImp) doJSON(uri, method string, request interface{}, response interface{}) error {
func (n *portainerClientImp) doJSON(uri, method string, headers http.Header, request interface{}, response interface{}) error {
var body io.Reader
if request != nil {
@ -176,7 +165,7 @@ func (n *portainerClientImp) doJSON(uri, method string, request interface{}, res
body = bytes.NewReader(reqBodyBytes)
}
resp, err := n.do(uri, method, body, "application/json", nil)
resp, err := n.do(uri, method, body, "application/json", headers)
if err != nil {
return err
}
@ -192,6 +181,20 @@ func (n *portainerClientImp) doJSON(uri, method string, request interface{}, res
return nil
}
// Do a JSON http request with an auth token
func (n *portainerClientImp) doJSONWithToken(uri, method string, headers http.Header, request interface{}, response interface{}) (err error) {
// Ensure there is an auth token
if n.token == "" {
n.token, err = n.Authenticate()
if err != nil {
return
}
}
headers.Set("Authorization", "Bearer "+n.token)
return n.doJSON(uri, method, headers, request, response)
}
func (n *portainerClientImp) BeforeRequest(hook func(req *http.Request) (err error)) {
n.beforeRequestHooks = append(n.beforeRequestHooks, hook)
}
@ -208,28 +211,23 @@ func (n *portainerClientImp) Authenticate() (token string, err error) {
respBody := AuthenticateUserResponse{}
previousDoNotUseTokenValue := n.doNotUseToken
n.doNotUseToken = true
err = n.doJSON("auth", http.MethodPost, &reqBody, &respBody)
err = n.doJSON("auth", http.MethodPost, http.Header{}, &reqBody, &respBody)
if err != nil {
return
}
n.doNotUseToken = previousDoNotUseTokenValue
token = respBody.Jwt
return
}
func (n *portainerClientImp) GetEndpoints() (endpoints []portainer.Endpoint, err error) {
err = n.doJSON("endpoints", http.MethodGet, nil, &endpoints)
err = n.doJSONWithToken("endpoints", http.MethodGet, http.Header{}, nil, &endpoints)
return
}
func (n *portainerClientImp) GetEndpointGroups() (endpointGroups []portainer.EndpointGroup, err error) {
err = n.doJSON("endpoint_groups", http.MethodGet, nil, &endpointGroups)
err = n.doJSONWithToken("endpoint_groups", http.MethodGet, http.Header{}, nil, &endpointGroups)
return
}
@ -242,7 +240,7 @@ func (n *portainerClientImp) GetStacks(swarmID string, endpointID portainer.Endp
filterJSONBytes, _ := json.Marshal(filter)
filterJSONString := string(filterJSONBytes)
err = n.doJSON(fmt.Sprintf("stacks?filters=%s", filterJSONString), http.MethodGet, nil, &stacks)
err = n.doJSONWithToken(fmt.Sprintf("stacks?filters=%s", filterJSONString), http.MethodGet, http.Header{}, nil, &stacks)
return
}
@ -254,7 +252,7 @@ func (n *portainerClientImp) CreateSwarmStack(stackName string, environmentVaria
StackFileContent: stackFileContent,
}
err = n.doJSON(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 1, "string", endpointID), http.MethodPost, &reqBody, &stack)
err = n.doJSONWithToken(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 1, "string", endpointID), http.MethodPost, http.Header{}, &reqBody, &stack)
return
}
@ -265,7 +263,7 @@ func (n *portainerClientImp) CreateComposeStack(stackName string, environmentVar
StackFileContent: stackFileContent,
}
err = n.doJSON(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 2, "string", endpointID), http.MethodPost, &reqBody, &stack)
err = n.doJSONWithToken(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 2, "string", endpointID), http.MethodPost, http.Header{}, &reqBody, &stack)
return
}
@ -276,19 +274,19 @@ func (n *portainerClientImp) UpdateStack(stack portainer.Stack, environmentVaria
Prune: prune,
}
err = n.doJSON(fmt.Sprintf("stacks/%v?endpointId=%v", stack.ID, endpointID), http.MethodPut, &reqBody, nil)
err = n.doJSONWithToken(fmt.Sprintf("stacks/%v?endpointId=%v", stack.ID, endpointID), http.MethodPut, http.Header{}, &reqBody, nil)
return
}
func (n *portainerClientImp) DeleteStack(stackID portainer.StackID) (err error) {
err = n.doJSON(fmt.Sprintf("stacks/%d", stackID), http.MethodDelete, nil, nil)
err = n.doJSONWithToken(fmt.Sprintf("stacks/%d", stackID), http.MethodDelete, http.Header{}, nil, nil)
return
}
func (n *portainerClientImp) GetStackFileContent(stackID portainer.StackID) (content string, err error) {
var respBody StackFileInspectResponse
err = n.doJSON(fmt.Sprintf("stacks/%v/file", stackID), http.MethodGet, nil, &respBody)
err = n.doJSONWithToken(fmt.Sprintf("stacks/%v/file", stackID), http.MethodGet, http.Header{}, nil, &respBody)
if err != nil {
return
}
@ -299,12 +297,12 @@ func (n *portainerClientImp) GetStackFileContent(stackID portainer.StackID) (con
}
func (n *portainerClientImp) GetEndpointDockerInfo(endpointID portainer.EndpointID) (info map[string]interface{}, err error) {
err = n.doJSON(fmt.Sprintf("endpoints/%v/docker/info", endpointID), http.MethodGet, nil, &info)
err = n.doJSONWithToken(fmt.Sprintf("endpoints/%v/docker/info", endpointID), http.MethodGet, http.Header{}, nil, &info)
return
}
func (n *portainerClientImp) GetStatus() (status portainer.Status, err error) {
err = n.doJSON("status", http.MethodGet, nil, &status)
err = n.doJSONWithToken("status", http.MethodGet, http.Header{}, nil, &status)
return
}

View File

@ -79,7 +79,6 @@ func Test_portainerClientImp_do(t *testing.T) {
password string
token string
userAgent string
doNotUseToken bool
beforeRequestHooks []func(req *http.Request) (err error)
afterResponseHooks []func(resp *http.Response) (err error)
server *httptest.Server
@ -122,7 +121,6 @@ func Test_portainerClientImp_do(t *testing.T) {
{
name: "extra headers are added",
fields: fields{
token: "somerandomtoken",
server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, req.Header.Get("Some-Header"), "value")
})),
@ -135,24 +133,6 @@ func Test_portainerClientImp_do(t *testing.T) {
},
},
},
{
name: "Authorization header is added when doNotUseToken is false",
fields: fields{
token: "token",
server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.NotEmpty(t, req.Header.Get("Authorization"))
})),
},
},
{
name: "Authorization header is not added when doNotUseToken is true",
fields: fields{
doNotUseToken: true,
server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, req.Header.Get("Authorization"), "")
})),
},
},
{
name: "returns error on http error",
fields: fields{
@ -167,7 +147,6 @@ func Test_portainerClientImp_do(t *testing.T) {
{
name: "returns error on response error",
fields: fields{
token: "token",
server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
})),
@ -189,7 +168,6 @@ func Test_portainerClientImp_do(t *testing.T) {
password: tt.fields.password,
token: tt.fields.token,
userAgent: tt.fields.userAgent,
doNotUseToken: tt.fields.doNotUseToken,
beforeRequestHooks: tt.fields.beforeRequestHooks,
afterResponseHooks: tt.fields.afterResponseHooks,
}