Skip to content

Commit

Permalink
chore(vad): try to hook vad to received data from the API
Browse files Browse the repository at this point in the history
Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Nov 14, 2024
1 parent 3379a29 commit 38e3959
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 3 deletions.
17 changes: 14 additions & 3 deletions core/http/endpoints/openai/realtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@ import (
"strings"
"sync"

"github.com/go-audio/audio"
"github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/sound"

"google.golang.org/grpc"

"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -456,9 +459,17 @@ func handleVAD(session *Session, conversation *Conversation, c *websocket.Conn,
// Check if there's audio data to process
session.AudioBufferLock.Lock()
if len(session.InputAudioBuffer) > 0 {
// TODO: what to put in the VADRequest request?
// Data is received as buffer, but we want PCM as float32 here...
resp, err := session.ModelInterface.VAD(context.Background(), &proto.VADRequest{})

adata := sound.BytesToInt16sLE(session.InputAudioBuffer)

soundIntBuffer := &audio.IntBuffer{
Format: &audio.Format{SampleRate: 16000, NumChannels: 1},
}
soundIntBuffer.Data = sound.ConvertInt16ToInt(adata)

resp, err := session.ModelInterface.VAD(context.Background(), &proto.VADRequest{
Audio: soundIntBuffer.AsFloat32Buffer().Data,
})
if err != nil {
log.Error().Msgf("failed to process audio: %s", err.Error())
sendError(c, "processing_error", "Failed to process audio", "", "")
Expand Down
20 changes: 20 additions & 0 deletions pkg/sound/float32.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package sound

import (
"encoding/binary"
"math"
)

func BytesToFloat32Array(aBytes []byte) []float32 {
aArr := make([]float32, 3)
for i := 0; i < 3; i++ {
aArr[i] = BytesFloat32(aBytes[i*4:])
}
return aArr
}

func BytesFloat32(bytes []byte) float32 {
bits := binary.LittleEndian.Uint32(bytes)
float := math.Float32frombits(bits)
return float
}
65 changes: 65 additions & 0 deletions pkg/sound/int16.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package sound

/*
MIT License
Copyright (c) 2024 Xbozon
*/

func ResampleInt16(input []int16, inputRate, outputRate int) []int16 {
// Calculate the resampling ratio
ratio := float64(inputRate) / float64(outputRate)

// Calculate the length of the resampled output
outputLength := int(float64(len(input)) / ratio)

// Allocate a slice for the resampled output
output := make([]int16, outputLength)

// Perform linear interpolation for resampling
for i := 0; i < outputLength-1; i++ {
// Calculate the corresponding position in the input
pos := float64(i) * ratio

// Calculate the indices of the surrounding input samples
indexBefore := int(pos)
indexAfter := indexBefore + 1
if indexAfter >= len(input) {
indexAfter = len(input) - 1
}

// Calculate the fractional part of the position
frac := pos - float64(indexBefore)

// Linearly interpolate between the two surrounding input samples
output[i] = int16((1-frac)*float64(input[indexBefore]) + frac*float64(input[indexAfter]))
}

// Handle the last sample explicitly to avoid index out of range
output[outputLength-1] = input[len(input)-1]

return output
}

func ConvertInt16ToInt(input []int16) []int {
output := make([]int, len(input)) // Allocate a slice for the output
for i, value := range input {
output[i] = int(value) // Convert each int16 to int and assign it to the output slice
}
return output // Return the converted slice
}

func BytesToInt16sLE(bytes []byte) []int16 {
// Ensure the byte slice length is even
if len(bytes)%2 != 0 {
panic("bytesToInt16sLE: input bytes slice has odd length, must be even")
}

int16s := make([]int16, len(bytes)/2)
for i := 0; i < len(int16s); i++ {
int16s[i] = int16(bytes[2*i]) | int16(bytes[2*i+1])<<8
}
return int16s
}

0 comments on commit 38e3959

Please sign in to comment.