Skip to content

Commit

Permalink
feat: Move function transition tracking into celery (#25173)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjackwhite authored Sep 30, 2024
1 parent 0ef0fe1 commit e8ce162
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 103 deletions.
28 changes: 1 addition & 27 deletions plugin-server/src/cdp/cdp-consumers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import {
} from '../types'
import { createKafkaProducerWrapper } from '../utils/db/hub'
import { KafkaProducerWrapper } from '../utils/db/kafka-producer-wrapper'
import { captureTeamEvent } from '../utils/posthog'
import { status } from '../utils/status'
import { castTimestampOrNow } from '../utils/utils'
import { RustyHook } from '../worker/rusty-hook'
Expand All @@ -44,7 +43,6 @@ import {
HogFunctionInvocationSerialized,
HogFunctionInvocationSerializedCompressed,
HogFunctionMessageToProduce,
HogFunctionType,
HogHooksFetchResponse,
} from './types'
import {
Expand Down Expand Up @@ -117,9 +115,7 @@ abstract class CdpConsumerBase {
constructor(protected hub: Hub) {
this.redis = createCdpRedisPool(hub)
this.hogFunctionManager = new HogFunctionManager(hub)
this.hogWatcher = new HogWatcher(hub, this.redis, (id, state) => {
void this.captureInternalPostHogEvent(id, 'hog function state changed', { state })
})
this.hogWatcher = new HogWatcher(hub, this.redis)
this.hogMasker = new HogMasker(this.redis)
this.hogExecutor = new HogExecutor(this.hub, this.hogFunctionManager)
const rustyHook = this.hub?.rustyHook ?? new RustyHook(this.hub)
Expand All @@ -136,28 +132,6 @@ abstract class CdpConsumerBase {
}
}

private async captureInternalPostHogEvent(
hogFunctionId: HogFunctionType['id'],
event: string,
properties: any = {}
) {
const hogFunction = this.hogFunctionManager.getHogFunction(hogFunctionId)
if (!hogFunction) {
return
}
const team = await this.hub.teamManager.fetchTeam(hogFunction.team_id)

if (!team) {
return
}

captureTeamEvent(team, event, {
...properties,
hog_function_id: hogFunctionId,
hog_function_url: `${this.hub.SITE_URL}/project/${team.id}/pipeline/destinations/hog-${hogFunctionId}`,
})
}

protected async runWithHeartbeat<T>(func: () => Promise<T> | T): Promise<T> {
// Helper function to ensure that looping over lots of hog functions doesn't block up the thread, killing the consumer
const res = await func()
Expand Down
28 changes: 16 additions & 12 deletions plugin-server/src/cdp/hog-watcher.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ export type HogWatcherFunctionState = {
rating: number
}

// TODO: Future follow up - we should swap this to an API call or something.
// Having it as a celery task ID based on a file path is brittle and hard to test.
export const CELERY_TASK_ID = 'posthog.tasks.hog_functions.hog_function_state_transition'

export class HogWatcher {
constructor(
private hub: Hub,
private redis: CdpRedis,
private onStateChange: (id: HogFunctionType['id'], state: HogWatcherState) => void
) {}
constructor(private hub: Hub, private redis: CdpRedis) {}

private async onStateChange(id: HogFunctionType['id'], state: HogWatcherState) {
await this.hub.db.celeryApplyAsync(CELERY_TASK_ID, [id, state])
}

private rateLimitArgs(id: HogFunctionType['id'], cost: number) {
const nowSeconds = Math.round(now() / 1000)
Expand Down Expand Up @@ -117,7 +121,7 @@ export class HogWatcher {
}
})

this.onStateChange(id, state)
await this.onStateChange(id, state)
}

public async observeResults(results: HogFunctionInvocationResult[]): Promise<void> {
Expand Down Expand Up @@ -208,15 +212,15 @@ export class HogWatcher {
}

// Finally track the results
functionsToDisablePermanently.forEach((id) => {
this.onStateChange(id, HogWatcherState.disabledIndefinitely)
})
for (const id of functionsToDisablePermanently) {
await this.onStateChange(id, HogWatcherState.disabledIndefinitely)
}

functionsTempDisabled.forEach((id) => {
for (const id of functionsTempDisabled) {
if (!functionsToDisablePermanently.includes(id)) {
this.onStateChange(id, HogWatcherState.disabledForPeriod)
await this.onStateChange(id, HogWatcherState.disabledForPeriod)
}
})
}
}
}
}
49 changes: 34 additions & 15 deletions plugin-server/tests/cdp/hog-watcher.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ jest.mock('../../src/utils/now', () => {
now: jest.fn(() => Date.now()),
}
})
import { BASE_REDIS_KEY, HogWatcher, HogWatcherState } from '../../src/cdp/hog-watcher'
import { BASE_REDIS_KEY, CELERY_TASK_ID, HogWatcher, HogWatcherState } from '../../src/cdp/hog-watcher'
import { CdpRedis, createCdpRedisPool } from '../../src/cdp/redis'
import { HogFunctionInvocationResult } from '../../src/cdp/types'
import { Hub } from '../../src/types'
Expand Down Expand Up @@ -43,20 +43,21 @@ describe('HogWatcher', () => {
let now: number
let hub: Hub
let watcher: HogWatcher
let mockStateChangeCallback: jest.Mock
let mockCeleryApplyAsync: jest.Mock
let redis: CdpRedis

beforeEach(async () => {
hub = await createHub()
mockCeleryApplyAsync = jest.fn()
hub.db.celeryApplyAsync = mockCeleryApplyAsync

now = 1720000000000
mockNow.mockReturnValue(now)
mockStateChangeCallback = jest.fn()

redis = createCdpRedisPool(hub)
await deleteKeysWithPrefix(redis, BASE_REDIS_KEY)

watcher = new HogWatcher(hub, redis, mockStateChangeCallback)
watcher = new HogWatcher(hub, redis)
})

const advanceTime = (ms: number) => {
Expand Down Expand Up @@ -183,8 +184,11 @@ describe('HogWatcher', () => {

await watcher.observeResults(badResults)

expect(mockStateChangeCallback).toHaveBeenCalledTimes(1)
expect(mockStateChangeCallback).toHaveBeenCalledWith('id1', HogWatcherState.disabledForPeriod)
expect(mockCeleryApplyAsync).toHaveBeenCalledTimes(1)
expect(mockCeleryApplyAsync).toHaveBeenCalledWith(CELERY_TASK_ID, [
'id1',
HogWatcherState.disabledForPeriod,
])

expect(await watcher.getState('id1')).toMatchInlineSnapshot(`
Object {
Expand Down Expand Up @@ -216,7 +220,7 @@ describe('HogWatcher', () => {
"tokens": 10000,
}
`)
expect(mockStateChangeCallback).toHaveBeenCalledWith('id1', HogWatcherState.healthy)
expect(mockCeleryApplyAsync).toHaveBeenCalledWith(CELERY_TASK_ID, ['id1', HogWatcherState.healthy])
})
it('should force degraded', async () => {
await watcher.forceStateChange('id1', HogWatcherState.degraded)
Expand All @@ -227,7 +231,7 @@ describe('HogWatcher', () => {
"tokens": 8000,
}
`)
expect(mockStateChangeCallback).toHaveBeenCalledWith('id1', HogWatcherState.degraded)
expect(mockCeleryApplyAsync).toHaveBeenCalledWith(CELERY_TASK_ID, ['id1', HogWatcherState.degraded])
})
it('should force disabledForPeriod', async () => {
await watcher.forceStateChange('id1', HogWatcherState.disabledForPeriod)
Expand All @@ -238,7 +242,10 @@ describe('HogWatcher', () => {
"tokens": 0,
}
`)
expect(mockStateChangeCallback).toHaveBeenCalledWith('id1', HogWatcherState.disabledForPeriod)
expect(mockCeleryApplyAsync).toHaveBeenCalledWith(CELERY_TASK_ID, [
'id1',
HogWatcherState.disabledForPeriod,
])
})
it('should force disabledIndefinitely', async () => {
await watcher.forceStateChange('id1', HogWatcherState.disabledIndefinitely)
Expand All @@ -249,7 +256,10 @@ describe('HogWatcher', () => {
"tokens": 0,
}
`)
expect(mockStateChangeCallback).toHaveBeenCalledWith('id1', HogWatcherState.disabledIndefinitely)
expect(mockCeleryApplyAsync).toHaveBeenCalledWith(CELERY_TASK_ID, [
'id1',
HogWatcherState.disabledIndefinitely,
])
})
})

Expand All @@ -269,17 +279,26 @@ describe('HogWatcher', () => {
expect((await watcher.getState('id1')).state).toEqual(HogWatcherState.degraded)
}

expect(mockStateChangeCallback).toHaveBeenCalledTimes(2)
expect(mockStateChangeCallback.mock.calls[0]).toEqual(['id1', HogWatcherState.disabledForPeriod])
expect(mockStateChangeCallback.mock.calls[1]).toEqual(['id1', HogWatcherState.disabledForPeriod])
expect(mockCeleryApplyAsync).toHaveBeenCalledTimes(2)
expect(mockCeleryApplyAsync.mock.calls[0]).toEqual([
CELERY_TASK_ID,
['id1', HogWatcherState.disabledForPeriod],
])
expect(mockCeleryApplyAsync.mock.calls[1]).toEqual([
CELERY_TASK_ID,
['id1', HogWatcherState.disabledForPeriod],
])

await watcher.observeResults([createResult({ id: 'id1', error: 'error!' })])
expect((await watcher.getState('id1')).state).toEqual(HogWatcherState.disabledIndefinitely)
await reallyAdvanceTime(1000)
expect((await watcher.getState('id1')).state).toEqual(HogWatcherState.disabledIndefinitely)

expect(mockStateChangeCallback).toHaveBeenCalledTimes(3)
expect(mockStateChangeCallback.mock.calls[2]).toEqual(['id1', HogWatcherState.disabledIndefinitely])
expect(mockCeleryApplyAsync).toHaveBeenCalledTimes(3)
expect(mockCeleryApplyAsync.mock.calls[2]).toEqual([
CELERY_TASK_ID,
['id1', HogWatcherState.disabledIndefinitely],
])
})
})
})
Expand Down
5 changes: 5 additions & 0 deletions posthog/models/hog_functions/hog_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
patch_hog_function_status,
reload_hog_functions_on_workers,
)
from posthog.utils import absolute_uri

DEFAULT_STATE = {"state": 0, "tokens": 0, "rating": 0}

Expand Down Expand Up @@ -121,6 +122,10 @@ def move_secret_inputs(self):
self.inputs = final_inputs
self.encrypted_inputs = final_encrypted_inputs

@property
def url(self):
return absolute_uri(f"/project/{self.team_id}/pipeline/destinations/hog-{str(self.id)}")

def save(self, *args, **kwargs):
from posthog.cdp.filters import compile_filters_bytecode

Expand Down
Loading

0 comments on commit e8ce162

Please sign in to comment.