Skip to content

Commit

Permalink
Merge pull request #5 from center-for-threat-informed-defense/matt/pr…
Browse files Browse the repository at this point in the history
…ediction-labels

Matt/prediction labels
  • Loading branch information
mturner-ml authored Mar 5, 2024
2 parents 7a3f2e2 + c4da065 commit ee53265
Show file tree
Hide file tree
Showing 5 changed files with 558 additions and 134 deletions.
322 changes: 201 additions & 121 deletions models/main.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
},
{
"data": {
"text/plain": [
"<module 'recommender' from '/Users/mjturner/code/technique-inference-engine/models/recommender/__init__.py'>"
]
},
"execution_count": 1,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -22,13 +30,16 @@
"\n",
"# Imports\n",
"import json\n",
"from mitreattack.stix20 import MitreAttackData\n",
"import tensorflow as tf\n",
"import recommender\n",
"from matrix import ReportTechniqueMatrix\n",
"from matrix_builder import ReportTechniqueMatrixBuilder\n",
"import random\n",
"import math\n",
"import importlib\n",
"import pandas as pd\n",
"import numpy as np\n",
"from utils import get_mitre_technique_ids_to_names\n",
"\n",
"tf.config.run_functions_eagerly(True)\n",
"\n",
Expand All @@ -39,43 +50,10 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def get_mitre_technique_ids(stix_filepath: str) -> frozenset[str]:\n",
" \"\"\"Gets all MITRE technique ids.\"\"\"\n",
" mitre_attack_data = MitreAttackData(stix_filepath)\n",
" techniques = mitre_attack_data.get_techniques(remove_revoked_deprecated=True)\n",
"\n",
" all_technique_ids = set()\n",
"\n",
" for technique in techniques:\n",
" external_references = technique.get(\"external_references\")\n",
" mitre_references = tuple(filter(lambda external_reference: external_reference.get(\"source_name\") == \"mitre-attack\", external_references))\n",
" assert len(mitre_references) == 1\n",
" mitre_technique_id = mitre_references[0][\"external_id\"]\n",
" all_technique_ids.add(mitre_technique_id)\n",
"\n",
" return frozenset(all_technique_ids)\n",
"\n",
"def get_campaign_techniques(filepath: str) -> tuple[frozenset[str]]:\n",
" \"\"\"Gets a set of MITRE technique ids present in each campaign.\"\"\"\n",
"\n",
" with open(filepath) as f:\n",
" data = json.load(f)\n",
"\n",
" campaigns = data[\"bags_of_techniques\"]\n",
"\n",
" ret = []\n",
"\n",
" for campaign in campaigns:\n",
"\n",
" techniques = campaign[\"mitre_techniques\"]\n",
" ret.append(frozenset(techniques.keys()))\n",
"\n",
" return ret\n",
"\n",
"def train_test_split(indices: list, values: list, test_ratio: float=0.1) -> tuple:\n",
" n = len(indices)\n",
" assert len(values) == n\n",
Expand All @@ -100,116 +78,218 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def main():\n",
" # want matrix of campaigns on horizontal, techniques on vertical\n",
" all_mitre_technique_ids = tuple(get_mitre_technique_ids(\"../enterprise-attack.json\"))\n",
" mitre_technique_ids_to_index = {all_mitre_technique_ids[i]: i for i in range(len(all_mitre_technique_ids))}\n",
"\n",
" campaigns = get_campaign_techniques(\"../data/combined_dataset_full_frequency.json\")\n",
"\n",
" indices = []\n",
" values = []\n",
"\n",
" # for each campaign, make a vector, filling in each present technique with a 1\n",
" for i in range(len(campaigns)):\n",
"\n",
" campaign = campaigns[i]\n",
"\n",
" for mitre_technique_id in campaign:\n",
" if mitre_technique_id in mitre_technique_ids_to_index:\n",
" # campaign id, technique id\n",
" index = [i, mitre_technique_ids_to_index[mitre_technique_id]]\n",
"\n",
" indices.append(index)\n",
" values.append(1)\n",
"\n",
"def view_prediction_performance_table_for_report(\n",
" train_data: ReportTechniqueMatrix,\n",
" test_data: ReportTechniqueMatrix,\n",
" predictions: pd.DataFrame,\n",
" report_id: int,\n",
" ) -> pd.DataFrame:\n",
" \"\"\"Gets a dataframe to visualize the training data, test data, and predictions for a report.\"\"\"\n",
" # 1. training_data\n",
" training_dataframe = train_data.to_pandas()\n",
" report_train_techniques = training_dataframe.loc[report_id]\n",
" report_train_techniques.name = \"training_data\"\n",
"\n",
" # 2. predictions\n",
" predicted_techniques = predictions.loc[report_id]\n",
" predicted_techniques.name = \"predictions\"\n",
"\n",
" # now test data\n",
" test_dataframe = test_data.to_pandas()\n",
" report_test_techniques = test_dataframe.loc[report_id]\n",
" report_test_techniques.name = \"test_data\"\n",
"\n",
" report_data = pd.concat((predicted_techniques, report_train_techniques, report_test_techniques), axis=1)\n",
"\n",
" # add name for convenience\n",
" all_mitre_technique_ids_to_names = get_mitre_technique_ids_to_names(\"../enterprise-attack.json\")\n",
" report_data.loc[:, \"technique_name\"] = report_data.apply(lambda row: all_mitre_technique_ids_to_names.get(row.name), axis=1)\n",
"\n",
" return report_data\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE Error 0.07193438\n"
]
}
],
"source": [
"test_ratio = 0.1\n",
"embedding_dimension = 10\n",
"\n",
" train_indices, train_values, test_indices, test_values = train_test_split(indices, values)\n",
"data_builder = ReportTechniqueMatrixBuilder(\n",
" combined_dataset_filepath=\"../data/combined_dataset_full_frequency.json\",\n",
" enterprise_attack_filepath=\"../enterprise-attack.json\",\n",
")\n",
"data = data_builder.build()\n",
"\n",
" training_data = tf.SparseTensor(\n",
" indices=train_indices,\n",
" values=train_values,\n",
" dense_shape=(len(campaigns), len(all_mitre_technique_ids))\n",
" )\n",
" test_data = tf.SparseTensor(\n",
" indices=test_indices,\n",
" values=test_values,\n",
" dense_shape=(len(campaigns), len(all_mitre_technique_ids))\n",
" )\n",
"train_indices = frozenset(random.sample(data.indices, k=math.floor((1-test_ratio) * len(data.indices))))\n",
"test_indices = frozenset(data.indices).difference(train_indices)\n",
"\n",
" # train\n",
" model = recommender.FactorizationRecommender(m=len(campaigns), n=len(all_mitre_technique_ids), k=10)\n",
" model.fit(training_data, num_iterations=1000, learning_rate=10.)\n",
"training_data = data.mask(train_indices)\n",
"test_data = data.mask(test_indices)\n",
"\n",
" evaluation = model.evaluate(test_data)\n",
" print(\"MSE\", evaluation)\n",
"# train\n",
"model = recommender.FactorizationRecommender(m=data.m, n=data.n, k=embedding_dimension)\n",
"model.fit(training_data.to_sparse_tensor(), num_iterations=1000, learning_rate=10., regularization_coefficient=0.1, gravity_coefficient=0.0)\n",
"\n",
" predictions = model.predict()\n",
"evaluation = model.evaluate(test_data.to_sparse_tensor())\n",
"print(\"MSE Error\", evaluation)\n",
"\n",
" predictions_dataframe = pd.DataFrame(predictions, columns=all_mitre_technique_ids)\n",
"predictions = model.predict()\n",
"\n",
" print(predictions_dataframe)\n"
"predictions_dataframe = pd.DataFrame(predictions, columns=data.technique_ids)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/anaconda3/envs/tie/lib/python3.11/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n",
" return _methods._mean(a, axis=axis, dtype=dtype,\n",
"/opt/homebrew/anaconda3/envs/tie/lib/python3.11/site-packages/numpy/core/_methods.py:121: RuntimeWarning: invalid value encountered in divide\n",
" ret = um.true_divide(\n"
]
}
],
"source": [
"# get best and worst test performance\n",
"test_ndarray = test_data.to_numpy()\n",
"predictions_ndarray = predictions_dataframe.to_numpy()\n",
"# where test data, use predictions, else, fill with Nan\n",
"test_performance = np.mean(np.square(predictions_ndarray - test_ndarray), axis=1, where=test_ndarray > 0.5)\n",
"\n",
"best_test_perf = np.nanargmin(test_performance, )\n",
"worst_test_perf = np.nanargmax(test_performance)\n",
"\n",
"best_performance_results = view_prediction_performance_table_for_report(\n",
" train_data=training_data,\n",
" test_data=test_data,\n",
" predictions=predictions_dataframe,\n",
" report_id=best_test_perf\n",
")\n",
"\n",
"worst_performance_results = view_prediction_performance_table_for_report(\n",
" train_data=training_data,\n",
" test_data=test_data,\n",
" predictions=predictions_dataframe,\n",
" report_id=worst_test_perf\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE 0.34170407\n",
" T1578 T1547.014 T1030 T1112 T1550.003 T1049 T1092 \\\n",
"0 0.784831 0.377630 0.101901 1.001325 -0.158771 0.694602 0.411023 \n",
"1 -0.763482 0.224698 1.636720 1.009761 1.009708 0.661873 0.355017 \n",
"2 0.544853 -0.018921 0.331601 1.006808 0.202223 1.563948 0.297718 \n",
"3 0.902573 0.454326 0.440447 0.998493 -1.109977 0.140948 0.042689 \n",
"4 -0.050296 0.301442 0.311153 1.002875 1.002803 0.711632 0.073821 \n",
".. ... ... ... ... ... ... ... \n",
"186 0.437564 -0.162894 0.365576 0.994123 -0.190575 0.995859 0.380639 \n",
"187 1.199729 0.407794 -0.073012 0.963695 -0.157554 0.741034 0.225907 \n",
"188 0.499829 0.735397 0.651445 1.002326 0.695246 0.366094 0.020164 \n",
"189 0.358869 0.484445 0.639233 1.018354 0.295820 1.279500 0.540394 \n",
"190 1.458898 0.310654 -0.401749 0.997384 -0.898985 1.000898 0.677263 \n",
"\n",
" T1505.004 T1218.014 T1564.005 ... T1053 T1134.002 T1542.001 \\\n",
"0 0.285281 0.335994 0.902758 ... 0.566240 -0.611709 1.115770 \n",
"1 -0.629567 0.746796 0.346939 ... -0.243643 0.755467 0.519914 \n",
"2 -0.550626 1.306988 1.579033 ... 0.914382 1.859166 -0.776076 \n",
"3 -1.224650 -0.232498 0.439434 ... -0.266217 -0.269991 1.799628 \n",
"4 -0.252191 0.439709 0.809605 ... 0.299892 0.683129 -0.487585 \n",
".. ... ... ... ... ... ... ... \n",
"186 -0.327226 0.823837 1.397233 ... 0.826141 0.927247 -0.261659 \n",
"187 -0.130431 0.093835 1.231475 ... 0.198718 0.316289 0.038840 \n",
"188 -0.414925 -0.820273 1.041236 ... -0.876987 -0.187479 1.266579 \n",
"189 -0.641792 0.559728 1.007619 ... -0.391899 1.019392 0.607860 \n",
"190 -0.092881 1.040499 0.375332 ... 0.440140 -0.377884 0.358312 \n",
" predictions training_data test_data \\\n",
"T1027 1.001889 0.0 1.0 \n",
"T1552.001 0.848310 0.0 0.0 \n",
"T1561 0.326331 0.0 0.0 \n",
"T1573.001 0.939929 0.0 0.0 \n",
"T1485 0.451991 0.0 0.0 \n",
"T1190 0.950845 1.0 0.0 \n",
"T1132.001 0.248229 0.0 0.0 \n",
"T1078.004 0.772365 0.0 0.0 \n",
"T1056.001 0.943955 0.0 0.0 \n",
"T1001.002 0.234360 0.0 0.0 \n",
"T1095 0.881229 0.0 0.0 \n",
"T1110.003 0.727067 0.0 0.0 \n",
"T1078.002 0.852815 0.0 0.0 \n",
"T1218.011 0.918010 0.0 0.0 \n",
"T1553.004 0.275539 0.0 0.0 \n",
"\n",
" T1071.002 T1569 T1550.001 T1584.001 T1656 T1008 T1505.005 \n",
"0 -0.296587 0.878848 0.176307 0.307053 0.391668 -0.267710 0.141556 \n",
"1 0.566720 0.062479 -0.699203 -1.095403 0.726196 -0.411449 0.335667 \n",
"2 0.987327 0.377138 0.462070 -0.883505 2.029478 0.422649 0.337644 \n",
"3 -0.875782 0.071158 0.010061 1.022122 -0.773336 -0.457362 0.303583 \n",
"4 0.645899 0.741363 0.021589 0.063035 0.143244 0.868330 0.655884 \n",
".. ... ... ... ... ... ... ... \n",
"186 0.506374 0.350138 1.159779 -0.838942 1.277720 -0.040173 0.016383 \n",
"187 0.118864 1.013651 0.930431 0.365024 0.448756 0.507837 -0.008061 \n",
"188 0.261102 0.962185 -0.047198 0.115494 -0.060542 0.513719 -0.190671 \n",
"189 -0.084509 0.162352 0.334240 -1.347584 2.198924 -0.044667 0.445749 \n",
"190 -1.493496 0.578612 0.991296 0.092978 0.497014 -0.455444 0.495544 \n",
" technique_name \n",
"T1027 Obfuscated Files or Information \n",
"T1552.001 Credentials In Files \n",
"T1561 Disk Wipe \n",
"T1573.001 Symmetric Cryptography \n",
"T1485 Data Destruction \n",
"T1190 Exploit Public-Facing Application \n",
"T1132.001 Standard Encoding \n",
"T1078.004 Cloud Accounts \n",
"T1056.001 Keylogging \n",
"T1001.002 Steganography \n",
"T1095 Non-Application Layer Protocol \n",
"T1110.003 Password Spraying \n",
"T1078.002 Domain Accounts \n",
"T1218.011 Rundll32 \n",
"T1553.004 Install Root Certificate \n"
]
}
],
"source": [
"print(best_performance_results.sort_values(\"test_data\", ascending=False).head(15))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" predictions training_data test_data \\\n",
"T1106 -0.030872 0.0 1.0 \n",
"T1552.001 -0.026832 0.0 0.0 \n",
"T1561 -0.000862 0.0 0.0 \n",
"T1573.001 -0.027231 0.0 0.0 \n",
"T1485 -0.000736 0.0 0.0 \n",
"T1190 -0.031759 0.0 0.0 \n",
"T1132.001 0.010440 0.0 0.0 \n",
"T1078.004 -0.028884 0.0 0.0 \n",
"T1056.001 -0.029308 0.0 0.0 \n",
"T1001.002 0.004404 0.0 0.0 \n",
"T1095 -0.023185 0.0 0.0 \n",
"T1110.003 -0.021503 0.0 0.0 \n",
"T1078.002 -0.027470 0.0 0.0 \n",
"T1218.011 -0.031216 0.0 0.0 \n",
"T1553.004 0.005378 0.0 0.0 \n",
"\n",
"[191 rows x 625 columns]\n"
" technique_name \n",
"T1106 Native API \n",
"T1552.001 Credentials In Files \n",
"T1561 Disk Wipe \n",
"T1573.001 Symmetric Cryptography \n",
"T1485 Data Destruction \n",
"T1190 Exploit Public-Facing Application \n",
"T1132.001 Standard Encoding \n",
"T1078.004 Cloud Accounts \n",
"T1056.001 Keylogging \n",
"T1001.002 Steganography \n",
"T1095 Non-Application Layer Protocol \n",
"T1110.003 Password Spraying \n",
"T1078.002 Domain Accounts \n",
"T1218.011 Rundll32 \n",
"T1553.004 Install Root Certificate \n"
]
}
],
"source": [
"main()"
"print(worst_performance_results.sort_values(\"test_data\", ascending=False).head(15))"
]
}
],
Expand Down
Loading

0 comments on commit ee53265

Please sign in to comment.