diff --git a/cmd/vbantxt/main.go b/cmd/vbantxt/main.go index 83baf96..e4993a9 100644 --- a/cmd/vbantxt/main.go +++ b/cmd/vbantxt/main.go @@ -47,12 +47,28 @@ func (f *Flags) String() string { ) } -func exitOnError(err error) { - fmt.Fprintf(os.Stderr, "Error: %s\n", err) - os.Exit(1) +func main() { + var exitCode int + + // Defer exit with the final exit code + defer func() { + if exitCode != 0 { + os.Exit(exitCode) + } + }() + + closer, err := run() + if closer != nil { + defer closer() + } + if err != nil { + log.Error(err) + exitCode = 1 + } } -func main() { +// run contains the main application logic and returns a closer function and any error. +func run() (func(), error) { var flags Flags // VBAN specific flags @@ -66,7 +82,7 @@ func main() { configDir, err := os.UserConfigDir() if err != nil { - exitOnError(fmt.Errorf("failed to get user config directory: %w", err)) + return nil, fmt.Errorf("failed to get user config directory: %w", err) } defaultConfigPath := filepath.Join(configDir, "vbantxt", "config.toml") @@ -98,7 +114,7 @@ func main() { fmt.Fprintf(os.Stderr, "%s\n", ffhelp.Flags(fs, "vbantxt [flags] ")) os.Exit(0) case err != nil: - exitOnError(fmt.Errorf("failed to parse flags: %w", err)) + return nil, fmt.Errorf("failed to parse flags: %w", err) } if flags.Version { @@ -108,7 +124,7 @@ func main() { level, err := log.ParseLevel(flags.Loglevel) if err != nil { - exitOnError(fmt.Errorf("invalid log level: %s", flags.Loglevel)) + return nil, fmt.Errorf("invalid log level %q", flags.Loglevel) } log.SetLevel(level) @@ -116,16 +132,18 @@ func main() { client, closer, err := createClient(&flags) if err != nil { - exitOnError(err) + return nil, fmt.Errorf("failed to create VBAN client: %w", err) } - defer closer() commands := fs.GetArgs() if len(commands) == 0 { - exitOnError(errors.New("no VBAN commands provided")) + return closer, errors.New( + "no VBAN commands provided; please provide at least one command as an argument", + ) } sendCommands(client, commands) + return closer, nil } // versionFromBuild retrieves the version information from the build metadata. @@ -133,7 +151,7 @@ func versionFromBuild() string { if version == "" { info, ok := debug.ReadBuildInfo() if !ok { - exitOnError(errors.New("failed to read build info")) + return "(unable to read build info)" } version = strings.Split(info.Main.Version, "-")[0] }