diff options
| author | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-03-18 13:58:08 +0100 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-04-05 12:26:16 +0200 |
| commit | 2d5c7dad599b3f8b70ab07382973c51d1de2193d (patch) | |
| tree | 3ca48a1104f958f896813a4d70093cdc27429133 | |
| parent | 343836597df3efd6f31a68e29ff82b6ec4979f69 (diff) | |
Refactor: Structures changed and added Token refresh function
| -rw-r--r-- | cli/main.go | 14 | ||||
| -rw-r--r-- | src/api.go | 40 | ||||
| -rw-r--r-- | src/discovery.go | 96 | ||||
| -rw-r--r-- | src/http.go | 17 | ||||
| -rw-r--r-- | src/oauth.go | 255 | ||||
| -rw-r--r-- | src/server.go | 107 | ||||
| -rw-r--r-- | src/state.go | 27 |
7 files changed, 303 insertions, 253 deletions
diff --git a/cli/main.go b/cli/main.go index 42c4b67..ccec3f4 100644 --- a/cli/main.go +++ b/cli/main.go @@ -34,17 +34,23 @@ func main() { state := eduvpn.GetVPNState() - eduvpn.Register(state, "org.eduvpn.app.linux", urlString, logState) - authURL, err := eduvpn.InitializeOAuth(state) + eduvpn.Register(state, "org.eduvpn.app.linux", logState) + state.Server = &eduvpn.Server{} + serverInitializeErr := state.Server.Initialize(urlString) + if serverInitializeErr != nil { + log.Fatal(serverInitializeErr) + } + + authURL, err := state.InitializeOAuth() if err != nil { log.Fatal(err) } openBrowser(authURL) - oauthErr := eduvpn.FinishOAuth(state) + oauthErr := state.FinishOAuth() if oauthErr != nil { log.Fatal(oauthErr) } - infoString, infoErr := eduvpn.APIAuthenticatedInfo(state) + infoString, infoErr := state.APIAuthenticatedInfo() if infoErr != nil { log.Fatal(infoErr) } @@ -1,47 +1,13 @@ package eduvpn import ( - "encoding/json" "net/http" ) -type endpointList struct { - API string `json:"api_endpoint"` - Authorization string `json:"authorization_endpoint"` - Token string `json:"token_endpoint"` -} - -// Struct that defines the json format for /.well-known/vpn-user-portal" -type EduVPNEndpoints struct { - API struct { - V2 endpointList `json:"http://eduvpn.org/api#2"` - V3 endpointList `json:"http://eduvpn.org/api#3"` - } `json:"api"` - V string `json:"v"` -} - -func APIGetEndpoints(vpnState *EduVPNState) (*EduVPNEndpoints, error) { - url := vpnState.Server + "/.well-known/vpn-user-portal" - body, bodyErr := HTTPGet(url) - - if bodyErr != nil { - return nil, bodyErr - } - - structure := &EduVPNEndpoints{} - jsonErr := json.Unmarshal(body, &structure) - - if jsonErr != nil { - return nil, jsonErr - } - - return structure, nil -} - -func APIAuthenticatedInfo(vpnState *EduVPNState) (string, error) { - url := vpnState.Endpoints.API.V3.API + "/info" +func (eduvpn *VPNState) APIAuthenticatedInfo() (string, error) { + url := eduvpn.Server.Endpoints.API.V3.API + "/info" - headers := &http.Header{"Authorization": {"Bearer " + vpnState.OAuthToken.Access}} + headers := &http.Header{"Authorization": {"Bearer " + eduvpn.Server.OAuth.Token.Access}} body, bodyErr := HTTPGetWithOptionalParams(url, &HTTPOptionalParams{Headers: headers}) if bodyErr != nil { return "", bodyErr diff --git a/src/discovery.go b/src/discovery.go new file mode 100644 index 0000000..ced7716 --- /dev/null +++ b/src/discovery.go @@ -0,0 +1,96 @@ +package eduvpn + +import ( + "fmt" +) + +type DiscoFileError struct { + URL string + Err error +} + +func (e *DiscoFileError) Error() string { + return fmt.Sprintf("failed obtaining disco file %s with error %v", e.URL, e.Err) +} + +type DiscoSigFileError struct { + URL string + Err error +} + +func (e *DiscoSigFileError) Error() string { + return fmt.Sprintf("failed obtaining disco signature file %s with error %v", e.URL, e.Err) +} + +type DiscoVerifyError struct { + File string + Sigfile string + Err error +} + +func (e *DiscoVerifyError) Error() string { + return fmt.Sprintf("failed verifying file %s with signature %s due to error %v", e.File, e.Sigfile, e.Err) +} + +// Helper function that gets a disco json +func getDiscoFile(jsonFile string) (string, error) { + // Get json data + discoURL := "https://disco.eduvpn.org/v2/" + fileURL := discoURL + jsonFile + fileBody, fileErr := HTTPGet(fileURL) + + if fileErr != nil { + return "", &DiscoFileError{fileURL, fileErr} + } + + // Get signature + sigFile := jsonFile + ".minisig" + sigURL := discoURL + sigFile + sigBody, sigFileErr := HTTPGet(sigURL) + + if sigFileErr != nil { + return "", &DiscoSigFileError{URL: sigURL, Err: sigFileErr} + } + + // Verify signature + // TODO: Handle this by keeping track of the previous sign time + // Wrappers must do this? + var previousSigTime uint64 = 0 + forcePrehash := false + verifySuccess, verifyErr := Verify(string(sigBody), fileBody, jsonFile, previousSigTime, forcePrehash) + + if !verifySuccess || verifyErr != nil { + return "", &DiscoVerifyError{File: jsonFile, Sigfile: sigFile, Err: verifyErr} + } + + return string(fileBody), nil +} + +type GetListError struct { + File string + Err error +} + +func (e *GetListError) Error() string { + return fmt.Sprintf("failed getting disco list file %s with error %v", e.File, e.Err) +} + +// Get the organization list +func GetOrganizationsList() (string, error) { + file := "organization_list.json" + body, err := getDiscoFile(file) + if err != nil { + return "", &GetListError{File: file, Err: err} + } + return body, nil +} + +// Get the server list +func GetServersList() (string, error) { + file := "server_list.json" + body, err := getDiscoFile("server_list.json") + if err != nil { + return "", &GetListError{File: file, Err: err} + } + return body, nil +} diff --git a/src/http.go b/src/http.go index 57e5939..1374eed 100644 --- a/src/http.go +++ b/src/http.go @@ -62,6 +62,23 @@ func HTTPGet(url string) ([]byte, error) { return HTTPGetWithOptionalParams(url, nil) } +func HTTPConstructURL(baseURL string, parameters map[string]string) (string, error) { + url, err := url.Parse(baseURL) + + if err != nil { + return "", err + } + + q := url.Query() + + for parameter, value := range parameters { + q.Set(parameter, value) + } + url.RawQuery = q.Encode() + return url.String(), nil +} + + func HTTPGetWithOptionalParams(url string, opts *HTTPOptionalParams) ([]byte, error) { client := &http.Client{} req, reqErr := http.NewRequest(http.MethodGet, url, nil) diff --git a/src/oauth.go b/src/oauth.go index 2da7af5..80f60d7 100644 --- a/src/oauth.go +++ b/src/oauth.go @@ -10,22 +10,6 @@ import ( "net/url" ) -type OAuthGenStateUnableError struct { - Err error -} - -func (e *OAuthGenStateUnableError) Error() string { - return fmt.Sprintf("failed generating state with error %v", e.Err) -} - -type OAuthGenVerifierUnableError struct { - Err error -} - -func (e *OAuthGenVerifierUnableError) Error() string { - return fmt.Sprintf("failed generating verifier with error %v", e.Err) -} - // Generates a random base64 string to be used for state // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-4.1.1 // "state": OPTIONAL. An opaque value used by the client to maintain @@ -68,65 +52,63 @@ func genVerifier() (string, error) { return base64.RawURLEncoding.EncodeToString(randomBytes), nil } +type OAuth struct { + Session *OAuthExchangeSession + Token *OAuthToken + TokenURL string +} + // This structure gets passed to the callback for easy access to the current state -type EduVPNOAuthSession struct { - // Public - AuthURL string - VPNState *EduVPNState - - // private - callbackError error - context context.Context - state string - server *http.Server - verifier string +type OAuthExchangeSession struct { + // returned from the callback + CallbackError error + + // filled in in initialize + ClientID string + State string + Verifier string + + // filled in when constructing the callback + Context context.Context + Server *http.Server } // Struct that defines the json format for /.well-known/vpn-user-portal" -type EduVPNOAuthToken struct { +type OAuthToken struct { Access string `json:"access_token"` Refresh string `json:"refresh_token"` Type string `json:"token_type"` Expires int `json:"expires_in"` } -type OAuthFailedCallbackError struct { - Addr string - Err error -} - -func (e *OAuthFailedCallbackError) Error() string { - return fmt.Sprintf("failed callback %s with error %v", e.Addr, e.Err) -} - // Gets an authenticated HTTP client by obtaining refresh and access tokens -func (eduvpn *EduVPNOAuthSession) getHTTPTokenClient() error { - eduvpn.context = context.Background() +func (oauth *OAuth) getTokensWithCallback() error { + oauth.Session.Context = context.Background() mux := http.NewServeMux() addr := "127.0.0.1:8000" - eduvpn.server = &http.Server{ + oauth.Session.Server = &http.Server{ Addr: addr, Handler: mux, } - mux.HandleFunc("/callback", eduvpn.oauthCallback) - if err := eduvpn.server.ListenAndServe(); err != http.ErrServerClosed { + mux.HandleFunc("/callback", oauth.Callback) + if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed { return &OAuthFailedCallbackError{Addr: addr, Err: err} } - return eduvpn.callbackError + return oauth.Session.CallbackError } // Get the access and refresh tokens // Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 // Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 -func (eduvpn *EduVPNOAuthSession) getTokens(authCode string) error { +func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { // Make sure the verifier is set as the parameter // so that the server can verify that we are the actual owner of the authorization code - reqURL := eduvpn.VPNState.Endpoints.API.V3.Token + reqURL := oauth.TokenURL data := url.Values{ - "client_id": {eduvpn.VPNState.Name}, + "client_id": {oauth.Session.ClientID}, "code": {authCode}, - "code_verifier": {eduvpn.verifier}, + "code_verifier": {oauth.Session.Verifier}, "grant_type": {"authorization_code"}, "redirect_uri": {"http://127.0.0.1:8000/callback"}, } @@ -138,52 +120,55 @@ func (eduvpn *EduVPNOAuthSession) getTokens(authCode string) error { return bodyErr } - tokenStructure := &EduVPNOAuthToken{} + tokenStructure := &OAuthToken{} jsonErr := json.Unmarshal(body, tokenStructure) if jsonErr != nil { return &HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} } - eduvpn.VPNState.OAuthToken = tokenStructure + oauth.Token = tokenStructure return nil } -type OAuthFailedCallbackParameterError struct { - Parameter string - URL string -} - -func (e *OAuthFailedCallbackParameterError) Error() string { - return fmt.Sprintf("failed retrieving parameter %s in url %s", e.Parameter, e.URL) -} +// Get the access and refresh tokens with a previously received refresh token +// Access tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.4 +// Refresh tokens: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.2 +func (oauth *OAuth) getTokensWithRefresh() error { + reqURL := oauth.TokenURL + data := url.Values{ + "refresh_token": {oauth.Token.Refresh}, + "grant_type": {"refresh_token"}, + } + headers := &http.Header{ + "content-type": {"application/x-www-form-urlencoded"}} + opts := &HTTPOptionalParams{Headers: headers} + body, bodyErr := HTTPPostWithOptionalParams(reqURL, data, opts) + if bodyErr != nil { + return bodyErr + } -type OAuthFailedCallbackStateMatchError struct { - State string - ExpectedState string -} + tokenStructure := &OAuthToken{} + jsonErr := json.Unmarshal(body, tokenStructure) -func (e *OAuthFailedCallbackStateMatchError) Error() string { - return fmt.Sprintf("failed matching state, got %s, want %s", e.State, e.ExpectedState) -} + if jsonErr != nil { + return &HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} + } -type OAuthFailedCallbackGetTokensError struct { - Err error -} + oauth.Token = tokenStructure -func (e *OAuthFailedCallbackGetTokensError) Error() string { - return fmt.Sprintf("failed getting tokens with error %v", e.Err) + return nil } // //// The callback to retrieve the authorization code: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-1.3.1 -func (eduvpn *EduVPNOAuthSession) oauthCallback(w http.ResponseWriter, req *http.Request) { +func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { // Extract the authorization code code, success := req.URL.Query()["code"] if !success { - eduvpn.callbackError = &OAuthFailedCallbackParameterError{Parameter: "code", URL: req.URL.String()} - go eduvpn.server.Shutdown(eduvpn.context) + oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "code", URL: req.URL.String()} + go oauth.Session.Server.Shutdown(oauth.Session.Context) return } // The code is the first entry @@ -193,64 +178,36 @@ func (eduvpn *EduVPNOAuthSession) oauthCallback(w http.ResponseWriter, req *http // https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-04#section-7.15 state, success := req.URL.Query()["state"] if !success { - eduvpn.callbackError = &OAuthFailedCallbackParameterError{Parameter: "state", URL: req.URL.String()} - go eduvpn.server.Shutdown(eduvpn.context) + oauth.Session.CallbackError = &OAuthFailedCallbackParameterError{Parameter: "state", URL: req.URL.String()} + go oauth.Session.Server.Shutdown(oauth.Session.Context) return } // The state is the first entry extractedState := state[0] - if extractedState != eduvpn.state { - eduvpn.callbackError = &OAuthFailedCallbackStateMatchError{State: extractedState, ExpectedState: eduvpn.state} - go eduvpn.server.Shutdown(eduvpn.context) + if extractedState != oauth.Session.State { + oauth.Session.CallbackError = &OAuthFailedCallbackStateMatchError{State: extractedState, ExpectedState: oauth.Session.State} + go oauth.Session.Server.Shutdown(oauth.Session.Context) return } // Now that we have obtained the authorization code, we can move to the next step: // Obtaining the access and refresh tokens - err := eduvpn.getTokens(extractedCode) + err := oauth.getTokensWithAuthCode(extractedCode) if err != nil { - eduvpn.callbackError = &OAuthFailedCallbackGetTokensError{Err: err} - go eduvpn.server.Shutdown(eduvpn.context) + oauth.Session.CallbackError = &OAuthFailedCallbackGetTokensError{Err: err} + go oauth.Session.Server.Shutdown(oauth.Session.Context) return } // Shutdown the server as we're done listening - go eduvpn.server.Shutdown(eduvpn.context) -} - -func constructURL(baseURL string, parameters map[string]string) (string, error) { - url, err := url.Parse(baseURL) - - if err != nil { - return "", err - } - - q := url.Query() - - for parameter, value := range parameters { - q.Set(parameter, value) - } - url.RawQuery = q.Encode() - return url.String(), nil -} - -type OAuthFailedInitializeError struct { - Err error -} - -func (e *OAuthFailedInitializeError) Error() string { - return fmt.Sprintf("failed initializing OAuth with error %v", e.Err) + go oauth.Session.Server.Shutdown(oauth.Session.Context) } // Initializes the OAuth for eduvpn. // It needs a vpn state that was gotten from `Register` // It returns the authurl for the browser and an error if present -func InitializeOAuth(vpnState *EduVPNState) (string, error) { - if vpnState == nil { - panic("invalid state") - } - +func (eduvpn *VPNState) InitializeOAuth() (string, error) { // Generate the state state, stateErr := genState() if stateErr != nil { @@ -265,7 +222,7 @@ func InitializeOAuth(vpnState *EduVPNState) (string, error) { challenge := genChallengeS256(verifier) parameters := map[string]string{ - "client_id": vpnState.Name, + "client_id": eduvpn.Name, "code_challenge_method": "S256", "code_challenge": challenge, "response_type": "code", @@ -274,24 +231,84 @@ func InitializeOAuth(vpnState *EduVPNState) (string, error) { "redirect_uri": "http://127.0.0.1:8000/callback", } - authURL, urlErr := constructURL(vpnState.Endpoints.API.V3.Authorization, parameters) + authURL, urlErr := HTTPConstructURL(eduvpn.Server.Endpoints.API.V3.Authorization, parameters) if urlErr != nil { // shouldn't happen panic(urlErr) } // Fill the struct with the necessary fields filled for the next call to getting the HTTP client - vpnState.OAuthSession = &EduVPNOAuthSession{AuthURL: authURL, VPNState: vpnState, state: state, verifier: verifier} + oauthSession := &OAuthExchangeSession{ClientID: eduvpn.Name, State: state, Verifier: verifier} + eduvpn.Server.OAuth = &OAuth{TokenURL: eduvpn.Server.Endpoints.API.V3.Token, Session: oauthSession} return authURL, nil } -func FinishOAuth(vpnState *EduVPNState) error { - if vpnState == nil { - panic("invalid state") - } - if vpnState.OAuthSession == nil { +// Error definitions +func (eduvpn *VPNState) FinishOAuth() error { + oauth := eduvpn.Server.OAuth + if oauth == nil { panic("invalid oauth state") } - return vpnState.OAuthSession.getHTTPTokenClient() + return oauth.getTokensWithCallback() +} + +type OAuthGenStateUnableError struct { + Err error +} + +func (e *OAuthGenStateUnableError) Error() string { + return fmt.Sprintf("failed generating state with error %v", e.Err) +} + +type OAuthGenVerifierUnableError struct { + Err error +} + +func (e *OAuthGenVerifierUnableError) Error() string { + return fmt.Sprintf("failed generating verifier with error %v", e.Err) +} + + +type OAuthFailedCallbackError struct { + Addr string + Err error +} + +func (e *OAuthFailedCallbackError) Error() string { + return fmt.Sprintf("failed callback %s with error %v", e.Addr, e.Err) +} + +type OAuthFailedCallbackParameterError struct { + Parameter string + URL string +} + +func (e *OAuthFailedCallbackParameterError) Error() string { + return fmt.Sprintf("failed retrieving parameter %s in url %s", e.Parameter, e.URL) +} + +type OAuthFailedCallbackStateMatchError struct { + State string + ExpectedState string +} + +func (e *OAuthFailedCallbackStateMatchError) Error() string { + return fmt.Sprintf("failed matching state, got %s, want %s", e.State, e.ExpectedState) +} + +type OAuthFailedCallbackGetTokensError struct { + Err error +} + +func (e *OAuthFailedCallbackGetTokensError) Error() string { + return fmt.Sprintf("failed getting tokens with error %v", e.Err) +} + +type OAuthFailedInitializeError struct { + Err error +} + +func (e *OAuthFailedInitializeError) Error() string { + return fmt.Sprintf("failed initializing OAuth with error %v", e.Err) } diff --git a/src/server.go b/src/server.go index ced7716..bf1fb3d 100644 --- a/src/server.go +++ b/src/server.go @@ -1,96 +1,57 @@ package eduvpn import ( - "fmt" + "encoding/json" ) -type DiscoFileError struct { - URL string - Err error +type Server struct { + BaseURL string + Endpoints *ServerEndpoints + OAuth *OAuth } -func (e *DiscoFileError) Error() string { - return fmt.Sprintf("failed obtaining disco file %s with error %v", e.URL, e.Err) +type ServerEndpointList struct { + API string `json:"api_endpoint"` + Authorization string `json:"authorization_endpoint"` + Token string `json:"token_endpoint"` } -type DiscoSigFileError struct { - URL string - Err error +// Struct that defines the json format for /.well-known/vpn-user-portal" +type ServerEndpoints struct { + API struct { + V2 ServerEndpointList `json:"http://eduvpn.org/api#2"` + V3 ServerEndpointList `json:"http://eduvpn.org/api#3"` + } `json:"api"` + V string `json:"v"` } -func (e *DiscoSigFileError) Error() string { - return fmt.Sprintf("failed obtaining disco signature file %s with error %v", e.URL, e.Err) -} - -type DiscoVerifyError struct { - File string - Sigfile string - Err error -} -func (e *DiscoVerifyError) Error() string { - return fmt.Sprintf("failed verifying file %s with signature %s due to error %v", e.File, e.Sigfile, e.Err) +func (server *Server) Initialize(url string) error { + server.BaseURL = url + endpointsErr := server.GetEndpoints() + if endpointsErr != nil { + return endpointsErr + } + return nil } -// Helper function that gets a disco json -func getDiscoFile(jsonFile string) (string, error) { - // Get json data - discoURL := "https://disco.eduvpn.org/v2/" - fileURL := discoURL + jsonFile - fileBody, fileErr := HTTPGet(fileURL) - if fileErr != nil { - return "", &DiscoFileError{fileURL, fileErr} - } - - // Get signature - sigFile := jsonFile + ".minisig" - sigURL := discoURL + sigFile - sigBody, sigFileErr := HTTPGet(sigURL) +func (server *Server) GetEndpoints() error { + url := server.BaseURL + "/.well-known/vpn-user-portal" + body, bodyErr := HTTPGet(url) - if sigFileErr != nil { - return "", &DiscoSigFileError{URL: sigURL, Err: sigFileErr} + if bodyErr != nil { + return bodyErr } - // Verify signature - // TODO: Handle this by keeping track of the previous sign time - // Wrappers must do this? - var previousSigTime uint64 = 0 - forcePrehash := false - verifySuccess, verifyErr := Verify(string(sigBody), fileBody, jsonFile, previousSigTime, forcePrehash) + endpoints := &ServerEndpoints{} + jsonErr := json.Unmarshal(body, &endpoints) - if !verifySuccess || verifyErr != nil { - return "", &DiscoVerifyError{File: jsonFile, Sigfile: sigFile, Err: verifyErr} + if jsonErr != nil { + return jsonErr } - return string(fileBody), nil -} + server.Endpoints = endpoints -type GetListError struct { - File string - Err error -} - -func (e *GetListError) Error() string { - return fmt.Sprintf("failed getting disco list file %s with error %v", e.File, e.Err) -} - -// Get the organization list -func GetOrganizationsList() (string, error) { - file := "organization_list.json" - body, err := getDiscoFile(file) - if err != nil { - return "", &GetListError{File: file, Err: err} - } - return body, nil -} - -// Get the server list -func GetServersList() (string, error) { - file := "server_list.json" - body, err := getDiscoFile("server_list.json") - if err != nil { - return "", &GetListError{File: file, Err: err} - } - return body, nil + return nil } diff --git a/src/state.go b/src/state.go index be85a45..272bbc6 100644 --- a/src/state.go +++ b/src/state.go @@ -1,38 +1,25 @@ package eduvpn -type EduVPNState struct { - // The endpoints - Endpoints *EduVPNEndpoints - +type VPNState struct { // Info passed by the client Name string - Server string - // OAuth - OAuthToken *EduVPNOAuthToken - OAuthSession *EduVPNOAuthSession + // The chosen server + Server *Server } -func Register(state *EduVPNState, name string, server string, stateCallback func(string, string)) error { +func Register(state *VPNState, name string, stateCallback func(string, string)) error { state.Name = name - state.Server = server - - endpoints, err := APIGetEndpoints(state) - - if err != nil { - return err - } - state.Endpoints = endpoints stateCallback("START", "REGISTER") return nil } -var VPNStateInstance *EduVPNState +var VPNStateInstance *VPNState -func GetVPNState() *EduVPNState { +func GetVPNState() *VPNState { if VPNStateInstance == nil { - VPNStateInstance = &EduVPNState{} + VPNStateInstance = &VPNState{} } return VPNStateInstance } |
