diff --git a/pkg/test/vpc_client/bastion.go b/pkg/test/vpc_client/bastion.go index 913ace6..1e1c260 100644 --- a/pkg/test/vpc_client/bastion.go +++ b/pkg/test/vpc_client/bastion.go @@ -6,11 +6,15 @@ import ( "fmt" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/openshift-online/ocm-common/pkg/file" - "net" + "golang.org/x/crypto/bcrypt" + "golang.org/x/crypto/ssh" + "os" + "path" "time" CON "github.com/openshift-online/ocm-common/pkg/aws/consts" "github.com/openshift-online/ocm-common/pkg/log" + "github.com/openshift-online/ocm-common/pkg/utils" ) // LaunchBastion will launch a bastion instance on the indicated zone. @@ -94,8 +98,8 @@ func (vpc *VPC) LaunchBastion(imageID string, zone string, userData string, keyp return inst, nil } -func (vpc *VPC) PrepareBastionProxy(zone string, cidrBlock string, keypairName string, - privateKeyPath string) (*types.Instance, error) { +func (vpc *VPC) PrepareBastionProxy(zone string, keypairName string, privateKeyPath string) (*types.Instance, string, + string, error) { filters := []map[string][]string{ { "vpc-id": { @@ -111,38 +115,57 @@ func (vpc *VPC) PrepareBastionProxy(zone string, cidrBlock string, keypairName s insts, err := vpc.AWSClient.ListInstances([]string{}, filters...) if err != nil { - return nil, err + return nil, "", "", err } if len(insts) == 0 { log.LogInfo("Didn't found an existing bastion, going to launch one") - if cidrBlock == "" { - cidrBlock = CON.RouteDestinationCidrBlock - } - _, _, err = net.ParseCIDR(cidrBlock) - if err != nil { - log.LogError("CIDR IP address format is invalid") - return nil, err - } - userData := fmt.Sprintf(`#!/bin/bash + userData := `#!/bin/bash yum update -y - yum install -y squid + sudo dnf install squid -y cd /etc/squid/ sudo mv ./squid.conf ./squid.conf.bak sudo touch squid.conf echo http_port 3128 >> /etc/squid/squid.conf - echo acl allowed_ips src %s >> /etc/squid/squid.conf - echo http_access allow allowed_ips >> /etc/squid/squid.conf + echo auth_param basic program /usr/lib64/squid/basic_ncsa_auth /etc/squid/passwords >> /etc/squid/squid.conf + echo auth_param basic realm Squid Proxy Server >> /etc/squid/squid.conf + echo acl authenticated proxy_auth REQUIRED >> /etc/squid/squid.conf + echo http_access allow authenticated >> /etc/squid/squid.conf echo http_access deny all >> /etc/squid/squid.conf systemctl start squid - systemctl enable squid`, cidrBlock) + systemctl enable squid` encodeUserData := base64.StdEncoding.EncodeToString([]byte(userData)) - return vpc.LaunchBastion("", zone, encodeUserData, keypairName, privateKeyPath) + instance, err := vpc.LaunchBastion("", zone, encodeUserData, keypairName, privateKeyPath) + if err != nil { + log.LogError("Launch bastion failed") + } + + username := utils.RandomLabel(5) + password := utils.RandomLabel(5) + hashedPassword, err := generateBcryptPassword(password) + if err != nil { + return nil, "", "", err + } + + localFilePath := "./tmp/passwords" + err = writePasswordToFile(username, hashedPassword, localFilePath) + if err != nil { + return nil, "", "", err + } + + remoteFilePath := "/etc/squid/passwords" + err = uploadFileToBastion(*instance.PublicIpAddress, "22", "ec2-user", privateKeyPath, + keypairName, localFilePath, remoteFilePath) + + if err != nil { + return nil, "", "", err + } + return instance, username, password, nil } log.LogInfo("Found existing bastion: %s", *insts[0].InstanceId) - return &insts[0], nil + return &insts[0], "", "", nil } func (vpc *VPC) DestroyBastionProxy(instance types.Instance) error { @@ -155,3 +178,121 @@ func (vpc *VPC) DestroyBastionProxy(instance types.Instance) error { } return nil } + +func generateBcryptPassword(plainPassword string) (string, error) { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(plainPassword), bcrypt.DefaultCost) + if err != nil { + log.LogError("Generate hashed password failed") + return "", nil + } + return string(hashedPassword), nil +} + +func writePasswordToFile(username, hashedPassword, filePath string) error { + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + dir := path.Dir(filePath) + err = os.MkdirAll(dir, os.ModePerm) + if err != nil { + log.LogError("Create directory failed") + return err + } + _, err = os.Create(filePath) + if err != nil { + log.LogError("Create passed file failed") + return err + } + } else if err != nil { + log.LogError("Stat file failed") + return err + } + file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600) + if err != nil { + log.LogError("Open password file failed") + return err + } + defer file.Close() + + line := fmt.Sprintf("%s:%s\n", username, hashedPassword) + _, err = file.WriteString(line) + if err != nil { + log.LogError("Write password file failed") + return err + } + return nil +} + +func loadPrivateKey(privateKeyPath, keypairName string) (ssh.Signer, error) { + privateKeyName := fmt.Sprintf("%s/%s-%s", privateKeyPath, keypairName, "keyPair.pem") + key, err := os.ReadFile(privateKeyName) + if err != nil { + log.LogError("Read privte key failed") + return nil, err + } + + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + log.LogError("Parse privte key failed") + return nil, err + } + return signer, nil +} + +func uploadFileToBastion(host, port, username, privateKeyPath, keypairName, localFilePath, remoteFilePath string) error { + signer, err := loadPrivateKey(privateKeyPath, keypairName) + if err != nil { + log.LogError("Load private key failed") + return err + } + + sshConfig := &ssh.ClientConfig{ + User: username, + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + address := fmt.Sprintf("%s:%s", host, port) + client, err := ssh.Dial("tcp", address, sshConfig) + if err != nil { + log.LogError("Connect to bastion failed") + return err + } + defer client.Close() + + fileContent, err := os.ReadFile(localFilePath) + if err != nil { + log.LogError("Open local passed file failed") + return err + } + + session, err := client.NewSession() + if err != nil { + log.LogError("Create ssh session failed") + return err + } + defer session.Close() + + remoteFile, err := session.StdinPipe() + if err != nil { + log.LogError("Create stdin pipe failed") + return err + } + + command := fmt.Sprintf("sudo tee %s", remoteFilePath) + err = session.Start(command) + if err != nil { + log.LogError("Start remote command failed") + return err + } + + _, err = remoteFile.Write(fileContent) + if err != nil { + log.LogError("Write file to bastion failed") + return err + } + + log.LogInfo("Upload file to bastion successfully") + return nil +}