diff options
| -rw-r--r-- | src/api.go | 10 | ||||
| -rw-r--r-- | src/discovery.go | 24 | ||||
| -rw-r--r-- | src/fsm.go | 10 | ||||
| -rw-r--r-- | src/http.go | 12 | ||||
| -rw-r--r-- | src/log.go | 9 | ||||
| -rw-r--r-- | src/oauth.go | 39 | ||||
| -rw-r--r-- | src/server.go | 19 | ||||
| -rw-r--r-- | src/state.go | 16 |
8 files changed, 63 insertions, 76 deletions
@@ -28,7 +28,7 @@ func (server *Server) apiAuthenticated(method string, endpoint string, opts *HTT if opts.Headers != nil { opts.Headers.Add(headerKey, headerValue) } else { - opts.Headers = &http.Header{headerKey: {headerValue}} + opts.Headers = http.Header{headerKey: {headerValue}} } return HTTPMethodWithOpts(method, url, opts) } @@ -54,8 +54,8 @@ func (server *Server) APIInfo() error { if bodyErr != nil { return bodyErr } - structure := &ServerProfileInfo{} - jsonErr := json.Unmarshal(body, structure) + structure := ServerProfileInfo{} + jsonErr := json.Unmarshal(body, &structure) if jsonErr != nil { return jsonErr @@ -67,7 +67,7 @@ func (server *Server) APIInfo() error { } func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (string, string, error) { - headers := &http.Header{ + headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-wireguard-profile"}, } @@ -86,7 +86,7 @@ func (server *Server) APIConnectWireguard(profile_id string, pubkey string) (str } func (server *Server) APIConnectOpenVPN(profile_id string) (string, string, error) { - headers := &http.Header{ + headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, "accept": {"application/x-openvpn-profile"}, } diff --git a/src/discovery.go b/src/discovery.go index 4498779..c7682f0 100644 --- a/src/discovery.go +++ b/src/discovery.go @@ -33,8 +33,8 @@ func (e *DiscoVerifyError) Error() string { } type DiscoList struct { - Organizations *string `json:"organizations"` - Servers *string `json:"servers"` + Organizations string `json:"organizations"` + Servers string `json:"servers"` } // Helper function that gets a disco json @@ -82,45 +82,37 @@ func (e *GetListError) Error() string { // FIXME: Implement these properly based on version and time info func (eduvpn *VPNState) DetermineOrganizationsUpdate() bool { - return eduvpn.DiscoList == nil || eduvpn.DiscoList.Organizations == nil + return eduvpn.DiscoList.Organizations == "" } func (eduvpn *VPNState) DetermineServersUpdate() bool { - return eduvpn.DiscoList == nil || eduvpn.DiscoList.Servers == nil -} - -func (eduvpn *VPNState) EnsureDisco() { - if eduvpn.DiscoList == nil { - eduvpn.DiscoList = &DiscoList{} - } + return eduvpn.DiscoList.Servers == "" } // Get the organization list func (eduvpn *VPNState) GetOrganizationsList() (string, error) { if !eduvpn.DetermineOrganizationsUpdate() { - return *eduvpn.DiscoList.Organizations, nil + return eduvpn.DiscoList.Organizations, nil } file := "organization_list.json" body, err := getDiscoFile(file) if err != nil { return "", &GetListError{File: file, Err: err} } - eduvpn.EnsureDisco() - eduvpn.DiscoList.Organizations = &body + eduvpn.DiscoList.Organizations = body return body, nil } // Get the server list func (eduvpn *VPNState) GetServersList() (string, error) { if !eduvpn.DetermineServersUpdate() { - return *eduvpn.DiscoList.Servers, nil + return eduvpn.DiscoList.Servers, nil } file := "server_list.json" body, err := getDiscoFile("server_list.json") if err != nil { return "", &GetListError{File: file, Err: err} } - eduvpn.EnsureDisco() - eduvpn.DiscoList.Servers = &body + eduvpn.DiscoList.Servers = body return body, nil } @@ -93,10 +93,6 @@ type FSM struct { } func (eduvpn *VPNState) HasTransition(check FSMStateID) bool { - // No fsm - if eduvpn.FSM == nil { - return false - } for _, transition_state := range eduvpn.FSM.States[eduvpn.FSM.Current] { if transition_state.To == check { return true @@ -107,10 +103,6 @@ func (eduvpn *VPNState) HasTransition(check FSMStateID) bool { } func (eduvpn *VPNState) InState(check FSMStateID) bool { - // No fsm - if eduvpn.FSM == nil { - return false - } return check == eduvpn.FSM.Current } @@ -184,7 +176,7 @@ func (eduvpn *VPNState) GenerateGraph() string { } func (eduvpn *VPNState) InitializeFSM() { - eduvpn.FSM = &FSM{ + eduvpn.FSM = FSM{ States: FSMStates{ DEREGISTERED: {{NO_SERVER, "Client registers"}}, NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}}, diff --git a/src/http.go b/src/http.go index b247dbb..bbc866b 100644 --- a/src/http.go +++ b/src/http.go @@ -58,8 +58,8 @@ func (e *HTTPRequestCreateError) Error() string { type URLParameters map[string]string type HTTPOptionalParams struct { - Headers *http.Header - URLParameters *URLParameters + Headers http.Header + URLParameters URLParameters Body url.Values } @@ -97,8 +97,8 @@ func HTTPPostWithOpts(url string, opts *HTTPOptionalParams) (http.Header, []byte } func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) { - if opts != nil && opts.URLParameters != nil { - url, urlErr := HTTPConstructURL(url, *opts.URLParameters) + if opts != nil { + url, urlErr := HTTPConstructURL(url, opts.URLParameters) if urlErr != nil { return url, &HTTPRequestCreateError{URL: url, Err: urlErr} @@ -110,8 +110,8 @@ func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) { func httpOptionalHeaders(req *http.Request, opts *HTTPOptionalParams) { // Add headers - if opts != nil && opts.Headers != nil && req != nil { - for k, v := range *opts.Headers { + if opts != nil && req != nil { + for k, v := range opts.Headers { req.Header.Add(k, v[0]) } } @@ -15,13 +15,16 @@ type FileLogger struct { type LogLevel int8 const ( - LOG_INFO LogLevel = iota + LOG_NOTSET LogLevel = iota + LOG_INFO LOG_WARNING LOG_ERROR ) func (e LogLevel) String() string { switch e { + case LOG_NOTSET: + return "NOTSET" case LOG_INFO: return "INFO" case LOG_WARNING: @@ -48,12 +51,12 @@ func (eduvpn *VPNState) InitLog(level LogLevel) error { return logOpenErr } log.SetOutput(logFile) - eduvpn.LogFile = &FileLogger{Level: level, File: logFile} + eduvpn.LogFile = FileLogger{Level: level, File: logFile} return nil } func (eduvpn *VPNState) Log(level LogLevel, str string) { - if level >= eduvpn.LogFile.Level { + if level >= eduvpn.LogFile.Level && eduvpn.LogFile.Level != LOG_NOTSET { log.Printf("[%s]: %s", level.String(), str) } } diff --git a/src/oauth.go b/src/oauth.go index 831fa0f..96ce8b2 100644 --- a/src/oauth.go +++ b/src/oauth.go @@ -54,9 +54,9 @@ func genVerifier() (string, error) { } type OAuth struct { - Session *OAuthExchangeSession `json:"-"` - Token *OAuthToken `json:"token"` - TokenURL string `json:"token_url"` + Session OAuthExchangeSession `json:"-"` + Token OAuthToken `json:"token"` + TokenURL string `json:"token_url"` } // This structure gets passed to the callback for easy access to the current state @@ -71,7 +71,7 @@ type OAuthExchangeSession struct { // filled in when constructing the callback Context context.Context - Server *http.Server + Server http.Server } func GenerateTimeSeconds() int64 { @@ -93,7 +93,7 @@ func (oauth *OAuth) getTokensWithCallback() error { oauth.Session.Context = context.Background() mux := http.NewServeMux() addr := "127.0.0.1:8000" - oauth.Session.Server = &http.Server{ + oauth.Session.Server = http.Server{ Addr: addr, Handler: mux, } @@ -119,7 +119,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { "grant_type": {"authorization_code"}, "redirect_uri": {"http://127.0.0.1:8000/callback"}, } - headers := &http.Header{ + headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } opts := &HTTPOptionalParams{Headers: headers, Body: data} @@ -129,9 +129,9 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { return bodyErr } - tokenStructure := &OAuthToken{} + tokenStructure := OAuthToken{} - jsonErr := json.Unmarshal(body, tokenStructure) + jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { return &HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} @@ -139,7 +139,6 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires oauth.Token = tokenStructure - return nil } @@ -158,7 +157,7 @@ func (oauth *OAuth) getTokensWithRefresh() error { "refresh_token": {oauth.Token.Refresh}, "grant_type": {"refresh_token"}, } - headers := &http.Header{ + headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } opts := &HTTPOptionalParams{Headers: headers, Body: data} @@ -168,8 +167,8 @@ func (oauth *OAuth) getTokensWithRefresh() error { return bodyErr } - tokenStructure := &OAuthToken{} - jsonErr := json.Unmarshal(body, tokenStructure) + tokenStructure := OAuthToken{} + jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { return &HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} @@ -177,7 +176,6 @@ func (oauth *OAuth) getTokensWithRefresh() error { tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires oauth.Token = tokenStructure - return nil } @@ -260,8 +258,8 @@ func (eduvpn *VPNState) InitializeOAuth() error { } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client - oauthSession := &OAuthExchangeSession{ClientID: eduvpn.Name, State: state, Verifier: verifier} - eduvpn.Server.OAuth = &OAuth{TokenURL: eduvpn.Server.Endpoints.API.V3.Token, Session: oauthSession} + oauthSession := OAuthExchangeSession{ClientID: eduvpn.Name, State: state, Verifier: verifier} + eduvpn.Server.OAuth = OAuth{TokenURL: eduvpn.Server.Endpoints.API.V3.Token, Session: oauthSession} eduvpn.GoTransition(OAUTH_STARTED, authURL) return nil } @@ -271,8 +269,7 @@ func (eduvpn *VPNState) FinishOAuth() error { if !eduvpn.HasTransition(AUTHENTICATED) { return errors.New("invalid state to finish oauth") } - oauth := eduvpn.Server.OAuth - tokenErr := oauth.getTokensWithCallback() + tokenErr := eduvpn.Server.OAuth.getTokensWithCallback() if tokenErr != nil { return tokenErr } @@ -300,6 +297,14 @@ func (oauth *OAuth) Login() error { } func (oauth *OAuth) NeedsRelogin() bool { + // Access Token or Refresh Tokens empty, definitely needs a relogin + if oauth.Token.Access == "" || oauth.Token.Refresh == "" { + GetVPNState().Log(LOG_INFO, "OAuth: Tokens are empty") + return true + } + + // We have tokens... + // The tokens are not expired yet // No relogin is needed if !oauth.isTokensExpired() { diff --git a/src/server.go b/src/server.go index b7d55cb..7e323f6 100644 --- a/src/server.go +++ b/src/server.go @@ -6,11 +6,11 @@ import ( ) type Server struct { - BaseURL string `json:"base_url"` - Endpoints *ServerEndpoints `json:"endpoints"` - OAuth *OAuth `json:"oauth"` - Profiles *ServerProfileInfo `json:"profiles"` - ProfilesRaw string `json:"profiles_raw"` + BaseURL string `json:"base_url"` + Endpoints ServerEndpoints `json:"endpoints"` + OAuth OAuth `json:"oauth"` + Profiles ServerProfileInfo `json:"profiles"` + ProfilesRaw string `json:"profiles_raw"` } type ServerProfile struct { @@ -56,12 +56,7 @@ func (server *Server) Initialize(url string) error { } func (server *Server) NeedsRelogin() bool { - // Server has no oauth tokens - if server.OAuth == nil || server.OAuth.Token == nil { - return true - } - - // Server has oauth tokens, check if they need a relogin + // Check if OAuth needs relogin return server.OAuth.NeedsRelogin() } @@ -73,7 +68,7 @@ func (server *Server) GetEndpoints() error { return bodyErr } - endpoints := &ServerEndpoints{} + endpoints := ServerEndpoints{} jsonErr := json.Unmarshal(body, &endpoints) if jsonErr != nil { diff --git a/src/state.go b/src/state.go index 6c740c1..76cd635 100644 --- a/src/state.go +++ b/src/state.go @@ -12,16 +12,16 @@ type VPNState struct { StateCallbackData string `json:"-"` // The chosen server - Server *Server `json:"server"` + Server Server `json:"server"` // The list of servers and organizations from disco - DiscoList *DiscoList `json:"disco"` + DiscoList DiscoList `json:"disco"` // The file we keep open for logging - LogFile *FileLogger `json:"-"` + LogFile FileLogger `json:"-"` // The fsm - FSM *FSM `json:"-"` + FSM FSM `json:"-"` // Whether to enable debugging Debug bool `json:"-"` @@ -66,21 +66,21 @@ func (state *VPNState) Deregister() error { state.WriteConfig() // Re-initialize the server and FSM - state.Server = &Server{} + state.Server = Server{} state.InitializeFSM() return nil } func (state *VPNState) Connect(url string) (string, error) { - if state.Server == nil || state.Server.BaseURL != url { - state.Server = &Server{} + // New server chosen, ensure the server is fresh + if state.Server.BaseURL != url { + state.Server = Server{} } initializeErr := state.Server.Initialize(url) if initializeErr != nil { return "", initializeErr } - // Relogin with oauth // This moves the state to authenticated if state.Server.NeedsRelogin() { |
