diff options
Diffstat (limited to 'internal/api')
| -rw-r--r-- | internal/api/api_test.go | 532 |
1 files changed, 532 insertions, 0 deletions
diff --git a/internal/api/api_test.go b/internal/api/api_test.go new file mode 100644 index 0000000..28b5be1 --- /dev/null +++ b/internal/api/api_test.go @@ -0,0 +1,532 @@ +package api + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "reflect" + "regexp" + "strings" + "testing" + "time" + + "github.com/eduvpn/eduvpn-common/internal/api/profiles" + httpw "github.com/eduvpn/eduvpn-common/internal/http" + "github.com/eduvpn/eduvpn-common/internal/test" + "github.com/eduvpn/eduvpn-common/internal/wireguard" + "github.com/eduvpn/eduvpn-common/types/protocol" + "github.com/eduvpn/eduvpn-common/types/server" + "github.com/jwijenbergh/eduoauth-go" +) + +func tokenHandler(t *testing.T, gt []string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for token handler: %v", r.Method) + } + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed reading token endpoint body: %v", err) + } + parsed, err := url.ParseQuery(string(b)) + if err != nil { + t.Fatalf("failed parsing query body: %v", err) + } + grant := parsed.Get("grant_type") + + for _, v := range gt { + if v == grant { + w.Write([]byte(` +{ + "access_token": "validaccess", + "refresh_token": "validrefresh", + "expires_in": 3600 +} + `)) + return + } + } + t.Fatalf("grant type: %v, not allowed", grant) + } +} + +func checkAuthBearer(t *testing.T, r *http.Request) { + authh := r.Header.Get("Authorization") + if !strings.HasPrefix(authh, "Bearer ") { + t.Fatalf("API call is not given with an authorization Bearer header, got: %v", authh) + } +} + +func connectHandler(t *testing.T, proto string, exp time.Time) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for connect handler: %v", r.Method) + } + checkAuthBearer(t, r) + w.Header().Set("expires", exp.Format(http.TimeFormat)) + w.Header().Set("content-type", fmt.Sprintf("application/x-%s-profile", proto)) + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed reading token endpoint body: %v", err) + } + parsed, err := url.ParseQuery(string(b)) + if err != nil { + t.Fatalf("failed parsing query body: %v", err) + } + // the wireguard config we parse + cfg := proto + if proto == "openvpn" { + cfg = "openvpnconfig" + } else { + if parsed.Get("public_key") == "" { + t.Fatalf("no public_key given") + } + if proto == "wireguard+tcp" { + ptcp := parsed.Get("prefer_tcp") + if ptcp != "yes" { + t.Fatalf("prefer TCP is not yes: %s", ptcp) + } + cfg = ` +[Interface] +[Peer] +ProxyEndpoint = https://proxyendpoint +` + } else { + cfg = "[Interface]" + } + } + w.Write([]byte(cfg)) + } +} + +func disconnectHandler(t *testing.T) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for disconnect handler: %v", r.Method) + } + checkAuthBearer(t, r) + } +} + +type TestCallback struct { + t *testing.T +} + +func (tc *TestCallback) TriggerAuth(ctx context.Context, str string, cr bool) (string, error) { + go func() { + u, err := url.Parse(str) + if err != nil { + tc.t.Fatalf("cannot parse TriggerAuth uri: %v", err) + } + ru, err := url.Parse(u.Query().Get("redirect_uri")) + if err != nil { + tc.t.Fatalf("failed to parse redirect uri: %v", err) + } + oq := u.Query() + q := ru.Query() + q.Set("state", oq.Get("state")) + q.Set("code", "fakeauthcode") + ru.RawQuery = q.Encode() + + c := http.Client{} + req, err := http.NewRequest("GET", ru.String(), nil) + if err != nil { + tc.t.Fatalf("failed to create HTTP request: %v", err) + } + c.Do(req) + }() + return "", nil +} +func (tc *TestCallback) AuthDone(string, server.Type) {} +func (tc *TestCallback) TokensUpdated(string, server.Type, eduoauth.Token) {} + +// create a API struct with allowed grant types +func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.HandlerPath) (*API, *test.Server) { + // Create a simple API client and check if the fields are created correctly + listen, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to setup listener for test server: %v", err) + } + + fp := append(hps, []test.HandlerPath{ + { + Method: http.MethodGet, + Path: "/.well-known/vpn-user-portal", + Response: fmt.Sprintf(` +{ + "api": { + "http://eduvpn.org/api#3": { + "api_endpoint": "https://%[1]s/test-api-endpoint", + "authorization_endpoint": "https://%[1]s/test-authorization-endpoint", + "token_endpoint": "https://%[1]s/test-token-endpoint" + } + }, + "v": "0.0.0" +} +`, listen.Addr().String()), + ResponseCode: 200, + }, + { + Path: "/test-token-endpoint", + ResponseHandler: tokenHandler(t, gt), + }, + }...) + // start server + serv := test.NewServerWithHandles(fp, listen) + + sd := ServerData{ + ID: "randomidentifier", + Type: server.TypeCustom, + BaseWK: serv.URL, + BaseAuthWK: serv.URL, + ProcessAuth: func(in string) string { + return in + }, + DisableAuthorize: false, + } + servc, err := serv.Client() + if err != nil { + t.Fatalf("failed to setup HTTP test server client: %v", servc) + } + // TODO: Mock underlying clients instead + http.DefaultTransport = servc.Client.Transport + + tc := &TestCallback{t: t} + + a, err := NewAPI(context.Background(), "testclient", sd, tc, tok) + if err != nil { + t.Fatalf("failed creating API: %v", err) + } + return a, serv +} + +func TestNewAPI(t *testing.T) { + gts := []string{"refresh_token"} + tok := &eduoauth.Token{ + Access: "expiredaccess", + Refresh: "expiredrefresh", + // tokens are expired, let's try authorizing + ExpiredTimestamp: time.Now(), + } + a, srv := createTestAPI(t, tok, gts, nil) + srv.Close() + + // now the tokens should be the new access tokens + if a.oauth.Token().Access != "validaccess" { + t.Fatalf("access token is not valid access") + } + if a.oauth.Token().Refresh != "validrefresh" { + t.Fatalf("refresh token is not valid refresh") + } + + gts = []string{"authorization_code"} + tok = &eduoauth.Token{ + Access: "expiredaccess", + Refresh: "", + ExpiredTimestamp: time.Now(), + } + a, srv = createTestAPI(t, tok, gts, nil) + srv.Close() + + // now the tokens should be the new access tokens + if a.oauth.Token().Access != "validaccess" { + t.Fatalf("access token is not valid access") + } + if a.oauth.Token().Refresh != "validrefresh" { + t.Fatalf("refresh token is not valid refresh") + } +} + +func TestAPIInfo(t *testing.T) { + // auth should not be triggered + var gts []string = nil + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + statErr := &httpw.StatusError{} + cases := []struct { + hp test.HandlerPath + info *profiles.Info + err interface{} + }{ + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: ` +{ + "info": { + "profile_list": [ + { + "default_gateway": false, + "display_name": "test profile 1", + "profile_id": "test1", + "vpn_proto_list": [ + "openvpn", + "wireguard" + ] + } + ] + } +} +`, + ResponseCode: 200, + }, + info: &profiles.Info{ + Info: profiles.ListInfo{ + ProfileList: []profiles.Profile{ + { + ID: "test1", + DisplayName: "test profile 1", + VPNProtoList: []string{"openvpn", "wireguard"}, + DefaultGateway: false, + }, + }, + }, + }, + }, + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: ` +{ + "info": { + "profile_list": [ + { + "display_name": "test profile 2", + "profile_id": "test2", + "vpn_proto_list": [ + "wireguard" + ] + } + ] + } +} +`, + ResponseCode: 200, + }, + info: &profiles.Info{ + Info: profiles.ListInfo{ + ProfileList: []profiles.Profile{ + { + ID: "test2", + DisplayName: "test profile 2", + VPNProtoList: []string{"wireguard"}, + DefaultGateway: false, + }, + }, + }, + }, + }, + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: "", + ResponseCode: 404, + }, + info: nil, + err: &statErr, + }, + } + + for _, c := range cases { + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) + defer srv.Close() + gprfs, err := a.Info(context.Background()) + // got error but the want error is nil + if err != nil { + if c.err == nil { + t.Fatalf("failed profiles info: %v but want no error", err) + } + + if !errors.As(err, c.err) { + t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err.Error()) + } + } else { + if c.err != nil { + t.Fatalf("got no error but want error: %T", c.err) + } + } + + if !reflect.DeepEqual(gprfs, c.info) { + t.Fatalf("got info: %v, not equal to want: %v", gprfs, c.info) + } + } +} + +func TestAPIConnect(t *testing.T) { + // auth should not be triggered + var gts []string = nil + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + cases := []struct { + hp test.HandlerPath + cd *ConnectData + prof profiles.Profile + protos []protocol.Protocol + ptcp bool + err error + }{ + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + ResponseCode: 200, + }, + cd: nil, + err: ErrNoProtocols, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + ResponseCode: 200, + }, + cd: nil, + protos: []protocol.Protocol{protocol.Unknown}, + err: ErrUnknownProtocol, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + ResponseCode: 200, + }, + cd: nil, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard, protocol.Unknown}, + err: ErrUnknownProtocol, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "openvpn", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: "openvpnconfig\nscript-security 0", + Protocol: protocol.OpenVPN, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + Proxy: nil, + }, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "wireguard", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: `\[Interface\] +PrivateKey = .*`, + Protocol: protocol.WireGuard, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + Proxy: nil, + }, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "wireguard+tcp", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: `\[Interface\] +PrivateKey = .*`, + Protocol: protocol.WireGuardProxy, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + // proxy will be manually checked + Proxy: &wireguard.Proxy{}, + }, + ptcp: true, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + } + + for _, c := range cases { + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) + defer srv.Close() + gcd, err := a.Connect(context.Background(), c.prof, c.protos, c.ptcp) + // got error but the want error is nil + if err != nil { + if c.err == nil { + t.Fatalf("failed connect: %v but want no error", err) + } + + if !errors.Is(err, c.err) { + t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err) + } + } else { + if c.err != nil { + t.Fatalf("got no error but want error: %T", c.err) + } + } + + if gcd != nil && c.cd != nil { + m, err := regexp.MatchString(c.cd.Configuration, gcd.Configuration) + if err != nil { + t.Fatalf("failed matching regexp: %v", err) + } + if !m { + t.Fatalf("regex:\n%s\ndoes not match config:\n%s", c.cd.Configuration, gcd.Configuration) + } + // we have already checked the config using a regex + c.cd.Configuration = gcd.Configuration + + // check proxy manually + if c.cd.Proxy != nil && gcd.Proxy != nil { + if gcd.Proxy.Peer != "https://proxyendpoint" { + t.Fatalf("config data proxy peer is no proxyendpoint with HTTPS scheme: %s", gcd.Proxy.Peer) + } + if gcd.Proxy.SourcePort <= 0 { + t.Fatalf("got proxy source port is smaller or equal to 0: %v", gcd.Proxy.SourcePort) + } + if !strings.Contains(gcd.Proxy.Listen, "127.0.0.1") { + t.Fatalf("proxy listen does not contain 127.0.0.1: %s", gcd.Proxy.Listen) + } + c.cd.Proxy = gcd.Proxy + } + } + if !reflect.DeepEqual(gcd, c.cd) { + t.Fatalf("got connect data: %v, not equal to want: %v", gcd, c.cd) + } + } +} + +func TestDisconnect(t *testing.T) { + var gts []string = nil + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{ + { + Path: "/test-api-endpoint/disconnect", + ResponseHandler: disconnectHandler(t), + }, + }) + defer srv.Close() + err := a.Disconnect(context.Background()) + if err != nil { + t.Fatalf("failed /disconnect: %v", err) + } +} |
