From ac41bda1e059d24da08e584b81b4507cf781beb8 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Wed, 4 Jan 2023 21:06:52 +0100 Subject: OAuth: Add auth url test --- internal/oauth/oauth_test.go | 68 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) (limited to 'internal') diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go index b6adcd6..bafb7e5 100644 --- a/internal/oauth/oauth_test.go +++ b/internal/oauth/oauth_test.go @@ -2,6 +2,7 @@ package oauth import ( "encoding/json" + "fmt" "net/url" "strings" "testing" @@ -126,3 +127,70 @@ func Test_secretJSON(t *testing.T) { t.Fatalf("Serialized OAuth contains Refresh Token! Serialized: %v, Refresh Token: %v", s, a) } } + +func Test_AuthURL(t *testing.T) { + iss := "local" + auth := "https://127.0.0.1/auth" + token := "https://127.0.0.1/token" + id := "client_id" + o := OAuth{ISS: iss, BaseAuthorizationURL: auth, TokenURL: token} + s, err := o.AuthURL(id, func(s string) string { + // We do nothing here are this function is for skipping WAYF + return s + }) + if err != nil { + t.Fatalf("Error in getting OAuth URL: %v", err) + } + + // Check if the OAuth session has valid values + if o.session.ClientID != id { + t.Fatalf("OAuth ClientID not equal, want: %v, got: %v", o.session.ClientID, id) + } + if o.session.ISS != iss { + t.Fatalf("OAuth ISS not equal, want: %v, got: %v", o.session.ISS, iss) + } + if o.session.State == "" { + t.Fatal("No OAuth session state paremeter found") + } + if o.session.Verifier == "" { + t.Fatal("No OAuth session state paremeter found") + } + if o.session.ErrChan == nil { + t.Fatal("No OAuth session error channel found") + } + + u, err := url.Parse(s) + if err != nil { + t.Fatalf("Returned Auth URL cannot be parsed with error: %v", err) + } + + port, err := o.ListenerPort() + if err != nil { + t.Fatalf("Listener port cannot be found with error: %v", err) + } + + c := []struct { + query string + want string + }{ + {query: "client_id", want: id}, + {query: "code_challenge_method", want: "S256"}, + {query: "response_type", want: "code"}, + {query: "scope", want: "config"}, + {query: "redirect_uri", want: fmt.Sprintf("http://127.0.0.1:%d/callback", port)}, + } + + q := u.Query() + + // We should have 7 parameters: client_id, challenge method, challenge, response type, scope, state and redirect uri + if len(q) != 7 { + t.Fatalf("Total query parameters is not 7, url: %v, total params: %v", u, len(q)) + } + + for _, v := range c { + p := q.Get(v.query) + if p != v.want { + t.Fatalf("Parameter: %v, not equal, want: %v, got: %v", v.query, v.want, p) + } + } +} -- cgit v1.2.3