package main
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"log"
"os"
"os/exec"
"os/user"
"regexp"
"strings"
errors "git.sequentialread.com/forest/pkg-errors"
)
type SSHService struct {
SSHPrivateKeyFile string
}
func NewSSHService ( config * Config ) * SSHService {
return & SSHService {
SSHPrivateKeyFile : config . SSHPrivateKeyFile ,
}
}
func ( service * SSHService ) RunScriptOnRemoteHost ( script , username , ipv4 string ) ( string , string , error ) {
remoteCommand := fmt . Sprintf (
"sh -c 'echo %s | base64 -d | sh'" ,
base64 . StdEncoding . EncodeToString ( [ ] byte ( script ) ) ,
)
userAtHost := fmt . Sprintf ( "%s@%s" , username , ipv4 )
exitCode , stderr , stdout , err := shellExecInputPipe (
& remoteCommand ,
"ssh" , "-i" , service . SSHPrivateKeyFile , userAtHost ,
)
commandForErrorMessage := fmt . Sprintf (
"echo \"sh -c 'echo <base64InstallScript> | base64 -d | sh'\" | ssh -i %s %s" ,
service . SSHPrivateKeyFile , userAtHost ,
)
err = errorFromShellExecResult ( commandForErrorMessage , exitCode , stderr , stdout , err )
if err != nil {
return string ( stdout ) , string ( stderr ) , err
}
return string ( stdout ) , string ( stderr ) , nil
}
func ( service * SSHService ) AppendToKnownHostsFile ( knownHostsFileContent string ) error {
lines := strings . Split ( string ( knownHostsFileContent ) , "\n" )
validLines := [ ] string { }
ipAddress := ""
for _ , line := range lines {
if len ( strings . Trim ( line , "\t \n\r" ) ) > 10 {
fields := strings . Split ( line , " " )
if len ( fields ) >= 3 {
ip := fields [ 0 ]
hostKeyType := fields [ 1 ]
base64PublicKey := fields [ 2 ]
ipValid := regexp . MustCompile ( "(\\d+\\.)+\\d+" ) . FindString ( ip ) != ""
typeValid := ( hostKeyType == "ecdsa-sha2-nistp256" || hostKeyType == "ssh-rsa" || hostKeyType == "ssh-ed25519" )
base64Valid := regexp . MustCompile ( "[A-Za-z0-9+/=]+" ) . FindString ( base64PublicKey ) != ""
if ipValid && typeValid && base64Valid {
ipAddress = ip
validLines = append ( validLines , fmt . Sprintf ( "%s %s %s" , ip , hostKeyType , base64PublicKey ) )
}
}
}
}
if len ( validLines ) > 0 {
user , err := user . Current ( )
if err != nil {
return errors . Wrap ( err , "getCurrentUserSSHFolder" )
}
sshFolder := fmt . Sprintf ( "%s/.ssh" , user . HomeDir )
_ , err = os . Stat ( sshFolder )
if err != nil {
err := os . Mkdir ( sshFolder , 0700 )
if err != nil {
return err
}
}
knownHostsFilename := fmt . Sprintf ( "%s/known_hosts" , sshFolder )
_ , err = os . Stat ( knownHostsFilename )
if err == nil {
log . Printf ( "Removing %s from %s:\n" , ipAddress , knownHostsFilename )
log . Printf ( "ssh-keygen -f %s -R %s\n" , knownHostsFilename , ipAddress )
process := exec . Command ( "ssh-keygen" , "-f" , knownHostsFilename , "-R" , ipAddress )
err := process . Start ( )
if err != nil {
return err
}
err = process . Wait ( )
if err != nil {
return err
}
} else {
log . Printf ( "%s doesn't exist yet, creating it...\n" , knownHostsFilename )
}
file , err := os . OpenFile ( knownHostsFilename , os . O_CREATE | os . O_APPEND | os . O_WRONLY , 0600 )
if err != nil {
return err
}
defer file . Close ( )
log . Printf ( "Writing to %s:\n" , knownHostsFilename )
for _ , line := range validLines {
log . Println ( line )
if _ , err = file . WriteString ( fmt . Sprintf ( "\n%s" , line ) ) ; err != nil {
return err
}
}
} else {
return fmt . Errorf ( "knownHostsFileContent (%s) did not contain any valid lines" , knownHostsFileContent )
}
return nil
}
func shellExecInputPipe ( input * string , executable string , arguments ... string ) ( int , [ ] byte , [ ] byte , error ) {
process := exec . Command ( executable , arguments ... )
if input != nil {
stdin , err := process . StdinPipe ( )
if err != nil {
return - 1 , [ ] byte { } , [ ] byte { } , errors . Wrap ( err , "process.StdinPipe() returned" )
}
go func ( ) {
defer stdin . Close ( )
io . WriteString ( stdin , * input )
} ( )
}
var processStdoutBuffer , processStderrBuffer bytes . Buffer
process . Stdout = & processStdoutBuffer
process . Stderr = & processStderrBuffer
err := process . Start ( )
if err != nil {
err = errors . Wrapf ( err , "can't shellExecInputPipe(echo '%s' | %s %s), process.Start() returned" , * input , executable , strings . Join ( arguments , " " ) )
return process . ProcessState . ExitCode ( ) , [ ] byte ( "" ) , [ ] byte ( "" ) , err
}
err = process . Wait ( )
if err != nil {
err = errors . Wrapf ( err , "can't shellExecInputPipe(echo '%s' | %s %s), process.Wait() returned" , * input , executable , strings . Join ( arguments , " " ) )
}
return process . ProcessState . ExitCode ( ) , processStdoutBuffer . Bytes ( ) , processStderrBuffer . Bytes ( ) , err
}
func errorFromShellExecResult ( command string , exitCode int , stdout [ ] byte , stderr [ ] byte , err error ) error {
if exitCode != 0 || err != nil {
errorString := "nil"
if err != nil {
errorString = err . Error ( )
lines := strings . Split ( errorString , "\n" )
includeStack := [ ] string { }
for _ , line := range lines {
if ! strings . Contains ( line , "can't shellExecInputPipe" ) {
includeStack = append ( includeStack , line )
}
}
errorString = strings . Join ( includeStack , "\n" )
}
return fmt . Errorf (
"%s failed with exit code %d, stdout: \n----\n%s\n----\nstderr: \n----\n%s\n----\nstack: %s" ,
command , exitCode , stdout , stderr , errorString ,
)
}
return nil
}