diff --git a/lib/client/sshagent/agent.go b/lib/client/sshagent/agent.go index 1bbc8da..79745af 100644 --- a/lib/client/sshagent/agent.go +++ b/lib/client/sshagent/agent.go @@ -50,12 +50,13 @@ func deleteDuplicateEntries(comment string, agentClient agent.ExtendedAgent, log return deletedCount, nil } -func upsertCertIntoAgent( +func upsertCertIntoAgentConnection( certText []byte, privateKey interface{}, comment string, lifeTimeSecs uint32, confirmBeforeUse bool, + conn net.Conn, logger log.DebugLogger) error { pubKey, _, _, _, err := ssh.ParseAuthorizedKey(certText) if err != nil { @@ -72,19 +73,33 @@ func upsertCertIntoAgent( Comment: comment, ConfirmBeforeUse: confirmBeforeUse, } - return withAddedKeyUpsertCertIntoAgent(keyToAdd, logger) + return withAddedKeyUpsertCertIntoAgentConnection(keyToAdd, conn, logger) } -func withAddedKeyUpsertCertIntoAgent(certToAdd agent.AddedKey, logger log.DebugLogger) error { +func upsertCertIntoAgent( + certText []byte, + privateKey interface{}, + comment string, + lifeTimeSecs uint32, + confirmBeforeUse bool, + logger log.DebugLogger) error { + return upsertCertIntoAgentConnection(certText, privateKey, comment, lifeTimeSecs, confirmBeforeUse, nil, logger) +} + +func withAddedKeyUpsertCertIntoAgentConnection(certToAdd agent.AddedKey, conn net.Conn, logger log.DebugLogger) error { if certToAdd.Certificate == nil { return fmt.Errorf("Needs a certificate to be added") } - conn, err := connectToDefaultSSHAgentLocation() - if err != nil { - return err + var err error = nil + if conn == nil { + conn, err = connectToDefaultSSHAgentLocation() + if err != nil { + return err + } + defer conn.Close() } - defer conn.Close() + agentClient := agent.NewClient(conn) //delete certs in agent with the same comment @@ -102,3 +117,7 @@ func withAddedKeyUpsertCertIntoAgent(certToAdd agent.AddedKey, logger log.DebugL return agentClient.Add(certToAdd) } + +func withAddedKeyUpsertCertIntoAgent(certToAdd agent.AddedKey, logger log.DebugLogger) error { + return withAddedKeyUpsertCertIntoAgentConnection(certToAdd, nil, logger) +}