Skip to content

Commit

Permalink
feat(commits): Implement cached response for chain responses (#530)
Browse files Browse the repository at this point in the history
- Added `DocumentCache` class to handle caching of chain step responses based on workspace, config, and conversation.
- Introduced `getCachedResponse` and `setCachedResponse` functions to retrieve and store cached responses.
- Updated `runChain` function to utilize cached responses if available, reducing redundant calls to the AI module.
- Added tests to verify the caching behavior and ensure correct functionality with various configurations and scenarios.

This change improves performance by avoiding repeated AI calls for the same conversation and configuration, leveraging cached responses when appropriate.
  • Loading branch information
geclos authored Nov 4, 2024
1 parent 24f7b51 commit fe53763
Show file tree
Hide file tree
Showing 4 changed files with 414 additions and 2 deletions.
183 changes: 181 additions & 2 deletions packages/core/src/services/chains/run.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import { Chain, ContentType, MessageRole } from '@latitude-data/compiler'
import {
Chain,
ContentType,
Conversation,
MessageRole,
} from '@latitude-data/compiler'
import { v4 as uuid } from 'uuid'
import { beforeEach, describe, expect, it, vi } from 'vitest'

Expand All @@ -7,6 +12,8 @@ import { ErrorableEntity, LogSources, Providers } from '../../constants'
import { Result } from '../../lib'
import * as factories from '../../tests/factories'
import * as aiModule from '../ai'
import { setCachedResponse } from '../commits/promptCache'
import * as chainValidatorModule from './ChainValidator'
import { runChain } from './run'

// Mock other dependencies
Expand Down Expand Up @@ -540,7 +547,6 @@ describe('runChain', () => {
}),
},
})

vi.spyOn(aiModule, 'ai').mockResolvedValue(mockAiResponse as any)
vi.mocked(mockChain.step!).mockResolvedValue({
completed: true,
Expand Down Expand Up @@ -577,4 +583,177 @@ describe('runChain', () => {
}),
)
})

