Skip to content

Commit

Permalink
Merge pull request #2823 from rockwotj/main
Browse files Browse the repository at this point in the history
ollama: disable ollama subprocess on non-unix platforms
  • Loading branch information
Jeffail authored Aug 29, 2024
2 parents c0c4770 + 3b86c6a commit d49c1cd
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 123 deletions.
138 changes: 15 additions & 123 deletions internal/impl/ollama/base_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,14 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"sync"
"syscall"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/dustin/go-humanize"
"github.com/ollama/ollama/api"
"github.com/redpanda-data/benthos/v4/public/service"
Expand Down Expand Up @@ -249,6 +244,14 @@ func (co *commandOutput) Write(b []byte) (int, error) {
return len(b), nil
}

type baseOllamaProcessor struct {
model string
opts map[string]any
ticket singleton.Ticket
client *api.Client
logger *service.Logger
}

type key int

const (
Expand All @@ -262,124 +265,6 @@ type runOllamaConfig struct {
downloadURL string
}

// lookPath is very similar to exec.LookPath, except that it uses service.FS filesystem
// abstractions.
func (c *runOllamaConfig) lookPath(file string) (string, error) {
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
return "", errors.ErrUnsupported
}
path := os.Getenv("PATH")
for _, dir := range filepath.SplitList(path) {
if !filepath.IsAbs(path) {
continue
}
path := filepath.Join(dir, file)
p, err := c.fs.Stat(path)
if err != nil || p.IsDir() {
continue
}
// Check that the file is executable
if err := syscall.Access(path, 0x1); err == nil {
return path, nil
}
}
return "", exec.ErrNotFound
}

func (c *runOllamaConfig) downloadOllama(ctx context.Context, path string) error {
var url string
if c.downloadURL == "" {
const baseURL string = "https://github.com/ollama/ollama/releases/download/v0.3.6/ollama"
switch runtime.GOOS {
case "darwin":
// They ship an universal executable for darwin
url = baseURL + "-darwin"
case "linux":
url = fmt.Sprintf("%s-%s-%s", baseURL, runtime.GOOS, runtime.GOARCH)
default:
return fmt.Errorf("automatic download of ollama is not supported on %s, please download ollama manually", runtime.GOOS)
}
} else {
url = c.downloadURL
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to download ollama binary: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download ollama binary: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("failed to download ollama binary: status_code=%d", resp.StatusCode)
}
ollama, err := c.fs.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0755)
if err != nil {
return fmt.Errorf("unable to create file for ollama binary download: %w", err)
}
defer ollama.Close()
w, ok := (ollama).(io.Writer)
if !ok {
return errors.New("unable to download ollama binary to filesystem")
}
_, err = io.Copy(w, resp.Body)
return err
}

var ollamaProcess = singleton.New(singleton.Config[*exec.Cmd]{
Constructor: func(ctx context.Context) (*exec.Cmd, error) {
cfg, ok := ctx.Value(configKey).(runOllamaConfig)
if !ok {
return nil, errors.New("missing config")
}
serverPath, err := cfg.lookPath("ollama")
if errors.Is(err, exec.ErrNotFound) {
serverPath = path.Join(cfg.cacheDir, "ollama")
if err := os.MkdirAll(cfg.cacheDir, 0777); err != nil {
return nil, err
}
if _, err = os.Stat(serverPath); errors.Is(err, os.ErrNotExist) {
err = backoff.Retry(func() error {
cfg.logger.Infof("downloading ollama to %s", serverPath)
return cfg.downloadOllama(ctx, serverPath)
}, backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Second), 3))
}
if err != nil {
return nil, err
}
cfg.logger.Info("ollama download complete")
} else if err != nil {
return nil, err
}
cfg.logger.Tracef("starting ollama subprocess at %s", serverPath)
proc := exec.Command(serverPath, "serve")
proc.Env = append(os.Environ(), "OLLAMA_MODELS="+cfg.cacheDir)
proc.Stdout = &commandOutput{logger: cfg.logger}
proc.Stderr = &commandOutput{logger: cfg.logger}
if err = proc.Start(); err != nil {
return nil, err
}
return proc, nil
},
Destructor: func(ctx context.Context, cmd *exec.Cmd) error {
if cmd.Process == nil {
return nil
}
if err := cmd.Process.Kill(); err != nil {
return err
}
return cmd.Wait()
}})

type baseOllamaProcessor struct {
model string
opts map[string]any
ticket singleton.Ticket
client *api.Client
logger *service.Logger
}

