summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjwijenbergh <jeroenwijenbergh@protonmail.com>2023-01-04 21:06:52 +0100
committerjwijenbergh <jeroenwijenbergh@protonmail.com>2023-01-04 21:06:52 +0100
commitac41bda1e059d24da08e584b81b4507cf781beb8 (patch)
treeaa190d9c18f4bade3669306ee79e82eb2d396eb1
parentc2cda99ff02e9632c00500a1bc1810131f9e03eb (diff)
OAuth: Add auth url test
-rw-r--r--internal/oauth/oauth_test.go68
1 files changed, 68 insertions, 0 deletions
diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go
index b6adcd6..bafb7e5 100644
--- a/internal/oauth/oauth_test.go
+++ b/internal/oauth/oauth_test.go
@@ -2,6 +2,7 @@ package oauth
import (
"encoding/json"
+ "fmt"
"net/url"
"strings"
"testing"
@@ -126,3 +127,70 @@ func Test_secretJSON(t *testing.T) {
t.Fatalf("Serialized OAuth contains Refresh Token! Serialized: %v, Refresh Token: %v", s, a)
}
}
+
+func Test_AuthURL(t *testing.T) {
+ iss := "local"
+ auth := "https://127.0.0.1/auth"
+ token := "https://127.0.0.1/token"
+ id := "client_id"
+ o := OAuth{ISS: iss, BaseAuthorizationURL: auth, TokenURL: token}
+ s, err := o.AuthURL(id, func(s string) string {
+ // We do nothing here are this function is for skipping WAYF
+ return s
+ })
+ if err != nil {
+ t.Fatalf("Error in getting OAuth URL: %v", err)
+ }
+
+ // Check if the OAuth session has valid values
+ if o.session.ClientID != id {
+ t.Fatalf("OAuth ClientID not equal, want: %v, got: %v", o.session.ClientID, id)
+ }
+ if o.session.ISS != iss {
+ t.Fatalf("OAuth ISS not equal, want: %v, got: %v", o.session.ISS, iss)
+ }
+ if o.session.State == "" {
+ t.Fatal("No OAuth session state paremeter found")
+ }
+ if o.session.Verifier == "" {
+ t.Fatal("No OAuth session state paremeter found")
+ }
+ if o.session.ErrChan == nil {
+ t.Fatal("No OAuth session error channel found")
+ }
+
+ u, err := url.Parse(s)
+ if err != nil {
+ t.Fatalf("Returned Auth URL cannot be parsed with error: %v", err)
+ }
+
+ port, err := o.ListenerPort()
+ if err != nil {
+ t.Fatalf("Listener port cannot be found with error: %v", err)
+ }
+
+ c := []struct {
+ query string
+ want string
+ }{
+ {query: "client_id", want: id},
+ {query: "code_challenge_method", want: "S256"},
+ {query: "response_type", want: "code"},
+ {query: "scope", want: "config"},
+ {query: "redirect_uri", want: fmt.Sprintf("http://127.0.0.1:%d/callback", port)},
+ }
+
+ q := u.Query()
+
+ // We should have 7 parameters: client_id, challenge method, challenge, response type, scope, state and redirect uri
+ if len(q) != 7 {
+ t.Fatalf("Total query parameters is not 7, url: %v, total params: %v", u, len(q))
+ }
+
+ for _, v := range c {
+ p := q.Get(v.query)
+ if p != v.want {
+ t.Fatalf("Parameter: %v, not equal, want: %v, got: %v", v.query, v.want, p)
+ }
+ }
+}