describe('with cached response', () => {
let config = { provider: 'openai', model: 'gpt-3.5-turbo' }
let conversation = {
messages: [
{
role: MessageRole.user,
content: [{ type: ContentType.text, text: 'Test message' }],
},
],
config,
} as Conversation

beforeEach(async () => {
vi.mocked(mockChain.step!).mockResolvedValue({
completed: true,
conversation,
})

vi.spyOn(chainValidatorModule, 'ChainValidator').mockImplementation(
() => {
return {
call: vi.fn().mockResolvedValue(
Result.ok({
chainCompleted: true,
config,
conversation,

provider: providersMap.get('openai'),
}),
),
} as any
},
)
})

it('returns the cached response', async () => {
await setCachedResponse({
workspace,
config,
conversation,
response: {
streamType: 'text',
text: 'cached response',
usage: {
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
},
toolCalls: [],
},
})
const run = await runChain({
workspace,
chain: mockChain as Chain,
providersMap,
source: LogSources.API,
errorableType: ErrorableEntity.DocumentLog,
})
const spy = vi.spyOn(aiModule, 'ai')
const res = await run.response

expect(spy).not.toHaveBeenCalled()
expect(res.value).toEqual(
expect.objectContaining({ text: 'cached response' }),
)
})

describe('with config having temperature != 0', () => {
beforeEach(() => {
// @ts-expect-error - mock
config.temperature = 0.5
})

it('returns the cached response', async () => {
await setCachedResponse({
workspace,
config,
conversation,
response: {
streamType: 'text',
text: 'cached response',
usage: {
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
},
toolCalls: [],
},
})
const mockAiResponse = createMockAiResponse('AI response', 10)

const spy = vi
.spyOn(aiModule, 'ai')
.mockResolvedValue(mockAiResponse as any)

const run = await runChain({
workspace,
chain: mockChain as Chain,
providersMap,
source: LogSources.API,
errorableType: ErrorableEntity.DocumentLog,
})
const result = await run.response

expect(spy).toHaveBeenCalled()
expect(result.ok).toEqual(true)
expect(result.value).not.toEqual(
expect.objectContaining({ text: 'cached response' }),
)
})
})

describe('with conversation having multiple steps', () => {
beforeEach(() => {
vi.spyOn(chainValidatorModule, 'ChainValidator').mockImplementation(
() => {
return {
call: vi
.fn()
.mockResolvedValue(
Result.ok({ chainCompleted: false, config, conversation }),
)
.mockResolvedValue(
Result.ok({
chainCompleted: true,
config,
conversation,
provider: providersMap.get('openai'),
}),
),
} as any
},
)
})

it('returns the cached response first and then calls ai module', async () => {
await setCachedResponse({
workspace,
config,
conversation,
response: {
streamType: 'text',
text: 'cached response',
usage: {
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
},
toolCalls: [],
},
})

const mockAiResponse = createMockAiResponse('AI response', 10)
const spy = vi
.spyOn(aiModule, 'ai')
.mockResolvedValue(mockAiResponse as any)

const run = await runChain({
workspace,
chain: mockChain as Chain,
providersMap,
source: LogSources.API,
errorableType: ErrorableEntity.DocumentLog,
})

const result = await run.response

expect(spy).toHaveBeenCalledOnce()
expect(result.ok).toEqual(true)
})
})
})
})
39 changes: 39 additions & 0 deletions packages/core/src/services/chains/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
import { Result, TypedResult } from '../../lib'
import { generateUUIDIdentifier } from '../../lib/generateUUID'
import { ai } from '../ai'
import { getCachedResponse, setCachedResponse } from '../commits/promptCache'
import { createRunError } from '../runErrors/create'
import { ChainError } from './ChainErrors'
import { ChainStreamConsumer } from './ChainStreamConsumer'
Expand Down Expand Up @@ -162,6 +163,35 @@ async function runStep({
try {
const step = await chainValidator.call().then((r) => r.unwrap())
const { messageCount, stepStartTime } = streamConsumer.setup(step)
const cachedResponse = await getCachedResponse({
workspace,
config: step.config,
conversation: step.conversation,
})

if (cachedResponse) {
if (step.chainCompleted) {
streamConsumer.chainCompleted({ step, response: cachedResponse })

return cachedResponse
} else {
streamConsumer.stepCompleted(cachedResponse)

return runStep({
workspace,
source,
chain,
providersMap,
controller,
errorableUuid,
errorableType,
previousCount: previousCount + 1,
previousResponse: cachedResponse,
configOverrides,
})
}
}

const providerProcessor = new ProviderProcessor({
workspace,
source,
Expand Down Expand Up @@ -193,12 +223,21 @@ async function runStep({
.then((r) => r.unwrap())

if (consumedStream.error) throw consumedStream.error

await setCachedResponse({
workspace,
config: step.config,
conversation: step.conversation,
response,
})

if (step.chainCompleted) {
streamConsumer.chainCompleted({ step, response })

return response
} else {
streamConsumer.stepCompleted(response)

return runStep({
workspace,
source,
Expand Down
114 changes: 114 additions & 0 deletions packages/core/src/services/commits/promptCache.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'

import * as cacheModule from '../../cache'
import { createProject } from '../../tests/factories'
import { getCachedResponse, setCachedResponse } from './promptCache'

describe('promptCache', async () => {
const mockCache = {
get: vi.fn(),
set: vi.fn(),
}

beforeEach(() => {
// @ts-expect-error - mock
vi.spyOn(cacheModule, 'cache').mockResolvedValue(mockCache)
vi.clearAllMocks()
})

const { workspace } = await createProject()
const config = { temperature: 0 }
const conversation = { messages: [], config }
const response = {
streamType: 'text',
text: 'cached response',
usage: {
promptTokens: 0,
completionTokens: 0,
totalTokens: 0,
},
toolCalls: [],
}

describe('getCachedResponse', () => {
it('returns undefined when temperature is not 0', async () => {
const result = await getCachedResponse({
workspace,
config: { temperature: 0.5 },
conversation,
})

expect(result).toBeUndefined()
expect(mockCache.get).not.toHaveBeenCalled()
})

it('returns cached response when available', async () => {
mockCache.get.mockResolvedValueOnce(JSON.stringify(response))

const result = await getCachedResponse({
workspace,
config,
conversation,
})

expect(result).toEqual(response)
expect(mockCache.get).toHaveBeenCalledTimes(1)
})

it('returns undefined when cache throws error', async () => {
mockCache.get.mockRejectedValueOnce(new Error('Cache error'))

const result = await getCachedResponse({
workspace,
config,
conversation,
})

expect(result).toBeUndefined()
})
})

describe('setCachedResponse', () => {
it('does not cache when temperature is not 0', async () => {
await setCachedResponse({
workspace,
config: { temperature: 0.5 },
conversation,
// @ts-expect-error - mock
response,
})

expect(mockCache.set).not.toHaveBeenCalled()
})

it('caches response when temperature is 0', async () => {
await setCachedResponse({
workspace,
config,
conversation,
// @ts-expect-error - mock
response,
})

expect(mockCache.set).toHaveBeenCalledTimes(1)
expect(mockCache.set).toHaveBeenCalledWith(
expect.stringContaining(`workspace:${workspace.id}:prompt:`),
JSON.stringify(response),
)
})

it('silently fails when cache throws error', async () => {
mockCache.set.mockRejectedValueOnce(new Error('Cache error'))

await expect(
setCachedResponse({
workspace,
config,
conversation,
// @ts-expect-error - mock
response,
}),
).resolves.toBeUndefined()
})
})
})
Loading

0 comments on commit fe53763

Please sign in to comment.