From dd61cd1f935930850986510675a2c37f0e85ef27 Mon Sep 17 00:00:00 2001 From: jwijenbergh Date: Mon, 8 Jul 2024 09:18:10 +0200 Subject: Client + API: Mark organizations expired *before* processing url --- client/client.go | 1 - internal/api/api.go | 7 +++++-- internal/api/api_test.go | 4 ++-- internal/server/secureinternet.go | 26 +++++++++++++++++++++----- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/client/client.go b/client/client.go index 12e0ea5..962fe23 100644 --- a/client/client.go +++ b/client/client.go @@ -217,7 +217,6 @@ func (c *Client) AuthDone(id string, t srvtypes.Type) { if err != nil { log.Logger.Debugf("unhandled auth done main transition: %v", err) } - c.MarkOrganizationsExpired(t) c.TrySave() } diff --git a/internal/api/api.go b/internal/api/api.go index 16f86af..6cba35c 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -44,7 +44,7 @@ type ServerData struct { // BaseAuthWK is the base well-known endpoint for authorization. This is only different in case of secure internet BaseAuthWK string // ProcessAuth processes the OAuth authorization - ProcessAuth func(string) string + ProcessAuth func(context.Context, string) (string, error) // DisableAuthorize indicates whether or not new authorization requests should be disabled DisableAuthorize bool // Transport is the HTTP transport, only used for testing currently @@ -134,7 +134,10 @@ func (a *API) authorize(ctx context.Context) (err error) { return err } if a.Data.ProcessAuth != nil { - url = a.Data.ProcessAuth(url) + url, err = a.Data.ProcessAuth(ctx, url) + if err != nil { + return err + } } // We expect an uri if custom redirect is non empty uri, err := a.cb.TriggerAuth(ctx, url, a.oauth.CustomRedirect != "") diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 5de1d7b..397dd3c 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -196,8 +196,8 @@ func createTestAPI(t *testing.T, tok *eduoauth.Token, gt []string, hps []test.Ha Type: server.TypeCustom, BaseWK: serv.URL, BaseAuthWK: serv.URL, - ProcessAuth: func(in string) string { - return in + ProcessAuth: func(ctx context.Context, in string) (string, error) { + return in, nil }, DisableAuthorize: false, Transport: servc.Client.Transport, diff --git a/internal/server/secureinternet.go b/internal/server/secureinternet.go index 0571e5f..f167756 100644 --- a/internal/server/secureinternet.go +++ b/internal/server/secureinternet.go @@ -29,12 +29,19 @@ func (s *Servers) AddSecure(ctx context.Context, disco *discovery.Discovery, org } sd := api.ServerData{ - ID: orgID, + ID: dorg.OrgID, Type: server.TypeSecureInternet, BaseWK: dsrv.BaseURL, BaseAuthWK: dsrv.BaseURL, - ProcessAuth: func(url string) string { - return util.ReplaceWAYF(dsrv.AuthenticationURLTemplate, url, dorg.OrgID) + ProcessAuth: func(ctx context.Context, url string) (string, error) { + disco.Servers(ctx) + disco.Organizations(ctx) + updorg, updsrv, err := disco.SecureHomeArgs(orgID) + if err != nil { + return "", err + } + ret := util.ReplaceWAYF(updsrv.AuthenticationURLTemplate, url, updorg.OrgID) + return ret, nil }, } @@ -96,8 +103,17 @@ func (s *Servers) GetSecure(ctx context.Context, orgID string, disco *discovery. Type: server.TypeSecureInternet, BaseWK: dloc.BaseURL, BaseAuthWK: dhome.BaseURL, - ProcessAuth: func(url string) string { - return util.ReplaceWAYF(dhome.AuthenticationURLTemplate, url, dorg.OrgID) + ProcessAuth: func(ctx context.Context, url string) (string, error) { + disco.MarkServersExpired() + disco.Servers(ctx) + disco.MarkOrganizationsExpired() + disco.Organizations(ctx) + updorg, updsrv, err := disco.SecureHomeArgs(orgID) + if err != nil { + return "", err + } + ret := util.ReplaceWAYF(updsrv.AuthenticationURLTemplate, url, updorg.OrgID) + return ret, nil }, DisableAuthorize: disableAuth, } -- cgit v1.2.3