func newBaseProcessor(conf *service.ParsedConfig, mgr *service.Resources) (p *baseOllamaProcessor, err error) {
p = &baseOllamaProcessor{}
p.logger = mgr.Logger()
Expand Down Expand Up @@ -428,6 +313,10 @@ func newBaseProcessor(conf *service.ParsedConfig, mgr *service.Resources) (p *ba
fs: mgr.FS(),
downloadURL: downloadURL,
})
if ollamaProcess == nil {
err = fmt.Errorf("running a local ollama process is not supported on %s, please specify a `%s`", runtime.GOOS, bopFieldServerAddress)
return
}
_, p.ticket, err = ollamaProcess.Acquire(ctx)
if err != nil {
return
Expand Down Expand Up @@ -477,5 +366,8 @@ func (o *baseOllamaProcessor) pullModel(ctx context.Context) error {
}

func (o *baseOllamaProcessor) Close(ctx context.Context) error {
if ollamaProcess == nil {
return nil
}
return ollamaProcess.Close(ctx, o.ticket)
}
19 changes: 19 additions & 0 deletions internal/impl/ollama/subprocess.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright 2024 Redpanda Data, Inc.
//
// Licensed as a Redpanda Enterprise file under the Redpanda Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md

//go:build !unix

package ollama

import (
"os/exec"

"github.com/redpanda-data/connect/v4/internal/singleton"
)

var ollamaProcess *singleton.Singleton[*exec.Cmd]
136 changes: 136 additions & 0 deletions internal/impl/ollama/subprocess_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright 2024 Redpanda Data, Inc.
//
// Licensed as a Redpanda Enterprise file under the Redpanda Community
// License (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md

//go:build unix

package ollama

import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path"
"path/filepath"
"runtime"
"syscall"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/redpanda-data/connect/v4/internal/singleton"
)

// lookPath is very similar to exec.LookPath, except that it uses service.FS filesystem
// abstractions.
func (c *runOllamaConfig) lookPath(file string) (string, error) {
path := os.Getenv("PATH")
for _, dir := range filepath.SplitList(path) {
if !filepath.IsAbs(path) {
continue
}
path := filepath.Join(dir, file)
p, err := c.fs.Stat(path)
if err != nil || p.IsDir() {
continue
}
// Check that the file is executable
if err := syscall.Access(path, 0x1); err == nil {
return path, nil
}
}
return "", exec.ErrNotFound
}

func (c *runOllamaConfig) downloadOllama(ctx context.Context, path string) error {
var url string
if c.downloadURL == "" {
const baseURL string = "https://github.com/ollama/ollama/releases/download/v0.3.6/ollama"
switch runtime.GOOS {
case "darwin":
// They ship an universal executable for darwin
url = baseURL + "-darwin"
case "linux":
url = fmt.Sprintf("%s-%s-%s", baseURL, runtime.GOOS, runtime.GOARCH)
default:
return fmt.Errorf("automatic download of ollama is not supported on %s, please download ollama manually", runtime.GOOS)
}
} else {
url = c.downloadURL
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to download ollama binary: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download ollama binary: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return fmt.Errorf("failed to download ollama binary: status_code=%d", resp.StatusCode)
}
ollama, err := c.fs.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0755)
if err != nil {
return fmt.Errorf("unable to create file for ollama binary download: %w", err)
}
defer ollama.Close()
w, ok := (ollama).(io.Writer)
if !ok {
return errors.New("unable to download ollama binary to filesystem")
}
_, err = io.Copy(w, resp.Body)
return err
}

var ollamaProcess = singleton.New(singleton.Config[*exec.Cmd]{
Constructor: func(ctx context.Context) (*exec.Cmd, error) {
cfg, ok := ctx.Value(configKey).(runOllamaConfig)
if !ok {
return nil, errors.New("missing config")
}
serverPath, err := cfg.lookPath("ollama")
if errors.Is(err, exec.ErrNotFound) {
serverPath = path.Join(cfg.cacheDir, "ollama")
if err := os.MkdirAll(cfg.cacheDir, 0777); err != nil {
return nil, err
}
if _, err = os.Stat(serverPath); errors.Is(err, os.ErrNotExist) {
err = backoff.Retry(func() error {
cfg.logger.Infof("downloading ollama to %s", serverPath)
return cfg.downloadOllama(ctx, serverPath)
}, backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Second), 3))
}
if err != nil {
return nil, err
}
cfg.logger.Info("ollama download complete")
} else if err != nil {
return nil, err
}
cfg.logger.Tracef("starting ollama subprocess at %s", serverPath)
proc := exec.Command(serverPath, "serve")
proc.Env = append(os.Environ(), "OLLAMA_MODELS="+cfg.cacheDir)
proc.Stdout = &commandOutput{logger: cfg.logger}
proc.Stderr = &commandOutput{logger: cfg.logger}
if err = proc.Start(); err != nil {
return nil, err
}
return proc, nil
},
Destructor: func(ctx context.Context, cmd *exec.Cmd) error {
if cmd.Process == nil {
return nil
}
if err := cmd.Process.Kill(); err != nil {
return err
}
return cmd.Wait()
}})

0 comments on commit d49c1cd

Please sign in to comment.