UNCLASSIFIED - NO CUI

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • big-bang/apps/sandbox/holocron/collector-jira-workflow
1 result
Show changes
Commits on Source (2)
...@@ -21,8 +21,8 @@ type JiraAgileServerRESTAPI interface { ...@@ -21,8 +21,8 @@ type JiraAgileServerRESTAPI interface {
type apiAttributes = map[string]string type apiAttributes = map[string]string
type apiImpl struct { type apiImpl struct {
onRequest func() client utils.HTTPClient
apiUrl string apiUrl string
} }
func copyAttrs(attrs apiAttributes) apiAttributes { func copyAttrs(attrs apiAttributes) apiAttributes {
...@@ -36,9 +36,7 @@ func copyAttrs(attrs apiAttributes) apiAttributes { ...@@ -36,9 +36,7 @@ func copyAttrs(attrs apiAttributes) apiAttributes {
//nolint:unparam // In case we might need the header from response. //nolint:unparam // In case we might need the header from response.
func (a *apiImpl) getItems(path string, attrs apiAttributes, func (a *apiImpl) getItems(path string, attrs apiAttributes,
dstPtr interface{}) (http.Header, error) { dstPtr interface{}) error {
a.onRequest()
url := a.apiUrl + path url := a.apiUrl + path
req, err := http.NewRequest(http.MethodGet, url, nil) req, err := http.NewRequest(http.MethodGet, url, nil)
...@@ -46,7 +44,7 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes, ...@@ -46,7 +44,7 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes,
if err != nil { if err != nil {
utils.Logger.Errorf("Request Creation Error : %s", err.Error()) utils.Logger.Errorf("Request Creation Error : %s", err.Error())
return nil, fmt.Errorf("Request Creation Error : %w", err) return fmt.Errorf("Request Creation Error : %w", err)
} }
// add the headers // add the headers
...@@ -59,7 +57,23 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes, ...@@ -59,7 +57,23 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes,
} }
req.URL.RawQuery = q.Encode() req.URL.RawQuery = q.Encode()
return utils.FetchJSON(req, dstPtr) resp, err := a.client.Request(req)
if err != nil {
utils.Logger.Errorf("Request Error : %s", err.Error())
return fmt.Errorf("Request Error : %w", err)
}
err = a.client.ParseBody(resp, dstPtr)
if err != nil {
utils.Logger.Errorf("Error Parsing Body : %s", err.Error())
return fmt.Errorf("Request Error : %w", err)
}
return nil
} }
func (a *apiImpl) GetAllBoards(attrs apiAttributes) []*Board { func (a *apiImpl) GetAllBoards(attrs apiAttributes) []*Board {
...@@ -76,7 +90,7 @@ func (a *apiImpl) GetAllBoards(attrs apiAttributes) []*Board { ...@@ -76,7 +90,7 @@ func (a *apiImpl) GetAllBoards(attrs apiAttributes) []*Board {
for startAt, hasNext := 0, true; hasNext; startAt += perPage { for startAt, hasNext := 0, true; hasNext; startAt += perPage {
var pageBoards pageResults var pageBoards pageResults
attrsCpy["startAt"] = strconv.Itoa(startAt) attrsCpy["startAt"] = strconv.Itoa(startAt)
if _, err := a.getItems("/board", attrsCpy, &pageBoards); err != nil { if err := a.getItems("/board", attrsCpy, &pageBoards); err != nil {
return boards return boards
} }
boards = append(boards, pageBoards.Values...) boards = append(boards, pageBoards.Values...)
...@@ -101,7 +115,7 @@ func (a *apiImpl) GetAllIssuesForBoard(board *Board, attrs apiAttributes) []*Iss ...@@ -101,7 +115,7 @@ func (a *apiImpl) GetAllIssuesForBoard(board *Board, attrs apiAttributes) []*Iss
var pageIssues pageResults var pageIssues pageResults
attrsCpy["startAt"] = strconv.Itoa(startAt) attrsCpy["startAt"] = strconv.Itoa(startAt)
path := fmt.Sprintf("/board/%d/issue", board.ID) path := fmt.Sprintf("/board/%d/issue", board.ID)
if _, err := a.getItems(path, attrsCpy, &pageIssues); err != nil { if err := a.getItems(path, attrsCpy, &pageIssues); err != nil {
return issues return issues
} }
issues = append(issues, pageIssues.Issues...) issues = append(issues, pageIssues.Issues...)
...@@ -111,9 +125,9 @@ func (a *apiImpl) GetAllIssuesForBoard(board *Board, attrs apiAttributes) []*Iss ...@@ -111,9 +125,9 @@ func (a *apiImpl) GetAllIssuesForBoard(board *Board, attrs apiAttributes) []*Iss
return issues return issues
} }
func NewJiraAgileServerRESTAPI(onRequest func()) JiraAgileServerRESTAPI { func NewJiraAgileServerRESTAPI(client utils.HTTPClient) JiraAgileServerRESTAPI {
return &apiImpl{ return &apiImpl{
apiUrl: fmt.Sprintf("%s/rest/agile/1.0", config.JIRA_URL), apiUrl: fmt.Sprintf("%s/rest/agile/1.0", config.JIRA_URL),
onRequest: onRequest, client: client,
} }
} }
...@@ -19,14 +19,12 @@ type JiraServerPlatformRESTAPI interface { ...@@ -19,14 +19,12 @@ type JiraServerPlatformRESTAPI interface {
type apiAttributes = map[string]string type apiAttributes = map[string]string
type apiImpl struct { type apiImpl struct {
onRequest func() client utils.HTTPClient
apiUrl string apiUrl string
} }
func (a *apiImpl) getItems(path string, attrs apiAttributes, func (a *apiImpl) getItems(path string, attrs apiAttributes,
dstPtr interface{}) (http.Header, error) { dstPtr interface{}) error {
a.onRequest()
url := a.apiUrl + path url := a.apiUrl + path
req, err := http.NewRequest(http.MethodGet, url, nil) req, err := http.NewRequest(http.MethodGet, url, nil)
...@@ -34,7 +32,7 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes, ...@@ -34,7 +32,7 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes,
if err != nil { if err != nil {
utils.Logger.Errorf("Request Creation Error : %s", err.Error()) utils.Logger.Errorf("Request Creation Error : %s", err.Error())
return nil, fmt.Errorf("Request Creation Error : %w", err) return fmt.Errorf("Request Creation Error : %w", err)
} }
// add the headers // add the headers
...@@ -47,13 +45,29 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes, ...@@ -47,13 +45,29 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes,
} }
req.URL.RawQuery = q.Encode() req.URL.RawQuery = q.Encode()
return utils.FetchJSON(req, dstPtr) resp, err := a.client.Request(req)
if err != nil {
utils.Logger.Errorf("Request Error : %s", err.Error())
return fmt.Errorf("Request Error : %w", err)
}
err = a.client.ParseBody(resp, dstPtr)
if err != nil {
utils.Logger.Errorf("Error Parsing Body : %s", err.Error())
return fmt.Errorf("Request Error : %w", err)
}
return nil
} }
func (a *apiImpl) GetAllStatuses() []*Status { func (a *apiImpl) GetAllStatuses() []*Status {
var statuses []*Status var statuses []*Status
if _, err := a.getItems("/status", if err := a.getItems("/status",
// empty map because this rest api path takes no // empty map because this rest api path takes no
// parameters // parameters
map[string]string{}, &statuses); err != nil { map[string]string{}, &statuses); err != nil {
...@@ -63,9 +77,9 @@ func (a *apiImpl) GetAllStatuses() []*Status { ...@@ -63,9 +77,9 @@ func (a *apiImpl) GetAllStatuses() []*Status {
return statuses return statuses
} }
func NewJiraServerPlatformRESTAPI(onRequest func()) JiraServerPlatformRESTAPI { func NewJiraServerPlatformRESTAPI(client utils.HTTPClient) JiraServerPlatformRESTAPI {
return &apiImpl{ return &apiImpl{
apiUrl: fmt.Sprintf("%s/rest/api/2", config.JIRA_URL), apiUrl: fmt.Sprintf("%s/rest/api/2", config.JIRA_URL),
onRequest: onRequest, client: client,
} }
} }
...@@ -18,8 +18,8 @@ type JiraServiceDeskRESTAPI interface { ...@@ -18,8 +18,8 @@ type JiraServiceDeskRESTAPI interface {
type apiAttributes = map[string]string type apiAttributes = map[string]string
type apiImpl struct { type apiImpl struct {
onRequest func() client utils.HTTPClient
apiUrl string apiUrl string
} }
func copyAttrs(attrs apiAttributes) apiAttributes { func copyAttrs(attrs apiAttributes) apiAttributes {
...@@ -33,9 +33,7 @@ func copyAttrs(attrs apiAttributes) apiAttributes { ...@@ -33,9 +33,7 @@ func copyAttrs(attrs apiAttributes) apiAttributes {
//nolint:unparam // In case we might need the header from response. //nolint:unparam // In case we might need the header from response.
func (a *apiImpl) getItems(path string, attrs apiAttributes, func (a *apiImpl) getItems(path string, attrs apiAttributes,
dstPtr interface{}) (http.Header, error) { dstPtr interface{}) error {
a.onRequest()
url := a.apiUrl + path url := a.apiUrl + path
req, err := http.NewRequest(http.MethodGet, url, nil) req, err := http.NewRequest(http.MethodGet, url, nil)
...@@ -43,7 +41,7 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes, ...@@ -43,7 +41,7 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes,
if err != nil { if err != nil {
utils.Logger.Errorf("Request Creation Error : %s", err.Error()) utils.Logger.Errorf("Request Creation Error : %s", err.Error())
return nil, fmt.Errorf("Request Creation Error :%w", err) return fmt.Errorf("Request Creation Error :%w", err)
} }
// add the headers // add the headers
...@@ -56,7 +54,23 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes, ...@@ -56,7 +54,23 @@ func (a *apiImpl) getItems(path string, attrs apiAttributes,
} }
req.URL.RawQuery = q.Encode() req.URL.RawQuery = q.Encode()
return utils.FetchJSON(req, dstPtr) resp, err := a.client.Request(req)
if err != nil {
utils.Logger.Errorf("Request Error : %s", err.Error())
return fmt.Errorf("Request Error : %w", err)
}
err = a.client.ParseBody(resp, dstPtr)
if err != nil {
utils.Logger.Errorf("Error Parsing Body : %s", err.Error())
return fmt.Errorf("Request Error : %w", err)
}
return nil
} }
func (a *apiImpl) GetAllServiceDesks(attrs apiAttributes) []*ServiceDesk { func (a *apiImpl) GetAllServiceDesks(attrs apiAttributes) []*ServiceDesk {
...@@ -75,7 +89,7 @@ func (a *apiImpl) GetAllServiceDesks(attrs apiAttributes) []*ServiceDesk { ...@@ -75,7 +89,7 @@ func (a *apiImpl) GetAllServiceDesks(attrs apiAttributes) []*ServiceDesk {
for startAt, hasNext := 0, true; hasNext; startAt += perPage { for startAt, hasNext := 0, true; hasNext; startAt += perPage {
var pageServiceDesks pageResults var pageServiceDesks pageResults
attrsCpy["start"] = strconv.Itoa(startAt) attrsCpy["start"] = strconv.Itoa(startAt)
if _, err := a.getItems("/servicedesk", attrsCpy, &pageServiceDesks); err != nil { if err := a.getItems("/servicedesk", attrsCpy, &pageServiceDesks); err != nil {
return serviceDesks return serviceDesks
} }
serviceDesks = append(serviceDesks, pageServiceDesks.Values...) serviceDesks = append(serviceDesks, pageServiceDesks.Values...)
...@@ -106,7 +120,7 @@ func (a *apiImpl) GetAllRequestsForServiceDesk(serviceDesk *ServiceDesk, ...@@ -106,7 +120,7 @@ func (a *apiImpl) GetAllRequestsForServiceDesk(serviceDesk *ServiceDesk,
var pageRequests pageResults var pageRequests pageResults
attrsCpy["start"] = strconv.Itoa(startAt) attrsCpy["start"] = strconv.Itoa(startAt)
path := "/request" path := "/request"
if _, err := a.getItems(path, attrsCpy, &pageRequests); err != nil { if err := a.getItems(path, attrsCpy, &pageRequests); err != nil {
return requests return requests
} }
requests = append(requests, pageRequests.Requests...) requests = append(requests, pageRequests.Requests...)
...@@ -132,7 +146,7 @@ func (a *apiImpl) GetExtraStatusForRequest(request *Request, attrs apiAttributes ...@@ -132,7 +146,7 @@ func (a *apiImpl) GetExtraStatusForRequest(request *Request, attrs apiAttributes
var pageStatuses pageResults var pageStatuses pageResults
attrsCpy["start"] = strconv.Itoa(startAt) attrsCpy["start"] = strconv.Itoa(startAt)
path := fmt.Sprintf("/request/%s/status", request.Key) path := fmt.Sprintf("/request/%s/status", request.Key)
if _, err := a.getItems(path, attrsCpy, &pageStatuses); err != nil { if err := a.getItems(path, attrsCpy, &pageStatuses); err != nil {
return history return history
} }
history = append(history, pageStatuses.Values...) history = append(history, pageStatuses.Values...)
...@@ -142,9 +156,9 @@ func (a *apiImpl) GetExtraStatusForRequest(request *Request, attrs apiAttributes ...@@ -142,9 +156,9 @@ func (a *apiImpl) GetExtraStatusForRequest(request *Request, attrs apiAttributes
return history return history
} }
func NewJiraServiceDeskRESTAPI(onRequest func()) JiraServiceDeskRESTAPI { func NewJiraServiceDeskRESTAPI(client utils.HTTPClient) JiraServiceDeskRESTAPI {
return &apiImpl{ return &apiImpl{
apiUrl: fmt.Sprintf("%s/rest/servicedeskapi", config.JIRA_URL), apiUrl: fmt.Sprintf("%s/rest/servicedeskapi", config.JIRA_URL),
onRequest: onRequest, client: client,
} }
} }
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"holocron/collector-jira-workflow/pkg/httpClient/api/restagile1" "holocron/collector-jira-workflow/pkg/httpClient/api/restagile1"
"holocron/collector-jira-workflow/pkg/httpClient/api/restapi2" "holocron/collector-jira-workflow/pkg/httpClient/api/restapi2"
"holocron/collector-jira-workflow/pkg/httpClient/api/servicedeskapi" "holocron/collector-jira-workflow/pkg/httpClient/api/servicedeskapi"
"holocron/collector-jira-workflow/pkg/utils"
) )
type JiraClient struct { type JiraClient struct {
...@@ -18,16 +19,14 @@ type JiraClient struct { ...@@ -18,16 +19,14 @@ type JiraClient struct {
func NewJiraClient() *JiraClient { func NewJiraClient() *JiraClient {
jiraClient := JiraClient{} jiraClient := JiraClient{}
throttler := requestThrottler{ httpClient := utils.HTTPClient{
maxRequestsPerMinute: config.MAX_REQUESTS_PER_MINUTE, MaxRequestsPerMinute: config.MAX_REQUESTS_PER_MINUTE,
mutex: sync.Mutex{}, Mutex: &sync.Mutex{},
sleep: time.Sleep, Sleep: time.Sleep,
} }
jiraClient.RestAgile1 = restagile1.NewJiraAgileServerRESTAPI(throttler.makeRequest) jiraClient.RestAgile1 = restagile1.NewJiraAgileServerRESTAPI(httpClient)
jiraClient.RestAPI2 = restapi2.NewJiraServerPlatformRESTAPI(throttler.makeRequest) jiraClient.RestAPI2 = restapi2.NewJiraServerPlatformRESTAPI(httpClient)
jiraClient.ServiceDeskAPI = servicedeskapi.NewJiraServiceDeskRESTAPI( jiraClient.ServiceDeskAPI = servicedeskapi.NewJiraServiceDeskRESTAPI(httpClient)
throttler.makeRequest,
)
return &jiraClient return &jiraClient
} }
package httpClient
import (
"sync"
"time"
)
type requestThrottler struct {
sleep func(time.Duration)
maxRequestsPerMinute int
mutex sync.Mutex
}
func (rt *requestThrottler) makeRequest() {
rt.mutex.Lock()
defer rt.mutex.Unlock()
rt.sleep(time.Minute / time.Duration(rt.maxRequestsPerMinute))
}
package httpClient
import (
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestMakeRequest(t *testing.T) {
const TEST_MAX_REQUEST_PER_MINUTE = 5
t.Run(`it should wait to make requests based on the
maximum per minute`, func(t *testing.T) {
var TestSleep = func(duration time.Duration) {
assert.Equal(t, duration, time.Minute/TEST_MAX_REQUEST_PER_MINUTE)
}
testThrottler := requestThrottler{
maxRequestsPerMinute: TEST_MAX_REQUEST_PER_MINUTE,
mutex: sync.Mutex{},
sleep: TestSleep,
}
testThrottler.makeRequest()
})
t.Run(`it should only allow one goroutine to access the function
at a time`, func(t *testing.T) {
madeRequests := map[int]bool{}
count := 0
wg := sync.WaitGroup{}
var TestSleep = func(duration time.Duration) {
assert.False(t, madeRequests[count])
madeRequests[count] = true
count += 1
wg.Done()
}
testThrottler := requestThrottler{
maxRequestsPerMinute: TEST_MAX_REQUEST_PER_MINUTE,
mutex: sync.Mutex{},
sleep: TestSleep,
}
wg.Add(100)
for i := 0; i < 100; i++ {
go testThrottler.makeRequest()
}
wg.Wait()
for i := 0; i < 100; i++ {
assert.True(t, madeRequests[i])
}
})
}
...@@ -3,12 +3,27 @@ package utils ...@@ -3,12 +3,27 @@ package utils
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io"
"net/http" "net/http"
"sync"
"time" "time"
) )
func Request(req *http.Request) (*http.Response, error) { type HTTPClient struct {
Sleep func(time.Duration)
Mutex *sync.Mutex
MaxRequestsPerMinute int
}
func (c *HTTPClient) throttleRequest() {
c.Mutex.Lock()
defer c.Mutex.Unlock()
c.Sleep(time.Minute / time.Duration(c.MaxRequestsPerMinute))
}
func (c *HTTPClient) Request(req *http.Request) (*http.Response, error) {
c.throttleRequest()
const timeoutSeconds = 90 const timeoutSeconds = 90
client := &http.Client{Timeout: timeoutSeconds * time.Second} client := &http.Client{Timeout: timeoutSeconds * time.Second}
resp, err := client.Do(req) resp, err := client.Do(req)
...@@ -28,33 +43,13 @@ func Request(req *http.Request) (*http.Response, error) { ...@@ -28,33 +43,13 @@ func Request(req *http.Request) (*http.Response, error) {
return resp, nil return resp, nil
} }
func ParseBody(resp *http.Response, dstPtr interface{}) error { func (c *HTTPClient) ParseBody(resp *http.Response, dstPtr interface{}) error {
defer resp.Body.Close() defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return fmt.Errorf("Error Reading Body : %w", err) return fmt.Errorf("Error Reading Body : %w", err)
} }
return json.Unmarshal(body, dstPtr) return json.Unmarshal(body, dstPtr)
} }
func FetchJSON(req *http.Request, dstPtr interface{}) (http.Header, error) {
resp, err := Request(req)
if err != nil {
Logger.Errorf("Request Error : %s", err.Error())
return nil, fmt.Errorf("Request Error : %w", err)
}
err = ParseBody(resp, dstPtr)
if err != nil {
Logger.Errorf("Erroring Parsing body : %s", err.Error())
return nil, fmt.Errorf("Request Error : %w", err)
}
return resp.Header, nil
}
...@@ -2,12 +2,10 @@ package utils ...@@ -2,12 +2,10 @@ package utils
import ( import (
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"io/ioutil"
"log"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time" "time"
...@@ -33,13 +31,12 @@ func (errBody) Close() (err error) { ...@@ -33,13 +31,12 @@ func (errBody) Close() (err error) {
return nil return nil
} }
func mockLogger() {
Logger.InfoLogger = log.New(ioutil.Discard, "TEST: ", log.Ldate|log.Ltime)
Logger.WarningLogger = log.New(ioutil.Discard, "TEST: ", log.Ldate|log.Ltime)
Logger.ErrorLogger = log.New(ioutil.Discard, "TEST: ", log.Ldate|log.Ltime)
}
func TestRequest(t *testing.T) { func TestRequest(t *testing.T) {
testClient := HTTPClient{
Sleep: time.Sleep,
MaxRequestsPerMinute: 1000,
Mutex: &sync.Mutex{},
}
t.Run("it should make an http request", func(t *testing.T) { t.Run("it should make an http request", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc( server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) { func(w http.ResponseWriter, r *http.Request) {
...@@ -53,7 +50,7 @@ func TestRequest(t *testing.T) { ...@@ -53,7 +50,7 @@ func TestRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, server.URL+"/path", nil) req, err := http.NewRequest(http.MethodGet, server.URL+"/path", nil)
assert.Nil(t, err) assert.Nil(t, err)
resp, err := Request(req) resp, err := testClient.Request(req)
assert.Nil(t, err) assert.Nil(t, err)
defer func() { defer func() {
...@@ -73,7 +70,7 @@ func TestRequest(t *testing.T) { ...@@ -73,7 +70,7 @@ func TestRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "bad-url", nil) req, err := http.NewRequest(http.MethodGet, "bad-url", nil)
assert.Nil(t, err) assert.Nil(t, err)
//nolint:bodyclose // There should be no body to close. //nolint:bodyclose // There should be no body to close.
_, err = Request(req) _, err = testClient.Request(req)
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Contains(t, err.Error(), "Response Error : ") assert.Contains(t, err.Error(), "Response Error : ")
}) })
...@@ -90,9 +87,9 @@ func TestRequest(t *testing.T) { ...@@ -90,9 +87,9 @@ func TestRequest(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, server.URL+"/path", nil) req, err := http.NewRequest(http.MethodGet, server.URL+"/path", nil)
assert.Nil(t, err) assert.Nil(t, err)
//nolint:bodyclose // The body is //nolint: bodyclose // The body is
// closed within the function in case of bad status code. // closed within the function in case of bad status code.
_, err = Request(req) _, err = testClient.Request(req)
defer server.Close() defer server.Close()
...@@ -102,6 +99,11 @@ func TestRequest(t *testing.T) { ...@@ -102,6 +99,11 @@ func TestRequest(t *testing.T) {
} }
func TestParseBody(t *testing.T) { func TestParseBody(t *testing.T) {
testClient := HTTPClient{
Sleep: time.Sleep,
MaxRequestsPerMinute: 1000,
Mutex: &sync.Mutex{},
}
t.Run("it should parse a response body into the proper JSON", func(t *testing.T) { t.Run("it should parse a response body into the proper JSON", func(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
buffer := bytes.Buffer{} buffer := bytes.Buffer{}
...@@ -122,7 +124,7 @@ func TestParseBody(t *testing.T) { ...@@ -122,7 +124,7 @@ func TestParseBody(t *testing.T) {
defer response.Body.Close() defer response.Body.Close()
var result TestJSONStruct var result TestJSONStruct
err := ParseBody(response, &result) err := testClient.ParseBody(response, &result)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "test", result.Text) assert.Equal(t, "test", result.Text)
...@@ -146,7 +148,7 @@ func TestParseBody(t *testing.T) { ...@@ -146,7 +148,7 @@ func TestParseBody(t *testing.T) {
defer response.Body.Close() defer response.Body.Close()
var result TestJSONStruct var result TestJSONStruct
err := ParseBody(response, &result) err := testClient.ParseBody(response, &result)
assert.NotNil(t, err) assert.NotNil(t, err)
}) })
...@@ -155,92 +157,8 @@ func TestParseBody(t *testing.T) { ...@@ -155,92 +157,8 @@ func TestParseBody(t *testing.T) {
response.Body = &errBody{} response.Body = &errBody{}
var result TestJSONStruct var result TestJSONStruct
err := ParseBody(response, &result) err := testClient.ParseBody(response, &result)
assert.NotNil(t, err)
})
}
func TestFetchJSON(t *testing.T) {
mockLogger()
t.Run("it errors if the request fails", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/path":
w.WriteHeader(http.StatusInternalServerError)
default:
t.Fatalf("Unexpected path %s", r.URL.Path)
}
}))
buffer := bytes.Buffer{}
Logger.ErrorLogger.SetOutput(&buffer)
req, err := http.NewRequest(http.MethodGet, server.URL+"/path", nil)
assert.Nil(t, err)
var result TestJSONStruct
_, err = FetchJSON(req, &result)
defer server.Close()
assert.NotNil(t, err)
assert.Contains(t, buffer.String(), "failed with status code : 500")
})
t.Run("it errors if the body can't be read", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/path":
_, err := w.Write([]byte("test"))
assert.Nil(t, err)
default:
t.Fatalf("Unexpected path %s", r.URL.Path)
}
}))
buffer := bytes.Buffer{}
Logger.ErrorLogger.SetOutput(&buffer)
req, err := http.NewRequest(http.MethodGet, server.URL+"/path", nil)
assert.Nil(t, err)
var result TestJSONStruct
_, err = FetchJSON(req, &result)
defer server.Close()
assert.NotNil(t, err) assert.NotNil(t, err)
assert.Contains(t, buffer.String(), "Erroring Parsing body")
})
t.Run(`it should return the headers with no error
and have data in dstPtr if it succeeds`, func(t *testing.T) {
type test struct {
Value string `json:"value"`
}
server := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/path":
bytes, err := json.Marshal(&test{Value: "test"})
assert.Nil(t, err)
w.Header().Add("key", "value")
_, err = w.Write(bytes)
assert.Nil(t, err)
default:
t.Fatalf("Unexpected path %s", r.URL.Path)
}
}))
req, err := http.NewRequest(http.MethodGet, server.URL+"/path", nil)
assert.Nil(t, err)
var result test
header, err := FetchJSON(req, &result)
defer server.Close()
assert.Equal(t, "value", header.Get("key"))
assert.Nil(t, err)
assert.Equal(t, result.Value, "test")
}) })
} }