From a30ef6b27e578a4cf0a674b24f5b52b4c1516c63 Mon Sep 17 00:00:00 2001 From: Jeroen Wijenbergh Date: Thu, 12 Feb 2026 12:34:08 +0100 Subject: All: Rename packages that sound useless or clash with std --- internal/eduvpnapi/eduvpnapi_test.go | 513 +++++++++++++++++++++++++++++++++++ 1 file changed, 513 insertions(+) create mode 100644 internal/eduvpnapi/eduvpnapi_test.go (limited to 'internal/eduvpnapi/eduvpnapi_test.go') diff --git a/internal/eduvpnapi/eduvpnapi_test.go b/internal/eduvpnapi/eduvpnapi_test.go new file mode 100644 index 0000000..23f895b --- /dev/null +++ b/internal/eduvpnapi/eduvpnapi_test.go @@ -0,0 +1,513 @@ +package eduvpnapi + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "reflect" + "regexp" + "slices" + "strings" + "testing" + "time" + + "codeberg.org/eduVPN/eduvpn-common/internal/eduvpnapi/profiles" + "codeberg.org/eduVPN/eduvpn-common/internal/httpwrap" + "codeberg.org/eduVPN/eduvpn-common/internal/test" + "codeberg.org/eduVPN/eduvpn-common/types/protocol" + "codeberg.org/eduVPN/eduvpn-common/types/server" + "codeberg.org/jwijenbergh/eduoauth-go/v2" +) + +func tokenHandler(t *testing.T, gt []string) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for token handler: %v", r.Method) + } + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed reading token endpoint body: %v", err) + } + parsed, err := url.ParseQuery(string(b)) + if err != nil { + t.Fatalf("failed parsing query body: %v", err) + } + grant := parsed.Get("grant_type") + + if slices.Contains(gt, grant) { + _, err = w.Write([]byte(` +{ + "access_token": "validaccess", + "refresh_token": "validrefresh", + "expires_in": 3600 +} + `)) + if err != nil { + t.Fatalf("failed writing in token handler: %v", err) + } + return + } + t.Fatalf("grant type: %v, not allowed", grant) + } +} + +func checkAuthBearer(t *testing.T, r *http.Request) { + authh := r.Header.Get("Authorization") + if !strings.HasPrefix(authh, "Bearer ") { + t.Fatalf("API call is not given with an authorization Bearer header, got: %v", authh) + } +} + +func connectHandler(t *testing.T, proto string, exp time.Time) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for connect handler: %v", r.Method) + } + checkAuthBearer(t, r) + w.Header().Set("expires", exp.Format(http.TimeFormat)) + w.Header().Set("content-type", fmt.Sprintf("application/x-%s-profile", proto)) + b, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("failed reading token endpoint body: %v", err) + } + parsed, err := url.ParseQuery(string(b)) + if err != nil { + t.Fatalf("failed parsing query body: %v", err) + } + // the wireguard config we parse + var cfg string + if proto == "openvpn" { + cfg = "openvpnconfig" + } else { + if parsed.Get("public_key") == "" { + t.Fatalf("no public_key given") + } + if proto == "wireguard+tcp" { + ptcp := parsed.Get("prefer_tcp") + if ptcp != "yes" { + t.Fatalf("prefer TCP is not yes: %s", ptcp) + } + cfg = ` +[Interface] +[Peer] +ProxyEndpoint = https://proxyendpoint +` + } else { + cfg = "[Interface]" + } + } + _, err = w.Write([]byte(cfg)) + if err != nil { + t.Fatalf("failed writing /connect response: %v", err) + } + } +} + +func disconnectHandler(t *testing.T) func(http.ResponseWriter, *http.Request) { + return func(_ http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("invalid HTTP method for disconnect handler: %v", r.Method) + } + checkAuthBearer(t, r) + } +} + +type TestCallback struct { + t *testing.T +} + +func (tc *TestCallback) TriggerAuth(_ context.Context, str string, _ bool) (string, error) { + go func() { + u, err := url.Parse(str) + if err != nil { + panic(err) + } + ru, err := url.Parse(u.Query().Get("redirect_uri")) + if err != nil { + panic(err) + } + oq := u.Query() + q := ru.Query() + q.Set("state", oq.Get("state")) + q.Set("code", "fakeauthcode") + ru.RawQuery = q.Encode() + + c := http.Client{} + req, err := http.NewRequest("GET", ru.String(), nil) + if err != nil { + panic(err) + } + _, err = c.Do(req) + if err != nil { + panic(err) + } + }() + return "", nil +} +func (tc *TestCallback) AuthDone(string, server.Type) {} +func (tc *TestCallback) TokensUpdated(string, server.Type, eduoauth.Token) {} + +// create a API struct with allowed grant types +func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.HandlerPath) (*API, *test.Server) { + // Create a simple API client and check if the fields are created correctly + listen, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to setup listener for test server: %v", err) + } + + hps = append(hps, []test.HandlerPath{ + { + Method: http.MethodGet, + Path: "/.well-known/vpn-user-portal", + Response: fmt.Sprintf(` +{ + "api": { + "http://eduvpn.org/api#3": { + "api_endpoint": "https://%[1]s/test-api-endpoint", + "authorization_endpoint": "https://%[1]s/test-authorization-endpoint", + "token_endpoint": "https://%[1]s/test-token-endpoint" + } + }, + "v": "0.0.0" +} +`, listen.Addr().String()), + }, + { + Path: "/test-token-endpoint", + ResponseHandler: tokenHandler(t, gt), + }, + }...) + // start server + serv := test.NewServerWithHandles(hps, listen) + servc, err := serv.Client() + if err != nil { + t.Fatalf("failed to setup HTTP test server client: %v", servc) + } + + sd := ServerData{ + ID: "randomidentifier", + Type: server.TypeCustom, + BaseWK: serv.URL, + BaseAuthWK: serv.URL, + ProcessAuth: func(_ context.Context, in string) (string, error) { + return in, nil + }, + DisableAuthorize: false, + transport: servc.Client.Transport, + } + + tc := &TestCallback{t: t} + + a, err := NewAPI(context.Background(), "testclient", sd, tc, tok) + if err != nil { + t.Fatalf("failed creating API: %v", err) + } + return a, serv +} + +func TestNewAPI(t *testing.T) { + gts := []string{"refresh_token"} + tok := &eduoauth.Token{ + Access: "expiredaccess", + Refresh: "expiredrefresh", + // tokens are expired, let's try authorizing + ExpiredTimestamp: time.Now(), + } + a, srv := createTestAPI(t, tok, gts, nil) + srv.Close() + + // now the tokens should be the new access tokens + if a.oauth.Token().Access != "validaccess" { + t.Fatalf("access token is not valid access") + } + if a.oauth.Token().Refresh != "validrefresh" { + t.Fatalf("refresh token is not valid refresh") + } + + gts = []string{"authorization_code"} + tok = &eduoauth.Token{ + Access: "expiredaccess", + Refresh: "", + ExpiredTimestamp: time.Now(), + } + a, srv = createTestAPI(t, tok, gts, nil) + srv.Close() + + // now the tokens should be the new access tokens + if a.oauth.Token().Access != "validaccess" { + t.Fatalf("access token is not valid access") + } + if a.oauth.Token().Refresh != "validrefresh" { + t.Fatalf("refresh token is not valid refresh") + } +} + +func TestAPIInfo(t *testing.T) { + // auth should not be triggered + var gts []string + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + statErr := &httpwrap.StatusError{} + cases := []struct { + hp test.HandlerPath + info *profiles.Info + err any + }{ + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: ` +{ + "info": { + "profile_list": [ + { + "default_gateway": false, + "display_name": "test profile 1", + "profile_id": "test1", + "profile_priority": 3, + "vpn_proto_list": [ + "openvpn", + "wireguard" + ] + } + ] + } +} +`, + }, + info: &profiles.Info{ + Info: profiles.ListInfo{ + ProfileList: []profiles.Profile{ + { + ID: "test1", + DisplayName: "test profile 1", + VPNProtoList: []string{"openvpn", "wireguard"}, + Priority: 3, + DefaultGateway: false, + }, + }, + }, + }, + }, + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: ` +{ + "info": { + "profile_list": [ + { + "display_name": "test profile 2", + "profile_id": "test2", + "vpn_proto_list": [ + "wireguard" + ] + } + ] + } +} +`, + }, + info: &profiles.Info{ + Info: profiles.ListInfo{ + ProfileList: []profiles.Profile{ + { + ID: "test2", + DisplayName: "test profile 2", + VPNProtoList: []string{"wireguard"}, + DefaultGateway: false, + }, + }, + }, + }, + }, + { + hp: test.HandlerPath{ + Method: http.MethodGet, + Path: "/test-api-endpoint/info", + Response: "", + ResponseCode: 404, + }, + info: nil, + err: &statErr, + }, + } + + for _, c := range cases { + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) + defer srv.Close() + gprfs, err := a.Info(context.Background()) + // got error but the want error is nil + if err != nil { + if c.err == nil { + t.Fatalf("failed profiles info: %v but want no error", err) + } + + if !errors.As(err, c.err) { + t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err.Error()) + } + } else if c.err != nil { + t.Fatalf("got no error but want error: %T", c.err) + } + + if !reflect.DeepEqual(gprfs, c.info) { + t.Fatalf("got info: %v, not equal to want: %v", gprfs, c.info) + } + } +} + +func TestAPIConnect(t *testing.T) { + // auth should not be triggered + var gts []string + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + cases := []struct { + hp test.HandlerPath + cd *ConnectData + prof profiles.Profile + protos []protocol.Protocol + ptcp bool + err error + }{ + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + }, + cd: nil, + err: ErrNoProtocols, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + }, + cd: nil, + protos: []protocol.Protocol{protocol.Unknown}, + err: ErrUnknownProtocol, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + Response: ``, + }, + cd: nil, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard, protocol.Unknown}, + err: ErrUnknownProtocol, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "openvpn", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: "openvpnconfig\nscript-security 0", + Protocol: protocol.OpenVPN, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + }, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "wireguard", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: `\[Interface\] +PrivateKey = .*`, + Protocol: protocol.WireGuard, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + }, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + { + hp: test.HandlerPath{ + Method: http.MethodPost, + Path: "/test-api-endpoint/connect", + ResponseHandler: connectHandler(t, "wireguard+tcp", time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC)), + }, + cd: &ConnectData{ + Configuration: `\[Interface\] +PrivateKey = .*`, + Protocol: protocol.WireGuardProxy, + Expires: time.Date(2000, time.January, 0, 0, 0, 0, 0, time.UTC), + }, + ptcp: true, + protos: []protocol.Protocol{protocol.OpenVPN, protocol.WireGuard}, + err: nil, + }, + } + + for _, c := range cases { + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{c.hp}) + defer srv.Close() + gcd, err := a.Connect(context.Background(), c.prof, c.protos, c.ptcp) + // got error but the want error is nil + if err != nil { + if c.err == nil { + t.Fatalf("failed connect: %v but want no error", err) + } + + if !errors.Is(err, c.err) { + t.Fatalf("error type not equal: %T, want: %T, error string: %s", err, c.err, err) + } + } else if c.err != nil { + t.Fatalf("got no error but want error: %T", c.err) + } + + if gcd != nil && c.cd != nil { + m, err := regexp.MatchString(c.cd.Configuration, gcd.Configuration) + if err != nil { + t.Fatalf("failed matching regexp: %v", err) + } + if !m { + t.Fatalf("regex:\n%s\ndoes not match config:\n%s", c.cd.Configuration, gcd.Configuration) + } + // we have already checked the config using a regex + c.cd.Configuration = gcd.Configuration + + } + if !reflect.DeepEqual(gcd, c.cd) { + t.Fatalf("got connect data: %v, not equal to want: %v", gcd, c.cd) + } + } +} + +func TestDisconnect(t *testing.T) { + var gts []string + tok := &eduoauth.Token{ + Access: "validaccess", + Refresh: "validrefresh", + ExpiredTimestamp: time.Now().Add(1 * time.Hour), + } + a, srv := createTestAPI(t, tok, gts, []test.HandlerPath{ + { + Path: "/test-api-endpoint/disconnect", + ResponseHandler: disconnectHandler(t), + }, + }) + defer srv.Close() + err := a.Disconnect(context.Background()) + if err != nil { + t.Fatalf("failed /disconnect: %v", err) + } +} -- cgit v1.2.3