summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/api.go10
-rw-r--r--src/discovery.go24
-rw-r--r--src/fsm.go10
-rw-r--r--src/http.go12
-rw-r--r--src/log.go9
-rw-r--r--src/oauth.go39
-rw-r--r--src/server.go19
-rw-r--r--src/state.go16
8 files changed, 63 insertions, 76 deletions
diff --git a/src/api.go b/src/api.go
index a11c907..bb7d86b 100644
--- a/src/api.go
+++ b/src/api.go
@@ -28,7 +28,7 @@ func (server *Server) apiAuthenticated(method string, endpoint string, opts *HTT
if opts.Headers != nil {
opts.Headers.Add(headerKey, headerValue)
} else {
- opts.Headers = &http.Header{headerKey: {headerValue}}
+ opts.Headers = http.Header{headerKey: {headerValue}}
}
return HTTPMethodWithOpts(method, url, opts)
}
@@ -54,8 +54,8 @@ func (server *Server) APIInfo() error {
if bodyErr != nil {
return bodyErr
}
- structure := &ServerProfileInfo{}
- jsonErr := json.Unmarshal(body, structure)
+ structure := ServerProfileInfo{}
+ jsonErr := json.Unmarshal(body, &structure)
if jsonErr != nil {
return jsonErr
@@ -67,7 +67,7 @@ func (server *Server) APIInfo() error {
}
func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (string, string, error) {
- headers := &http.Header{
+ headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-wireguard-profile"},
}
@@ -86,7 +86,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str
}
func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, error) {
- headers := &http.Header{
+ headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
"accept": {"application/x-openvpn-profile"},
}
diff --git a/src/discovery.go b/src/discovery.go
index 4498779..c7682f0 100644
--- a/src/discovery.go
+++ b/src/discovery.go
@@ -33,8 +33,8 @@ func (e *DiscoVerifyError) Error() string {
}
type DiscoList struct {
- Organizations *string `json:"organizations"`
- Servers *string `json:"servers"`
+ Organizations string `json:"organizations"`
+ Servers string `json:"servers"`
}
// Helper function that gets a disco json
@@ -82,45 +82,37 @@ func (e *GetListError) Error() string {
// FIXME: Implement these properly based on version and time info
func (eduvpn *VPNState) DetermineOrganizationsUpdate() bool {
- return eduvpn.DiscoList == nil || eduvpn.DiscoList.Organizations == nil
+ return eduvpn.DiscoList.Organizations == ""
}
func (eduvpn *VPNState) DetermineServersUpdate() bool {
- return eduvpn.DiscoList == nil || eduvpn.DiscoList.Servers == nil
-}
-
-func (eduvpn *VPNState) EnsureDisco() {
- if eduvpn.DiscoList == nil {
- eduvpn.DiscoList = &DiscoList{}
- }
+ return eduvpn.DiscoList.Servers == ""
}
// Get the organization list
func (eduvpn *VPNState) GetOrganizationsList() (string, error) {
if !eduvpn.DetermineOrganizationsUpdate() {
- return *eduvpn.DiscoList.Organizations, nil
+ return eduvpn.DiscoList.Organizations, nil
}
file := "organization_list.json"
body, err := getDiscoFile(file)
if err != nil {
return "", &GetListError{File: file, Err: err}
}
- eduvpn.EnsureDisco()
- eduvpn.DiscoList.Organizations = &body
+ eduvpn.DiscoList.Organizations = body
return body, nil
}
// Get the server list
func (eduvpn *VPNState) GetServersList() (string, error) {
if !eduvpn.DetermineServersUpdate() {
- return *eduvpn.DiscoList.Servers, nil
+ return eduvpn.DiscoList.Servers, nil
}
file := "server_list.json"
body, err := getDiscoFile("server_list.json")
if err != nil {
return "", &GetListError{File: file, Err: err}
}
- eduvpn.EnsureDisco()
- eduvpn.DiscoList.Servers = &body
+ eduvpn.DiscoList.Servers = body
return body, nil
}
diff --git a/src/fsm.go b/src/fsm.go
index 2e778d9..c51d345 100644
--- a/src/fsm.go
+++ b/src/fsm.go
@@ -93,10 +93,6 @@ type FSM struct {
}
func (eduvpn *VPNState) HasTransition(check FSMStateID) bool {
- // No fsm
- if eduvpn.FSM == nil {
- return false
- }
for _, transition_state := range eduvpn.FSM.States[eduvpn.FSM.Current] {
if transition_state.To == check {
return true
@@ -107,10 +103,6 @@ func (eduvpn *VPNState) HasTransition(check FSMStateID) bool {
}
func (eduvpn *VPNState) InState(check FSMStateID) bool {
- // No fsm
- if eduvpn.FSM == nil {
- return false
- }
return check == eduvpn.FSM.Current
}
@@ -184,7 +176,7 @@ func (eduvpn *VPNState) GenerateGraph() string {
}
func (eduvpn *VPNState) InitializeFSM() {
- eduvpn.FSM = &FSM{
+ eduvpn.FSM = FSM{
States: FSMStates{
DEREGISTERED: {{NO_SERVER, "Client registers"}},
NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}},
diff --git a/src/http.go b/src/http.go
index b247dbb..bbc866b 100644
--- a/src/http.go
+++ b/src/http.go
@@ -58,8 +58,8 @@ func (e *HTTPRequestCreateError) Error() string {
type URLParameters map[string]string
type HTTPOptionalParams struct {
- Headers *http.Header
- URLParameters *URLParameters
+ Headers http.Header
+ URLParameters URLParameters
Body url.Values
}
@@ -97,8 +97,8 @@ func HTTPPostWithOpts(url string, opts *HTTPOptionalParams) (http.Header, []byte
}
func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) {
- if opts != nil && opts.URLParameters != nil {
- url, urlErr := HTTPConstructURL(url, *opts.URLParameters)
+ if opts != nil {
+ url, urlErr := HTTPConstructURL(url, opts.URLParameters)
if urlErr != nil {
return url, &HTTPRequestCreateError{URL: url, Err: urlErr}
@@ -110,8 +110,8 @@ func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) {
func httpOptionalHeaders(req *http.Request, opts *HTTPOptionalParams) {
// Add headers
- if opts != nil && opts.Headers != nil && req != nil {
- for k, v := range *opts.Headers {
+ if opts != nil && req != nil {
+ for k, v := range opts.Headers {
req.Header.Add(k, v[0])
}
}
diff --git a/src/log.go b/src/log.go
index 6ee81e3..7402e31 100644
--- a/src/log.go
+++ b/src/log.go
@@ -15,13 +15,16 @@ type FileLogger struct {
type LogLevel int8
const (
- LOG_INFO LogLevel = iota
+ LOG_NOTSET LogLevel = iota
+ LOG_INFO
LOG_WARNING
LOG_ERROR
)
func (e LogLevel) String() string {
switch e {
+ case LOG_NOTSET:
+ return "NOTSET"
case LOG_INFO:
return "INFO"
case LOG_WARNING:
@@ -48,12 +51,12 @@ func (eduvpn *VPNState) InitLog(level LogLevel) error {
return logOpenErr
}
log.SetOutput(logFile)
- eduvpn.LogFile = &FileLogger{Level: level, File: logFile}
+ eduvpn.LogFile = FileLogger{Level: level, File: logFile}
return nil
}
func (eduvpn *VPNState) Log(level LogLevel, str string) {
- if level >= eduvpn.LogFile.Level {
+ if level >= eduvpn.LogFile.Level && eduvpn.LogFile.Level != LOG_NOTSET {
log.Printf("[%s]: %s", level.String(), str)
}
}
diff --git a/src/oauth.go b/src/oauth.go
index 831fa0f..96ce8b2 100644
--- a/src/oauth.go
+++ b/src/oauth.go
@@ -54,9 +54,9 @@ func genVerifier() (string, error) {
}
type OAuth struct {
- Session *OAuthExchangeSession `json:"-"`
- Token *OAuthToken `json:"token"`
- TokenURL string `json:"token_url"`
+ Session OAuthExchangeSession `json:"-"`
+ Token OAuthToken `json:"token"`
+ TokenURL string `json:"token_url"`
}
// This structure gets passed to the callback for easy access to the current state
@@ -71,7 +71,7 @@ type OAuthExchangeSession struct {
// filled in when constructing the callback
Context context.Context
- Server *http.Server
+ Server http.Server
}
func GenerateTimeSeconds() int64 {
@@ -93,7 +93,7 @@ func (oauth *OAuth) getTokensWithCallback() error {
oauth.Session.Context = context.Background()
mux := http.NewServeMux()
addr := "127.0.0.1:8000"
- oauth.Session.Server = &http.Server{
+ oauth.Session.Server = http.Server{
Addr: addr,
Handler: mux,
}
@@ -119,7 +119,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
"grant_type": {"authorization_code"},
"redirect_uri": {"http://127.0.0.1:8000/callback"},
}
- headers := &http.Header{
+ headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
}
opts := &HTTPOptionalParams{Headers: headers, Body: data}
@@ -129,9 +129,9 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
return bodyErr
}
- tokenStructure := &OAuthToken{}
+ tokenStructure := OAuthToken{}
- jsonErr := json.Unmarshal(body, tokenStructure)
+ jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
return &HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}
@@ -139,7 +139,6 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires
oauth.Token = tokenStructure
-
return nil
}
@@ -158,7 +157,7 @@ func (oauth *OAuth) getTokensWithRefresh() error {
"refresh_token": {oauth.Token.Refresh},
"grant_type": {"refresh_token"},
}
- headers := &http.Header{
+ headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
}
opts := &HTTPOptionalParams{Headers: headers, Body: data}
@@ -168,8 +167,8 @@ func (oauth *OAuth) getTokensWithRefresh() error {
return bodyErr
}
- tokenStructure := &OAuthToken{}
- jsonErr := json.Unmarshal(body, tokenStructure)
+ tokenStructure := OAuthToken{}
+ jsonErr := json.Unmarshal(body, &tokenStructure)
if jsonErr != nil {
return &HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr}
@@ -177,7 +176,6 @@ func (oauth *OAuth) getTokensWithRefresh() error {
tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires
oauth.Token = tokenStructure
-
return nil
}
@@ -260,8 +258,8 @@ func (eduvpn *VPNState) InitializeOAuth() error {
}
// Fill the struct with the necessary fields filled for the next call to getting the HTTP client
- oauthSession := &OAuthExchangeSession{ClientID: eduvpn.Name, State: state, Verifier: verifier}
- eduvpn.Server.OAuth = &OAuth{TokenURL: eduvpn.Server.Endpoints.API.V3.Token, Session: oauthSession}
+ oauthSession := OAuthExchangeSession{ClientID: eduvpn.Name, State: state, Verifier: verifier}
+ eduvpn.Server.OAuth = OAuth{TokenURL: eduvpn.Server.Endpoints.API.V3.Token, Session: oauthSession}
eduvpn.GoTransition(OAUTH_STARTED, authURL)
return nil
}
@@ -271,8 +269,7 @@ func (eduvpn *VPNState) FinishOAuth() error {
if !eduvpn.HasTransition(AUTHENTICATED) {
return errors.New("invalid state to finish oauth")
}
- oauth := eduvpn.Server.OAuth
- tokenErr := oauth.getTokensWithCallback()
+ tokenErr := eduvpn.Server.OAuth.getTokensWithCallback()
if tokenErr != nil {
return tokenErr
}
@@ -300,6 +297,14 @@ func (oauth *OAuth) Login() error {
}
func (oauth *OAuth) NeedsRelogin() bool {
+ // Access Token or Refresh Tokens empty, definitely needs a relogin
+ if oauth.Token.Access == "" || oauth.Token.Refresh == "" {
+ GetVPNState().Log(LOG_INFO, "OAuth: Tokens are empty")
+ return true
+ }
+
+ // We have tokens...
+
// The tokens are not expired yet
// No relogin is needed
if !oauth.isTokensExpired() {
diff --git a/src/server.go b/src/server.go
index b7d55cb..7e323f6 100644
--- a/src/server.go
+++ b/src/server.go
@@ -6,11 +6,11 @@ import (
)
type Server struct {
- BaseURL string `json:"base_url"`
- Endpoints *ServerEndpoints `json:"endpoints"`
- OAuth *OAuth `json:"oauth"`
- Profiles *ServerProfileInfo `json:"profiles"`
- ProfilesRaw string `json:"profiles_raw"`
+ BaseURL string `json:"base_url"`
+ Endpoints ServerEndpoints `json:"endpoints"`
+ OAuth OAuth `json:"oauth"`
+ Profiles ServerProfileInfo `json:"profiles"`
+ ProfilesRaw string `json:"profiles_raw"`
}
type ServerProfile struct {
@@ -56,12 +56,7 @@ func (server *Server) Initialize(url string) error {
}
func (server *Server) NeedsRelogin() bool {
- // Server has no oauth tokens
- if server.OAuth == nil || server.OAuth.Token == nil {
- return true
- }
-
- // Server has oauth tokens, check if they need a relogin
+ // Check if OAuth needs relogin
return server.OAuth.NeedsRelogin()
}
@@ -73,7 +68,7 @@ func (server *Server) GetEndpoints() error {
return bodyErr
}
- endpoints := &ServerEndpoints{}
+ endpoints := ServerEndpoints{}
jsonErr := json.Unmarshal(body, &endpoints)
if jsonErr != nil {
diff --git a/src/state.go b/src/state.go
index 6c740c1..76cd635 100644
--- a/src/state.go
+++ b/src/state.go
@@ -12,16 +12,16 @@ type VPNState struct {
StateCallbackData string `json:"-"`
// The chosen server
- Server *Server `json:"server"`
+ Server Server `json:"server"`
// The list of servers and organizations from disco
- DiscoList *DiscoList `json:"disco"`
+ DiscoList DiscoList `json:"disco"`
// The file we keep open for logging
- LogFile *FileLogger `json:"-"`
+ LogFile FileLogger `json:"-"`
// The fsm
- FSM *FSM `json:"-"`
+ FSM FSM `json:"-"`
// Whether to enable debugging
Debug bool `json:"-"`
@@ -66,21 +66,21 @@ func (state *VPNState) Deregister() error {
state.WriteConfig()
// Re-initialize the server and FSM
- state.Server = &Server{}
+ state.Server = Server{}
state.InitializeFSM()
return nil
}
func (state *VPNState) Connect(url string) (string, error) {
- if state.Server == nil || state.Server.BaseURL != url {
- state.Server = &Server{}
+ // New server chosen, ensure the server is fresh
+ if state.Server.BaseURL != url {
+ state.Server = Server{}
}
initializeErr := state.Server.Initialize(url)
if initializeErr != nil {
return "", initializeErr
}
-
// Relogin with oauth
// This moves the state to authenticated
if state.Server.NeedsRelogin() {