diff options
| -rw-r--r-- | internal/api/api.go | 18 | ||||
| -rw-r--r-- | internal/api/cache.go | 64 |
2 files changed, 73 insertions, 9 deletions
diff --git a/internal/api/api.go b/internal/api/api.go index fd17e02..d923d13 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -65,19 +65,19 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t cr := customRedirect(clientID) // Construct OAuth o := eduoauth.OAuth{ - ClientID: clientID, + ClientID: clientID, EndpointFunc: func(ctx context.Context) (*eduoauth.EndpointResponse, error) { - ep, err := getEndpoints(ctx, sd.BaseAuthWK) + ep, err := GetEndpointCache().Get(ctx, sd.BaseAuthWK) if err != nil { return nil, err } return &eduoauth.EndpointResponse{ AuthorizationURL: ep.API.V3.Authorization, - TokenURL: ep.API.V3.Token, + TokenURL: ep.API.V3.Token, }, nil }, - CustomRedirect: cr, - RedirectPath: "/callback", + CustomRedirect: cr, + RedirectPath: "/callback", TokensUpdated: func(tok eduoauth.Token) { cb.TokensUpdated(sd.ID, sd.Type, tok) }, @@ -88,9 +88,9 @@ func NewAPI(ctx context.Context, clientID string, sd ServerData, cb Callbacks, t } api := &API{ - cb: cb, - oauth: &o, - Data: sd, + cb: cb, + oauth: &o, + Data: sd, } err := api.authorize(ctx) if err != nil { @@ -141,7 +141,7 @@ func (a *API) authorize(ctx context.Context) (err error) { } func (a *API) authorized(ctx context.Context, method string, endpoint string, opts *httpw.OptionalParams) (http.Header, []byte, error) { - ep, err := getEndpoints(ctx, a.Data.BaseWK) + ep, err := GetEndpointCache().Get(ctx, a.Data.BaseWK) if err != nil { return nil, nil, err } diff --git a/internal/api/cache.go b/internal/api/cache.go new file mode 100644 index 0000000..2baefc8 --- /dev/null +++ b/internal/api/cache.go @@ -0,0 +1,64 @@ +package api + +import ( + "context" + "sync" + "time" + + "github.com/eduvpn/eduvpn-common/internal/api/endpoints" +) + +// EndpointCache is a struct that caches well-known API endpoints +type EndpointCache struct { + lastUpdate map[string]time.Time + lastEP map[string]*endpoints.Endpoints + mu sync.Mutex +} + +// Get() returns a cached or fresh endpoint cache copy +func (ec *EndpointCache) Get(ctx context.Context, wk string) (*endpoints.Endpoints, error) { + ec.mu.Lock() + defer ec.mu.Unlock() + + // get the last update time + lu := time.Time{} + if v, ok := ec.lastUpdate[wk]; ok { + lu = v + } + + // if not 10 minutes have passed, return cached copy + if !lu.IsZero() && !time.Now().After(lu.Add(10*time.Minute)) { + v, ok := ec.lastEP[wk] + if ok { + return v, nil + } + } + + // get fresh API endpoints + ep, err := getEndpoints(ctx, wk) + if err != nil { + return nil, err + } + + // update endpoints + ec.lastUpdate[wk] = time.Now() + ec.lastEP[wk] = ep + + return ep, nil +} + +var ( + epCache *EndpointCache + epCacheOnce sync.Once +) + +func GetEndpointCache() *EndpointCache { + epCacheOnce.Do(func() { + epCache = &EndpointCache{ + lastUpdate: make(map[string]time.Time), + lastEP: make(map[string]*endpoints.Endpoints), + } + }) + + return epCache +} |
