summaryrefslogtreecommitdiff
path: root/internal/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server.go')
-rw-r--r--internal/server.go40
1 files changed, 32 insertions, 8 deletions
diff --git a/internal/server.go b/internal/server.go
index 4452c6d..c9e31af 100644
--- a/internal/server.go
+++ b/internal/server.go
@@ -273,15 +273,23 @@ func getEndpoints(baseURL string) (*ServerEndpoints, error) {
return endpoints, nil
}
-func (profile *ServerProfile) supportsWireguard() bool {
+func (profile *ServerProfile) supportsProtocol(protocol string) bool {
for _, proto := range profile.VPNProtoList {
- if proto == "wireguard" {
+ if proto == protocol {
return true
}
}
return false
}
+func (profile *ServerProfile) supportsWireguard() bool {
+ return profile.supportsProtocol("wireguard")
+}
+
+func (profile *ServerProfile) supportsOpenVPN() bool {
+ return profile.supportsProtocol("openvpn")
+}
+
func getCurrentProfile(server Server) (*ServerProfile, error) {
base, baseErr := server.GetBase()
@@ -297,7 +305,7 @@ func getCurrentProfile(server Server) (*ServerProfile, error) {
return nil, &ServerGetCurrentProfileNotFoundError{ProfileID: profileID}
}
-func getConfigWithProfile(server Server) (string, error) {
+func getConfigWithProfile(server Server, forceTCP bool) (string, error) {
base, baseErr := server.GetBase()
if baseErr != nil {
@@ -312,8 +320,18 @@ func getConfigWithProfile(server Server) (string, error) {
return "", &ServerGetConfigWithProfileError{Err: profileErr}
}
- if profile.supportsWireguard() {
- return WireguardGetConfig(server)
+ supportsOpenVPN := profile.supportsOpenVPN()
+ supportsWireguard := profile.supportsWireguard()
+
+ // If forceTCP we must be able to get a config with OpenVPN
+ if forceTCP && supportsOpenVPN {
+ return "", &ServerGetConfigForceTCPError{}
+ }
+
+ if supportsWireguard {
+ // A wireguard connect call needs to generate a wireguard key and add it to the config
+ // Also the server could send back an OpenVPN config if it supports OpenVPN
+ return WireguardGetConfig(server, supportsOpenVPN)
}
return OpenVPNGetConfig(server)
}
@@ -331,7 +349,7 @@ func askForProfileID(server Server) error {
return nil
}
-func GetConfig(server Server) (string, error) {
+func GetConfig(server Server, forceTCP bool) (string, error) {
base, baseErr := server.GetBase()
if baseErr != nil {
@@ -363,7 +381,7 @@ func GetConfig(server Server) (string, error) {
if base.Profiles.Current == "" {
base.Profiles.Current = base.Profiles.Info.ProfileList[0].ID
}
- return getConfigWithProfile(server)
+ return getConfigWithProfile(server, forceTCP)
}
profileErr := askForProfileID(server)
@@ -372,7 +390,7 @@ func GetConfig(server Server) (string, error) {
return "", &ServerGetConfigError{Err: profileErr}
}
- return getConfigWithProfile(server)
+ return getConfigWithProfile(server, forceTCP)
}
type ServerGetCurrentProfileNotFoundError struct {
@@ -391,6 +409,12 @@ func (e *ServerGetConfigWithProfileError) Error() string {
return fmt.Sprintf("failed to get config including profile with error %v", e.Err)
}
+type ServerGetConfigForceTCPError struct{}
+
+func (e *ServerGetConfigForceTCPError) Error() string {
+ return fmt.Sprintf("failed to get config, force TCP is on but the server does not support OpenVPN")
+}
+
type ServerGetEndpointsError struct {
Err error
}