mirror of
https://gitlab.com/psuapp/psu.git
synced 2024-08-30 18:12:34 +00:00
Refactor auth token setting in Portainer client
This commit is contained in:
parent
2cc0faa625
commit
14043a8bf8
@ -78,7 +78,6 @@ type portainerClientImp struct {
|
||||
password string
|
||||
token string
|
||||
userAgent string
|
||||
doNotUseToken bool
|
||||
beforeRequestHooks []func(req *http.Request) (err error)
|
||||
afterResponseHooks []func(resp *http.Response) (err error)
|
||||
}
|
||||
@ -125,16 +124,6 @@ func (n *portainerClientImp) do(uri, method string, request io.Reader, requestTy
|
||||
req.Header.Set("User-Agent", n.userAgent)
|
||||
}
|
||||
|
||||
if !n.doNotUseToken {
|
||||
if n.token == "" {
|
||||
n.token, err = n.Authenticate()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+n.token)
|
||||
}
|
||||
|
||||
// Run all "before request" hooks
|
||||
for i := 0; i < len(n.beforeRequestHooks); i++ {
|
||||
err = n.beforeRequestHooks[i](req)
|
||||
@ -165,7 +154,7 @@ func (n *portainerClientImp) do(uri, method string, request io.Reader, requestTy
|
||||
}
|
||||
|
||||
// Do a JSON http request
|
||||
func (n *portainerClientImp) doJSON(uri, method string, request interface{}, response interface{}) error {
|
||||
func (n *portainerClientImp) doJSON(uri, method string, headers http.Header, request interface{}, response interface{}) error {
|
||||
var body io.Reader
|
||||
|
||||
if request != nil {
|
||||
@ -176,7 +165,7 @@ func (n *portainerClientImp) doJSON(uri, method string, request interface{}, res
|
||||
body = bytes.NewReader(reqBodyBytes)
|
||||
}
|
||||
|
||||
resp, err := n.do(uri, method, body, "application/json", nil)
|
||||
resp, err := n.do(uri, method, body, "application/json", headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -192,6 +181,20 @@ func (n *portainerClientImp) doJSON(uri, method string, request interface{}, res
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do a JSON http request with an auth token
|
||||
func (n *portainerClientImp) doJSONWithToken(uri, method string, headers http.Header, request interface{}, response interface{}) (err error) {
|
||||
// Ensure there is an auth token
|
||||
if n.token == "" {
|
||||
n.token, err = n.Authenticate()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
headers.Set("Authorization", "Bearer "+n.token)
|
||||
|
||||
return n.doJSON(uri, method, headers, request, response)
|
||||
}
|
||||
|
||||
func (n *portainerClientImp) BeforeRequest(hook func(req *http.Request) (err error)) {
|
||||
n.beforeRequestHooks = append(n.beforeRequestHooks, hook)
|
||||
}
|
||||
@ -208,28 +211,23 @@ func (n *portainerClientImp) Authenticate() (token string, err error) {
|
||||
|
||||
respBody := AuthenticateUserResponse{}
|
||||
|
||||
previousDoNotUseTokenValue := n.doNotUseToken
|
||||
n.doNotUseToken = true
|
||||
|
||||
err = n.doJSON("auth", http.MethodPost, &reqBody, &respBody)
|
||||
err = n.doJSON("auth", http.MethodPost, http.Header{}, &reqBody, &respBody)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
n.doNotUseToken = previousDoNotUseTokenValue
|
||||
|
||||
token = respBody.Jwt
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (n *portainerClientImp) GetEndpoints() (endpoints []portainer.Endpoint, err error) {
|
||||
err = n.doJSON("endpoints", http.MethodGet, nil, &endpoints)
|
||||
err = n.doJSONWithToken("endpoints", http.MethodGet, http.Header{}, nil, &endpoints)
|
||||
return
|
||||
}
|
||||
|
||||
func (n *portainerClientImp) GetEndpointGroups() (endpointGroups []portainer.EndpointGroup, err error) {
|
||||
err = n.doJSON("endpoint_groups", http.MethodGet, nil, &endpointGroups)
|
||||
err = n.doJSONWithToken("endpoint_groups", http.MethodGet, http.Header{}, nil, &endpointGroups)
|
||||
return
|
||||
}
|
||||
|
||||
@ -242,7 +240,7 @@ func (n *portainerClientImp) GetStacks(swarmID string, endpointID portainer.Endp
|
||||
filterJSONBytes, _ := json.Marshal(filter)
|
||||
filterJSONString := string(filterJSONBytes)
|
||||
|
||||
err = n.doJSON(fmt.Sprintf("stacks?filters=%s", filterJSONString), http.MethodGet, nil, &stacks)
|
||||
err = n.doJSONWithToken(fmt.Sprintf("stacks?filters=%s", filterJSONString), http.MethodGet, http.Header{}, nil, &stacks)
|
||||
return
|
||||
}
|
||||
|
||||
@ -254,7 +252,7 @@ func (n *portainerClientImp) CreateSwarmStack(stackName string, environmentVaria
|
||||
StackFileContent: stackFileContent,
|
||||
}
|
||||
|
||||
err = n.doJSON(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 1, "string", endpointID), http.MethodPost, &reqBody, &stack)
|
||||
err = n.doJSONWithToken(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 1, "string", endpointID), http.MethodPost, http.Header{}, &reqBody, &stack)
|
||||
return
|
||||
}
|
||||
|
||||
@ -265,7 +263,7 @@ func (n *portainerClientImp) CreateComposeStack(stackName string, environmentVar
|
||||
StackFileContent: stackFileContent,
|
||||
}
|
||||
|
||||
err = n.doJSON(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 2, "string", endpointID), http.MethodPost, &reqBody, &stack)
|
||||
err = n.doJSONWithToken(fmt.Sprintf("stacks?type=%v&method=%s&endpointId=%v", 2, "string", endpointID), http.MethodPost, http.Header{}, &reqBody, &stack)
|
||||
return
|
||||
}
|
||||
|
||||
@ -276,19 +274,19 @@ func (n *portainerClientImp) UpdateStack(stack portainer.Stack, environmentVaria
|
||||
Prune: prune,
|
||||
}
|
||||
|
||||
err = n.doJSON(fmt.Sprintf("stacks/%v?endpointId=%v", stack.ID, endpointID), http.MethodPut, &reqBody, nil)
|
||||
err = n.doJSONWithToken(fmt.Sprintf("stacks/%v?endpointId=%v", stack.ID, endpointID), http.MethodPut, http.Header{}, &reqBody, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (n *portainerClientImp) DeleteStack(stackID portainer.StackID) (err error) {
|
||||
err = n.doJSON(fmt.Sprintf("stacks/%d", stackID), http.MethodDelete, nil, nil)
|
||||
err = n.doJSONWithToken(fmt.Sprintf("stacks/%d", stackID), http.MethodDelete, http.Header{}, nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
func (n *portainerClientImp) GetStackFileContent(stackID portainer.StackID) (content string, err error) {
|
||||
var respBody StackFileInspectResponse
|
||||
|
||||
err = n.doJSON(fmt.Sprintf("stacks/%v/file", stackID), http.MethodGet, nil, &respBody)
|
||||
err = n.doJSONWithToken(fmt.Sprintf("stacks/%v/file", stackID), http.MethodGet, http.Header{}, nil, &respBody)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -299,12 +297,12 @@ func (n *portainerClientImp) GetStackFileContent(stackID portainer.StackID) (con
|
||||
}
|
||||
|
||||
func (n *portainerClientImp) GetEndpointDockerInfo(endpointID portainer.EndpointID) (info map[string]interface{}, err error) {
|
||||
err = n.doJSON(fmt.Sprintf("endpoints/%v/docker/info", endpointID), http.MethodGet, nil, &info)
|
||||
err = n.doJSONWithToken(fmt.Sprintf("endpoints/%v/docker/info", endpointID), http.MethodGet, http.Header{}, nil, &info)
|
||||
return
|
||||
}
|
||||
|
||||
func (n *portainerClientImp) GetStatus() (status portainer.Status, err error) {
|
||||
err = n.doJSON("status", http.MethodGet, nil, &status)
|
||||
err = n.doJSONWithToken("status", http.MethodGet, http.Header{}, nil, &status)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -79,7 +79,6 @@ func Test_portainerClientImp_do(t *testing.T) {
|
||||
password string
|
||||
token string
|
||||
userAgent string
|
||||
doNotUseToken bool
|
||||
beforeRequestHooks []func(req *http.Request) (err error)
|
||||
afterResponseHooks []func(resp *http.Response) (err error)
|
||||
server *httptest.Server
|
||||
@ -122,7 +121,6 @@ func Test_portainerClientImp_do(t *testing.T) {
|
||||
{
|
||||
name: "extra headers are added",
|
||||
fields: fields{
|
||||
token: "somerandomtoken",
|
||||
server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equal(t, req.Header.Get("Some-Header"), "value")
|
||||
})),
|
||||
@ -135,24 +133,6 @@ func Test_portainerClientImp_do(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Authorization header is added when doNotUseToken is false",
|
||||
fields: fields{
|
||||
token: "token",
|
||||
server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.NotEmpty(t, req.Header.Get("Authorization"))
|
||||
})),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Authorization header is not added when doNotUseToken is true",
|
||||
fields: fields{
|
||||
doNotUseToken: true,
|
||||
server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
assert.Equal(t, req.Header.Get("Authorization"), "")
|
||||
})),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns error on http error",
|
||||
fields: fields{
|
||||
@ -167,7 +147,6 @@ func Test_portainerClientImp_do(t *testing.T) {
|
||||
{
|
||||
name: "returns error on response error",
|
||||
fields: fields{
|
||||
token: "token",
|
||||
server: httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
})),
|
||||
@ -189,7 +168,6 @@ func Test_portainerClientImp_do(t *testing.T) {
|
||||
password: tt.fields.password,
|
||||
token: tt.fields.token,
|
||||
userAgent: tt.fields.userAgent,
|
||||
doNotUseToken: tt.fields.doNotUseToken,
|
||||
beforeRequestHooks: tt.fields.beforeRequestHooks,
|
||||
afterResponseHooks: tt.fields.afterResponseHooks,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user