From 56548c511163b4dd22d9a96a2f5ae647f1627a7b Mon Sep 17 00:00:00 2001 From: Jeroen Wijenbergh Date: Mon, 7 Mar 2022 15:43:07 +0100 Subject: Refactor: Simplify API by using a state as context --- src/api.go | 30 +++++++++++++++++--- src/oauth.go | 93 ++++++++++++++++++++++++++++++++++++++++++------------------ src/state.go | 24 ++++++++++++++++ 3 files changed, 115 insertions(+), 32 deletions(-) create mode 100644 src/state.go (limited to 'src') diff --git a/src/api.go b/src/api.go index 9e082cc..8bd2847 100644 --- a/src/api.go +++ b/src/api.go @@ -2,6 +2,7 @@ package eduvpn import ( "encoding/json" + "errors" "io/ioutil" "net/http" ) @@ -13,7 +14,7 @@ type endpointList struct { } // Struct that defines the json format for /.well-known/vpn-user-portal" -type PortalEndpoints struct { +type EduVPNEndpoints struct { API struct { V2 endpointList `json:"http://eduvpn.org/api#2"` V3 endpointList `json:"http://eduvpn.org/api#3"` @@ -21,8 +22,8 @@ type PortalEndpoints struct { V string `json:"v"` } -func APIGetEndpoints(baseURL string) (*PortalEndpoints, error) { - url := baseURL + "/.well-known/vpn-user-portal" +func APIGetEndpoints(vpnState *EduVPNState) (*EduVPNEndpoints, error) { + url := vpnState.Server + "/.well-known/vpn-user-portal" resp, reqErr := http.Get(url) if reqErr != nil { return nil, reqErr @@ -40,7 +41,7 @@ func APIGetEndpoints(baseURL string) (*PortalEndpoints, error) { return nil, readErr } - structure := &PortalEndpoints{} + structure := &EduVPNEndpoints{} jsonErr := json.Unmarshal(body, &structure) if jsonErr != nil { @@ -49,3 +50,24 @@ func APIGetEndpoints(baseURL string) (*PortalEndpoints, error) { return structure, nil } + +func APIAuthenticatedInfo(vpnState *EduVPNState) (string, error) { + url := vpnState.Endpoints.API.V3.Endpoint + "/info" + resp, reqErr := vpnState.OAuth.client.Get(url) + if reqErr != nil { + return "", reqErr + } + // Close the response body at the end + defer resp.Body.Close() + + // Check if http response code is ok + if resp.StatusCode != http.StatusOK { + return "", errors.New("HTTP code not ok") + } + // Read the body + body, readErr := ioutil.ReadAll(resp.Body) + if readErr != nil { + return "", readErr + } + return string(body), nil +} diff --git a/src/oauth.go b/src/oauth.go index 9eb2272..87284a2 100644 --- a/src/oauth.go +++ b/src/oauth.go @@ -66,33 +66,8 @@ type EduVPNOauth struct { verifier string } -// Initializes the OAuth eduvpn class. It returns a tuple of the class and error. -// If the error is non-nil, the class will be nil. -func InitializeOAuth(config *oauth2.Config) (*EduVPNOauth, error) { - // Generate the state - state, stateErr := genState() - if stateErr != nil { - return nil, detailedOAuthError{errGenStateError, fmt.Sprintf("oauth failed to gen random bytes for state"), stateErr} - } - - // Generate the verifier and challenge - verifier, err := genVerifier() - if err != nil { - return nil, detailedOAuthError{errGenVerifierError, fmt.Sprintf("oauth failed to verifier"), err} - } - challenge := genChallengeS256(verifier) - - // Update the auth url with the challenge and state - codeChallengeMethod := oauth2.SetAuthURLParam("code_challenge_method", "S256") - codeChallenge := oauth2.SetAuthURLParam("code_challenge", challenge) - authURL := config.AuthCodeURL(state, codeChallengeMethod, codeChallenge) - - // Return the struct with the necessary fields filled for the next call to getting the HTTP client - return &EduVPNOauth{AuthURL: authURL, Config: config, state: state, verifier: verifier}, nil -} - // Gets an authenticated HTTP client by obtaining refresh and access tokens -func (eduvpn *EduVPNOauth) GetHTTPTokenClient() (*http.Client, error) { +func (eduvpn *EduVPNOauth) getHTTPTokenClient() error { eduvpn.context = context.Background() mux := http.NewServeMux() eduvpn.server = &http.Server{ @@ -101,9 +76,9 @@ func (eduvpn *EduVPNOauth) GetHTTPTokenClient() (*http.Client, error) { } mux.HandleFunc("/callback", eduvpn.oauthCallback) if err := eduvpn.server.ListenAndServe(); err != http.ErrServerClosed { - return nil, detailedOAuthError{errCallbackServerError, fmt.Sprintf("oauth callback server error"), err} + return detailedOAuthError{errCallbackServerError, fmt.Sprintf("oauth callback server error"), err} } - return eduvpn.client, eduvpn.callbackError + return eduvpn.callbackError } // Get the access and refresh tokens @@ -169,6 +144,68 @@ func (eduvpn *EduVPNOauth) oauthCallback(w http.ResponseWriter, req *http.Reques go eduvpn.server.Shutdown(eduvpn.context) } +// Generate a config for oauth +// It uses the state to get the server and the name +func genConfig(vpnState *EduVPNState) (*oauth2.Config, error) { + config := &oauth2.Config{ + RedirectURL: "http://127.0.0.1:8000/callback", + ClientID: vpnState.Name, + Scopes: []string{"config"}, + Endpoint: oauth2.Endpoint{ + AuthURL: vpnState.Endpoints.API.V3.AuthorizationEndpoint, + TokenURL: vpnState.Endpoints.API.V3.TokenEndpoint, + }, + } + return config, nil +} + +// Initializes the OAuth for eduvpn. +// It needs a vpn state that was gotten from `Register` +// It returns the authurl for the browser and an error if present +func InitializeOAuth(vpnState *EduVPNState) (string, error) { + if vpnState == nil { + panic("invalid state") + } + + config, configErr := genConfig(vpnState) + if configErr != nil { + return "", configErr + } + + // Generate the state + state, stateErr := genState() + if stateErr != nil { + return "", detailedOAuthError{errGenStateError, fmt.Sprintf("oauth failed to gen random bytes for state"), stateErr} + } + + // Generate the verifier and challenge + verifier, err := genVerifier() + if err != nil { + return "", detailedOAuthError{errGenVerifierError, fmt.Sprintf("oauth failed to verifier"), err} + } + challenge := genChallengeS256(verifier) + + // Update the auth url with the challenge and state + codeChallengeMethod := oauth2.SetAuthURLParam("code_challenge_method", "S256") + codeChallenge := oauth2.SetAuthURLParam("code_challenge", challenge) + authURL := config.AuthCodeURL(state, codeChallengeMethod, codeChallenge) + + // Fill the struct with the necessary fields filled for the next call to getting the HTTP client + vpnState.OAuth = &EduVPNOauth{AuthURL: authURL, Config: config, state: state, verifier: verifier} + return authURL, nil +} + +func FinishOAuth(vpnState *EduVPNState) error { + if vpnState == nil { + panic("invalid state") + } + + if vpnState.OAuth == nil { + panic("invalid oauth state") + } + return vpnState.OAuth.getHTTPTokenClient() +} + // OAuthErrorCode Simplified error code for public interface. type OAuthErrorCode = VPNErrorCode type OAuthError = VPNError diff --git a/src/state.go b/src/state.go new file mode 100644 index 0000000..a03733a --- /dev/null +++ b/src/state.go @@ -0,0 +1,24 @@ +package eduvpn + +type EduVPNState struct { + // The struct used for oauth + OAuth *EduVPNOauth + + // The endpoints + Endpoints *EduVPNEndpoints + + // Info passed by the client + Name string + Server string +} + +func Register(name string, server string) *EduVPNState { + state := &EduVPNState{Name: name, Server: server} + endpoints, err := APIGetEndpoints(state) + + if err != nil { + panic(err) + } + state.Endpoints = endpoints + return state +} -- cgit v1.2.3