summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/eduvpn-cli/main.go77
1 files changed, 46 insertions, 31 deletions
diff --git a/cmd/eduvpn-cli/main.go b/cmd/eduvpn-cli/main.go
index 354682a..d8942c5 100644
--- a/cmd/eduvpn-cli/main.go
+++ b/cmd/eduvpn-cli/main.go
@@ -24,8 +24,6 @@ func openBrowser(data interface{}) {
if !ok {
return
}
- fmt.Printf("OAuth: Authorization URL: %s\n", str)
- fmt.Println("Opening browser...")
go func() {
err := browser.OpenURL(str)
if err != nil {
@@ -35,24 +33,12 @@ func openBrowser(data interface{}) {
}()
}
-// Ask for a profile in the command line.
-func sendProfile(state *client.Client, data interface{}) {
+func getProfileInteractive(profiles *srvtypes.Profiles, data interface{}) (string, error) {
fmt.Printf("Multiple VPN profiles found. Please select a profile by entering e.g. 1")
- d, ok := data.(*srvtypes.RequiredAskTransition)
- if !ok {
- fmt.Fprintf(os.Stderr, "\ninvalid data type: %v\n", reflect.TypeOf(data))
- return
- }
- sps, ok := d.Data.(*srvtypes.Profiles)
- if !ok {
- fmt.Fprintf(os.Stderr, "\ninvalid data type for profiles: %v\n", reflect.TypeOf(d.Data))
- return
- }
-
ps := ""
var options []string
i := 0
- for k, v := range sps.Map {
+ for k, v := range profiles.Map {
ps += fmt.Sprintf("\n%d - %s", i+1, util.GetLanguageMatched(v.DisplayName, "en"))
options = append(options, k)
i++
@@ -63,16 +49,39 @@ func sendProfile(state *client.Client, data interface{}) {
var idx int
if _, err := fmt.Scanf("%d", &idx); err != nil || idx <= 0 ||
- idx > len(sps.Map) {
+ idx > len(profiles.Map) {
fmt.Fprintln(os.Stderr, "invalid profile chosen, please retry")
- sendProfile(state, data)
- return
+ return getProfileInteractive(profiles, data)
}
p := options[idx-1]
fmt.Println("Sending profile ID", p)
- if err := d.C.Send(p); err != nil {
- fmt.Fprintln(os.Stderr, "failed setting profile with error", err)
+ return p, nil
+}
+
+func sendProfile(profile string, data interface{}) {
+ d, ok := data.(*srvtypes.RequiredAskTransition)
+ if !ok {
+ fmt.Fprintf(os.Stderr, "\ninvalid data type: %v\n", reflect.TypeOf(data))
+ os.Exit(1)
+ }
+ sps, ok := d.Data.(*srvtypes.Profiles)
+ if !ok {
+ fmt.Fprintf(os.Stderr, "\ninvalid data type for profiles: %v\n", reflect.TypeOf(d.Data))
+ os.Exit(1)
+ }
+
+ if profile == "" {
+ gprof, err := getProfileInteractive(sps, data)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "failed getting profile interactively: %v\n", err)
+ os.Exit(1)
+ }
+ profile = gprof
+ }
+ if err := d.C.Send(profile); err != nil {
+ fmt.Fprintf(os.Stderr, "failed setting profile with error: %v\n", err)
+ os.Exit(1)
}
}
@@ -80,13 +89,13 @@ func sendProfile(state *client.Client, data interface{}) {
// If OAuth is started we open the browser with the Auth URL
// If we ask for a profile, we send the profile using command line input
// Note that this has an additional argument, the vpn state which was wrapped into this callback function below.
-func stateCallback(state *client.Client, _ client.FSMStateID, newState client.FSMStateID, data interface{}, dir string) {
+func stateCallback(_ client.FSMStateID, newState client.FSMStateID, data interface{}, prof string, dir string) {
if newState == client.StateOAuthStarted {
openBrowser(data)
}
if newState == client.StateAskProfile {
- sendProfile(state, data)
+ sendProfile(prof, data)
}
if newState == client.StateAskLocation {
@@ -98,7 +107,7 @@ func stateCallback(state *client.Client, _ client.FSMStateID, newState client.FS
}
// Get a config for Institute Access or Secure Internet Server.
-func getConfig(state *client.Client, url string, srvType srvtypes.Type, cc string) (*srvtypes.Configuration, error) {
+func getConfig(state *client.Client, url string, srvType srvtypes.Type, cc string, prof string) (*srvtypes.Configuration, error) {
if !strings.HasPrefix(url, "http") {
url = "https://" + url
}
@@ -118,11 +127,16 @@ func getConfig(state *client.Client, url string, srvType srvtypes.Type, cc strin
}
}
+ if prof != "" {
+ // this is best effort, e.g. if no server was chosen before this fails
+ _ = state.SetProfileID(prof) //nolint:errcheck
+ }
+
return state.GetConfig(ck, url, srvType, false, false)
}
// Get a config for a single server, Institute Access or Secure Internet.
-func printConfig(url string, cc string, srvType srvtypes.Type, debug bool) error {
+func printConfig(url string, cc string, srvType srvtypes.Type, prof string, debug bool) error {
var c *client.Client
var err error
var dir string
@@ -136,7 +150,7 @@ func printConfig(url string, cc string, srvType srvtypes.Type, debug bool) error
fmt.Sprintf("%s-cli", version.Version),
dir,
func(oldState client.FSMStateID, newState client.FSMStateID, data interface{}) bool {
- stateCallback(c, oldState, newState, data, dir)
+ stateCallback(oldState, newState, data, prof, dir)
return true
},
debug,
@@ -158,11 +172,11 @@ func printConfig(url string, cc string, srvType srvtypes.Type, debug bool) error
defer c.Deregister()
- cfg, err := getConfig(c, url, srvType, cc)
+ cfg, err := getConfig(c, url, srvType, cc, prof)
if err != nil {
return err
}
- fmt.Printf("Obtained config:\n%s\n", cfg.VPNConfig)
+ fmt.Println(cfg.VPNConfig)
return nil
}
@@ -173,6 +187,7 @@ func main() {
u := flag.String("get-institute", "", "The url of an institute to connect to")
sec := flag.String("get-secure", "", "Gets secure internet servers")
cc := flag.String("country-code", "", "The country code to use in case of a secure internet server")
+ prof := flag.String("profile", "", "The profile ID to choose")
debug := flag.Bool("debug", false, "Whether or not to enable debugging")
flag.Parse()
@@ -180,11 +195,11 @@ func main() {
var err error
switch {
case *cu != "":
- err = printConfig(*cu, "", srvtypes.TypeCustom, *debug)
+ err = printConfig(*cu, "", srvtypes.TypeCustom, *prof, *debug)
case *u != "":
- err = printConfig(*u, "", srvtypes.TypeInstituteAccess, *debug)
+ err = printConfig(*u, "", srvtypes.TypeInstituteAccess, *prof, *debug)
case *sec != "":
- err = printConfig(*sec, *cc, srvtypes.TypeSecureInternet, *debug)
+ err = printConfig(*sec, *cc, srvtypes.TypeSecureInternet, *prof, *debug)
default:
flag.PrintDefaults()
}