summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/oauth/oauth.go83
1 files changed, 65 insertions, 18 deletions
diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go
index d256dbc..246f1e8 100644
--- a/internal/oauth/oauth.go
+++ b/internal/oauth/oauth.go
@@ -5,14 +5,16 @@ import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
+ "errors"
"fmt"
+ "net"
"net/http"
"net/url"
"time"
httpw "github.com/eduvpn/eduvpn-common/internal/http"
- "github.com/eduvpn/eduvpn-common/types"
"github.com/eduvpn/eduvpn-common/internal/util"
+ "github.com/eduvpn/eduvpn-common/types"
)
// Generates a random base64 string to be used for state
@@ -87,6 +89,7 @@ type OAuthExchangeSession struct {
// filled in when constructing the callback
Context context.Context
Server *http.Server
+ Listener net.Listener
}
// Struct that defines the json format for /.well-known/vpn-user-portal"
@@ -98,18 +101,34 @@ type OAuthToken struct {
ExpiredTimestamp time.Time `json:"expires_in_timestamp"`
}
-// Gets an authorized HTTP client by obtaining refresh and access tokens
-func (oauth *OAuth) getTokensWithCallback() error {
+// Sets up a listener
+func (oauth *OAuth) setupListener() error {
+ errorMessage := "failed setting up listener"
oauth.Session.Context = context.Background()
+
+ // create a listener
+ listener, listenerErr := net.Listen("tcp", ":0")
+ if listenerErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: listenerErr}
+ }
+ oauth.Session.Listener = listener
+ return nil
+}
+
+func (oauth *OAuth) getTokensWithCallback() error {
+ errorMessage := "failed getting tokens with callback"
+ if oauth.Session.Listener == nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No listener")}
+ }
mux := http.NewServeMux()
- addr := "127.0.0.1:8000"
+ // server /callback over the listener address
oauth.Session.Server = &http.Server{
- Addr: addr,
Handler: mux,
}
mux.HandleFunc("/callback", oauth.Callback)
- if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed {
- return &types.WrappedErrorMessage{Message: "failed getting tokens with callback", Err: err}
+
+ if err := oauth.Session.Server.Serve(oauth.Session.Listener); err != http.ErrServerClosed {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: err}
}
return oauth.Session.CallbackError
}
@@ -122,12 +141,18 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error {
// Make sure the verifier is set as the parameter
// so that the server can verify that we are the actual owner of the authorization code
reqURL := oauth.TokenURL
+
+ port, portErr := oauth.GetListenerPort()
+ if portErr != nil {
+ return &types.WrappedErrorMessage{Message: errorMessage, Err: portErr}
+ }
+
data := url.Values{
"client_id": {oauth.Session.ClientID},
"code": {authCode},
"code_verifier": {oauth.Session.Verifier},
"grant_type": {"authorization_code"},
- "redirect_uri": {"http://127.0.0.1:8000/callback"},
+ "redirect_uri": {fmt.Sprintf("http://127.0.0.1:%d/callback", port)},
}
headers := http.Header{
"content-type": {"application/x-www-form-urlencoded"},
@@ -263,14 +288,18 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string) {
oauth.TokenURL = tokenURL
}
+func (oauth OAuth) GetListenerPort() (int, error) {
+ errorMessage := "failed to get listener port"
+
+ if oauth.Session.Listener == nil {
+ return 0, &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No OAuth listener")}
+ }
+ return oauth.Session.Listener.Addr().(*net.TCPAddr).Port, nil
+}
+
// Starts the OAuth exchange for eduvpn.
func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) (string, error) {
errorMessage := "failed starting OAuth exchange"
- // Generate the state
- state, stateErr := genState()
- if stateErr != nil {
- return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
- }
// Generate the verifier and challenge
verifier, verifierErr := genVerifier()
@@ -279,6 +308,28 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string)
}
challenge := genChallengeS256(verifier)
+ // Generate the state
+ state, stateErr := genState()
+ if stateErr != nil {
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
+ }
+
+ // Fill the struct with the necessary fields filled for the next call to getting the HTTP client
+ oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier}
+ oauth.Session = oauthSession
+
+ // set up the listener to get the redirect URI
+ listenerErr := oauth.setupListener()
+ if listenerErr != nil {
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr}
+ }
+
+ // Get the listener port
+ port, portErr := oauth.GetListenerPort()
+ if portErr != nil {
+ return "", &types.WrappedErrorMessage{Message: errorMessage, Err: portErr}
+ }
+
parameters := map[string]string{
"client_id": name,
"code_challenge_method": "S256",
@@ -286,7 +337,7 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string)
"response_type": "code",
"scope": "config",
"state": state,
- "redirect_uri": "http://127.0.0.1:8000/callback",
+ "redirect_uri": fmt.Sprintf("http://127.0.0.1:%d/callback", port),
}
authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters)
@@ -295,10 +346,6 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string)
return "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr}
}
- // Fill the struct with the necessary fields filled for the next call to getting the HTTP client
- oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier}
- oauth.Session = oauthSession
-
// Return the url processed
return postProcessAuth(authURL), nil
}