diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/__mocks__/mocks.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/__mocks__/mocks.ts index 5bc9d4e23bc68..1127b4b0f2a9b 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/__mocks__/mocks.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/__mocks__/mocks.ts @@ -5,16 +5,16 @@ * 2.0. */ -import { mockRuleMigrationsDataClient } from '../data/__mocks__/mocks'; -import { mockRuleMigrationsTaskClient } from '../task/__mocks__/mocks'; +import { createRuleMigrationsDataClientMock } from '../data/__mocks__/mocks'; +import { createRuleMigrationsTaskClientMock } from '../task/__mocks__/mocks'; export const createRuleMigrationDataClient = jest .fn() - .mockImplementation(() => mockRuleMigrationsDataClient); + .mockImplementation(() => createRuleMigrationsDataClientMock()); export const createRuleMigrationTaskClient = jest .fn() - .mockImplementation(() => mockRuleMigrationsTaskClient); + .mockImplementation(() => createRuleMigrationsTaskClientMock()); export const createRuleMigrationClient = () => ({ data: createRuleMigrationDataClient(), diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/data/__mocks__/mocks.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/data/__mocks__/mocks.ts index 77ed5e87084e9..7a844f671b07f 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/data/__mocks__/mocks.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/data/__mocks__/mocks.ts @@ -5,24 +5,28 @@ * 2.0. */ +import type { RuleMigrationsDataIntegrationsClient } from '../rule_migrations_data_integrations_client'; +import type { RuleMigrationsDataLookupsClient } from '../rule_migrations_data_lookups_client'; +import type { RuleMigrationsDataPrebuiltRulesClient } from '../rule_migrations_data_prebuilt_rules_client'; +import type { RuleMigrationsDataResourcesClient } from '../rule_migrations_data_resources_client'; import type { RuleMigrationsDataRulesClient } from '../rule_migrations_data_rules_client'; // Rule migrations data rules client export const mockRuleMigrationsDataRulesClient = { create: jest.fn().mockResolvedValue(undefined), - get: jest.fn().mockResolvedValue([]), + get: jest.fn().mockResolvedValue({ data: [], total: 0 }), searchBatches: jest.fn().mockReturnValue({ next: jest.fn().mockResolvedValue([]), all: jest.fn().mockResolvedValue([]), }), - takePending: jest.fn().mockResolvedValue([]), + saveProcessing: jest.fn().mockResolvedValue(undefined), saveCompleted: jest.fn().mockResolvedValue(undefined), saveError: jest.fn().mockResolvedValue(undefined), releaseProcessing: jest.fn().mockResolvedValue(undefined), updateStatus: jest.fn().mockResolvedValue(undefined), getStats: jest.fn().mockResolvedValue(undefined), getAllStats: jest.fn().mockResolvedValue([]), -} as unknown as RuleMigrationsDataRulesClient; +} as unknown as jest.Mocked; export const MockRuleMigrationsDataRulesClient = jest .fn() .mockImplementation(() => mockRuleMigrationsDataRulesClient); @@ -35,30 +39,42 @@ export const mockRuleMigrationsDataResourcesClient = { next: jest.fn().mockResolvedValue([]), all: jest.fn().mockResolvedValue([]), }), -}; +} as unknown as jest.Mocked; export const MockRuleMigrationsDataResourcesClient = jest .fn() .mockImplementation(() => mockRuleMigrationsDataResourcesClient); export const mockRuleMigrationsDataIntegrationsClient = { + populate: jest.fn().mockResolvedValue(undefined), retrieveIntegrations: jest.fn().mockResolvedValue([]), -}; +} as unknown as jest.Mocked; + +export const mockRuleMigrationsDataPrebuiltRulesClient = { + populate: jest.fn().mockResolvedValue(undefined), + search: jest.fn().mockResolvedValue([]), +} as unknown as jest.Mocked; +export const mockRuleMigrationsDataLookupsClient = { + create: jest.fn().mockResolvedValue(undefined), + indexData: jest.fn().mockResolvedValue(undefined), +} as unknown as jest.Mocked; // Rule migrations data client -export const mockRuleMigrationsDataClient = { +export const createRuleMigrationsDataClientMock = () => ({ rules: mockRuleMigrationsDataRulesClient, resources: mockRuleMigrationsDataResourcesClient, integrations: mockRuleMigrationsDataIntegrationsClient, -}; + prebuiltRules: mockRuleMigrationsDataPrebuiltRulesClient, + lookups: mockRuleMigrationsDataLookupsClient, +}); export const MockRuleMigrationsDataClient = jest .fn() - .mockImplementation(() => mockRuleMigrationsDataClient); + .mockImplementation(() => createRuleMigrationsDataClientMock()); // Rule migrations data service export const mockIndexName = 'mocked_siem_rule_migrations_index_name'; export const mockInstall = jest.fn().mockResolvedValue(undefined); -export const mockCreateClient = jest.fn().mockReturnValue(mockRuleMigrationsDataClient); +export const mockCreateClient = jest.fn(() => createRuleMigrationsDataClientMock()); export const MockRuleMigrationsDataService = jest.fn().mockImplementation(() => ({ createAdapter: jest.fn(), diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/data/rule_migrations_data_rules_client.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/data/rule_migrations_data_rules_client.ts index d66b1f54a710c..1097b9475102d 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/data/rule_migrations_data_rules_client.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/data/rule_migrations_data_rules_client.ts @@ -151,45 +151,19 @@ export class RuleMigrationsDataRulesClient extends RuleMigrationsDataBaseClient } } - /** - * Retrieves `pending` rule migrations with the provided id and updates their status to `processing`. - * This operation is not atomic at migration level: - * - Multiple tasks can process different migrations simultaneously. - * - Multiple tasks should not process the same migration simultaneously. - */ - async takePending(migrationId: string, size: number): Promise { + /** Updates one rule migration status to `processing` */ + async saveProcessing(id: string): Promise { const index = await this.getIndexName(); const profileId = await this.getProfileUid(); - const query = this.getFilterQuery(migrationId, { status: SiemMigrationStatus.PENDING }); - - const storedRuleMigrations = await this.esClient - .search({ index, query, sort: '_doc', size }) - .then((response) => - this.processResponseHits(response, { status: SiemMigrationStatus.PROCESSING }) - ) - .catch((error) => { - this.logger.error(`Error searching rule migrations: ${error.message}`); - throw error; - }); - - await this.esClient - .bulk({ - refresh: 'wait_for', - operations: storedRuleMigrations.flatMap(({ id, status }) => [ - { update: { _id: id, _index: index } }, - { - doc: { status, updated_by: profileId, updated_at: new Date().toISOString() }, - }, - ]), - }) - .catch((error) => { - this.logger.error( - `Error updating for rule migrations status to processing: ${error.message}` - ); - throw error; - }); - - return storedRuleMigrations; + const doc = { + status: SiemMigrationStatus.PROCESSING, + updated_by: profileId, + updated_at: new Date().toISOString(), + }; + await this.esClient.update({ index, id, doc, refresh: 'wait_for' }).catch((error) => { + this.logger.error(`Error updating rule migration status to processing: ${error.message}`); + throw error; + }); } /** Updates one rule migration with the provided data and sets the status to `completed` */ diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/__mocks__/mocks.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/__mocks__/mocks.ts index 1f463c8417b90..91a298d876e1f 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/__mocks__/mocks.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/__mocks__/mocks.ts @@ -5,7 +5,7 @@ * 2.0. */ -export const mockRuleMigrationsTaskClient = { +export const createRuleMigrationsTaskClientMock = () => ({ start: jest.fn().mockResolvedValue({ started: true }), stop: jest.fn().mockResolvedValue({ stopped: true }), getStats: jest.fn().mockResolvedValue({ @@ -19,15 +19,15 @@ export const mockRuleMigrationsTaskClient = { }, }), getAllStats: jest.fn().mockResolvedValue([]), -}; +}); export const MockRuleMigrationsTaskClient = jest .fn() - .mockImplementation(() => mockRuleMigrationsTaskClient); + .mockImplementation(() => createRuleMigrationsTaskClientMock()); // Rule migrations task service export const mockStopAll = jest.fn(); -export const mockCreateClient = jest.fn().mockReturnValue(mockRuleMigrationsTaskClient); +export const mockCreateClient = jest.fn(() => createRuleMigrationsTaskClientMock()); export const MockRuleMigrationsTaskService = jest.fn().mockImplementation(() => ({ createClient: mockCreateClient, diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/retrievers/rule_migrations_retriever.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/retrievers/rule_migrations_retriever.ts index 5616bfd4fb26b..0135dba877c35 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/retrievers/rule_migrations_retriever.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/retrievers/rule_migrations_retriever.ts @@ -18,6 +18,11 @@ export interface RuleMigrationsRetrieverClients { savedObjects: SavedObjectsClientContract; } +/** + * RuleMigrationsRetriever is a class that is responsible for retrieving all the necessary data during the rule migration process. + * It is composed of multiple retrievers that are responsible for retrieving specific types of data. + * Such as rule integrations, prebuilt rules, and rule resources. + */ export class RuleMigrationsRetriever { public readonly resources: RuleResourceRetriever; public readonly integrations: IntegrationRetriever; diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_client.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_client.ts index 0a4999377a133..0a7ead395c2fb 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_client.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_client.ts @@ -6,10 +6,6 @@ */ import type { AuthenticatedUser, Logger } from '@kbn/core/server'; -import { AbortError, abortSignalToPromise } from '@kbn/kibana-utils-plugin/server'; -import type { RunnableConfig } from '@langchain/core/runnables'; -import { TELEMETRY_SIEM_MIGRATION_ID } from './util/constants'; -import { EsqlKnowledgeBase } from './util/esql_knowledge_base'; import { SiemMigrationStatus, SiemMigrationTaskStatus, @@ -19,26 +15,14 @@ import type { RuleMigrationFilters } from '../../../../../common/siem_migrations import type { RuleMigrationsDataClient } from '../data/rule_migrations_data_client'; import type { RuleMigrationDataStats } from '../data/rule_migrations_data_rules_client'; import type { SiemRuleMigrationsClientDependencies } from '../types'; -import { getRuleMigrationAgent } from './agent'; -import type { MigrateRuleState } from './agent/types'; -import { RuleMigrationsRetriever } from './retrievers'; -import { SiemMigrationTelemetryClient } from './rule_migrations_telemetry_client'; import type { - MigrationAgent, - RuleMigrationTaskCreateAgentParams, - RuleMigrationTaskRunParams, RuleMigrationTaskStartParams, RuleMigrationTaskStartResult, RuleMigrationTaskStopResult, } from './types'; -import type { ChatModel } from './util/actions_client_chat'; -import { ActionsClientChat } from './util/actions_client_chat'; -import { generateAssistantComment } from './util/comments'; +import { RuleMigrationTaskRunner } from './rule_migrations_task_runner'; -const ITERATION_BATCH_SIZE = 15 as const; -const ITERATION_SLEEP_SECONDS = 10 as const; - -type MigrationsRunning = Map; +export type MigrationsRunning = Map; export class RuleMigrationsTaskClient { constructor( @@ -51,7 +35,7 @@ export class RuleMigrationsTaskClient { /** Starts a rule migration task */ async start(params: RuleMigrationTaskStartParams): Promise { - const { migrationId, connectorId } = params; + const { migrationId, connectorId, invocationConfig } = params; if (this.migrationsRunning.has(migrationId)) { return { exists: true, started: false }; } @@ -70,161 +54,40 @@ export class RuleMigrationsTaskClient { if (rules.pending === 0) { return { exists: true, started: false }; } - const abortController = new AbortController(); - const model = await this.createModel(connectorId, migrationId, abortController); - - // run the migration without awaiting it to execute it in the background - this.run({ ...params, model, abortController }).catch((error) => { - this.logger.error(`Error executing migration ID:${migrationId} with error ${error}`); - }); - - return { exists: true, started: true }; - } - - private async run(params: RuleMigrationTaskRunParams): Promise { - const { migrationId, invocationConfig, abortController, model } = params; - if (this.migrationsRunning.has(migrationId)) { - // This should never happen, but just in case - throw new Error(`Task already running for migration ID:${migrationId} `); - } - this.logger.info(`Starting migration ID:${migrationId}`); - - this.migrationsRunning.set(migrationId, { user: this.currentUser.username, abortController }); - - const abortPromise = abortSignalToPromise(abortController.signal); - const withAbortRace = async (task: Promise) => Promise.race([task, abortPromise.promise]); - const sleep = async (seconds: number) => { - this.logger.debug(`Sleeping ${seconds}s for migration ID:${migrationId}`); - await withAbortRace(new Promise((resolve) => setTimeout(resolve, seconds * 1000))); - }; - - const stats = { completed: 0, failed: 0 }; - const telemetryClient = new SiemMigrationTelemetryClient( - this.dependencies.telemetry, - this.logger, + const migrationLogger = this.logger.get(migrationId); + const abortController = new AbortController(); + const migrationTaskRunner = new RuleMigrationTaskRunner( migrationId, - model.model + this.currentUser, + abortController, + this.data, + migrationLogger, + this.dependencies ); - const endSiemMigration = telemetryClient.startSiemMigration(); - try { - this.logger.debug(`Creating agent for migration ID:${migrationId}`); - - const agent = await withAbortRace(this.createAgent({ ...params, model, telemetryClient })); - - const config: RunnableConfig = { - ...invocationConfig, - // signal: abortController.signal, // not working properly https://github.com/langchain-ai/langgraphjs/issues/319 - }; - let isDone: boolean = false; - do { - const ruleMigrations = await this.data.rules.takePending(migrationId, ITERATION_BATCH_SIZE); - this.logger.debug( - `Processing ${ruleMigrations.length} rules for migration ID:${migrationId}` - ); - - await Promise.all( - ruleMigrations.map(async (ruleMigration) => { - this.logger.debug(`Starting migration of rule "${ruleMigration.original_rule.title}"`); - if (ruleMigration.elastic_rule?.id) { - await this.data.rules.saveCompleted(ruleMigration); - return; // skip already installed rules - } - const endRuleTranslation = telemetryClient.startRuleTranslation(); - try { - const invocationData = { - original_rule: ruleMigration.original_rule, - }; - // using withAbortRace is a workaround for the issue with the langGraph signal not working properly - const migrationResult = await withAbortRace( - agent.invoke(invocationData, config) - ); + await migrationTaskRunner.setup(connectorId); - this.logger.debug( - `Migration of rule "${ruleMigration.original_rule.title}" finished` - ); - endRuleTranslation({ migrationResult }); - await this.data.rules.saveCompleted({ - ...ruleMigration, - elastic_rule: migrationResult.elastic_rule, - translation_result: migrationResult.translation_result, - comments: migrationResult.comments, - }); - stats.completed++; - } catch (error) { - stats.failed++; - if (error instanceof AbortError) { - throw error; - } - endRuleTranslation({ error }); - this.logger.error( - `Error migrating rule "${ruleMigration.original_rule.title} with error: ${error.message}"` - ); - await this.data.rules.saveError({ - ...ruleMigration, - comments: [generateAssistantComment(`Error migrating rule: ${error.message}`)], - }); - } - }) - ); - - this.logger.debug(`Batch processed successfully for migration ID:${migrationId}`); - - const { rules } = await this.data.rules.getStats(migrationId); - isDone = rules.pending === 0; - if (!isDone) { - await sleep(ITERATION_SLEEP_SECONDS); - } - } while (!isDone); - - this.logger.info(`Finished migration ID:${migrationId}`); - - endSiemMigration({ stats }); - } catch (error) { - await this.data.rules.releaseProcessing(migrationId); - - if (error instanceof AbortError) { - this.logger.info(`Abort signal received, stopping migration ID:${migrationId}`); - return; - } else { - endSiemMigration({ error, stats }); - this.logger.error(`Error processing migration ID:${migrationId} ${error}`); - } - } finally { - this.migrationsRunning.delete(migrationId); - abortPromise.cleanup(); + if (this.migrationsRunning.has(migrationId)) { + // Just to prevent a race condition in the setup + throw new Error('Task already running for this migration'); } - } + this.migrationsRunning.set(migrationId, migrationTaskRunner); - private async createAgent({ - connectorId, - migrationId, - model, - telemetryClient, - }: RuleMigrationTaskCreateAgentParams): Promise { - const { inferenceClient, rulesClient, savedObjectsClient } = this.dependencies; - const esqlKnowledgeBase = new EsqlKnowledgeBase( - connectorId, - migrationId, - inferenceClient, - this.logger - ); - const ruleMigrationsRetriever = new RuleMigrationsRetriever(migrationId, { - data: this.data, - rules: rulesClient, - savedObjects: savedObjectsClient, - }); + migrationLogger.info('Starting migration'); - await ruleMigrationsRetriever.initialize(); + // run the migration in the background without awaiting and resolve the `start` promise + migrationTaskRunner + .run(invocationConfig) + .catch((error) => { + // no need to throw, the `start` promise is long gone. Just log the error + migrationLogger.error('Error executing migration', error); + }) + .finally(() => { + this.migrationsRunning.delete(migrationId); + }); - return getRuleMigrationAgent({ - model, - esqlKnowledgeBase, - ruleMigrationsRetriever, - telemetryClient, - logger: this.logger, - }); + return { exists: true, started: true }; } /** Updates all the rules in a migration to be re-executed */ @@ -233,9 +96,10 @@ export class RuleMigrationsTaskClient { filter: RuleMigrationFilters ): Promise<{ updated: boolean }> { if (this.migrationsRunning.has(migrationId)) { + // not update migrations that are currently running return { updated: false }; } - + filter.installed = false; // only retry rules that are not installed await this.data.rules.updateStatus(migrationId, filter, SiemMigrationStatus.PENDING, { refresh: true, }); @@ -293,20 +157,4 @@ export class RuleMigrationsTaskClient { return { exists: true, stopped: false }; } } - - private async createModel( - connectorId: string, - migrationId: string, - abortController: AbortController - ): Promise { - const { actionsClient } = this.dependencies; - const actionsClientChat = new ActionsClientChat(connectorId, actionsClient, this.logger); - const model = await actionsClientChat.createModel({ - telemetryMetadata: { pluginId: TELEMETRY_SIEM_MIGRATION_ID, aggregateBy: migrationId }, - maxRetries: 10, - signal: abortController.signal, - temperature: 0.05, - }); - return model; - } } diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_runner.test.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_runner.test.ts new file mode 100644 index 0000000000000..2f45f43fc15c3 --- /dev/null +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_runner.test.ts @@ -0,0 +1,383 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { RuleMigrationTaskRunner } from './rule_migrations_task_runner'; +import { SiemMigrationStatus } from '../../../../../common/siem_migrations/constants'; +import type { AuthenticatedUser } from '@kbn/core/server'; +import type { SiemRuleMigrationsClientDependencies, StoredRuleMigration } from '../types'; +import { createRuleMigrationsDataClientMock } from '../data/__mocks__/mocks'; +import { loggerMock } from '@kbn/logging-mocks'; + +const mockRetrieverInitialize = jest.fn().mockResolvedValue(undefined); +jest.mock('./retrievers', () => ({ + ...jest.requireActual('./retrievers'), + RuleMigrationsRetriever: jest + .fn() + .mockImplementation(() => ({ initialize: mockRetrieverInitialize })), +})); + +const mockCreateModel = jest.fn(() => ({ model: 'test-model' })); +jest.mock('./util/actions_client_chat', () => ({ + ...jest.requireActual('./util/actions_client_chat'), + ActionsClientChat: jest.fn().mockImplementation(() => ({ createModel: mockCreateModel })), +})); + +const mockInvoke = jest.fn().mockResolvedValue({}); +jest.mock('./agent', () => ({ + ...jest.requireActual('./agent'), + getRuleMigrationAgent: () => ({ invoke: mockInvoke }), +})); + +jest.mock('./rule_migrations_telemetry_client', () => ({ + SiemMigrationTelemetryClient: jest.fn().mockImplementation(() => ({ + startSiemMigrationTask: jest.fn(() => ({ + startRuleTranslation: jest.fn(() => ({ success: jest.fn(), failure: jest.fn() })), + success: jest.fn(), + failure: jest.fn(), + })), + })), +})); + +// Mock dependencies +const mockLogger = loggerMock.create(); + +const mockDependencies: jest.Mocked = { + rulesClient: {}, + savedObjectsClient: {}, + inferenceClient: {}, + actionsClient: {}, + telemetry: {}, +} as unknown as SiemRuleMigrationsClientDependencies; + +const mockUser = {} as unknown as AuthenticatedUser; +const ruleId = 'test-rule-id'; + +jest.useFakeTimers(); +jest.spyOn(global, 'setTimeout'); +const mockTimeout = setTimeout as unknown as jest.Mock; +mockTimeout.mockImplementation((cb) => { + // never actually wait, we'll check the calls manually + cb(); +}); + +describe('RuleMigrationTaskRunner', () => { + let taskRunner: RuleMigrationTaskRunner; + let abortController: AbortController; + let mockRuleMigrationsDataClient: ReturnType; + + beforeEach(() => { + mockRetrieverInitialize.mockResolvedValue(undefined); // Reset the mock + mockInvoke.mockResolvedValue({}); // Reset the mock + mockRuleMigrationsDataClient = createRuleMigrationsDataClientMock(); + jest.clearAllMocks(); + + abortController = new AbortController(); + taskRunner = new RuleMigrationTaskRunner( + 'test-migration-id', + mockUser, + abortController, + mockRuleMigrationsDataClient, + mockLogger, + mockDependencies + ); + }); + + describe('setup', () => { + it('should create the agent and tools', async () => { + await expect(taskRunner.setup('test-connector-id')).resolves.toBeUndefined(); + // @ts-expect-error (checking private properties) + expect(taskRunner.agent).toBeDefined(); + // @ts-expect-error (checking private properties) + expect(taskRunner.retriever).toBeDefined(); + // @ts-expect-error (checking private properties) + expect(taskRunner.telemetry).toBeDefined(); + }); + + it('should throw if an error occurs', async () => { + const errorMessage = 'Test error'; + mockCreateModel.mockImplementationOnce(() => { + throw new Error(errorMessage); + }); + + await expect(taskRunner.setup('test-connector-id')).rejects.toThrowError(errorMessage); + }); + }); + + describe('run', () => { + let runPromise: Promise; + beforeEach(async () => { + await taskRunner.setup('test-connector-id'); + }); + + it('should handle the migration successfully', async () => { + mockRuleMigrationsDataClient.rules.get.mockResolvedValue({ total: 0, data: [] }); + mockRuleMigrationsDataClient.rules.get.mockResolvedValueOnce({ + total: 1, + data: [{ id: ruleId, status: SiemMigrationStatus.PENDING }] as StoredRuleMigration[], + }); + + await taskRunner.setup('test-connector-id'); + await expect(taskRunner.run({})).resolves.toBeUndefined(); + + expect(mockRuleMigrationsDataClient.rules.saveProcessing).toHaveBeenCalled(); + expect(mockTimeout).toHaveBeenCalledTimes(1); // execution sleep + expect(mockInvoke).toHaveBeenCalledTimes(1); + expect(mockRuleMigrationsDataClient.rules.saveCompleted).toHaveBeenCalled(); + expect(mockRuleMigrationsDataClient.rules.get).toHaveBeenCalledTimes(2); // One with data, one without + expect(mockLogger.info).toHaveBeenCalledWith('Migration completed successfully'); + }); + + describe('when error occurs', () => { + const errorMessage = 'Test error message'; + + describe('during initialization', () => { + it('should handle abort error correctly', async () => { + runPromise = taskRunner.run({}); + abortController.abort(); // Trigger the abort signal + + await expect(runPromise).resolves.toBeUndefined(); // Ensure the function handles abort gracefully + + expect(mockLogger.info).toHaveBeenCalledWith( + 'Abort signal received, stopping initialization' + ); + }); + + it('should handle other errors correctly', async () => { + mockRetrieverInitialize.mockRejectedValueOnce(new Error(errorMessage)); + + runPromise = taskRunner.run({}); + await expect(runPromise).resolves.toBeUndefined(); // Ensure the function handles abort gracefully + + expect(mockLogger.error).toHaveBeenCalledWith( + `Error initializing migration: Error: ${errorMessage}` + ); + }); + }); + + describe('during migration', () => { + beforeEach(() => { + mockRuleMigrationsDataClient.rules.get.mockRestore(); + mockRuleMigrationsDataClient.rules.get + .mockResolvedValue({ total: 0, data: [] }) + .mockResolvedValueOnce({ + total: 1, + data: [{ id: ruleId, status: SiemMigrationStatus.PENDING }] as StoredRuleMigration[], + }); + }); + + it('should handle abort error correctly', async () => { + runPromise = taskRunner.run({}); + await Promise.resolve(); // Wait for the initialization to complete + abortController.abort(); // Trigger the abort signal + + await expect(runPromise).resolves.toBeUndefined(); // Ensure the function handles abort gracefully + expect(mockLogger.info).toHaveBeenCalledWith('Abort signal received, stopping migration'); + expect(mockRuleMigrationsDataClient.rules.releaseProcessing).toHaveBeenCalled(); + }); + + it('should handle other errors correctly', async () => { + mockInvoke.mockRejectedValue(new Error(errorMessage)); + + runPromise = taskRunner.run({}); + await expect(runPromise).resolves.toBeUndefined(); + + expect(mockLogger.error).toHaveBeenCalledWith( + `Error translating rule \"${ruleId}\" with error: ${errorMessage}` + ); + expect(mockRuleMigrationsDataClient.rules.saveError).toHaveBeenCalled(); + }); + + describe('during rate limit errors', () => { + const rule2Id = 'test-rule-id-2'; + const error = new Error('429. You did way too many requests to this random LLM API bud'); + + beforeEach(async () => { + mockRuleMigrationsDataClient.rules.get.mockRestore(); + mockRuleMigrationsDataClient.rules.get + .mockResolvedValue({ total: 0, data: [] }) + .mockResolvedValueOnce({ + total: 2, + data: [ + { id: ruleId, status: SiemMigrationStatus.PENDING }, + { id: rule2Id, status: SiemMigrationStatus.PENDING }, + ] as StoredRuleMigration[], + }); + }); + + it('should retry with exponential backoff', async () => { + mockInvoke + .mockResolvedValue({}) // Successful calls from here on + .mockRejectedValueOnce(error) // First failed call for rule 1 + .mockRejectedValueOnce(error) // First failed call for rule 2 + .mockRejectedValueOnce(error) // Second failed call for rule 1 + .mockRejectedValueOnce(error); // Third failed call for rule 1 + + await expect(taskRunner.run({})).resolves.toBeUndefined(); // success + + /** + * Invoke calls: + * rule 1 -> failure -> start backoff retries + * rule 2 -> failure -> await for rule 1 backoff + * then: + * rule 1 retry 1 -> failure + * rule 1 retry 2 -> failure + * rule 1 retry 3 -> success + * then: + * rule 2 -> success + */ + expect(mockInvoke).toHaveBeenCalledTimes(6); + expect(mockTimeout).toHaveBeenCalledTimes(6); // 3 backoff sleeps + 3 execution sleeps + expect(mockTimeout).toHaveBeenNthCalledWith( + 1, + expect.any(Function), + expect.any(Number) + ); + expect(mockTimeout).toHaveBeenNthCalledWith( + 2, + expect.any(Function), + expect.any(Number) + ); + expect(mockTimeout).toHaveBeenNthCalledWith(3, expect.any(Function), 1000); + expect(mockTimeout).toHaveBeenNthCalledWith(4, expect.any(Function), 2000); + expect(mockTimeout).toHaveBeenNthCalledWith(5, expect.any(Function), 4000); + expect(mockTimeout).toHaveBeenNthCalledWith( + 6, + expect.any(Function), + expect.any(Number) + ); + + expect(mockLogger.debug).toHaveBeenCalledWith( + `Awaiting backoff task for rule "${rule2Id}"` + ); + expect(mockInvoke).toHaveBeenCalledTimes(6); // 3 retries + 3 executions + expect(mockRuleMigrationsDataClient.rules.saveCompleted).toHaveBeenCalledTimes(2); // 2 rules + }); + + it('should fail when reached maxRetries', async () => { + mockInvoke.mockRejectedValue(error); + + await expect(taskRunner.run({})).resolves.toBeUndefined(); // success + + // maxRetries = 8 + expect(mockInvoke).toHaveBeenCalledTimes(10); // 8 retries + 2 executions + expect(mockTimeout).toHaveBeenCalledTimes(10); // 8 backoff sleeps + 2 execution sleeps + + expect(mockRuleMigrationsDataClient.rules.saveError).toHaveBeenCalledTimes(2); // 2 rules + }); + + it('should fail when reached max recovery attempts', async () => { + const rule3Id = 'test-rule-id-3'; + const rule4Id = 'test-rule-id-4'; + mockRuleMigrationsDataClient.rules.get.mockRestore(); + mockRuleMigrationsDataClient.rules.get + .mockResolvedValue({ total: 0, data: [] }) + .mockResolvedValueOnce({ + total: 4, + data: [ + { id: ruleId, status: SiemMigrationStatus.PENDING }, + { id: rule2Id, status: SiemMigrationStatus.PENDING }, + { id: rule3Id, status: SiemMigrationStatus.PENDING }, + { id: rule4Id, status: SiemMigrationStatus.PENDING }, + ] as StoredRuleMigration[], + }); + + // max recovery attempts = 3 + mockInvoke + .mockResolvedValue({}) // should never reach this + .mockRejectedValueOnce(error) // 1st failed call for rule 1 + .mockRejectedValueOnce(error) // 1st failed call for rule 2 + .mockRejectedValueOnce(error) // 1st failed call for rule 3 + .mockRejectedValueOnce(error) // 1st failed call for rule 4 + .mockResolvedValueOnce({}) // Successful call for the rule 1 backoff + .mockRejectedValueOnce(error) // 2nd failed call for the rule 2 recover + .mockRejectedValueOnce(error) // 2nd failed call for the rule 3 recover + .mockRejectedValueOnce(error) // 2nd failed call for the rule 4 recover + .mockResolvedValueOnce({}) // Successful call for the rule 2 backoff + .mockRejectedValueOnce(error) // 3rd failed call for the rule 3 recover + .mockRejectedValueOnce(error) // 3rd failed call for the rule 4 recover + .mockResolvedValueOnce({}) // Successful call for the rule 3 backoff + .mockRejectedValueOnce(error); // 4th failed call for the rule 4 recover (max attempts failure) + + await expect(taskRunner.run({})).resolves.toBeUndefined(); // success + + expect(mockRuleMigrationsDataClient.rules.saveCompleted).toHaveBeenCalledTimes(3); // rules 1, 2 and 3 + expect(mockRuleMigrationsDataClient.rules.saveError).toHaveBeenCalledTimes(1); // rule 4 + }); + + it('should increase the executor sleep time when rate limited', async () => { + const getResponse = { + total: 1, + data: [{ id: ruleId, status: SiemMigrationStatus.PENDING }] as StoredRuleMigration[], + }; + mockRuleMigrationsDataClient.rules.get.mockRestore(); + mockRuleMigrationsDataClient.rules.get + .mockResolvedValue({ total: 0, data: [] }) + .mockResolvedValueOnce(getResponse) + .mockResolvedValueOnce({ total: 0, data: [] }) + .mockResolvedValueOnce(getResponse) + .mockResolvedValueOnce({ total: 0, data: [] }) + .mockResolvedValueOnce(getResponse) + .mockResolvedValueOnce({ total: 0, data: [] }) + .mockResolvedValueOnce(getResponse) + .mockResolvedValueOnce({ total: 0, data: [] }) + .mockResolvedValueOnce(getResponse) + .mockResolvedValueOnce({ total: 0, data: [] }) + .mockResolvedValueOnce(getResponse) + .mockResolvedValueOnce({ total: 0, data: [] }); + + /** + * Current EXECUTOR_SLEEP settings: + * initialValueSeconds: 3, multiplier: 2, limitSeconds: 96, // 1m36s (5 increases) + */ + + // @ts-expect-error (checking private properties) + expect(taskRunner.executorSleepMultiplier).toBe(3); + + mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery + await expect(taskRunner.run({})).resolves.toBeUndefined(); // success + + // @ts-expect-error (checking private properties) + expect(taskRunner.executorSleepMultiplier).toBe(6); + + mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery + await expect(taskRunner.run({})).resolves.toBeUndefined(); // success + + // @ts-expect-error (checking private properties) + expect(taskRunner.executorSleepMultiplier).toBe(12); + + mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery + await expect(taskRunner.run({})).resolves.toBeUndefined(); // success + + // @ts-expect-error (checking private properties) + expect(taskRunner.executorSleepMultiplier).toBe(24); + + mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery + await expect(taskRunner.run({})).resolves.toBeUndefined(); // success + + // @ts-expect-error (checking private properties) + expect(taskRunner.executorSleepMultiplier).toBe(48); + + mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery + await expect(taskRunner.run({})).resolves.toBeUndefined(); // success + + // @ts-expect-error (checking private properties) + expect(taskRunner.executorSleepMultiplier).toBe(96); + + mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery + await expect(taskRunner.run({})).resolves.toBeUndefined(); // success + + // @ts-expect-error (checking private properties) + expect(taskRunner.executorSleepMultiplier).toBe(96); // limit reached + expect(mockLogger.warn).toHaveBeenCalledWith( + 'Executor sleep reached the maximum value' + ); + }); + }); + }); + }); + }); +}); diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_runner.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_runner.ts new file mode 100644 index 0000000000000..c4880a173213e --- /dev/null +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_runner.ts @@ -0,0 +1,351 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import assert from 'assert'; +import type { AuthenticatedUser, Logger } from '@kbn/core/server'; +import { abortSignalToPromise, AbortError } from '@kbn/kibana-utils-plugin/server'; +import type { RunnableConfig } from '@langchain/core/runnables'; +import { SiemMigrationStatus } from '../../../../../common/siem_migrations/constants'; +import { initPromisePool } from '../../../../utils/promise_pool'; +import type { RuleMigrationsDataClient } from '../data/rule_migrations_data_client'; +import type { MigrateRuleState } from './agent/types'; +import { getRuleMigrationAgent } from './agent'; +import { RuleMigrationsRetriever } from './retrievers'; +import { SiemMigrationTelemetryClient } from './rule_migrations_telemetry_client'; +import type { MigrationAgent } from './types'; +import { generateAssistantComment } from './util/comments'; +import type { SiemRuleMigrationsClientDependencies, StoredRuleMigration } from '../types'; +import { ActionsClientChat } from './util/actions_client_chat'; +import { EsqlKnowledgeBase } from './util/esql_knowledge_base'; + +/** Number of concurrent rule translations in the pool */ +const TASK_CONCURRENCY = 10 as const; +/** Number of rules loaded in memory to be translated in the pool */ +const TASK_BATCH_SIZE = 100 as const; + +/** Exponential backoff configuration to handle rate limit errors */ +const RETRY_CONFIG = { + initialRetryDelaySeconds: 1, + backoffMultiplier: 2, + maxRetries: 8, + // max waiting time 4m15s (1*2^8 = 256s) +} as const; + +/** Executor sleep configuration + * A sleep time applied at the beginning of each single rule translation in the execution pool, + * The objective of this sleep is to spread the load of concurrent translations, and prevent hitting the rate limit repeatedly. + * The sleep time applied is a random number between [0-value]. Every time we hit rate limit the value is increased by the multiplier, up to the limit. + */ +const EXECUTOR_SLEEP = { + initialValueSeconds: 3, + multiplier: 2, + limitSeconds: 96, // 1m36s (5 increases) +} as const; + +/** This limit should never be reached, it's a safety net to prevent infinite loops. + * It represents the max number of consecutive rate limit recovery & failure attempts. + * This can only happen when the API can not process TASK_CONCURRENCY translations at a time, + * even after the executor sleep is increased on every attempt. + **/ +const EXECUTOR_RECOVER_MAX_ATTEMPTS = 3 as const; + +export class RuleMigrationTaskRunner { + private telemetry?: SiemMigrationTelemetryClient; + private agent?: MigrationAgent; + private retriever?: RuleMigrationsRetriever; + private actionsClientChat: ActionsClientChat; + private abort: ReturnType; + private executorSleepMultiplier: number = EXECUTOR_SLEEP.initialValueSeconds; + public isWaiting: boolean = false; + + constructor( + public readonly migrationId: string, + public readonly startedBy: AuthenticatedUser, + public readonly abortController: AbortController, + private readonly data: RuleMigrationsDataClient, + private readonly logger: Logger, + private readonly dependencies: SiemRuleMigrationsClientDependencies + ) { + this.actionsClientChat = new ActionsClientChat(this.dependencies.actionsClient, this.logger); + this.abort = abortSignalToPromise(this.abortController.signal); + } + + /** Retrieves the connector and creates the migration agent */ + public async setup(connectorId: string) { + const { rulesClient, savedObjectsClient, inferenceClient } = this.dependencies; + + const model = await this.actionsClientChat.createModel({ + connectorId, + migrationId: this.migrationId, + abortController: this.abortController, + }); + + const esqlKnowledgeBase = new EsqlKnowledgeBase( + connectorId, + this.migrationId, + inferenceClient, + this.logger + ); + + this.retriever = new RuleMigrationsRetriever(this.migrationId, { + data: this.data, + rules: rulesClient, + savedObjects: savedObjectsClient, + }); + + this.telemetry = new SiemMigrationTelemetryClient( + this.dependencies.telemetry, + this.logger, + this.migrationId, + model.model + ); + + this.agent = getRuleMigrationAgent({ + model, + esqlKnowledgeBase, + ruleMigrationsRetriever: this.retriever, + telemetryClient: this.telemetry, + logger: this.logger, + }); + } + + /** Initializes the retriever populating ELSER indices. It may take a few minutes */ + private async initialize() { + assert(this.retriever, 'setup() must be called before initialize()'); + await this.retriever.initialize(); + } + + public async run(invocationConfig: RunnableConfig): Promise { + assert(this.telemetry, 'telemetry is missing please call setup() first'); + const { telemetry, migrationId } = this; + + const migrationTaskTelemetry = telemetry.startSiemMigrationTask(); + + try { + // TODO: track the duration of the initialization alone in the telemetry + this.logger.debug('Initializing migration'); + await this.withAbort(this.initialize()); // long running operation + } catch (error) { + migrationTaskTelemetry.failure(error); + if (error instanceof AbortError) { + this.logger.info('Abort signal received, stopping initialization'); + return; + } else { + this.logger.error(`Error initializing migration: ${error}`); + return; + } + } + + const migrateRuleTask = this.createMigrateRuleTask(invocationConfig); + this.logger.debug(`Started rule translations. Concurrency is: ${TASK_CONCURRENCY}`); + + try { + do { + const { data: ruleMigrations } = await this.data.rules.get(migrationId, { + filters: { status: SiemMigrationStatus.PENDING }, + size: TASK_BATCH_SIZE, // keep these rules in memory and process them in the promise pool with concurrency limit + }); + if (ruleMigrations.length === 0) { + break; + } + + this.logger.debug(`Start processing batch of ${ruleMigrations.length} rules`); + + const { errors } = await initPromisePool({ + concurrency: TASK_CONCURRENCY, + abortSignal: this.abortController.signal, + items: ruleMigrations, + executor: async (ruleMigration) => { + const ruleTranslationTelemetry = migrationTaskTelemetry.startRuleTranslation(); + try { + await this.saveRuleProcessing(ruleMigration); + + const migrationResult = await migrateRuleTask(ruleMigration); + + await this.saveRuleCompleted(ruleMigration, migrationResult); + ruleTranslationTelemetry.success(migrationResult); + } catch (error) { + if (error instanceof AbortError) { + throw error; + } + ruleTranslationTelemetry.failure(error); + await this.saveRuleFailed(ruleMigration, error); + } + }, + }); + + if (errors.length > 0) { + throw errors[0].error; // Only AbortError is thrown from the pool. The task was aborted + } + + this.logger.debug('Batch processed successfully'); + } while (true); + + migrationTaskTelemetry.success(); + this.logger.info('Migration completed successfully'); + } catch (error) { + await this.data.rules.releaseProcessing(migrationId); + + migrationTaskTelemetry.failure(error); + if (error instanceof AbortError) { + this.logger.info('Abort signal received, stopping migration'); + return; + } else { + this.logger.error(`Error processing migration: ${error}`); + } + } finally { + this.abort.cleanup(); + } + } + + private createMigrateRuleTask(invocationConfig: RunnableConfig) { + assert(this.agent, 'agent is missing please call setup() first'); + const { agent } = this; + const config: RunnableConfig = { + ...invocationConfig, + // signal: abortController.signal, // not working properly https://github.com/langchain-ai/langgraphjs/issues/319 + }; + + const invoke = async (migrationRule: StoredRuleMigration): Promise => { + // using withAbort in the agent invocation is not ideal but is a workaround for the issue with the langGraph signal not working properly + return this.withAbort( + agent.invoke({ original_rule: migrationRule.original_rule }, config) + ); + }; + + // Invokes the rule translation with exponential backoff, should be called only when the rate limit has been hit + const invokeWithBackoff = async ( + migrationRule: StoredRuleMigration + ): Promise => { + this.logger.debug(`Rate limit backoff started for rule "${migrationRule.id}"`); + let retriesLeft: number = RETRY_CONFIG.maxRetries; + while (true) { + try { + await this.sleepRetry(retriesLeft); + retriesLeft--; + const result = await invoke(migrationRule); + this.logger.info( + `Rate limit backoff completed successfully for rule "${migrationRule.id}" after ${ + RETRY_CONFIG.maxRetries - retriesLeft + } retries` + ); + return result; + } catch (error) { + if (!this.isRateLimitError(error) || retriesLeft === 0) { + this.logger.debug( + `Rate limit backoff completed unsuccessfully for rule "${migrationRule.id}"` + ); + const logMessage = + retriesLeft === 0 + ? `Rate limit backoff completed unsuccessfully for rule "${migrationRule.id}"` + : `Rate limit backoff interrupted for rule "${migrationRule.id}". ${error} `; + this.logger.debug(logMessage); + throw error; + } + this.logger.debug( + `Rate limit backoff not completed for rule "${migrationRule.id}", retries left: ${retriesLeft}` + ); + } + } + }; + + let backoffPromise: Promise | undefined; + // Migrates one rule, this function will be called concurrently by the promise pool. + // Handles rate limit errors and ensures only one task is executing the backoff retries at a time, the rest of translation will await. + const migrateRule = async (migrationRule: StoredRuleMigration): Promise => { + let recoverAttemptsLeft: number = EXECUTOR_RECOVER_MAX_ATTEMPTS; + while (true) { + try { + await this.executorSleep(); // Random sleep, increased every time we hit the rate limit. + return await invoke(migrationRule); + } catch (error) { + if (!this.isRateLimitError(error) || recoverAttemptsLeft === 0) { + throw error; + } + if (!backoffPromise) { + // only one translation handles the rate limit backoff retries, the rest will await it and try again when it's resolved + backoffPromise = invokeWithBackoff(migrationRule); + this.isWaiting = true; + return backoffPromise.finally(() => { + backoffPromise = undefined; + this.increaseExecutorSleep(); + this.isWaiting = false; + }); + } + this.logger.debug(`Awaiting backoff task for rule "${migrationRule.id}"`); + await backoffPromise.catch(() => { + throw error; // throw the original error + }); + recoverAttemptsLeft--; + } + } + }; + + return migrateRule; + } + + private isRateLimitError(error: Error) { + return error.message.match(/\b429\b/); // "429" (whole word in the error message): Too Many Requests. + } + + private async withAbort(promise: Promise): Promise { + return Promise.race([promise, this.abort.promise]); + } + + private async sleep(seconds: number) { + await this.withAbort(new Promise((resolve) => setTimeout(resolve, seconds * 1000))); + } + + // Exponential backoff implementation + private async sleepRetry(retriesLeft: number) { + const seconds = + RETRY_CONFIG.initialRetryDelaySeconds * + Math.pow(RETRY_CONFIG.backoffMultiplier, RETRY_CONFIG.maxRetries - retriesLeft); + this.logger.debug(`Retry sleep: ${seconds}s`); + await this.sleep(seconds); + } + + private executorSleep = async () => { + const seconds = Math.random() * this.executorSleepMultiplier; + this.logger.debug(`Executor sleep: ${seconds.toFixed(3)}s`); + await this.sleep(seconds); + }; + + private increaseExecutorSleep = () => { + const increasedMultiplier = this.executorSleepMultiplier * EXECUTOR_SLEEP.multiplier; + if (increasedMultiplier > EXECUTOR_SLEEP.limitSeconds) { + this.logger.warn('Executor sleep reached the maximum value'); + return; + } + this.executorSleepMultiplier = increasedMultiplier; + }; + + private async saveRuleProcessing(ruleMigration: StoredRuleMigration) { + this.logger.debug(`Starting translation of rule "${ruleMigration.id}"`); + return this.data.rules.saveProcessing(ruleMigration.id); + } + + private async saveRuleCompleted( + ruleMigration: StoredRuleMigration, + migrationResult: MigrateRuleState + ) { + this.logger.debug(`Translation of rule "${ruleMigration.id}" succeeded`); + const ruleMigrationTranslated = { + ...ruleMigration, + elastic_rule: migrationResult.elastic_rule, + translation_result: migrationResult.translation_result, + comments: migrationResult.comments, + }; + return this.data.rules.saveCompleted(ruleMigrationTranslated); + } + + private async saveRuleFailed(ruleMigration: StoredRuleMigration, error: Error) { + this.logger.error(`Error translating rule "${ruleMigration.id}" with error: ${error.message}`); + const comments = [generateAssistantComment(`Error migrating rule: ${error.message}`)]; + return this.data.rules.saveError({ ...ruleMigration, comments }); + } +} diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_service.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_service.ts index 9b8d9f86245d5..3efda577dbd61 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_service.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_task_service.ts @@ -7,12 +7,10 @@ import type { Logger } from '@kbn/core/server'; import type { RuleMigrationTaskCreateClientParams } from './types'; -import { RuleMigrationsTaskClient } from './rule_migrations_task_client'; - -export type MigrationRunning = Map; +import { RuleMigrationsTaskClient, type MigrationsRunning } from './rule_migrations_task_client'; export class RuleMigrationsTaskService { - private migrationsRunning: MigrationRunning; + private migrationsRunning: MigrationsRunning; constructor(private logger: Logger) { this.migrationsRunning = new Map(); diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_telemetry_client.test.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_telemetry_client.test.ts index 35c0df6484166..026851ba13898 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_telemetry_client.test.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_telemetry_client.test.ts @@ -18,7 +18,6 @@ const translationResultWithMatchMock = { const translationResultMock = { translation_result: 'partial', } as MigrateRuleState; -const stats = { completed: 2, failed: 2 }; const preFilterRulesMock: RuleMigrationPrebuiltRule[] = [ { rule_id: 'rule1id', @@ -96,13 +95,22 @@ describe('siemMigrationTelemetry', () => { jest.useRealTimers(); }); it('start/end migration with error', async () => { - const endSiemMigration = siemTelemetryClient.startSiemMigration(); - const error = new Error('test'); - endSiemMigration({ stats, error }); + const error = 'test error message'; + const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask(); + const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation(); + + // 2 success and 2 failures + ruleTranslationTelemetry.success(translationResultMock); + ruleTranslationTelemetry.success(translationResultMock); + ruleTranslationTelemetry.failure(new Error('test')); + ruleTranslationTelemetry.failure(new Error('test')); + + siemMigrationTaskTelemetry.failure(new Error(error)); + expect(mockTelemetry.reportEvent).toHaveBeenCalledWith('siem_migrations_migration_failure', { completed: 2, duration: 0, - error: 'test', + error, failed: 2, migrationId: 'testmigration', model: 'testModel', @@ -110,9 +118,17 @@ describe('siemMigrationTelemetry', () => { }); }); it('start/end migration success', async () => { - const endSiemMigration = siemTelemetryClient.startSiemMigration(); + const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask(); + const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation(); + + // 2 success and 2 failures + ruleTranslationTelemetry.success(translationResultMock); + ruleTranslationTelemetry.success(translationResultMock); + ruleTranslationTelemetry.failure(new Error('test')); + ruleTranslationTelemetry.failure(new Error('test')); + + siemMigrationTaskTelemetry.success(); - endSiemMigration({ stats }); expect(mockTelemetry.reportEvent).toHaveBeenCalledWith('siem_migrations_migration_success', { completed: 2, duration: 0, @@ -123,23 +139,23 @@ describe('siemMigrationTelemetry', () => { }); }); it('start/end rule translation with error', async () => { - const endRuleTranslation = siemTelemetryClient.startRuleTranslation(); - const error = new Error('test'); + const error = 'test error message'; + const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask(); + const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation(); + + ruleTranslationTelemetry.failure(new Error(error)); - endRuleTranslation({ error }); expect(mockTelemetry.reportEvent).toHaveBeenCalledWith( 'siem_migrations_rule_translation_failure', - { - error: 'test', - migrationId: 'testmigration', - model: 'testModel', - } + { error, migrationId: 'testmigration', model: 'testModel' } ); }); it('start/end rule translation success with prebuilt', async () => { - const endRuleTranslation = siemTelemetryClient.startRuleTranslation(); + const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask(); + const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation(); + + ruleTranslationTelemetry.success(translationResultWithMatchMock); - endRuleTranslation({ migrationResult: translationResultWithMatchMock }); expect(mockTelemetry.reportEvent).toHaveBeenCalledWith( 'siem_migrations_rule_translation_success', { @@ -152,9 +168,11 @@ describe('siemMigrationTelemetry', () => { ); }); it('start/end rule translation success without prebuilt', async () => { - const endRuleTranslation = siemTelemetryClient.startRuleTranslation(); + const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask(); + const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation(); + + ruleTranslationTelemetry.success(translationResultMock); - endRuleTranslation({ migrationResult: translationResultMock }); expect(mockTelemetry.reportEvent).toHaveBeenCalledWith( 'siem_migrations_rule_translation_success', { diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_telemetry_client.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_telemetry_client.ts index b472fb337c840..b00fb54535a3f 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_telemetry_client.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/rule_migrations_telemetry_client.ts @@ -27,19 +27,6 @@ interface PrebuiltRuleMatchEvent { postFilterRule?: RuleSemanticSearchResult; } -interface RuleTranslationEvent { - error?: Error; - migrationResult?: MigrateRuleState; -} - -interface SiemMigrationEvent { - error?: Error; - stats: { - failed: number; - completed: number; - }; -} - export class SiemMigrationTelemetryClient { constructor( private readonly telemetry: AnalyticsServiceSetup, @@ -69,6 +56,7 @@ export class SiemMigrationTelemetryClient { postFilterIntegrationCount: postFilterIntegration ? 1 : 0, }); } + public reportPrebuiltRulesMatch({ preFilterRules, postFilterRule, @@ -82,60 +70,58 @@ export class SiemMigrationTelemetryClient { postFilterRuleCount: postFilterRule ? 1 : 0, }); } - public startRuleTranslation(): ( - args: Pick - ) => void { - const startTime = Date.now(); - return ({ error, migrationResult }) => { - const duration = Date.now() - startTime; + public startSiemMigrationTask() { + const startTime = Date.now(); + const stats = { completed: 0, failed: 0 }; - if (error) { - this.reportEvent(SIEM_MIGRATIONS_RULE_TRANSLATION_FAILURE, { + return { + startRuleTranslation: () => { + const ruleStartTime = Date.now(); + return { + success: (migrationResult: MigrateRuleState) => { + stats.completed++; + this.reportEvent(SIEM_MIGRATIONS_RULE_TRANSLATION_SUCCESS, { + migrationId: this.migrationId, + translationResult: migrationResult.translation_result || '', + duration: Date.now() - ruleStartTime, + model: this.modelName, + prebuiltMatch: migrationResult.elastic_rule?.prebuilt_rule_id ? true : false, + }); + }, + failure: (error: Error) => { + stats.failed++; + this.reportEvent(SIEM_MIGRATIONS_RULE_TRANSLATION_FAILURE, { + migrationId: this.migrationId, + error: error.message, + model: this.modelName, + }); + }, + }; + }, + success: () => { + const duration = Date.now() - startTime; + this.reportEvent(SIEM_MIGRATIONS_MIGRATION_SUCCESS, { migrationId: this.migrationId, - error: error.message, - model: this.modelName, + model: this.modelName || '', + completed: stats.completed, + failed: stats.failed, + total: stats.completed + stats.failed, + duration, }); - return; - } - - this.reportEvent(SIEM_MIGRATIONS_RULE_TRANSLATION_SUCCESS, { - migrationId: this.migrationId, - translationResult: migrationResult?.translation_result || '', - duration, - model: this.modelName, - prebuiltMatch: migrationResult?.elastic_rule?.prebuilt_rule_id ? true : false, - }); - }; - } - public startSiemMigration(): (args: Pick) => void { - const startTime = Date.now(); - - return ({ error, stats }) => { - const duration = Date.now() - startTime; - const total = stats ? stats.completed + stats.failed : 0; - - if (error) { + }, + failure: (error: Error) => { + const duration = Date.now() - startTime; this.reportEvent(SIEM_MIGRATIONS_MIGRATION_FAILURE, { migrationId: this.migrationId, model: this.modelName || '', - completed: stats ? stats.completed : 0, - failed: stats ? stats.failed : 0, - total, + completed: stats.completed, + failed: stats.failed, + total: stats.completed + stats.failed, duration, error: error.message, }); - return; - } - - this.reportEvent(SIEM_MIGRATIONS_MIGRATION_SUCCESS, { - migrationId: this.migrationId, - model: this.modelName || '', - completed: stats ? stats.completed : 0, - failed: stats ? stats.failed : 0, - total, - duration, - }); + }, }; } } diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/types.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/types.ts index 756938dc0421b..152679b83b461 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/types.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/types.ts @@ -12,6 +12,7 @@ import type { SiemRuleMigrationsClientDependencies } from '../types'; import type { getRuleMigrationAgent } from './agent'; import type { SiemMigrationTelemetryClient } from './rule_migrations_telemetry_client'; import type { ChatModel } from './util/actions_client_chat'; +import type { RuleMigrationsRetriever } from './retrievers'; export type MigrationAgent = ReturnType; @@ -32,7 +33,9 @@ export interface RuleMigrationTaskRunParams extends RuleMigrationTaskStartParams abortController: AbortController; } -export interface RuleMigrationTaskCreateAgentParams extends RuleMigrationTaskStartParams { +export interface RuleMigrationTaskCreateAgentParams { + connectorId: string; + retriever: RuleMigrationsRetriever; telemetryClient: SiemMigrationTelemetryClient; model: ChatModel; } diff --git a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/util/actions_client_chat.ts b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/util/actions_client_chat.ts index 555662c8312c9..e232b66454fca 100644 --- a/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/util/actions_client_chat.ts +++ b/x-pack/solutions/security/plugins/security_solution/server/lib/siem_migrations/rules/task/util/actions_client_chat.ts @@ -17,6 +17,7 @@ import type { CustomChatModelInput as ActionsClientBedrockChatModelParams } from import type { ActionsClientChatOpenAIParams } from '@kbn/langchain/server/language_models/chat_openai'; import type { CustomChatModelInput as ActionsClientChatVertexAIParams } from '@kbn/langchain/server/language_models/gemini_chat'; import type { CustomChatModelInput as ActionsClientSimpleChatModelParams } from '@kbn/langchain/server/language_models/simple_chat_model'; +import { TELEMETRY_SIEM_MIGRATION_ID } from './constants'; export type ChatModel = | ActionsClientSimpleChatModel @@ -42,17 +43,23 @@ const llmTypeDictionary: Record = { [`.inference`]: `inference`, }; +interface CreateModelParams { + migrationId: string; + connectorId: string; + abortController: AbortController; +} + export class ActionsClientChat { - constructor( - private readonly connectorId: string, - private readonly actionsClient: ActionsClient, - private readonly logger: Logger - ) {} + constructor(private readonly actionsClient: ActionsClient, private readonly logger: Logger) {} - public async createModel(params?: ChatModelParams): Promise { - const connector = await this.actionsClient.get({ id: this.connectorId }); + public async createModel({ + migrationId, + connectorId, + abortController, + }: CreateModelParams): Promise { + const connector = await this.actionsClient.get({ id: connectorId }); if (!connector) { - throw new Error(`Connector not found: ${this.connectorId}`); + throw new Error(`Connector not found: ${connectorId}`); } const llmType = this.getLLMType(connector.actionTypeId); @@ -60,12 +67,15 @@ export class ActionsClientChat { const model = new ChatModelClass({ actionsClient: this.actionsClient, - connectorId: this.connectorId, - logger: this.logger, + connectorId, llmType, model: connector.config?.defaultModel, - ...params, - streaming: false, // disabling streaming by default + streaming: false, + temperature: 0.05, + maxRetries: 1, // Only retry once inside the model, we will handle backoff retries in the task runner + telemetryMetadata: { pluginId: TELEMETRY_SIEM_MIGRATION_ID, aggregateBy: migrationId }, + signal: abortController.signal, + logger: this.logger, }); return model; }