diff --git a/encoding/kzg/prover/icicle/multiframe_proof.go b/encoding/kzg/prover/icicle/multiframe_proof.go index 09194d8e7..0ec4ef7d5 100644 --- a/encoding/kzg/prover/icicle/multiframe_proof.go +++ b/encoding/kzg/prover/icicle/multiframe_proof.go @@ -16,7 +16,7 @@ import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/ingonyama-zk/icicle/v3/wrappers/golang/core" iciclebn254 "github.com/ingonyama-zk/icicle/v3/wrappers/golang/curves/bn254" - icicleRuntime "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" + "github.com/ingonyama-zk/icicle/v3/wrappers/golang/runtime" ) type KzgMultiProofIcicleBackend struct { @@ -28,7 +28,7 @@ type KzgMultiProofIcicleBackend struct { Srs *kzg.SRS NttCfg core.NTTConfig[[iciclebn254.SCALAR_LIMBS]uint32] MsmCfg core.MSMConfig - Device icicleRuntime.Device + Device runtime.Device GpuLock sync.Mutex } @@ -63,85 +63,87 @@ func (p *KzgMultiProofIcicleBackend) ComputeMultiFrameProof(polyFr []fr.Element, flattenCoeffStoreSf := icicle.ConvertFrToScalarFieldsBytes(flattenCoeffStoreFr) flattenCoeffStoreCopy := core.HostSliceFromElements[iciclebn254.ScalarField](flattenCoeffStoreSf) - var msmDone, fft1Done, fft2Done time.Time - flatProofsBatchHost := make(core.HostSlice[iciclebn254.Projective], int(numPoly)*int(dimE)) + var icicleFFTBatch []bn254.G1Affine + var icicleErr error - // Set device - icicleErr := icicleRuntime.SetDevice(&p.Device) - if icicleErr != icicleRuntime.Success { - return nil, fmt.Errorf("failed to set device: %v", icicleErr.AsString()) - } - - // Channel just for error handling - errChan := make(chan error, 1) + // GPU operations + p.GpuLock.Lock() + defer p.GpuLock.Unlock() wg := sync.WaitGroup{} wg.Add(1) + errChan := make(chan error, 1) - icicleRuntime.RunOnDevice(&p.Device, func(args ...any) { + var msmDone, firstECNttDone, secondECNttDone time.Time + runtime.RunOnDevice(&p.Device, func(args ...any) { defer wg.Done() defer close(errChan) defer func() { if r := recover(); r != nil { - errChan <- fmt.Errorf("panic in GPU operations: %v", r) + icicleErr = fmt.Errorf("GPU operation panic: %v", r) } }() - // Copy data to device + // Copy the flatten coeff to device var flattenStoreCopyToDevice core.DeviceSlice flattenCoeffStoreCopy.CopyToDevice(&flattenStoreCopyToDevice, true) - // 1. MSM Batch Operation sumVec, err := p.MsmBatchOnDevice(flattenStoreCopyToDevice, p.FlatFFTPointsT, int(numPoly)*int(dimE)*2) if err != nil { - errChan <- fmt.Errorf("msm batch error: %v", err) + icicleErr = fmt.Errorf("msm error: %w", err) return } + + // Free the flatten coeff store + flattenStoreCopyToDevice.Free() + msmDone = time.Now() - // 2. First ECNtt + // Compute the first ecntt, and set new batch size for ntt p.NttCfg.BatchSize = int32(numPoly) sumVecInv, err := p.ECNttOnDevice(sumVec, true, int(dimE)*2*int(numPoly)) if err != nil { - errChan <- fmt.Errorf("first ECNtt error: %v", err) + icicleErr = fmt.Errorf("first ECNtt error: %w", err) return } - fft1Done = time.Now() - // 3. Prune and Second ECNtt + sumVec.Free() + + firstECNttDone = time.Now() + prunedSumVecInv := sumVecInv.Range(0, int(dimE), false) + + // Compute the second ecntt on the reduced size array flatProofsBatch, err := p.ECNttToGnarkOnDevice(prunedSumVecInv, false, int(numPoly)*int(dimE)) if err != nil { - errChan <- fmt.Errorf("second ECNtt error: %v", err) + icicleErr = fmt.Errorf("second ECNtt error: %w", err) return } - fft2Done = time.Now() - // 4. Copy results back to host - flatProofsBatchHost.CopyFromDevice(&flatProofsBatch) + prunedSumVecInv.Free() + + secondECNttDone = time.Now() - // Free memory + flatProofsBatchHost := make(core.HostSlice[iciclebn254.Projective], int(numPoly)*int(dimE)) + flatProofsBatchHost.CopyFromDevice(&flatProofsBatch) flatProofsBatch.Free() - sumVecInv.Free() - sumVec.Free() - flattenStoreCopyToDevice.Free() + icicleFFTBatch = icicle.HostSliceIcicleProjectiveToGnarkAffine(flatProofsBatchHost, int(p.NumWorker)) }) wg.Wait() - // Check for errors - if err := <-errChan; err != nil { - return nil, err + if icicleErr != nil { + return nil, icicleErr } - // Convert to final format - icicleFFTBatch := icicle.HostSliceIcicleProjectiveToGnarkAffine(flatProofsBatchHost, int(p.NumWorker)) + end := time.Now() slog.Info("Multiproof Time Decomp", + "total", end.Sub(begin), "preproc", preprocessDone.Sub(begin), "msm", msmDone.Sub(preprocessDone), - "fft1", fft1Done.Sub(msmDone), - "fft2", fft2Done.Sub(fft1Done), + "fft1", firstECNttDone.Sub(msmDone), + "fft2", secondECNttDone.Sub(firstECNttDone), ) return icicleFFTBatch, nil diff --git a/encoding/rs/icicle/extend_poly.go b/encoding/rs/icicle/extend_poly.go index 067bdfd24..f6df44a80 100644 --- a/encoding/rs/icicle/extend_poly.go +++ b/encoding/rs/icicle/extend_poly.go @@ -38,18 +38,15 @@ func (g *RsIcicleBackend) ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, err return nil, fmt.Errorf("failed to set device: %v", err.AsString()) } - // Channel to receive any errors from the GPU operation - errChan := make(chan error, 1) - + var icicleErr error // Perform NTT wg := sync.WaitGroup{} wg.Add(1) icicleRuntime.RunOnDevice(&g.Device, func(args ...any) { defer wg.Done() - defer close(errChan) defer func() { if r := recover(); r != nil { - errChan <- fmt.Errorf("GPU operation panic: %v", r) + icicleErr = fmt.Errorf("GPU operation panic: %v", r) } }() @@ -59,8 +56,8 @@ func (g *RsIcicleBackend) ExtendPolyEval(coeffs []fr.Element) ([]fr.Element, err wg.Wait() // Check if there was a panic - if err := <-errChan; err != nil { - return nil, err + if icicleErr != nil { + return nil, icicleErr } evals := icicle.ConvertScalarFieldsToFrBytes(outputDevice)