diff options
| -rw-r--r-- | client_test.go | 11 | ||||
| -rw-r--r-- | internal/oauth/oauth.go | 83 |
2 files changed, 75 insertions, 19 deletions
diff --git a/client_test.go b/client_test.go index 1c7211f..b52d61c 100644 --- a/client_test.go +++ b/client_test.go @@ -106,7 +106,16 @@ func test_connect_oauth_parameter( "en", func(oldState FSMStateID, newState FSMStateID, data interface{}) { if newState == STATE_OAUTH_STARTED { - baseURL := "http://127.0.0.1:8000/callback" + current, currentErr := state.Servers.GetCurrentServer() + if currentErr != nil { + t.Fatalf("No current server with error: %v", currentErr) + } + port, portErr := current.GetOAuth().GetListenerPort() + + if portErr != nil { + t.Fatalf("No port with error: %v", portErr) + } + baseURL := fmt.Sprintf("http://127.0.0.1:%d/callback", port) url, err := httpw.HTTPConstructURL(baseURL, parameters) if err != nil { t.Fatalf( diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index d256dbc..246f1e8 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -5,14 +5,16 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" + "net" "net/http" "net/url" "time" httpw "github.com/eduvpn/eduvpn-common/internal/http" - "github.com/eduvpn/eduvpn-common/types" "github.com/eduvpn/eduvpn-common/internal/util" + "github.com/eduvpn/eduvpn-common/types" ) // Generates a random base64 string to be used for state @@ -87,6 +89,7 @@ type OAuthExchangeSession struct { // filled in when constructing the callback Context context.Context Server *http.Server + Listener net.Listener } // Struct that defines the json format for /.well-known/vpn-user-portal" @@ -98,18 +101,34 @@ type OAuthToken struct { ExpiredTimestamp time.Time `json:"expires_in_timestamp"` } -// Gets an authorized HTTP client by obtaining refresh and access tokens -func (oauth *OAuth) getTokensWithCallback() error { +// Sets up a listener +func (oauth *OAuth) setupListener() error { + errorMessage := "failed setting up listener" oauth.Session.Context = context.Background() + + // create a listener + listener, listenerErr := net.Listen("tcp", ":0") + if listenerErr != nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: listenerErr} + } + oauth.Session.Listener = listener + return nil +} + +func (oauth *OAuth) getTokensWithCallback() error { + errorMessage := "failed getting tokens with callback" + if oauth.Session.Listener == nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No listener")} + } mux := http.NewServeMux() - addr := "127.0.0.1:8000" + // server /callback over the listener address oauth.Session.Server = &http.Server{ - Addr: addr, Handler: mux, } mux.HandleFunc("/callback", oauth.Callback) - if err := oauth.Session.Server.ListenAndServe(); err != http.ErrServerClosed { - return &types.WrappedErrorMessage{Message: "failed getting tokens with callback", Err: err} + + if err := oauth.Session.Server.Serve(oauth.Session.Listener); err != http.ErrServerClosed { + return &types.WrappedErrorMessage{Message: errorMessage, Err: err} } return oauth.Session.CallbackError } @@ -122,12 +141,18 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { // Make sure the verifier is set as the parameter // so that the server can verify that we are the actual owner of the authorization code reqURL := oauth.TokenURL + + port, portErr := oauth.GetListenerPort() + if portErr != nil { + return &types.WrappedErrorMessage{Message: errorMessage, Err: portErr} + } + data := url.Values{ "client_id": {oauth.Session.ClientID}, "code": {authCode}, "code_verifier": {oauth.Session.Verifier}, "grant_type": {"authorization_code"}, - "redirect_uri": {"http://127.0.0.1:8000/callback"}, + "redirect_uri": {fmt.Sprintf("http://127.0.0.1:%d/callback", port)}, } headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, @@ -263,14 +288,18 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string) { oauth.TokenURL = tokenURL } +func (oauth OAuth) GetListenerPort() (int, error) { + errorMessage := "failed to get listener port" + + if oauth.Session.Listener == nil { + return 0, &types.WrappedErrorMessage{Message: errorMessage, Err: errors.New("No OAuth listener")} + } + return oauth.Session.Listener.Addr().(*net.TCPAddr).Port, nil +} + // Starts the OAuth exchange for eduvpn. func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) (string, error) { errorMessage := "failed starting OAuth exchange" - // Generate the state - state, stateErr := genState() - if stateErr != nil { - return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} - } // Generate the verifier and challenge verifier, verifierErr := genVerifier() @@ -279,6 +308,28 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) } challenge := genChallengeS256(verifier) + // Generate the state + state, stateErr := genState() + if stateErr != nil { + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} + } + + // Fill the struct with the necessary fields filled for the next call to getting the HTTP client + oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} + oauth.Session = oauthSession + + // set up the listener to get the redirect URI + listenerErr := oauth.setupListener() + if listenerErr != nil { + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: stateErr} + } + + // Get the listener port + port, portErr := oauth.GetListenerPort() + if portErr != nil { + return "", &types.WrappedErrorMessage{Message: errorMessage, Err: portErr} + } + parameters := map[string]string{ "client_id": name, "code_challenge_method": "S256", @@ -286,7 +337,7 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) "response_type": "code", "scope": "config", "state": state, - "redirect_uri": "http://127.0.0.1:8000/callback", + "redirect_uri": fmt.Sprintf("http://127.0.0.1:%d/callback", port), } authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) @@ -295,10 +346,6 @@ func (oauth *OAuth) GetAuthURL(name string, postProcessAuth func(string) string) return "", &types.WrappedErrorMessage{Message: errorMessage, Err: urlErr} } - // Fill the struct with the necessary fields filled for the next call to getting the HTTP client - oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} - oauth.Session = oauthSession - // Return the url processed return postProcessAuth(authURL), nil } |
