diff options
| author | Jeroen Wijenbergh <jeroenwijenbergh@protonmail.com> | 2022-03-21 14:58:58 +0100 |
|---|---|---|
| committer | jwijenbergh <jeroenwijenbergh@protonmail.com> | 2022-09-20 20:29:52 +0200 |
| commit | fc56f8770923ec1997444a8318a18be0a8397520 (patch) | |
| tree | 3c6522b9b6e44ca2ad6cd94b074da78eed2c1028 /src | |
| parent | d45f5df4dc5fa9ad8abdc47c940f6baf96fdbe45 (diff) | |
Wireguard: Add basic support
Diffstat (limited to 'src')
| -rw-r--r-- | src/api.go | 42 | ||||
| -rw-r--r-- | src/http.go | 98 | ||||
| -rw-r--r-- | src/oauth.go | 35 | ||||
| -rw-r--r-- | src/server.go | 6 | ||||
| -rw-r--r-- | src/state.go | 2 | ||||
| -rw-r--r-- | src/wireguard.go | 26 |
6 files changed, 148 insertions, 61 deletions
@@ -1,23 +1,57 @@ package eduvpn import ( + "fmt" "net/http" + "net/url" ) -func (eduvpn *VPNState) APIAuthenticatedGet(endpoint string) (string, error) { +// Authenticated wrappers on top of HTTP +func (eduvpn *VPNState) apiAuthenticatedWithOpts(method string, endpoint string, opts *HTTPOptionalParams) ([]byte, error) { + // Ensure optional is not nil as we will fill it with headers + if opts == nil { + opts = &HTTPOptionalParams{} + } url := eduvpn.Server.Endpoints.API.V3.API + endpoint // Ensure we have non-expired tokens oauthErr := eduvpn.EnsureTokensOAuth() if oauthErr != nil { - return "", oauthErr + return nil, oauthErr } - headers := &http.Header{"Authorization": {"Bearer " + eduvpn.Server.OAuth.Token.Access}} - body, bodyErr := HTTPGetWithOptionalParams(url, &HTTPOptionalParams{Headers: headers}) + headerKey := "Authorization" + headerValue := fmt.Sprintf("Bearer %s", eduvpn.Server.OAuth.Token.Access) + if opts.Headers != nil { + opts.Headers.Add(headerKey, headerValue) + } else { + opts.Headers = &http.Header{headerKey: {headerValue}} + } + body, bodyErr := HTTPMethodWithOpts(method, url, opts) + if bodyErr != nil { + return nil, bodyErr + } + return body, nil +} + +func (eduvpn *VPNState) APIConnectWireguard(pubkey string) (string, error) { + headers := &http.Header{ + "content-type": {"application/x-www-form-urlencoded"}, + "accept": {"application/x-wireguard-profile"}, + } + + urlForm := url.Values{ + "profile_id": {"default"}, + "public_key": {pubkey}, + } + body, bodyErr := eduvpn.apiAuthenticatedWithOpts(http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) if bodyErr != nil { return "", bodyErr } return string(body), nil } + +func (eduvpn *VPNState) APIInfo() ([]byte, error) { + return eduvpn.apiAuthenticatedWithOpts(http.MethodGet, "/info", nil) +} diff --git a/src/http.go b/src/http.go index 1374eed..5366c7e 100644 --- a/src/http.go +++ b/src/http.go @@ -2,6 +2,7 @@ package eduvpn import ( "fmt" + "io" "io/ioutil" "net/http" "net/url" @@ -54,15 +55,16 @@ func (e *HTTPRequestCreateError) Error() string { return fmt.Sprintf("failed to create HTTP request with url %s and error %v", e.URL, e.Err) } -type HTTPOptionalParams struct { - Headers *http.Header -} +type URLParameters map[string]string -func HTTPGet(url string) ([]byte, error) { - return HTTPGetWithOptionalParams(url, nil) +type HTTPOptionalParams struct { + Headers *http.Header + URLParameters *URLParameters + Body url.Values } -func HTTPConstructURL(baseURL string, parameters map[string]string) (string, error) { +// Construct an URL including on parameters +func HTTPConstructURL(baseURL string, parameters URLParameters) (string, error) { url, err := url.Parse(baseURL) if err != nil { @@ -78,59 +80,89 @@ func HTTPConstructURL(baseURL string, parameters map[string]string) (string, err return url.String(), nil } +// Convenience functions +func HTTPGet(url string) ([]byte, error) { + return HTTPMethodWithOpts(http.MethodGet, url, nil) +} -func HTTPGetWithOptionalParams(url string, opts *HTTPOptionalParams) ([]byte, error) { - client := &http.Client{} - req, reqErr := http.NewRequest(http.MethodGet, url, nil) - if reqErr != nil { - return nil, &HTTPRequestCreateError{URL: url, Err: reqErr} +func HTTPPost(url string, body url.Values) ([]byte, error) { + return HTTPMethodWithOpts(http.MethodGet, url, &HTTPOptionalParams{Body: body}) +} + +func HTTPGetWithOpts(url string, opts *HTTPOptionalParams) ([]byte, error) { + return HTTPMethodWithOpts(http.MethodGet, url, opts) +} + +func HTTPPostWithOpts(url string, opts *HTTPOptionalParams) ([]byte, error) { + return HTTPMethodWithOpts(http.MethodPost, url, opts) +} + +func httpOptionalURL(url string, opts *HTTPOptionalParams) (string, error) { + if opts != nil && opts.URLParameters != nil { + url, urlErr := HTTPConstructURL(url, *opts.URLParameters) + + if urlErr != nil { + return url, &HTTPRequestCreateError{URL: url, Err: urlErr} + } + return url, nil } - if opts != nil && opts.Headers != nil { + return url, nil +} + +func httpOptionalHeaders(req *http.Request, opts *HTTPOptionalParams) { + // Add headers + if opts != nil && opts.Headers != nil && req != nil { for k, v := range *opts.Headers { req.Header.Add(k, v[0]) } } - resp, respErr := client.Do(req) - if respErr != nil { - return nil, &HTTPResourceError{URL: url, Err: respErr} - } - defer resp.Body.Close() +} - body, readErr := ioutil.ReadAll(resp.Body) - if readErr != nil { - return nil, &HTTPReadError{URL: url, Err: readErr} +func httpOptionalBodyReader(opts *HTTPOptionalParams) io.Reader { + if opts != nil && opts.Body != nil { + return strings.NewReader(opts.Body.Encode()) } - - return body, nil + return nil } -func HTTPPost(url string, body url.Values) ([]byte, error) { - return HTTPPostWithOptionalParams(url, body, nil) -} +func HTTPMethodWithOpts(method string, url string, opts *HTTPOptionalParams) ([]byte, error) { -func HTTPPostWithOptionalParams(url string, data url.Values, opts *HTTPOptionalParams) ([]byte, error) { + // Make sure the url contains all the parameters + // This can return an error, + // it already has the right error so so we don't wrap it further + url, urlErr := httpOptionalURL(url, opts) + if urlErr != nil { + return nil, urlErr + } + + // Create a client client := &http.Client{} - req, reqErr := http.NewRequest(http.MethodPost, url, strings.NewReader(data.Encode())) + + // Create request object with the body reader generated from the optional arguments + req, reqErr := http.NewRequest(method, url, httpOptionalBodyReader(opts)) if reqErr != nil { return nil, &HTTPRequestCreateError{URL: url, Err: reqErr} } - if opts != nil && opts.Headers != nil { - for k, v := range *opts.Headers { - req.Header.Add(k, v[0]) - } - } - resp, respErr := client.Do(req) + // Make sure the headers contain all the parameters + httpOptionalHeaders(req, opts) + + // Do request + resp, respErr := client.Do(req) if respErr != nil { return nil, &HTTPResourceError{URL: url, Err: respErr} } + + // Request successful, make sure body is closed at the end defer resp.Body.Close() + // Return a string body, readErr := ioutil.ReadAll(resp.Body) if readErr != nil { return nil, &HTTPReadError{URL: url, Err: readErr} } + // Return the body in bytes and signal that there was no error return body, nil } diff --git a/src/oauth.go b/src/oauth.go index bbe34af..063034b 100644 --- a/src/oauth.go +++ b/src/oauth.go @@ -54,8 +54,8 @@ func genVerifier() (string, error) { } type OAuth struct { - Session *OAuthExchangeSession - Token *OAuthToken + Session *OAuthExchangeSession + Token *OAuthToken TokenURL string } @@ -65,13 +65,13 @@ type OAuthExchangeSession struct { CallbackError error // filled in in initialize - ClientID string - State string - Verifier string + ClientID string + State string + Verifier string // filled in when constructing the callback - Context context.Context - Server *http.Server + Context context.Context + Server *http.Server } func generateTimeSeconds() int64 { @@ -81,10 +81,10 @@ func generateTimeSeconds() int64 { // Struct that defines the json format for /.well-known/vpn-user-portal" type OAuthToken struct { - Access string `json:"access_token"` - Refresh string `json:"refresh_token"` - Type string `json:"token_type"` - Expires int64 `json:"expires_in"` + Access string `json:"access_token"` + Refresh string `json:"refresh_token"` + Type string `json:"token_type"` + Expires int64 `json:"expires_in"` ExpiredTimestamp int64 } @@ -121,9 +121,9 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { } headers := &http.Header{ "content-type": {"application/x-www-form-urlencoded"}} - opts := &HTTPOptionalParams{Headers: headers} + opts := &HTTPOptionalParams{Headers: headers, Body: data} current_time := generateTimeSeconds() - body, bodyErr := HTTPPostWithOptionalParams(reqURL, data, opts) + body, bodyErr := HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { return bodyErr } @@ -158,9 +158,9 @@ func (oauth *OAuth) getTokensWithRefresh() error { } headers := &http.Header{ "content-type": {"application/x-www-form-urlencoded"}} - opts := &HTTPOptionalParams{Headers: headers} + opts := &HTTPOptionalParams{Headers: headers, Body: data} current_time := generateTimeSeconds() - body, bodyErr := HTTPPostWithOptionalParams(reqURL, data, opts) + body, bodyErr := HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { return bodyErr } @@ -260,7 +260,6 @@ func (eduvpn *VPNState) InitializeOAuth() (string, error) { return authURL, nil } - // Error definitions func (eduvpn *VPNState) FinishOAuth() error { oauth := eduvpn.Server.OAuth @@ -277,12 +276,11 @@ func (eduvpn *VPNState) EnsureTokensOAuth() error { } if oauth.isTokensExpired() { - return oauth.getTokensWithRefresh(); + return oauth.getTokensWithRefresh() } return nil } - type OAuthGenStateUnableError struct { Err error } @@ -299,7 +297,6 @@ func (e *OAuthGenVerifierUnableError) Error() string { return fmt.Sprintf("failed generating verifier with error %v", e.Err) } - type OAuthFailedCallbackError struct { Addr string Err error diff --git a/src/server.go b/src/server.go index bf1fb3d..6f809c6 100644 --- a/src/server.go +++ b/src/server.go @@ -5,9 +5,9 @@ import ( ) type Server struct { - BaseURL string + BaseURL string Endpoints *ServerEndpoints - OAuth *OAuth + OAuth *OAuth } type ServerEndpointList struct { @@ -25,7 +25,6 @@ type ServerEndpoints struct { V string `json:"v"` } - func (server *Server) Initialize(url string) error { server.BaseURL = url endpointsErr := server.GetEndpoints() @@ -35,7 +34,6 @@ func (server *Server) Initialize(url string) error { return nil } - func (server *Server) GetEndpoints() error { url := server.BaseURL + "/.well-known/vpn-user-portal" body, bodyErr := HTTPGet(url) diff --git a/src/state.go b/src/state.go index 272bbc6..582dd5a 100644 --- a/src/state.go +++ b/src/state.go @@ -2,7 +2,7 @@ package eduvpn type VPNState struct { // Info passed by the client - Name string + Name string // The chosen server Server *Server diff --git a/src/wireguard.go b/src/wireguard.go new file mode 100644 index 0000000..9441c51 --- /dev/null +++ b/src/wireguard.go @@ -0,0 +1,26 @@ +package eduvpn + +import ( + "fmt" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "regexp" +) + +func WireguardGenerateKey() (wgtypes.Key, error) { + key, error := wgtypes.GeneratePrivateKey() + return key, error +} + +// FIXME: Instead of doing a regex replace, decide if we should use a parser +func WireguardConfigAddKey(config string, key wgtypes.Key) string { + interface_section := "[Interface]" + interface_section_escaped := regexp.QuoteMeta(interface_section) + + // (?m) enables multi line mode + // ^ match from beginning of line + // $ match till end of line + // So it matches [Interface] section exactly + interface_re := regexp.MustCompile(fmt.Sprintf("(?m)^%s$", interface_section_escaped)) + to_replace := fmt.Sprintf("%s\nPrivateKey = %s", interface_section, key.String()) + return interface_re.ReplaceAllString(config, to_replace) +} |
