diff options
| -rw-r--r-- | internal/config/config.go (renamed from internal/config.go) | 5 | ||||
| -rw-r--r-- | internal/discovery/discovery.go (renamed from internal/discovery.go) | 25 | ||||
| -rw-r--r-- | internal/fsm/fsm.go (renamed from internal/fsm.go) | 11 | ||||
| -rw-r--r-- | internal/http/http.go (renamed from internal/http.go) | 2 | ||||
| -rw-r--r-- | internal/log/log.go (renamed from internal/log.go) | 5 | ||||
| -rw-r--r-- | internal/oauth/oauth.go (renamed from internal/oauth.go) | 56 | ||||
| -rw-r--r-- | internal/openvpn.go | 31 | ||||
| -rw-r--r-- | internal/server/api.go (renamed from internal/api.go) | 49 | ||||
| -rw-r--r-- | internal/server/server.go (renamed from internal/server.go) | 169 | ||||
| -rw-r--r-- | internal/util/util.go (renamed from internal/util.go) | 2 | ||||
| -rw-r--r-- | internal/verify/test_data/empty (renamed from internal/test_data/empty) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/generate.sh (renamed from internal/test_data/generate.sh) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/generate_forged.py (renamed from internal/test_data/generate_forged.py) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/organization_list.json (renamed from internal/test_data/organization_list.json) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/organization_list.json.minisig (renamed from internal/test_data/organization_list.json.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/organization_list.json.tc_servlist.minisig (renamed from internal/test_data/organization_list.json.tc_servlist.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/other_list.json (renamed from internal/test_data/other_list.json) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/other_list.json.minisig (renamed from internal/test_data/other_list.json.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/public.key (renamed from internal/test_data/public.key) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/random.txt (renamed from internal/test_data/random.txt) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/secret.key (renamed from internal/test_data/secret.key) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json (renamed from internal/test_data/server_list.json) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.blake2b (renamed from internal/test_data/server_list.json.blake2b) | bin | 64 -> 64 bytes | |||
| -rw-r--r-- | internal/verify/test_data/server_list.json.forged_keyid.minisig (renamed from internal/test_data/server_list.json.forged_keyid.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.forged_pure.minisig (renamed from internal/test_data/server_list.json.forged_pure.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.large_time.minisig (renamed from internal/test_data/server_list.json.large_time.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.minisig (renamed from internal/test_data/server_list.json.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.pure.minisig (renamed from internal/test_data/server_list.json.pure.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_earliertime.minisig (renamed from internal/test_data/server_list.json.tc_earliertime.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_emptyfile.minisig (renamed from internal/test_data/server_list.json.tc_emptyfile.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_emptytime.minisig (renamed from internal/test_data/server_list.json.tc_emptytime.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_latertime.minisig (renamed from internal/test_data/server_list.json.tc_latertime.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_nofile.minisig (renamed from internal/test_data/server_list.json.tc_nofile.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_nohashed.minisig (renamed from internal/test_data/server_list.json.tc_nohashed.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_notime.minisig (renamed from internal/test_data/server_list.json.tc_notime.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_orglist.minisig (renamed from internal/test_data/server_list.json.tc_orglist.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_otherfile.minisig (renamed from internal/test_data/server_list.json.tc_otherfile.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.tc_random.minisig (renamed from internal/test_data/server_list.json.tc_random.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/server_list.json.wrong_key.minisig (renamed from internal/test_data/server_list.json.wrong_key.minisig) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/wrong_public.key (renamed from internal/test_data/wrong_public.key) | 0 | ||||
| -rw-r--r-- | internal/verify/test_data/wrong_secret.key (renamed from internal/test_data/wrong_secret.key) | 0 | ||||
| -rw-r--r-- | internal/verify/verify.go (renamed from internal/verify.go) | 2 | ||||
| -rw-r--r-- | internal/verify/verify_test.go (renamed from internal/verify_test.go) | 2 | ||||
| -rw-r--r-- | internal/wireguard.go | 82 | ||||
| -rw-r--r-- | internal/wireguard/wireguard.go | 38 | ||||
| -rw-r--r-- | state.go | 92 | ||||
| -rw-r--r-- | state_test.go | 43 |
47 files changed, 326 insertions, 288 deletions
diff --git a/internal/config.go b/internal/config/config.go index 0f13165..a9ebec7 100644 --- a/internal/config.go +++ b/internal/config/config.go @@ -1,10 +1,11 @@ -package internal +package config import ( "encoding/json" "fmt" "io/ioutil" "path" + "github.com/jwijenbergh/eduvpn-common/internal/util" ) type Config struct { @@ -23,7 +24,7 @@ func (config *Config) GetFilename() string { } func (config *Config) Save(readStruct interface{}) error { - configDirErr := EnsureDirectory(config.Directory) + configDirErr := util.EnsureDirectory(config.Directory) if configDirErr != nil { return &ConfigSaveError{Err: configDirErr} } diff --git a/internal/discovery.go b/internal/discovery/discovery.go index 59281bd..d72b4a6 100644 --- a/internal/discovery.go +++ b/internal/discovery/discovery.go @@ -1,8 +1,13 @@ -package internal +package discovery import ( "encoding/json" "fmt" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" + "github.com/jwijenbergh/eduvpn-common/internal/http" + "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/util" + "github.com/jwijenbergh/eduvpn-common/internal/verify" ) type DiscoFileError struct { @@ -57,8 +62,8 @@ type ServersList struct { type Discovery struct { Organizations OrganizationList Servers ServersList - FSM *FSM - Logger *FileLogger + FSM *fsm.FSM + Logger *log.FileLogger } // Helper function that gets a disco json @@ -66,7 +71,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} // Get json data discoURL := "https://disco.eduvpn.org/v2/" fileURL := discoURL + jsonFile - _, fileBody, fileErr := HTTPGet(fileURL) + _, fileBody, fileErr := http.HTTPGet(fileURL) if fileErr != nil { return &DiscoFileError{fileURL, fileErr} @@ -75,7 +80,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} // Get signature sigFile := jsonFile + ".minisig" sigURL := discoURL + sigFile - _, sigBody, sigFileErr := HTTPGet(sigURL) + _, sigBody, sigFileErr := http.HTTPGet(sigURL) if sigFileErr != nil { return &DiscoSigFileError{URL: sigURL, Err: sigFileErr} @@ -84,7 +89,7 @@ func getDiscoFile(jsonFile string, previousVersion uint64, structure interface{} // Verify signature // Set this to true when we want to force prehash forcePrehash := false - verifySuccess, verifyErr := Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash) + verifySuccess, verifyErr := verify.Verify(string(sigBody), fileBody, jsonFile, previousVersion, forcePrehash) if !verifySuccess || verifyErr != nil { return &DiscoVerifyError{File: jsonFile, Sigfile: sigFile, Err: verifyErr} @@ -109,7 +114,7 @@ func (e *GetListError) Error() string { return fmt.Sprintf("failed getting disco list file %s with error %v", e.File, e.Err) } -func (discovery *Discovery) Init(fsm *FSM, logger *FileLogger) { +func (discovery *Discovery) Init(fsm *fsm.FSM, logger *log.FileLogger) { discovery.FSM = fsm discovery.Logger = logger } @@ -133,11 +138,11 @@ func (discovery *Discovery) DetermineServersUpdate() bool { } // 1 hour from the last update should_update_time := discovery.Servers.Timestamp + 3600 - now := GenerateTimeSeconds() + now := util.GenerateTimeSeconds() if now >= should_update_time { return true } - discovery.Logger.Log(LOG_INFO, "No update needed for servers, 1h is not passed yet") + discovery.Logger.Log(log.LOG_INFO, "No update needed for servers, 1h is not passed yet") return false } @@ -167,6 +172,6 @@ func (discovery *Discovery) GetServersList() (string, error) { return string(discovery.Servers.JSON), &GetListError{File: file, Err: err} } // Update servers timestamp - discovery.Servers.Timestamp = GenerateTimeSeconds() + discovery.Servers.Timestamp = util.GenerateTimeSeconds() return string(discovery.Servers.JSON), nil } diff --git a/internal/fsm.go b/internal/fsm/fsm.go index 4df24a0..bb7f330 100644 --- a/internal/fsm.go +++ b/internal/fsm/fsm.go @@ -1,4 +1,4 @@ -package internal +package fsm import ( "fmt" @@ -6,6 +6,7 @@ import ( "os/exec" "path" "sort" + "github.com/jwijenbergh/eduvpn-common/internal/log" ) type ( @@ -96,12 +97,12 @@ type FSM struct { // Info to be passed from the parent state Name string StateCallback func(string, string, string) - Logger *FileLogger + Logger *log.FileLogger Directory string Debug bool } -func (fsm *FSM) Init(name string, callback func(string, string, string), logger *FileLogger, directory string, debug bool) { +func (fsm *FSM) Init(name string, callback func(string, string, string), logger *log.FileLogger, directory string, debug bool) { fsm.States = FSMStates{ DEREGISTERED: {{NO_SERVER, "Client registers"}}, NO_SERVER: {{CHOSEN_SERVER, "User chooses a server"}}, @@ -146,7 +147,7 @@ func (fsm *FSM) writeGraph() { graphImgFile := fsm.getGraphFilename(".png") f, err := os.Create(graphFile) if err != nil { - fsm.Logger.Log(LOG_INFO, fmt.Sprintf("Failed to write debug fsm graph with error %v", err)) + fsm.Logger.Log(log.LOG_INFO, fmt.Sprintf("Failed to write debug fsm graph with error %v", err)) return } @@ -167,7 +168,7 @@ func (fsm *FSM) GoTransitionWithData(newState FSMStateID, data string, backgroun fsm.writeGraph() } - fsm.Logger.Log(LOG_INFO, fmt.Sprintf("State: %s -> State: %s with data %s\n", oldState.String(), newState.String(), data)) + fsm.Logger.Log(log.LOG_INFO, fmt.Sprintf("State: %s -> State: %s with data %s\n", oldState.String(), newState.String(), data)) if background { go fsm.StateCallback(oldState.String(), newState.String(), data) diff --git a/internal/http.go b/internal/http/http.go index 0b1eda4..87346f1 100644 --- a/internal/http.go +++ b/internal/http/http.go @@ -1,4 +1,4 @@ -package internal +package http import ( "fmt" diff --git a/internal/log.go b/internal/log/log.go index 5109ba2..cba3364 100644 --- a/internal/log.go +++ b/internal/log/log.go @@ -1,10 +1,11 @@ -package internal +package log import ( "fmt" "log" "os" "path" + "github.com/jwijenbergh/eduvpn-common/internal/util" ) type FileLogger struct { @@ -37,7 +38,7 @@ func (e LogLevel) String() string { } func (logger *FileLogger) Init(level LogLevel, name string, directory string) error { - configDirErr := EnsureDirectory(directory) + configDirErr := util.EnsureDirectory(directory) if configDirErr != nil { return &LogInitializeError{Name: name, Directory: directory, Err: configDirErr} } diff --git a/internal/oauth.go b/internal/oauth/oauth.go index c566425..f6ed916 100644 --- a/internal/oauth.go +++ b/internal/oauth/oauth.go @@ -1,4 +1,4 @@ -package internal +package oauth import ( "context" @@ -8,6 +8,10 @@ import ( "fmt" "net/http" "net/url" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" + httpw "github.com/jwijenbergh/eduvpn-common/internal/http" + "github.com/jwijenbergh/eduvpn-common/internal/util" + "github.com/jwijenbergh/eduvpn-common/internal/log" ) // Generates a random base64 string to be used for state @@ -17,7 +21,7 @@ import ( // includes this value when redirecting the user agent back to the // client. func genState() (string, error) { - randomBytes, err := MakeRandomByteSlice(32) + randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { return "", &OAuthGenStateError{Err: err} } @@ -43,7 +47,7 @@ func genChallengeS256(verifier string) string { // minimum length of 43 characters and a maximum length of 128 // characters. func genVerifier() (string, error) { - randomBytes, err := MakeRandomByteSlice(32) + randomBytes, err := util.MakeRandomByteSlice(32) if err != nil { return "", &OAuthGenVerifierError{Err: err} } @@ -56,8 +60,8 @@ type OAuth struct { Token OAuthToken `json:"token"` BaseAuthorizationURL string `json:"base_authorization_url"` TokenURL string `json:"token_url"` - Logger *FileLogger `json:"-"` - FSM *FSM `json:"-"` + Logger *log.FileLogger `json:"-"` + FSM *fsm.FSM `json:"-"` } // This structure gets passed to the callback for easy access to the current state @@ -118,9 +122,9 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } - opts := &HTTPOptionalParams{Headers: headers, Body: data} - current_time := GenerateTimeSeconds() - _, body, bodyErr := HTTPPostWithOpts(reqURL, opts) + opts := &httpw.HTTPOptionalParams{Headers: headers, Body: data} + current_time := util.GenerateTimeSeconds() + _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { return &OAuthAuthError{Err: bodyErr} } @@ -130,7 +134,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} + return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} } tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires @@ -140,7 +144,7 @@ func (oauth *OAuth) getTokensWithAuthCode(authCode string) error { func (oauth *OAuth) isTokensExpired() bool { expired_time := oauth.Token.ExpiredTimestamp - current_time := GenerateTimeSeconds() + current_time := util.GenerateTimeSeconds() return current_time >= expired_time } @@ -156,9 +160,9 @@ func (oauth *OAuth) getTokensWithRefresh() error { headers := http.Header{ "content-type": {"application/x-www-form-urlencoded"}, } - opts := &HTTPOptionalParams{Headers: headers, Body: data} - current_time := GenerateTimeSeconds() - _, body, bodyErr := HTTPPostWithOpts(reqURL, opts) + opts := &httpw.HTTPOptionalParams{Headers: headers, Body: data} + current_time := util.GenerateTimeSeconds() + _, body, bodyErr := httpw.HTTPPostWithOpts(reqURL, opts) if bodyErr != nil { return &OAuthRefreshError{Err: bodyErr} } @@ -167,7 +171,7 @@ func (oauth *OAuth) getTokensWithRefresh() error { jsonErr := json.Unmarshal(body, &tokenStructure) if jsonErr != nil { - return &HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} + return &httpw.HTTPParseJsonError{URL: reqURL, Body: string(body), Err: jsonErr} } tokenStructure.ExpiredTimestamp = current_time + tokenStructure.Expires @@ -217,12 +221,12 @@ func (oauth *OAuth) Callback(w http.ResponseWriter, req *http.Request) { go oauth.Session.Server.Shutdown(oauth.Session.Context) } -func (oauth *OAuth) Update(fsm *FSM, logger *FileLogger) { +func (oauth *OAuth) Update(fsm *fsm.FSM, logger *log.FileLogger) { oauth.FSM = fsm oauth.Logger = logger } -func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *FSM, logger *FileLogger) { +func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *fsm.FSM, logger *log.FileLogger) { oauth.BaseAuthorizationURL = baseAuthorizationURL oauth.TokenURL = tokenURL oauth.FSM = fsm @@ -231,8 +235,8 @@ func (oauth *OAuth) Init(baseAuthorizationURL string, tokenURL string, fsm *FSM, // Starts the OAuth exchange for eduvpn. func (oauth *OAuth) start(name string) error { - if !oauth.FSM.HasTransition(OAUTH_STARTED) { - return &FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: OAUTH_STARTED} + if !oauth.FSM.HasTransition(fsm.OAUTH_STARTED) { + return &fsm.FSMWrongStateTransitionError{Got: oauth.FSM.Current, Want: fsm.OAUTH_STARTED} } // Generate the state state, stateErr := genState() @@ -257,7 +261,7 @@ func (oauth *OAuth) start(name string) error { "redirect_uri": "http://127.0.0.1:8000/callback", } - authURL, urlErr := HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) + authURL, urlErr := httpw.HTTPConstructURL(oauth.BaseAuthorizationURL, parameters) if urlErr != nil { return &OAuthInitializeError{Err: urlErr} @@ -267,21 +271,21 @@ func (oauth *OAuth) start(name string) error { oauthSession := OAuthExchangeSession{ClientID: name, State: state, Verifier: verifier} oauth.Session = oauthSession // Run the state callback in the background so that the user can login while we start the callback server - oauth.FSM.GoTransitionWithData(OAUTH_STARTED, authURL, true) + oauth.FSM.GoTransitionWithData(fsm.OAUTH_STARTED, authURL, true) return nil } // Error definitions func (oauth *OAuth) Finish() error { - if !oauth.FSM.HasTransition(AUTHORIZED) { - return &FSMWrongStateError{Got: oauth.FSM.Current, Want: AUTHORIZED} + if !oauth.FSM.HasTransition(fsm.AUTHORIZED) { + return &fsm.FSMWrongStateError{Got: oauth.FSM.Current, Want: fsm.AUTHORIZED} } tokenErr := oauth.getTokensWithCallback() if tokenErr != nil { return &OAuthFinishError{Err: tokenErr} } - oauth.FSM.GoTransition(AUTHORIZED) + oauth.FSM.GoTransition(fsm.AUTHORIZED) return nil } @@ -308,7 +312,7 @@ func (oauth *OAuth) Login(name string) error { func (oauth *OAuth) NeedsRelogin() bool { // Access Token or Refresh Tokens empty, definitely needs a relogin if oauth.Token.Access == "" || oauth.Token.Refresh == "" { - oauth.Logger.Log(LOG_INFO, "OAuth: Tokens are empty") + oauth.Logger.Log(log.LOG_INFO, "OAuth: Tokens are empty") return true } @@ -317,14 +321,14 @@ func (oauth *OAuth) NeedsRelogin() bool { // The tokens are not expired yet // No relogin is needed if !oauth.isTokensExpired() { - oauth.Logger.Log(LOG_INFO, "OAuth: Tokens are not expired, re-login not needed") + oauth.Logger.Log(log.LOG_INFO, "OAuth: Tokens are not expired, re-login not needed") return false } refreshErr := oauth.getTokensWithRefresh() // We have obtained new tokens with refresh if refreshErr == nil { - oauth.Logger.Log(LOG_INFO, "OAuth: Tokens could be re-acquired using the refresh token, re-login not needed") + oauth.Logger.Log(log.LOG_INFO, "OAuth: Tokens could be re-acquired using the refresh token, re-login not needed") return false } diff --git a/internal/openvpn.go b/internal/openvpn.go deleted file mode 100644 index 8f684ba..0000000 --- a/internal/openvpn.go +++ /dev/null @@ -1,31 +0,0 @@ -package internal - -import "fmt" - -func OpenVPNGetConfig(server Server) (string, string, error) { - base, baseErr := server.GetBase() - - if baseErr != nil { - return "", "", &OpenVPNGetConfigError{Err: baseErr} - } - profile_id := base.Profiles.Current - configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id) - - // Store start and end time - base.StartTime = GenerateTimeSeconds() - base.EndTime = expires - - if configErr != nil { - return "", "", &OpenVPNGetConfigError{Err: configErr} - } - - return configOpenVPN, "openvpn", nil -} - -type OpenVPNGetConfigError struct { - Err error -} - -func (e *OpenVPNGetConfigError) Error() string { - return fmt.Sprintf("failed getting OpenVPN config with error: %v", e.Err) -} diff --git a/internal/api.go b/internal/server/api.go index 5c2cf6d..96bd641 100644 --- a/internal/api.go +++ b/internal/server/api.go @@ -1,4 +1,4 @@ -package internal +package server import ( "encoding/json" @@ -6,14 +6,35 @@ import ( "fmt" "net/http" "net/url" + httpw "github.com/jwijenbergh/eduvpn-common/internal/http" + "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/util" ) +func APIGetEndpoints(baseURL string) (*ServerEndpoints, error) { + url := fmt.Sprintf("%s/%s", baseURL, WellKnownPath) + _, body, bodyErr := httpw.HTTPGet(url) + + if bodyErr != nil { + return nil, &APIGetEndpointsError{Err: bodyErr} + } + + endpoints := &ServerEndpoints{} + jsonErr := json.Unmarshal(body, endpoints) + + if jsonErr != nil { + return nil, &APIGetEndpointsError{Err: jsonErr} + } + + return endpoints, nil +} + // Authorized wrappers on top of HTTP // the errors will not be wrapped here so that the caller can check if we got a status error, to retry oauth -func apiAuthorized(server Server, method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { +func apiAuthorized(server Server, method string, endpoint string, opts *httpw.HTTPOptionalParams) (http.Header, []byte, error) { // Ensure optional is not nil as we will fill it with headers if opts == nil { - opts = &HTTPOptionalParams{} + opts = &httpw.HTTPOptionalParams{} } base, baseErr := server.GetBase() @@ -41,10 +62,10 @@ func apiAuthorized(server Server, method string, endpoint string, opts *HTTPOpti } else { opts.Headers = http.Header{headerKey: {headerValue}} } - return HTTPMethodWithOpts(method, url, opts) + return httpw.HTTPMethodWithOpts(method, url, opts) } -func apiAuthorizedRetry(server Server, method string, endpoint string, opts *HTTPOptionalParams) (http.Header, []byte, error) { +func apiAuthorizedRetry(server Server, method string, endpoint string, opts *httpw.HTTPOptionalParams) (http.Header, []byte, error) { header, body, bodyErr := apiAuthorized(server, method, endpoint, opts) base, baseErr := server.GetBase() @@ -52,13 +73,13 @@ func apiAuthorizedRetry(server Server, method string, endpoint string, opts *HTT return nil, nil, &APIAuthorizedError{Err: baseErr} } if bodyErr != nil { - var error *HTTPStatusError + var error *httpw.HTTPStatusError // Only retry authorized if we get a HTTP 401 if errors.As(bodyErr, &error) && error.Status == 401 { - base.Logger.Log(LOG_INFO, fmt.Sprintf("API: Got HTTP error %v, retrying authorized", error)) + base.Logger.Log(log.LOG_INFO, fmt.Sprintf("API: Got HTTP error %v, retrying authorized", error)) // Tell the method that the token is expired - server.GetOAuth().Token.ExpiredTimestamp = GenerateTimeSeconds() + server.GetOAuth().Token.ExpiredTimestamp = util.GenerateTimeSeconds() retryHeader, retryBody, retryErr := apiAuthorized(server, method, endpoint, opts) if retryErr != nil { return nil, nil, &APIAuthorizedError{Err: retryErr} @@ -110,7 +131,7 @@ func APIConnectWireguard(server Server, profile_id string, pubkey string, suppor "profile_id": {profile_id}, "public_key": {pubkey}, } - header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) + header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { return "", "", 0, &APIConnectWireguardError{Err: connectErr} } @@ -141,7 +162,7 @@ func APIConnectOpenVPN(server Server, profile_id string) (string, int64, error) "profile_id": {profile_id}, } - header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &HTTPOptionalParams{Headers: headers, Body: urlForm}) + header, connectBody, connectErr := apiAuthorizedRetry(server, http.MethodPost, "/connect", &httpw.HTTPOptionalParams{Headers: headers, Body: urlForm}) if connectErr != nil { return "", 0, &APIConnectOpenVPNError{Err: connectErr} } @@ -190,3 +211,11 @@ type APIInfoError struct { func (e *APIInfoError) Error() string { return fmt.Sprintf("failed api /info call with error: %v", e.Err) } + +type APIGetEndpointsError struct { + Err error +} + +func (e *APIGetEndpointsError) Error() string { + return fmt.Sprintf("failed to get server endpoint with error %v", e.Err) +} diff --git a/internal/server.go b/internal/server/server.go index d1fc433..a1fb749 100644 --- a/internal/server.go +++ b/internal/server/server.go @@ -1,8 +1,12 @@ -package internal +package server import ( - "encoding/json" "fmt" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" + "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/oauth" + "github.com/jwijenbergh/eduvpn-common/internal/util" + "github.com/jwijenbergh/eduvpn-common/internal/wireguard" ) // The base type for servers @@ -11,16 +15,16 @@ type ServerBase struct { Endpoints ServerEndpoints `json:"endpoints"` Profiles ServerProfileInfo `json:"profiles"` ProfilesRaw string `json:"profiles_raw"` - Logger *FileLogger `json:"-"` - FSM *FSM `json:"-"` StartTime int64 `json:"start-time"` EndTime int64 `json:"end-time"` + Logger *log.FileLogger `json:"-"` + FSM *fsm.FSM `json:"-"` } // An instute access server type InstituteAccessServer struct { // An instute access server has its own OAuth - OAuth OAuth `json:"oauth"` + OAuth oauth.OAuth `json:"oauth"` // Embed the server base Base ServerBase `json:"base"` @@ -29,7 +33,7 @@ type InstituteAccessServer struct { // A secure internet server which has its own OAuth tokens // It specifies the current location url it is connected to type SecureInternetHomeServer struct { - OAuth OAuth `json:"oauth"` + OAuth oauth.OAuth `json:"oauth"` // The home server has a list of info for each configured server BaseMap map[string]*ServerBase `json:"base_map"` @@ -69,21 +73,21 @@ type Servers struct { type Server interface { // Gets the current OAuth object - GetOAuth() *OAuth + GetOAuth() *oauth.OAuth // Gets the server base GetBase() (*ServerBase, error) // initialize method - init(url string, fsm *FSM, logger *FileLogger) error + init(url string, fsm *fsm.FSM, logger *log.FileLogger) error } // For an institute, we can simply get the OAuth -func (institute *InstituteAccessServer) GetOAuth() *OAuth { +func (institute *InstituteAccessServer) GetOAuth() *oauth.OAuth { return &institute.OAuth } -func (secure *SecureInternetHomeServer) GetOAuth() *OAuth { +func (secure *SecureInternetHomeServer) GetOAuth() *oauth.OAuth { return &secure.OAuth } @@ -104,11 +108,11 @@ func (server *SecureInternetHomeServer) GetBase() (*ServerBase, error) { return base, nil } -func (institute *InstituteAccessServer) init(url string, fsm *FSM, logger *FileLogger) error { +func (institute *InstituteAccessServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error { institute.Base.URL = url institute.Base.FSM = fsm institute.Base.Logger = logger - endpoints, endpointsErr := getEndpoints(url) + endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { return &ServerInitializeError{URL: url, Err: endpointsErr} } @@ -117,7 +121,7 @@ func (institute *InstituteAccessServer) init(url string, fsm *FSM, logger *FileL return nil } -func (secure *SecureInternetHomeServer) init(url string, fsm *FSM, logger *FileLogger) error { +func (secure *SecureInternetHomeServer) init(url string, fsm *fsm.FSM, logger *log.FileLogger) error { // Initialize the base map if it is non-nil if secure.BaseMap == nil { secure.BaseMap = make(map[string]*ServerBase) @@ -130,7 +134,7 @@ func (secure *SecureInternetHomeServer) init(url string, fsm *FSM, logger *FileL // Create the base to be added to the map base = &ServerBase{} base.URL = url - endpoints, endpointsErr := getEndpoints(url) + endpoints, endpointsErr := APIGetEndpoints(url) if endpointsErr != nil { return &ServerInitializeError{URL: url, Err: endpointsErr} } @@ -158,6 +162,37 @@ func (secure *SecureInternetHomeServer) init(url string, fsm *FSM, logger *FileL return nil } +func ShouldRenewButton(server Server) (bool, error) { + base, baseErr := server.GetBase() + + if baseErr != nil { + //return false, &GetRenewButtonTimeError{Err: baseErr} + return false, nil + } + + // Get current time + current := util.GenerateTimeSeconds() + + // 30 minutes have not passed + if current <= (base.StartTime + 30*60) { + return false, nil + } + + // Session will not expire today + if current <= (base.EndTime - 24*60*60) { + return false, nil + } + + // Session duration is less than 24 hours but not 75% has passed + duration := base.EndTime - base.StartTime + // TODO: Is converting to float64 okay here? + if duration < 24*60*60 && float64(current) <= (float64(base.StartTime) + 0.75*float64(duration)) { + return false, nil + } + + return true, nil +} + func Login(server Server) error { return server.GetOAuth().Login("org.eduvpn.app.linux") } @@ -169,7 +204,7 @@ func EnsureTokens(server Server) error { return &ServerEnsureTokensError{Err: baseErr} } if server.GetOAuth().NeedsRelogin() { - base.Logger.Log(LOG_INFO, "OAuth: Tokens are invalid, relogging in") + base.Logger.Log(log.LOG_INFO, "OAuth: Tokens are invalid, relogging in") loginErr := Login(server) if loginErr != nil { @@ -187,7 +222,7 @@ func CancelOAuth(server Server) { server.GetOAuth().Cancel() } -func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *FSM, logger *FileLogger) (Server, error) { +func (servers *Servers) EnsureServer(url string, isSecureInternet bool, fsm *fsm.FSM, logger *log.FileLogger) (Server, error) { // Intialize the secure internet server // This calls the init method which takes care of the rest if isSecureInternet { @@ -257,24 +292,6 @@ type ServerEndpoints struct { // Make this a var which we can overwrite in the tests var WellKnownPath string = ".well-known/vpn-user-portal" -func getEndpoints(baseURL string) (*ServerEndpoints, error) { - url := fmt.Sprintf("%s/%s", baseURL, WellKnownPath) - _, body, bodyErr := HTTPGet(url) - - if bodyErr != nil { - return nil, &ServerGetEndpointsError{Err: bodyErr} - } - - endpoints := &ServerEndpoints{} - jsonErr := json.Unmarshal(body, endpoints) - - if jsonErr != nil { - return nil, &ServerGetEndpointsError{Err: jsonErr} - } - - return endpoints, nil -} - func (profile *ServerProfile) supportsProtocol(protocol string) bool { for _, proto := range profile.VPNProtoList { if proto == protocol { @@ -307,14 +324,70 @@ func getCurrentProfile(server Server) (*ServerProfile, error) { return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID} } +func wireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, error) { + base, baseErr := server.GetBase() + + if baseErr != nil { + return "", "", baseErr + } + + profile_id := base.Profiles.Current + wireguardKey, wireguardErr := wireguard.GenerateKey() + + if wireguardErr != nil { + return "", "", wireguardErr + } + + wireguardPublicKey := wireguardKey.PublicKey().String() + config, content, expires, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey, supportsOpenVPN) + + if configErr != nil { + return "", "", wireguardErr + } + + // Store start and end time + base.StartTime = util.GenerateTimeSeconds() + base.EndTime = expires + + if content == "wireguard" { + // This needs the go code a way to identify a connection + // Use the uuid of the connection e.g. on Linux + // This needs the client code to call the go code + + config = wireguard.ConfigAddKey(config, wireguardKey) + } + + return config, content, nil +} + +func openVPNGetConfig(server Server) (string, string, error) { + base, baseErr := server.GetBase() + + if baseErr != nil { + return "", "", baseErr + } + profile_id := base.Profiles.Current + configOpenVPN, expires, configErr := APIConnectOpenVPN(server, profile_id) + + // Store start and end time + base.StartTime = util.GenerateTimeSeconds() + base.EndTime = expires + + if configErr != nil { + return "", "", configErr + } + + return configOpenVPN, "openvpn", nil +} + func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) { base, baseErr := server.GetBase() if baseErr != nil { return "", "", &ServerGetConfigWithProfileError{Err: baseErr} } - if !base.FSM.HasTransition(HAS_CONFIG) { - return "", "", &FSMWrongStateTransitionError{Got: base.FSM.Current, Want: HAS_CONFIG} + if !base.FSM.HasTransition(fsm.HAS_CONFIG) { + return "", "", &fsm.FSMWrongStateTransitionError{Got: base.FSM.Current, Want: fsm.HAS_CONFIG} } profile, profileErr := getCurrentProfile(server) @@ -337,9 +410,9 @@ func getConfigWithProfile(server Server, forceTCP bool) (string, string, error) if supportsWireguard { // A wireguard connect call needs to generate a wireguard key and add it to the config // Also the server could send back an OpenVPN config if it supports OpenVPN - config, configType, configErr = WireguardGetConfig(server, supportsOpenVPN) + config, configType, configErr = wireguardGetConfig(server, supportsOpenVPN) } else { - config, configType, configErr = OpenVPNGetConfig(server) + config, configType, configErr = openVPNGetConfig(server) } if configErr != nil { @@ -355,10 +428,10 @@ func askForProfileID(server Server) error { if baseErr != nil { return &ServerAskForProfileIDError{Err: baseErr} } - if !base.FSM.HasTransition(ASK_PROFILE) { - return &FSMWrongStateTransitionError{Got: base.FSM.Current, Want: ASK_PROFILE} + if !base.FSM.HasTransition(fsm.ASK_PROFILE) { + return &fsm.FSMWrongStateTransitionError{Got: base.FSM.Current, Want: fsm.ASK_PROFILE} } - base.FSM.GoTransitionWithData(ASK_PROFILE, base.ProfilesRaw, false) + base.FSM.GoTransitionWithData(fsm.ASK_PROFILE, base.ProfilesRaw, false) return nil } @@ -368,8 +441,8 @@ func GetConfig(server Server, forceTCP bool) (string, string, error) { if baseErr != nil { return "", "", &ServerGetConfigError{Err: baseErr} } - if !base.FSM.InState(REQUEST_CONFIG) { - return "", "", &FSMWrongStateError{Got: base.FSM.Current, Want: REQUEST_CONFIG} + if !base.FSM.InState(fsm.REQUEST_CONFIG) { + return "", "", &fsm.FSMWrongStateError{Got: base.FSM.Current, Want: fsm.REQUEST_CONFIG} } // Get new profiles using the info call @@ -383,7 +456,7 @@ func GetConfig(server Server, forceTCP bool) (string, string, error) { if base.Profiles.Current != "" { _, existsProfileErr := getCurrentProfile(server) if existsProfileErr != nil { - base.Logger.Log(LOG_INFO, fmt.Sprintf("Profile %s no longer exists, resetting the profile", base.Profiles.Current)) + base.Logger.Log(log.LOG_INFO, fmt.Sprintf("Profile %s no longer exists, resetting the profile", base.Profiles.Current)) base.Profiles.Current = "" } } @@ -428,14 +501,6 @@ func (e *ServerGetConfigForceTCPError) Error() string { return fmt.Sprintf("failed to get config, force TCP is on but the server does not support OpenVPN") } -type ServerGetEndpointsError struct { - Err error -} - -func (e *ServerGetEndpointsError) Error() string { - return fmt.Sprintf("failed to get server endpoint with error %v", e.Err) -} - type ServerGetSecureInternetHomeError struct{} func (e *ServerGetSecureInternetHomeError) Error() string { diff --git a/internal/util.go b/internal/util/util.go index 02855c2..4bdd1b5 100644 --- a/internal/util.go +++ b/internal/util/util.go @@ -1,4 +1,4 @@ -package internal +package util import ( "crypto/rand" diff --git a/internal/test_data/empty b/internal/verify/test_data/empty index e69de29..e69de29 100644 --- a/internal/test_data/empty +++ b/internal/verify/test_data/empty diff --git a/internal/test_data/generate.sh b/internal/verify/test_data/generate.sh index b1b4545..b1b4545 100644 --- a/internal/test_data/generate.sh +++ b/internal/verify/test_data/generate.sh diff --git a/internal/test_data/generate_forged.py b/internal/verify/test_data/generate_forged.py index 9d42adc..9d42adc 100644 --- a/internal/test_data/generate_forged.py +++ b/internal/verify/test_data/generate_forged.py diff --git a/internal/test_data/organization_list.json b/internal/verify/test_data/organization_list.json index 8c53044..8c53044 100644 --- a/internal/test_data/organization_list.json +++ b/internal/verify/test_data/organization_list.json diff --git a/internal/test_data/organization_list.json.minisig b/internal/verify/test_data/organization_list.json.minisig index 1fa546e..1fa546e 100644 --- a/internal/test_data/organization_list.json.minisig +++ b/internal/verify/test_data/organization_list.json.minisig diff --git a/internal/test_data/organization_list.json.tc_servlist.minisig b/internal/verify/test_data/organization_list.json.tc_servlist.minisig index a7fe41f..a7fe41f 100644 --- a/internal/test_data/organization_list.json.tc_servlist.minisig +++ b/internal/verify/test_data/organization_list.json.tc_servlist.minisig diff --git a/internal/test_data/other_list.json b/internal/verify/test_data/other_list.json index 25ba1a8..25ba1a8 100644 --- a/internal/test_data/other_list.json +++ b/internal/verify/test_data/other_list.json diff --git a/internal/test_data/other_list.json.minisig b/internal/verify/test_data/other_list.json.minisig index eaa2248..eaa2248 100644 --- a/internal/test_data/other_list.json.minisig +++ b/internal/verify/test_data/other_list.json.minisig diff --git a/internal/test_data/public.key b/internal/verify/test_data/public.key index 72676d3..72676d3 100644 --- a/internal/test_data/public.key +++ b/internal/verify/test_data/public.key diff --git a/internal/test_data/random.txt b/internal/verify/test_data/random.txt index b6fc4c6..b6fc4c6 100644 --- a/internal/test_data/random.txt +++ b/internal/verify/test_data/random.txt diff --git a/internal/test_data/secret.key b/internal/verify/test_data/secret.key index 6e4af37..6e4af37 100644 --- a/internal/test_data/secret.key +++ b/internal/verify/test_data/secret.key diff --git a/internal/test_data/server_list.json b/internal/verify/test_data/server_list.json index 67c4c8d..67c4c8d 100644 --- a/internal/test_data/server_list.json +++ b/internal/verify/test_data/server_list.json diff --git a/internal/test_data/server_list.json.blake2b b/internal/verify/test_data/server_list.json.blake2b Binary files differindex 5d2ca5a..5d2ca5a 100644 --- a/internal/test_data/server_list.json.blake2b +++ b/internal/verify/test_data/server_list.json.blake2b diff --git a/internal/test_data/server_list.json.forged_keyid.minisig b/internal/verify/test_data/server_list.json.forged_keyid.minisig index efa349d..efa349d 100644 --- a/internal/test_data/server_list.json.forged_keyid.minisig +++ b/internal/verify/test_data/server_list.json.forged_keyid.minisig diff --git a/internal/test_data/server_list.json.forged_pure.minisig b/internal/verify/test_data/server_list.json.forged_pure.minisig index a362504..a362504 100644 --- a/internal/test_data/server_list.json.forged_pure.minisig +++ b/internal/verify/test_data/server_list.json.forged_pure.minisig diff --git a/internal/test_data/server_list.json.large_time.minisig b/internal/verify/test_data/server_list.json.large_time.minisig index 79a2a52..79a2a52 100644 --- a/internal/test_data/server_list.json.large_time.minisig +++ b/internal/verify/test_data/server_list.json.large_time.minisig diff --git a/internal/test_data/server_list.json.minisig b/internal/verify/test_data/server_list.json.minisig index 143585b..143585b 100644 --- a/internal/test_data/server_list.json.minisig +++ b/internal/verify/test_data/server_list.json.minisig diff --git a/internal/test_data/server_list.json.pure.minisig b/internal/verify/test_data/server_list.json.pure.minisig index 57dccfc..57dccfc 100644 --- a/internal/test_data/server_list.json.pure.minisig +++ b/internal/verify/test_data/server_list.json.pure.minisig diff --git a/internal/test_data/server_list.json.tc_earliertime.minisig b/internal/verify/test_data/server_list.json.tc_earliertime.minisig index 03da710..03da710 100644 --- a/internal/test_data/server_list.json.tc_earliertime.minisig +++ b/internal/verify/test_data/server_list.json.tc_earliertime.minisig diff --git a/internal/test_data/server_list.json.tc_emptyfile.minisig b/internal/verify/test_data/server_list.json.tc_emptyfile.minisig index a7aa3ed..a7aa3ed 100644 --- a/internal/test_data/server_list.json.tc_emptyfile.minisig +++ b/internal/verify/test_data/server_list.json.tc_emptyfile.minisig diff --git a/internal/test_data/server_list.json.tc_emptytime.minisig b/internal/verify/test_data/server_list.json.tc_emptytime.minisig index d3ef01e..d3ef01e 100644 --- a/internal/test_data/server_list.json.tc_emptytime.minisig +++ b/internal/verify/test_data/server_list.json.tc_emptytime.minisig diff --git a/internal/test_data/server_list.json.tc_latertime.minisig b/internal/verify/test_data/server_list.json.tc_latertime.minisig index 8237123..8237123 100644 --- a/internal/test_data/server_list.json.tc_latertime.minisig +++ b/internal/verify/test_data/server_list.json.tc_latertime.minisig diff --git a/internal/test_data/server_list.json.tc_nofile.minisig b/internal/verify/test_data/server_list.json.tc_nofile.minisig index 3c1dcbe..3c1dcbe 100644 --- a/internal/test_data/server_list.json.tc_nofile.minisig +++ b/internal/verify/test_data/server_list.json.tc_nofile.minisig diff --git a/internal/test_data/server_list.json.tc_nohashed.minisig b/internal/verify/test_data/server_list.json.tc_nohashed.minisig index 1d140c1..1d140c1 100644 --- a/internal/test_data/server_list.json.tc_nohashed.minisig +++ b/internal/verify/test_data/server_list.json.tc_nohashed.minisig diff --git a/internal/test_data/server_list.json.tc_notime.minisig b/internal/verify/test_data/server_list.json.tc_notime.minisig index 39625c3..39625c3 100644 --- a/internal/test_data/server_list.json.tc_notime.minisig +++ b/internal/verify/test_data/server_list.json.tc_notime.minisig diff --git a/internal/test_data/server_list.json.tc_orglist.minisig b/internal/verify/test_data/server_list.json.tc_orglist.minisig index 7c2a3a8..7c2a3a8 100644 --- a/internal/test_data/server_list.json.tc_orglist.minisig +++ b/internal/verify/test_data/server_list.json.tc_orglist.minisig diff --git a/internal/test_data/server_list.json.tc_otherfile.minisig b/internal/verify/test_data/server_list.json.tc_otherfile.minisig index 58a29b2..58a29b2 100644 --- a/internal/test_data/server_list.json.tc_otherfile.minisig +++ b/internal/verify/test_data/server_list.json.tc_otherfile.minisig diff --git a/internal/test_data/server_list.json.tc_random.minisig b/internal/verify/test_data/server_list.json.tc_random.minisig index 7240980..7240980 100644 --- a/internal/test_data/server_list.json.tc_random.minisig +++ b/internal/verify/test_data/server_list.json.tc_random.minisig diff --git a/internal/test_data/server_list.json.wrong_key.minisig b/internal/verify/test_data/server_list.json.wrong_key.minisig index 5a83c0e..5a83c0e 100644 --- a/internal/test_data/server_list.json.wrong_key.minisig +++ b/internal/verify/test_data/server_list.json.wrong_key.minisig diff --git a/internal/test_data/wrong_public.key b/internal/verify/test_data/wrong_public.key index aa794d4..aa794d4 100644 --- a/internal/test_data/wrong_public.key +++ b/internal/verify/test_data/wrong_public.key diff --git a/internal/test_data/wrong_secret.key b/internal/verify/test_data/wrong_secret.key index 68e9092..68e9092 100644 --- a/internal/test_data/wrong_secret.key +++ b/internal/verify/test_data/wrong_secret.key diff --git a/internal/verify.go b/internal/verify/verify.go index 713e4d7..2d53b2e 100644 --- a/internal/verify.go +++ b/internal/verify/verify.go @@ -1,4 +1,4 @@ -package internal +package verify import ( "errors" diff --git a/internal/verify_test.go b/internal/verify/verify_test.go index f980dc2..7d577dd 100644 --- a/internal/verify_test.go +++ b/internal/verify/verify_test.go @@ -1,4 +1,4 @@ -package internal +package verify import ( "bufio" diff --git a/internal/wireguard.go b/internal/wireguard.go deleted file mode 100644 index 00c9467..0000000 --- a/internal/wireguard.go +++ /dev/null @@ -1,82 +0,0 @@ -package internal - -import ( - "fmt" - "regexp" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -func wireguardGenerateKey() (wgtypes.Key, error) { - key, keyErr := wgtypes.GeneratePrivateKey() - - if keyErr != nil { - return key, &WireguardGenerateKeyError{Err: keyErr} - } - return key, nil -} - -// FIXME: Instead of doing a regex replace, decide if we should use a parser -func wireguardConfigAddKey(config string, key wgtypes.Key) string { - interface_section := "[Interface]" - interface_section_escaped := regexp.QuoteMeta(interface_section) - - // (?m) enables multi line mode - // ^ match from beginning of line - // $ match till end of line - // So it matches [Interface] section exactly - interface_re := regexp.MustCompile(fmt.Sprintf("(?m)^%s$", interface_section_escaped)) - to_replace := fmt.Sprintf("%s\nPrivateKey = %s", interface_section, key.String()) - return interface_re.ReplaceAllString(config, to_replace) -} - -func WireguardGetConfig(server Server, supportsOpenVPN bool) (string, string, error) { - base, baseErr := server.GetBase() - - if baseErr != nil { - return "", "", &WireguardGetConfigError{Err: baseErr} - } - - profile_id := base.Profiles.Current - wireguardKey, wireguardErr := wireguardGenerateKey() - - if wireguardErr != nil { - return "", "", &WireguardGetConfigError{Err: wireguardErr} - } - - wireguardPublicKey := wireguardKey.PublicKey().String() - config, content, expires, configErr := APIConnectWireguard(server, profile_id, wireguardPublicKey, supportsOpenVPN) - - if configErr != nil { - return "", "", &WireguardGetConfigError{Err: wireguardErr} - } - - // Store start and end time - base.StartTime = GenerateTimeSeconds() - base.EndTime = expires - - if content == "wireguard" { - // This needs the go code a way to identify a connection - // Use the uuid of the connection e.g. on Linux - // This needs the client code to call the go code - - config = wireguardConfigAddKey(config, wireguardKey) - } - - return config, content, nil -} - -type WireguardGenerateKeyError struct { - Err error -} - -func (e *WireguardGenerateKeyError) Error() string { - return fmt.Sprintf("failed generating Wireguard key with error: %v", e.Err) -} - -type WireguardGetConfigError struct { - Err error -} - -func (e *WireguardGetConfigError) Error() string { - return fmt.Sprintf("failed getting Wireguard config with error: %v", e.Err) -} diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go new file mode 100644 index 0000000..db20067 --- /dev/null +++ b/internal/wireguard/wireguard.go @@ -0,0 +1,38 @@ +package wireguard + +import ( + "fmt" + "regexp" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +func GenerateKey() (wgtypes.Key, error) { + key, keyErr := wgtypes.GeneratePrivateKey() + + if keyErr != nil { + return key, &WireguardGenerateKeyError{Err: keyErr} + } + return key, nil +} + +// FIXME: Instead of doing a regex replace, decide if we should use a parser +func ConfigAddKey(config string, key wgtypes.Key) string { + interface_section := "[Interface]" + interface_section_escaped := regexp.QuoteMeta(interface_section) + + // (?m) enables multi line mode + // ^ match from beginning of line + // $ match till end of line + // So it matches [Interface] section exactly + interface_re := regexp.MustCompile(fmt.Sprintf("(?m)^%s$", interface_section_escaped)) + to_replace := fmt.Sprintf("%s\nPrivateKey = %s", interface_section, key.String()) + return interface_re.ReplaceAllString(config, to_replace) +} + +type WireguardGenerateKeyError struct { + Err error +} + +func (e *WireguardGenerateKeyError) Error() string { + return fmt.Sprintf("failed generating Wireguard key with error: %v", e.Err) +} @@ -3,38 +3,42 @@ package eduvpn import ( "fmt" - "github.com/jwijenbergh/eduvpn-common/internal" + "github.com/jwijenbergh/eduvpn-common/internal/config" + "github.com/jwijenbergh/eduvpn-common/internal/discovery" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" + "github.com/jwijenbergh/eduvpn-common/internal/log" + "github.com/jwijenbergh/eduvpn-common/internal/server" ) type VPNState struct { // The chosen server - Servers internal.Servers `json:"servers"` + Servers server.Servers `json:"servers"` // The list of servers and organizations from disco - Discovery internal.Discovery `json:"-"` + Discovery discovery.Discovery `json:"-"` // The fsm - FSM internal.FSM `json:"-"` + FSM fsm.FSM `json:"-"` // The logger - Logger internal.FileLogger `json:"-"` + Logger log.FileLogger `json:"-"` // The config - Config internal.Config `json:"-"` + Config config.Config `json:"-"` // Whether to enable debugging Debug bool `json:"-"` } func (state *VPNState) Register(name string, directory string, stateCallback func(string, string, string), debug bool) error { - if !state.FSM.InState(internal.DEREGISTERED) { - return &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.DEREGISTERED} + if !state.FSM.InState(fsm.DEREGISTERED) { + return &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.DEREGISTERED} } // Initialize the logger - logLevel := internal.LOG_WARNING + logLevel := log.LOG_WARNING if debug { - logLevel = internal.LOG_INFO + logLevel = log.LOG_INFO } loggerErr := state.Logger.Init(logLevel, name, directory) @@ -55,9 +59,9 @@ func (state *VPNState) Register(name string, directory string, stateCallback fun // Try to load the previous configuration if state.Config.Load(&state) != nil { // This error can be safely ignored, as when the config does not load, the struct will not be filled - state.Logger.Log(internal.LOG_INFO, "Previous configuration not found") + state.Logger.Log(log.LOG_INFO, "Previous configuration not found") } - state.FSM.GoTransition(internal.NO_SERVER) + state.FSM.GoTransition(fsm.NO_SERVER) return nil } @@ -74,20 +78,20 @@ func (state *VPNState) Deregister() error { } func (state *VPNState) CancelOAuth() error { - if !state.FSM.InState(internal.OAUTH_STARTED) { - return &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.OAUTH_STARTED} + if !state.FSM.InState(fsm.OAUTH_STARTED) { + return &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.OAUTH_STARTED} } - server, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { return &StateOAuthCancelError{Err: serverErr} } - internal.CancelOAuth(server) + server.CancelOAuth(currentServer) return nil } -func (state *VPNState) chooseServer(url string, isSecureInternet bool) (internal.Server, error) { +func (state *VPNState) chooseServer(url string, isSecureInternet bool) (server.Server, error) { // New server chosen, ensure the server is fresh server, serverErr := state.Servers.EnsureServer(url, isSecureInternet, &state.FSM, &state.Logger) @@ -96,51 +100,51 @@ func (state *VPNState) chooseServer(url string, isSecureInternet bool) (internal } // Make sure we are in the chosen state if available - state.FSM.GoTransition(internal.CHOSEN_SERVER) + state.FSM.GoTransition(fsm.CHOSEN_SERVER) return server, nil } func (state *VPNState) getConfigWithOptions(url string, isSecureInternet bool, forceTCP bool) (string, string, error) { - if state.FSM.InState(internal.DEREGISTERED) { + if state.FSM.InState(fsm.DEREGISTERED) { return "", "", &StateFSMNotRegisteredError{} } // Go to no server if possible, else return an error - if !state.FSM.InState(internal.NO_SERVER) && !state.FSM.GoTransition(internal.NO_SERVER) { - return "", "", &internal.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: internal.NO_SERVER} + if !state.FSM.InState(fsm.NO_SERVER) && !state.FSM.GoTransition(fsm.NO_SERVER) { + return "", "", &fsm.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: fsm.NO_SERVER} } // Make sure the server is chosen - server, serverErr := state.chooseServer(url, isSecureInternet) + chosenServer, serverErr := state.chooseServer(url, isSecureInternet) if serverErr != nil { return "", "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: serverErr} } // Relogin with oauth // This moves the state to authorized - if internal.NeedsRelogin(server) { - loginErr := internal.Login(server) + if server.NeedsRelogin(chosenServer) { + loginErr := server.Login(chosenServer) if loginErr != nil { // We are possibly in oauth started // So go to no server - state.FSM.GoTransition(internal.NO_SERVER) + state.FSM.GoTransition(fsm.NO_SERVER) return "", "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: loginErr} } } else { // OAuth was valid, ensure we are in the authorized state - state.FSM.GoTransition(internal.AUTHORIZED) + state.FSM.GoTransition(fsm.AUTHORIZED) } - state.FSM.GoTransition(internal.REQUEST_CONFIG) + state.FSM.GoTransition(fsm.REQUEST_CONFIG) - config, configType, configErr := internal.GetConfig(server, forceTCP) + config, configType, configErr := server.GetConfig(chosenServer, forceTCP) if configErr != nil { // Go back to no server if possible - state.FSM.GoTransition(internal.NO_SERVER) + state.FSM.GoTransition(fsm.NO_SERVER) return "", "", &StateConnectError{URL: url, IsSecureInternet: isSecureInternet, Err: configErr} } else { - state.FSM.GoTransition(internal.HAS_CONFIG) + state.FSM.GoTransition(fsm.HAS_CONFIG) } return config, configType, nil @@ -155,22 +159,22 @@ func (state *VPNState) GetConfigSecureInternet(url string, forceTCP bool) (strin } func (state *VPNState) GetDiscoOrganizations() (string, error) { - if state.FSM.InState(internal.DEREGISTERED) { - return "", &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.DEREGISTERED} + if state.FSM.InState(fsm.DEREGISTERED) { + return "", &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.DEREGISTERED} } return state.Discovery.GetOrganizationsList() } func (state *VPNState) GetDiscoServers() (string, error) { - if state.FSM.InState(internal.DEREGISTERED) { + if state.FSM.InState(fsm.DEREGISTERED) { return "", &StateFSMNotRegisteredError{} } return state.Discovery.GetServersList() } func (state *VPNState) SetProfileID(profileID string) error { - if !state.FSM.InState(internal.ASK_PROFILE) { - return &StateWrongFSMStateError{Got: state.FSM.Current, Want: internal.ASK_PROFILE} + if !state.FSM.InState(fsm.ASK_PROFILE) { + return &StateWrongFSMStateError{Got: state.FSM.Current, Want: fsm.ASK_PROFILE} } server, serverErr := state.Servers.GetCurrentServer() @@ -188,20 +192,20 @@ func (state *VPNState) SetProfileID(profileID string) error { } func (state *VPNState) SetConnected() error { - if !state.FSM.HasTransition(internal.CONNECTED) { - return &internal.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: internal.CONNECTED} + if !state.FSM.HasTransition(fsm.CONNECTED) { + return &fsm.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: fsm.CONNECTED} } - state.FSM.GoTransition(internal.CONNECTED) + state.FSM.GoTransition(fsm.CONNECTED) return nil } func (state *VPNState) SetDisconnected() error { - if !state.FSM.HasTransition(internal.HAS_CONFIG) { - return &internal.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: internal.HAS_CONFIG} + if !state.FSM.HasTransition(fsm.HAS_CONFIG) { + return &fsm.FSMWrongStateTransitionError{Got: state.FSM.Current, Want: fsm.HAS_CONFIG} } - state.FSM.GoTransition(internal.HAS_CONFIG) + state.FSM.GoTransition(fsm.HAS_CONFIG) return nil } @@ -225,12 +229,12 @@ func (e *StateRegisterError) Error() string { type StateFSMNotRegisteredError struct{} func (e *StateFSMNotRegisteredError) Error() string { - return fmt.Sprintf("state is not registered. Current FSM state: %s", internal.DEREGISTERED.String()) + return fmt.Sprintf("state is not registered. Current FSM state: %s", fsm.DEREGISTERED.String()) } type StateWrongFSMStateError struct { - Got internal.FSMStateID - Want internal.FSMStateID + Got fsm.FSMStateID + Want fsm.FSMStateID } func (e *StateWrongFSMStateError) Error() string { diff --git a/state_test.go b/state_test.go index b515ffb..521513a 100644 --- a/state_test.go +++ b/state_test.go @@ -11,14 +11,17 @@ import ( "testing" "time" - "github.com/jwijenbergh/eduvpn-common/internal" + "github.com/jwijenbergh/eduvpn-common/internal/fsm" + httpw "github.com/jwijenbergh/eduvpn-common/internal/http" + "github.com/jwijenbergh/eduvpn-common/internal/oauth" + "github.com/jwijenbergh/eduvpn-common/internal/server" ) func ensureLocalWellKnown() { wellKnown := os.Getenv("SERVER_IS_LOCAL") if wellKnown == "1" { - internal.WellKnownPath = "well-known.php" + server.WellKnownPath = "well-known.php" } } @@ -75,7 +78,7 @@ func Test_server(t *testing.T) { } } -func test_connect_oauth_parameter(t *testing.T, parameters internal.URLParameters, expectedErr interface{}) { +func test_connect_oauth_parameter(t *testing.T, parameters httpw.URLParameters, expectedErr interface{}) { serverURI := getServerURI(t) state := &VPNState{} configDirectory := "test_oauth_parameters" @@ -83,7 +86,7 @@ func test_connect_oauth_parameter(t *testing.T, parameters internal.URLParameter state.Register("org.eduvpn.app.linux", configDirectory, func(oldState string, newState string, data string) { if newState == "OAuth_Started" { baseURL := "http://127.0.0.1:8000/callback" - url, err := internal.HTTPConstructURL(baseURL, parameters) + url, err := httpw.HTTPConstructURL(baseURL, parameters) if err != nil { t.Fatalf("Error: Constructing url %s with parameters %s", baseURL, fmt.Sprint(parameters)) } @@ -94,8 +97,8 @@ func test_connect_oauth_parameter(t *testing.T, parameters internal.URLParameter _, _, configErr := state.GetConfigInstituteAccess(serverURI, false) var stateErr *StateConnectError - var loginErr *internal.OAuthLoginError - var finishErr *internal.OAuthFinishError + var loginErr *oauth.OAuthLoginError + var finishErr *oauth.OAuthFinishError // We go through the chain of errors by unwrapping them one by one @@ -128,17 +131,17 @@ func test_connect_oauth_parameter(t *testing.T, parameters internal.URLParameter func Test_connect_oauth_parameters(t *testing.T) { var ( - failedCallbackParameterError *internal.OAuthCallbackParameterError - failedCallbackStateMatchError *internal.OAuthCallbackStateMatchError + failedCallbackParameterError *oauth.OAuthCallbackParameterError + failedCallbackStateMatchError *oauth.OAuthCallbackStateMatchError ) tests := []struct { expectedErr interface{} - parameters internal.URLParameters + parameters httpw.URLParameters }{ - {&failedCallbackParameterError, internal.URLParameters{}}, - {&failedCallbackParameterError, internal.URLParameters{"code": "42"}}, - {&failedCallbackStateMatchError, internal.URLParameters{"code": "42", "state": "21"}}, + {&failedCallbackParameterError, httpw.URLParameters{}}, + {&failedCallbackParameterError, httpw.URLParameters{"code": "42"}}, + {&failedCallbackStateMatchError, httpw.URLParameters{"code": "42", "state": "21"}}, } ensureLocalWellKnown() @@ -177,12 +180,12 @@ func Test_token_expired(t *testing.T) { t.Fatalf("Connect error before expired: %v", configErr) } - server, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { t.Fatalf("No server found") } - oauth := server.GetOAuth() + oauth := currentServer.GetOAuth() accessToken := oauth.Token.Access refreshToken := oauth.Token.Refresh @@ -190,7 +193,7 @@ func Test_token_expired(t *testing.T) { // Wait for TTL so that the tokens expire time.Sleep(time.Duration(expiredInt) * time.Second) - infoErr := internal.APIInfo(server) + infoErr := server.APIInfo(currentServer) if infoErr != nil { t.Fatalf("Info error after expired: %v", infoErr) @@ -228,16 +231,16 @@ func Test_token_invalid(t *testing.T) { // Go to request_config so we can re-authorize // This is needed as the only actual authenticated requests we do in request_config (for profiles) and /connect // /disconnect is best effort so this does not need re-auth - state.FSM.GoTransition(internal.REQUEST_CONFIG) + state.FSM.GoTransition(fsm.REQUEST_CONFIG) dummy_value := "37" - server, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { t.Fatalf("No server found") } - oauth := server.GetOAuth() + oauth := currentServer.GetOAuth() // Override tokens with invalid values oauth.Token.Access = dummy_value @@ -275,12 +278,12 @@ func Test_invalid_profile_corrected(t *testing.T) { t.Fatalf("First connect error: %v", configErr) } - server, serverErr := state.Servers.GetCurrentServer() + currentServer, serverErr := state.Servers.GetCurrentServer() if serverErr != nil { t.Fatalf("No server found") } - base, baseErr := server.GetBase() + base, baseErr := currentServer.GetBase() if baseErr != nil { t.Fatalf("No base found") } |
