summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2022-08-09 14:18:22 +0200
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2022-08-09 14:18:22 +0200
commit9c7d9958132bcea0aa5ff6ab4aaec67c73087408 (patch)
treeb3a85b1b340de7b8169542c32972a7f3dd2314aa /internal/server
parent0c6233ab691973859b6d636e6d9fdddd2a6acd5e (diff)
Refactor: Cleanup time calculations and usage
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/api.go19
-rw-r--r--internal/server/common.go24
2 files changed, 22 insertions, 21 deletions
diff --git a/internal/server/api.go b/internal/server/api.go
index a3d8e31..9f4a9fb 100644
--- a/internal/server/api.go
+++ b/internal/server/api.go
@@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
+ "time"
httpw "github.com/jwijenbergh/eduvpn-common/internal/http"
"github.com/jwijenbergh/eduvpn-common/internal/types"
@@ -76,7 +77,7 @@ func apiAuthorizedRetry(server Server, method string, endpoint string, opts *htt
// Only retry authorized if we get a HTTP 401
if errors.As(bodyErr, &error) && error.Status == 401 {
// Tell the method that the token is expired
- server.GetOAuth().Token.ExpiredTimestamp = util.GenerateTimeSeconds()
+ server.GetOAuth().Token.ExpiredTimestamp = util.GetCurrentTime()
retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts)
if retryErr != nil {
return nil, nil, &types.WrappedErrorMessage{Message: errorMessage, Err: retryErr}
@@ -115,7 +116,7 @@ func APIInfo(server Server) error {
return nil
}
-func APIConnectWireguard(server Server, profile_id string, pubkey string, supportsOpenVPN bool) (string, string, int64, error) {
+func APIConnectWireguard(server Server, profile_id string, pubkey string, supportsOpenVPN bool) (string, string, time.Time, error) {
errorMessage := "failed obtaining a WireGuard configuration"
headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
@@ -132,14 +133,14 @@ func APIConnectWireguard(server Server, profile_id string, pubkey string, suppor
}
header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
- return "", "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr}
+ return "", "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr}
}
expires := header.Get("expires")
pTime, pTimeErr := http.ParseTime(expires)
if pTimeErr != nil {
- return "", "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr}
+ return "", "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr}
}
contentType := header.Get("content-type")
@@ -148,10 +149,10 @@ func APIConnectWireguard(server Server, profile_id string, pubkey string, suppor
if contentType == "application/x-wireguard-profile" {
content = "wireguard"
}
- return string(connectBody), content, pTime.Unix(), nil
+ return string(connectBody), content, pTime, nil
}
-func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error) {
+func APIConnectOpenVPN(server Server, profile_id string) (string, time.Time, error) {
errorMessage := "failed obtaining an OpenVPN configuration"
headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
@@ -164,15 +165,15 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error)
header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm})
if connectErr != nil {
- return "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr}
+ return "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: connectErr}
}
expires := header.Get("expires")
pTime, pTimeErr := http.ParseTime(expires)
if pTimeErr != nil {
- return "", 0, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr}
+ return "", time.Time{}, &types.WrappedErrorMessage{Message: errorMessage, Err: pTimeErr}
}
- return string(connectBody), pTime.Unix(), nil
+ return string(connectBody), pTime, nil
}
// This needs no further return value as it's best effort
diff --git a/internal/server/common.go b/internal/server/common.go
index 5340b39..be6ed46 100644
--- a/internal/server/common.go
+++ b/internal/server/common.go
@@ -3,6 +3,7 @@ package server
import (
"encoding/json"
"fmt"
+ "time"
"github.com/jwijenbergh/eduvpn-common/internal/fsm"
"github.com/jwijenbergh/eduvpn-common/internal/oauth"
@@ -19,8 +20,8 @@ type ServerBase struct {
Endpoints ServerEndpoints `json:"endpoints"`
Profiles ServerProfileInfo `json:"profiles"`
ProfilesRaw string `json:"profiles_raw"`
- StartTime int64 `json:"start_time"`
- EndTime int64 `json:"expire_time"`
+ StartTime time.Time `json:"start_time"`
+ EndTime time.Time `json:"expire_time"`
Type string `json:"server_type"`
FSM *fsm.FSM `json:"-"`
}
@@ -133,7 +134,7 @@ func getServerInfoScreen(base ServerBase) (ServerInfoScreen) {
serverInfoScreen.DisplayName = base.DisplayName
serverInfoScreen.SupportContact = base.SupportContact
serverInfoScreen.Profiles = base.Profiles
- serverInfoScreen.ExpireTime = base.EndTime
+ serverInfoScreen.ExpireTime = base.EndTime.Unix()
serverInfoScreen.Type = base.Type
return serverInfoScreen
@@ -287,23 +288,22 @@ func ShouldRenewButton(server Server) bool {
}
// Get current time
- current := util.GenerateTimeSeconds()
+ current := util.GetCurrentTime()
// 30 minutes have not passed
- if current <= (base.StartTime + 30*60) {
+ if !current.After(base.StartTime.Add(30 * time.Minute)) {
return false
}
// Session will not expire today
- if current <= (base.EndTime - 24*60*60) {
+ if !current.Add(24 * time.Hour).After(base.EndTime) {
return false
}
// Session duration is less than 24 hours but not 75% has passed
- duration := base.EndTime - base.StartTime
-
- // TODO: Is converting to float64 okay here?
- if duration < 24*60*60 && float64(current) <= (float64(base.StartTime)+0.75*float64(duration)) {
+ duration := base.EndTime.Sub(base.StartTime)
+ percentTime := base.StartTime.Add((duration/4)*3)
+ if duration < time.Duration(24 * time.Hour) && !current.After(percentTime) {
return false
}
@@ -391,7 +391,7 @@ func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, er
}
// Store start and end time
- base.StartTime = util.GenerateTimeSeconds()
+ base.StartTime = util.GetCurrentTime()
base.EndTime = expires
if content == "wireguard" {
@@ -416,7 +416,7 @@ func openVPNGetConfig(server Server) (string, string, error) {
configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id)
// Store start and end time
- base.StartTime = util.GenerateTimeSeconds()
+ base.StartTime = util.GetCurrentTime()
base.EndTime = expires
if configErr != nil {