diff --git a/models/main.ipynb b/models/main.ipynb index fd25110..a17d01e 100644 --- a/models/main.ipynb +++ b/models/main.ipynb @@ -2,16 +2,24 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + }, { "data": { "text/plain": [ "" ] }, - "execution_count": 1, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -22,13 +30,16 @@ "\n", "# Imports\n", "import json\n", - "from mitreattack.stix20 import MitreAttackData\n", "import tensorflow as tf\n", "import recommender\n", + "from matrix import ReportTechniqueMatrix\n", + "from matrix_builder import ReportTechniqueMatrixBuilder\n", "import random\n", "import math\n", "import importlib\n", "import pandas as pd\n", + "import numpy as np\n", + "from utils import get_mitre_technique_ids_to_names\n", "\n", "tf.config.run_functions_eagerly(True)\n", "\n", @@ -39,43 +50,10 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "def get_mitre_technique_ids(stix_filepath: str) -> frozenset[str]:\n", - " \"\"\"Gets all MITRE technique ids.\"\"\"\n", - " mitre_attack_data = MitreAttackData(stix_filepath)\n", - " techniques = mitre_attack_data.get_techniques(remove_revoked_deprecated=True)\n", - "\n", - " all_technique_ids = set()\n", - "\n", - " for technique in techniques:\n", - " external_references = technique.get(\"external_references\")\n", - " mitre_references = tuple(filter(lambda external_reference: external_reference.get(\"source_name\") == \"mitre-attack\", external_references))\n", - " assert len(mitre_references) == 1\n", - " mitre_technique_id = mitre_references[0][\"external_id\"]\n", - " all_technique_ids.add(mitre_technique_id)\n", - "\n", - " return frozenset(all_technique_ids)\n", - "\n", - "def get_campaign_techniques(filepath: str) -> tuple[frozenset[str]]:\n", - " \"\"\"Gets a set of MITRE technique ids present in each campaign.\"\"\"\n", - "\n", - " with open(filepath) as f:\n", - " data = json.load(f)\n", - "\n", - " campaigns = data[\"bags_of_techniques\"]\n", - "\n", - " ret = []\n", - "\n", - " for campaign in campaigns:\n", - "\n", - " techniques = campaign[\"mitre_techniques\"]\n", - " ret.append(frozenset(techniques.keys()))\n", - "\n", - " return ret\n", - "\n", "def train_test_split(indices: list, values: list, test_ratio: float=0.1) -> tuple:\n", " n = len(indices)\n", " assert len(values) == n\n", @@ -100,116 +78,218 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "def main():\n", - " # want matrix of campaigns on horizontal, techniques on vertical\n", - " all_mitre_technique_ids = tuple(get_mitre_technique_ids(\"../enterprise-attack.json\"))\n", - " mitre_technique_ids_to_index = {all_mitre_technique_ids[i]: i for i in range(len(all_mitre_technique_ids))}\n", - "\n", - " campaigns = get_campaign_techniques(\"../data/combined_dataset_full_frequency.json\")\n", - "\n", - " indices = []\n", - " values = []\n", - "\n", - " # for each campaign, make a vector, filling in each present technique with a 1\n", - " for i in range(len(campaigns)):\n", - "\n", - " campaign = campaigns[i]\n", - "\n", - " for mitre_technique_id in campaign:\n", - " if mitre_technique_id in mitre_technique_ids_to_index:\n", - " # campaign id, technique id\n", - " index = [i, mitre_technique_ids_to_index[mitre_technique_id]]\n", - "\n", - " indices.append(index)\n", - " values.append(1)\n", - "\n", + "def view_prediction_performance_table_for_report(\n", + " train_data: ReportTechniqueMatrix,\n", + " test_data: ReportTechniqueMatrix,\n", + " predictions: pd.DataFrame,\n", + " report_id: int,\n", + " ) -> pd.DataFrame:\n", + " \"\"\"Gets a dataframe to visualize the training data, test data, and predictions for a report.\"\"\"\n", + " # 1. training_data\n", + " training_dataframe = train_data.to_pandas()\n", + " report_train_techniques = training_dataframe.loc[report_id]\n", + " report_train_techniques.name = \"training_data\"\n", + "\n", + " # 2. predictions\n", + " predicted_techniques = predictions.loc[report_id]\n", + " predicted_techniques.name = \"predictions\"\n", + "\n", + " # now test data\n", + " test_dataframe = test_data.to_pandas()\n", + " report_test_techniques = test_dataframe.loc[report_id]\n", + " report_test_techniques.name = \"test_data\"\n", + "\n", + " report_data = pd.concat((predicted_techniques, report_train_techniques, report_test_techniques), axis=1)\n", + "\n", + " # add name for convenience\n", + " all_mitre_technique_ids_to_names = get_mitre_technique_ids_to_names(\"../enterprise-attack.json\")\n", + " report_data.loc[:, \"technique_name\"] = report_data.apply(lambda row: all_mitre_technique_ids_to_names.get(row.name), axis=1)\n", + "\n", + " return report_data\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSE Error 0.07193438\n" + ] + } + ], + "source": [ + "test_ratio = 0.1\n", + "embedding_dimension = 10\n", "\n", - " train_indices, train_values, test_indices, test_values = train_test_split(indices, values)\n", + "data_builder = ReportTechniqueMatrixBuilder(\n", + " combined_dataset_filepath=\"../data/combined_dataset_full_frequency.json\",\n", + " enterprise_attack_filepath=\"../enterprise-attack.json\",\n", + ")\n", + "data = data_builder.build()\n", "\n", - " training_data = tf.SparseTensor(\n", - " indices=train_indices,\n", - " values=train_values,\n", - " dense_shape=(len(campaigns), len(all_mitre_technique_ids))\n", - " )\n", - " test_data = tf.SparseTensor(\n", - " indices=test_indices,\n", - " values=test_values,\n", - " dense_shape=(len(campaigns), len(all_mitre_technique_ids))\n", - " )\n", + "train_indices = frozenset(random.sample(data.indices, k=math.floor((1-test_ratio) * len(data.indices))))\n", + "test_indices = frozenset(data.indices).difference(train_indices)\n", "\n", - " # train\n", - " model = recommender.FactorizationRecommender(m=len(campaigns), n=len(all_mitre_technique_ids), k=10)\n", - " model.fit(training_data, num_iterations=1000, learning_rate=10.)\n", + "training_data = data.mask(train_indices)\n", + "test_data = data.mask(test_indices)\n", "\n", - " evaluation = model.evaluate(test_data)\n", - " print(\"MSE\", evaluation)\n", + "# train\n", + "model = recommender.FactorizationRecommender(m=data.m, n=data.n, k=embedding_dimension)\n", + "model.fit(training_data.to_sparse_tensor(), num_iterations=1000, learning_rate=10., regularization_coefficient=0.1, gravity_coefficient=0.0)\n", "\n", - " predictions = model.predict()\n", + "evaluation = model.evaluate(test_data.to_sparse_tensor())\n", + "print(\"MSE Error\", evaluation)\n", "\n", - " predictions_dataframe = pd.DataFrame(predictions, columns=all_mitre_technique_ids)\n", + "predictions = model.predict()\n", "\n", - " print(predictions_dataframe)\n" + "predictions_dataframe = pd.DataFrame(predictions, columns=data.technique_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/anaconda3/envs/tie/lib/python3.11/site-packages/numpy/core/fromnumeric.py:3504: RuntimeWarning: Mean of empty slice.\n", + " return _methods._mean(a, axis=axis, dtype=dtype,\n", + "/opt/homebrew/anaconda3/envs/tie/lib/python3.11/site-packages/numpy/core/_methods.py:121: RuntimeWarning: invalid value encountered in divide\n", + " ret = um.true_divide(\n" + ] + } + ], + "source": [ + "# get best and worst test performance\n", + "test_ndarray = test_data.to_numpy()\n", + "predictions_ndarray = predictions_dataframe.to_numpy()\n", + "# where test data, use predictions, else, fill with Nan\n", + "test_performance = np.mean(np.square(predictions_ndarray - test_ndarray), axis=1, where=test_ndarray > 0.5)\n", + "\n", + "best_test_perf = np.nanargmin(test_performance, )\n", + "worst_test_perf = np.nanargmax(test_performance)\n", + "\n", + "best_performance_results = view_prediction_performance_table_for_report(\n", + " train_data=training_data,\n", + " test_data=test_data,\n", + " predictions=predictions_dataframe,\n", + " report_id=best_test_perf\n", + ")\n", + "\n", + "worst_performance_results = view_prediction_performance_table_for_report(\n", + " train_data=training_data,\n", + " test_data=test_data,\n", + " predictions=predictions_dataframe,\n", + " report_id=worst_test_perf\n", + ")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "MSE 0.34170407\n", - " T1578 T1547.014 T1030 T1112 T1550.003 T1049 T1092 \\\n", - "0 0.784831 0.377630 0.101901 1.001325 -0.158771 0.694602 0.411023 \n", - "1 -0.763482 0.224698 1.636720 1.009761 1.009708 0.661873 0.355017 \n", - "2 0.544853 -0.018921 0.331601 1.006808 0.202223 1.563948 0.297718 \n", - "3 0.902573 0.454326 0.440447 0.998493 -1.109977 0.140948 0.042689 \n", - "4 -0.050296 0.301442 0.311153 1.002875 1.002803 0.711632 0.073821 \n", - ".. ... ... ... ... ... ... ... \n", - "186 0.437564 -0.162894 0.365576 0.994123 -0.190575 0.995859 0.380639 \n", - "187 1.199729 0.407794 -0.073012 0.963695 -0.157554 0.741034 0.225907 \n", - "188 0.499829 0.735397 0.651445 1.002326 0.695246 0.366094 0.020164 \n", - "189 0.358869 0.484445 0.639233 1.018354 0.295820 1.279500 0.540394 \n", - "190 1.458898 0.310654 -0.401749 0.997384 -0.898985 1.000898 0.677263 \n", - "\n", - " T1505.004 T1218.014 T1564.005 ... T1053 T1134.002 T1542.001 \\\n", - "0 0.285281 0.335994 0.902758 ... 0.566240 -0.611709 1.115770 \n", - "1 -0.629567 0.746796 0.346939 ... -0.243643 0.755467 0.519914 \n", - "2 -0.550626 1.306988 1.579033 ... 0.914382 1.859166 -0.776076 \n", - "3 -1.224650 -0.232498 0.439434 ... -0.266217 -0.269991 1.799628 \n", - "4 -0.252191 0.439709 0.809605 ... 0.299892 0.683129 -0.487585 \n", - ".. ... ... ... ... ... ... ... \n", - "186 -0.327226 0.823837 1.397233 ... 0.826141 0.927247 -0.261659 \n", - "187 -0.130431 0.093835 1.231475 ... 0.198718 0.316289 0.038840 \n", - "188 -0.414925 -0.820273 1.041236 ... -0.876987 -0.187479 1.266579 \n", - "189 -0.641792 0.559728 1.007619 ... -0.391899 1.019392 0.607860 \n", - "190 -0.092881 1.040499 0.375332 ... 0.440140 -0.377884 0.358312 \n", + " predictions training_data test_data \\\n", + "T1027 1.001889 0.0 1.0 \n", + "T1552.001 0.848310 0.0 0.0 \n", + "T1561 0.326331 0.0 0.0 \n", + "T1573.001 0.939929 0.0 0.0 \n", + "T1485 0.451991 0.0 0.0 \n", + "T1190 0.950845 1.0 0.0 \n", + "T1132.001 0.248229 0.0 0.0 \n", + "T1078.004 0.772365 0.0 0.0 \n", + "T1056.001 0.943955 0.0 0.0 \n", + "T1001.002 0.234360 0.0 0.0 \n", + "T1095 0.881229 0.0 0.0 \n", + "T1110.003 0.727067 0.0 0.0 \n", + "T1078.002 0.852815 0.0 0.0 \n", + "T1218.011 0.918010 0.0 0.0 \n", + "T1553.004 0.275539 0.0 0.0 \n", "\n", - " T1071.002 T1569 T1550.001 T1584.001 T1656 T1008 T1505.005 \n", - "0 -0.296587 0.878848 0.176307 0.307053 0.391668 -0.267710 0.141556 \n", - "1 0.566720 0.062479 -0.699203 -1.095403 0.726196 -0.411449 0.335667 \n", - "2 0.987327 0.377138 0.462070 -0.883505 2.029478 0.422649 0.337644 \n", - "3 -0.875782 0.071158 0.010061 1.022122 -0.773336 -0.457362 0.303583 \n", - "4 0.645899 0.741363 0.021589 0.063035 0.143244 0.868330 0.655884 \n", - ".. ... ... ... ... ... ... ... \n", - "186 0.506374 0.350138 1.159779 -0.838942 1.277720 -0.040173 0.016383 \n", - "187 0.118864 1.013651 0.930431 0.365024 0.448756 0.507837 -0.008061 \n", - "188 0.261102 0.962185 -0.047198 0.115494 -0.060542 0.513719 -0.190671 \n", - "189 -0.084509 0.162352 0.334240 -1.347584 2.198924 -0.044667 0.445749 \n", - "190 -1.493496 0.578612 0.991296 0.092978 0.497014 -0.455444 0.495544 \n", + " technique_name \n", + "T1027 Obfuscated Files or Information \n", + "T1552.001 Credentials In Files \n", + "T1561 Disk Wipe \n", + "T1573.001 Symmetric Cryptography \n", + "T1485 Data Destruction \n", + "T1190 Exploit Public-Facing Application \n", + "T1132.001 Standard Encoding \n", + "T1078.004 Cloud Accounts \n", + "T1056.001 Keylogging \n", + "T1001.002 Steganography \n", + "T1095 Non-Application Layer Protocol \n", + "T1110.003 Password Spraying \n", + "T1078.002 Domain Accounts \n", + "T1218.011 Rundll32 \n", + "T1553.004 Install Root Certificate \n" + ] + } + ], + "source": [ + "print(best_performance_results.sort_values(\"test_data\", ascending=False).head(15))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " predictions training_data test_data \\\n", + "T1106 -0.030872 0.0 1.0 \n", + "T1552.001 -0.026832 0.0 0.0 \n", + "T1561 -0.000862 0.0 0.0 \n", + "T1573.001 -0.027231 0.0 0.0 \n", + "T1485 -0.000736 0.0 0.0 \n", + "T1190 -0.031759 0.0 0.0 \n", + "T1132.001 0.010440 0.0 0.0 \n", + "T1078.004 -0.028884 0.0 0.0 \n", + "T1056.001 -0.029308 0.0 0.0 \n", + "T1001.002 0.004404 0.0 0.0 \n", + "T1095 -0.023185 0.0 0.0 \n", + "T1110.003 -0.021503 0.0 0.0 \n", + "T1078.002 -0.027470 0.0 0.0 \n", + "T1218.011 -0.031216 0.0 0.0 \n", + "T1553.004 0.005378 0.0 0.0 \n", "\n", - "[191 rows x 625 columns]\n" + " technique_name \n", + "T1106 Native API \n", + "T1552.001 Credentials In Files \n", + "T1561 Disk Wipe \n", + "T1573.001 Symmetric Cryptography \n", + "T1485 Data Destruction \n", + "T1190 Exploit Public-Facing Application \n", + "T1132.001 Standard Encoding \n", + "T1078.004 Cloud Accounts \n", + "T1056.001 Keylogging \n", + "T1001.002 Steganography \n", + "T1095 Non-Application Layer Protocol \n", + "T1110.003 Password Spraying \n", + "T1078.002 Domain Accounts \n", + "T1218.011 Rundll32 \n", + "T1553.004 Install Root Certificate \n" ] } ], "source": [ - "main()" + "print(worst_performance_results.sort_values(\"test_data\", ascending=False).head(15))" ] } ], diff --git a/models/matrix.py b/models/matrix.py new file mode 100644 index 0000000..f561c0f --- /dev/null +++ b/models/matrix.py @@ -0,0 +1,142 @@ +import tensorflow as tf +import numpy as np +import pandas as pd + + +class ReportTechniqueMatrix: + # Abstraction function: + # AF(indices, values, report_ids, technique_ids) = a sparse matrix A where + # A_{ij} = values[k] where k is the index for (i, j) in indices, if present. + # and A_{ij} corresponds to the report report_ids[i] and + # technique technique_ids[j] + # Rep invariant: + # - len(indices) > 0 + # - len(values) > 0 + # - # TODO every row contains value + # - # TODO every column contains value + # Safety from rep exposure: + # - all fields in rep are private and immutable + + def __init__( + self, + indices: tuple[tuple[int]], + values: tuple[int], + report_ids: tuple[int], + technique_ids: tuple[str], + ): + """Initializes a ReportTechniqueMatrix object. + + Args: + indices: iterable of indices of the format (row, column) of matrix entries + values: iterable of matrix entry values such that values[i] contains the + entry for indices[i] for all i. + report_ids: unique identifiers for reports such that report_ids[i] is the + identifier for row i of the sparse matrix. + technique_ids: unique identifiers for techniques such that technique_ids[i] + is the unique identifier for column j of the sparse matrix. + """ + + self._indices = tuple(indices) + self._values = tuple(values) + self._report_ids = tuple(report_ids) + self._technique_ids = tuple(technique_ids) + + self._checkrep() + + def _checkrep(self): + """Asserts the rep invariant.""" + # - len(indices) > 0 + assert len(self._indices) > 0 + # - len(values) > 0 + assert len(self._values) > 0 + + @property + def m(self): + """The number of rows of the matrix.""" + self._checkrep() + return len(self._report_ids) + + @property + def n(self): + """The number of columns of the matrix.""" + self._checkrep() + return len(self._technique_ids) + + @property + def shape(self) -> tuple[int]: + """Gets the shape of the matrix.""" + return (self.m, self.n) + + @property + def indices(self) -> tuple[tuple[int]]: + """Gets the nonempty indices of the matrix.""" + # ok since immutable + self._checkrep() + return self._indices + + @property + def technique_ids(self) -> tuple[str]: + """Gets the technique ids that make up the column index of the matrix.""" + return self._technique_ids + + def to_sparse_tensor(self) -> tf.SparseTensor: + """Converts the matrix to a sparse tensor.""" + self._checkrep() + return tf.SparseTensor( + indices=self._indices, values=self._values, dense_shape=(self.m, self.n) + ) + + def to_numpy(self) -> np.ndarray: + """Converts the matrix to a numpy array of shape.""" + data = np.zeros(self.shape) + + horizontal_indices = tuple(index[0] for index in self._indices) + vertical_indices = tuple(index[1] for index in self._indices) + + data[horizontal_indices, vertical_indices] = self._values + + self._checkrep() + return data + + def to_pandas(self) -> pd.DataFrame: + """Converts the matrix to a pandas dataframe.""" + self._checkrep() + return pd.DataFrame( + data=self.to_numpy(), + index=self._report_ids, + columns=self._technique_ids, + ) + + def mask(self, indices: frozenset[tuple[int]]): # -> ReportTechniqueMatrix: + """Generates a new ReportTechniqueMatrix object with only a subset of the indices. + + Args: + indices: indices to include in the new object. + + Returns: + A new ReportTechniqueMatrix object. + """ + new_indices = [] + new_values = [] + + for i in range(len(self._indices)): + + old_index = self._indices[i] + + if old_index in indices: + + old_value = self._values[i] + new_indices.append(old_index) + new_values.append(old_value) + + assert len(new_indices) == len(indices) + assert len(new_values) == len(indices) + + self._checkrep() + + return ReportTechniqueMatrix( + indices=new_indices, + values=new_values, + report_ids=self._report_ids, + technique_ids=self._technique_ids, + ) diff --git a/models/matrix_builder.py b/models/matrix_builder.py new file mode 100644 index 0000000..c1d7adf --- /dev/null +++ b/models/matrix_builder.py @@ -0,0 +1,125 @@ +import json +from matrix import ReportTechniqueMatrix +from utils import get_mitre_technique_ids_to_names + + +class ReportTechniqueMatrixBuilder: + """A builder for report technique matrices.""" + + # Abstraction function: + # AF(combined_datset_filepath, enterprise_attack_filepath) = a builder for + # ReportTechniqueMatrix objects which adds m reports from the json object + # located at combined_dataset_filepath, zero-indexing them according to their + # location in the json, and n techniques according to the cardinality of the + # set of all techniques from all bags of techniques in the json at + # combined_dataset_filepath. Techniques are indexed by MITRE ATT&CK id. + # Rep invariant: + # - len(combined_dataset_filepath) >= 0 + # - len(enterprise_attack_filepath) >= 0 + # Safety from rep exposure: + # - rep is private, and immutable and never reasssigned + + def __init__(self, combined_dataset_filepath: str, enterprise_attack_filepath: str): + """Initializes a ReportTechniqueMatrixBuilder object.""" + + self._combined_datset_filepath = combined_dataset_filepath + self._enterprise_attack_filepath = enterprise_attack_filepath + + self._checkrep() + + def _checkrep(self): + """Asserts the rep invariant.""" + # - len(combined_dataset_filepath) >= 0 + assert len(self._combined_datset_filepath) >= 0 + # - len(enterprise_attack_filepath) >= 0 + assert len(self._enterprise_attack_filepath) >= 0 + + def _get_report_techniques(self, filepath: str) -> tuple[frozenset[str]]: + """Gets a set of all MITRE technique ids present in each report. + + Reports are in order of appearance in the json combined dataset located at filepath. + + All techniques are returned, regardless of whether they are valid + MITRE ATT&CK techniques. + + Args: + filepath: location of the json combined dataset. + + Returns: + An iterable of sets of techniques, where the ith set represents the set of + techniques in the ith report in the combined dataset. + """ + with open(filepath) as f: + data = json.load(f) + + reports = data["bags_of_techniques"] + + report_techniques = [] + + for report in reports: + + techniques = report["mitre_techniques"] + report_techniques.append(frozenset(techniques.keys())) + + self._checkrep() + + return tuple(report_techniques) + + def build(self) -> ReportTechniqueMatrix: + """Builds a ReportTechniqueMatrix. + + The rows of the matrix consist of the reports sourced from the combined dataset, + zero-indexed in their order of appearance in the json file. The columns consist + of the set of all techniques in the MITRE ATT&CK 2.* series framework mentioned + in the json. Value i,j of the matrix is 1.0 if technique j is mentioned in + report i, 0.0 otherwise. + + Returns: + A matrix of report data. + """ + # want matrix of reports on horizontal, techniques on vertical + reports = self._get_report_techniques(self._combined_datset_filepath) + all_mitre_technique_ids_to_names = get_mitre_technique_ids_to_names( + self._enterprise_attack_filepath + ) + + # get all techniques present in all reports + all_report_technique_ids = set() + for report in reports: + all_report_technique_ids.update(report) + # some reports contain invalid techniques from ATT&CK v1 + technique_ids = tuple( + set(all_mitre_technique_ids_to_names.keys()).intersection( + all_report_technique_ids + ) + ) + + techniques_to_index = {technique_ids[i]: i for i in range(len(technique_ids))} + + indices = [] + values = [] + report_ids = tuple(range(len(reports))) + + # for each campaign, make a vector, filling in each present technique with a 1 + for i in range(len(reports)): + report = reports[i] + + for mitre_technique_id in report: + + if mitre_technique_id in techniques_to_index: + # campaign id, technique id + index = (i, techniques_to_index[mitre_technique_id]) + + indices.append(index) + values.append(1) + + data = ReportTechniqueMatrix( + indices=indices, + values=values, + report_ids=report_ids, + technique_ids=technique_ids, + ) + + self._checkrep() + + return data diff --git a/models/recommender/factorization_recommender.py b/models/recommender/factorization_recommender.py index 6d170b0..958b5e6 100644 --- a/models/recommender/factorization_recommender.py +++ b/models/recommender/factorization_recommender.py @@ -62,7 +62,7 @@ def _checkrep(self): # - V.shape[1] > 0 assert self._V.shape[1] > 0 # - loss is not None - assert self._loss is not None + # assert self._loss is not None def _get_estimated_matrix(self) -> tf.Tensor: """Gets the estimated matrix UV^T.""" @@ -78,18 +78,63 @@ def _predict(self, data: tf.SparseTensor) -> tf.Tensor: of entries as the training data. Args: - data: An mxn sparse tensor of data. + data: An mxn sparse tensor of data containing p nonzero entries. Returns: - A tensor of predictions. + A length-p tensor of predictions, where predictions[i] corresponds to the + prediction for index data.indices[i]. """ - # indices contains indices of non-null entries # of data # gather_nd will get those entries in order and # add to an array return tf.gather_nd(self._get_estimated_matrix(), data.indices) + def _calculate_regularized_loss( + self, + data: tf.SparseTensor, + predictions: tf.Tensor, + regularization_coefficient: float, + gravity_coefficient: float, + ) -> float: + """Gets the regularized loss function. + + The regularized loss is the sum of: + - The MSE between data and predictions. + - A regularization term which is the average of the squared norm of each + entity embedding, plus the average of the squared norm of each item embedding + r = 1/m \sum_i ||U_i||^2 + 1/n \sum_j ||V_j||^2 + - A gravity term which is the average of the squares of all predictions. + g = 1/(MN) \sum_{ij} (UV^T)_{ij}^2 + + Args: + data: the data on which to evaluate. Predictions will be evaluated for + every non-null entry of data. + predictions: the model predictions on which to evaluate. Requires that + predictions[i] contains the predictions for data.indices[i]. + regularization_coefficient: the coefficient for the regularization component + of the loss function. + gravity_coefficient: the coefficient for the gravity component of the loss + function. + + Returns: + The regularized loss. + """ + regularization_loss = regularization_coefficient * ( + tf.reduce_sum(self._U * self._U) / self._U.shape[0] + + tf.reduce_sum(self._V * self._V) / self._V.shape[0] + ) + + gravity = ( + 1.0 + / (self._U.shape[0] * self._V.shape[0]) + * tf.reduce_sum(tf.square(tf.matmul(self._U, self._V, transpose_b=True))) + ) + + gravity_loss = gravity_coefficient * gravity + + return self._loss(data, predictions) + regularization_loss + gravity_loss + def _calculate_mean_square_error(self, data: tf.SparseTensor) -> tf.Tensor: """Calculates the mean squared error between observed values in the data and predictions from UV^T. @@ -98,11 +143,7 @@ def _calculate_mean_square_error(self, data: tf.SparseTensor) -> tf.Tensor: where Omega is the set of observed entries in training_data. Args: - data: A matrix of observations of dense_shape [N, M] - UY: A dense Tensor of shape [N, k] where k is the embedding - dimension, such that U_i is the embedding of element i. - V: A dense Tensor of shape [M, k] where k is the embedding - dimension, such that V_j is the embedding of element j. + data: A matrix of observations of dense_shape m, n Returns: A scalar Tensor representing the MSE between the true ratings and the @@ -112,7 +153,14 @@ def _calculate_mean_square_error(self, data: tf.SparseTensor) -> tf.Tensor: loss = self._loss(data.values, predictions) return loss - def fit(self, data: tf.SparseTensor, num_iterations: int, learning_rate: float): + def fit( + self, + data: tf.SparseTensor, + num_iterations: int, + learning_rate: float, + regularization_coefficient: float, + gravity_coefficient: float, + ): """Fits the model to data. Args: @@ -125,10 +173,15 @@ def fit(self, data: tf.SparseTensor, num_iterations: int, learning_rate: float): for i in range(num_iterations + 1): with tf.GradientTape() as tape: - predictions = tf.gather_nd( - tf.matmul(self._U, self._V, transpose_b=True), data.indices + # need to predict here and not in loss so doesn't affect gradient + predictions = self._predict(data) + + loss = self._calculate_regularized_loss( + data.values, + predictions, + regularization_coefficient, + gravity_coefficient, ) - loss = self._loss(data.values, predictions) gradients = tape.gradient(loss, [self._U, self._V]) optimizer.apply_gradients(zip(gradients, [self._U, self._V])) diff --git a/models/utils.py b/models/utils.py new file mode 100644 index 0000000..c2689ca --- /dev/null +++ b/models/utils.py @@ -0,0 +1,24 @@ +from mitreattack.stix20 import MitreAttackData + + +def get_mitre_technique_ids_to_names(stix_filepath: str) -> dict[str, str]: + """Gets all MITRE technique ids mapped to their description.""" + mitre_attack_data = MitreAttackData(stix_filepath) + techniques = mitre_attack_data.get_techniques(remove_revoked_deprecated=True) + + all_technique_ids = {} + + for technique in techniques: + external_references = technique.get("external_references") + mitre_references = tuple( + filter( + lambda external_reference: external_reference.get("source_name") + == "mitre-attack", + external_references, + ) + ) + assert len(mitre_references) == 1 + mitre_technique_id = mitre_references[0]["external_id"] + all_technique_ids[mitre_technique_id] = technique.get("name") + + return all_technique_ids