forked from mrwadams/stride-gpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
attack_tree.py
179 lines (143 loc) · 7.25 KB
/
attack_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import re
import requests
import streamlit as st
from mistralai import Mistral
from openai import OpenAI, AzureOpenAI
# Function to create a prompt to generate an attack tree
def create_attack_tree_prompt(app_type, authentication, internet_facing, sensitive_data, app_input):
prompt = f"""
APPLICATION TYPE: {app_type}
AUTHENTICATION METHODS: {authentication}
INTERNET FACING: {internet_facing}
SENSITIVE DATA: {sensitive_data}
APPLICATION DESCRIPTION: {app_input}
"""
return prompt
# Function to get attack tree from the GPT response.
def get_attack_tree(api_key, model_name, prompt):
client = OpenAI(api_key=api_key)
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": """
Act as a cyber security expert with more than 20 years experience of using the STRIDE threat modelling methodology to produce comprehensive threat models for a wide range of applications. Your task is to use the application description provided to you to produce an attack tree in Mermaid syntax. The attack tree should reflect the potential threats for the application based on the details given.
You MUST only respond with the Mermaid code block. See below for a simple example of the required format and syntax for your output.
```mermaid
graph TD
A[Enter Chart Definition] --> B(Preview)
B --> C{{decide}}
C --> D["Keep"]
C --> E["Edit Definition (Edit)"]
E --> B
D --> F["Save Image and Code"]
F --> B
```
IMPORTANT: Round brackets are special characters in Mermaid syntax. If you want to use round brackets inside a node label you MUST wrap the label in double quotes. For example, ["Example Node Label (ENL)"].
"""},
{"role": "user", "content": prompt}
]
)
# Access the 'content' attribute of the 'message' object directly
attack_tree_code = response.choices[0].message.content
# Remove Markdown code block delimiters using regular expression
attack_tree_code = re.sub(r'^```mermaid\s*|\s*```$', '', attack_tree_code, flags=re.MULTILINE)
return attack_tree_code
# Function to get attack tree from the Azure OpenAI response.
def get_attack_tree_azure(azure_api_endpoint, azure_api_key, azure_api_version, azure_deployment_name, prompt):
client = AzureOpenAI(
azure_endpoint = azure_api_endpoint,
api_key = azure_api_key,
api_version = azure_api_version,
)
response = client.chat.completions.create(
model = azure_deployment_name,
messages=[
{"role": "system", "content": """
Act as a cyber security expert with more than 20 years experience of using the STRIDE threat modelling methodology to produce comprehensive threat models for a wide range of applications. Your task is to use the application description provided to you to produce an attack tree in Mermaid syntax. The attack tree should reflect the potential threats for the application based on the details given.
You MUST only respond with the Mermaid code block. See below for a simple example of the required format and syntax for your output.
```mermaid
graph TD
A[Enter Chart Definition] --> B(Preview)
B --> C{{decide}}
C --> D["Keep"]
C --> E["Edit Definition (Edit)"]
E --> B
D --> F["Save Image and Code"]
F --> B
```
IMPORTANT: Round brackets are special characters in Mermaid syntax. If you want to use round brackets inside a node label you MUST wrap the label in double quotes. For example, ["Example Node Label (ENL)"].
"""},
{"role": "user", "content": prompt}
]
)
# Access the 'content' attribute of the 'message' object directly
attack_tree_code = response.choices[0].message.content
# Remove Markdown code block delimiters using regular expression
attack_tree_code = re.sub(r'^```mermaid\s*|\s*```$', '', attack_tree_code, flags=re.MULTILINE)
return attack_tree_code
# Function to get attack tree from the Mistral model's response.
def get_attack_tree_mistral(mistral_api_key, mistral_model, prompt):
client = Mistral(api_key=mistral_api_key)
response = client.chat.complete(
model=mistral_model,
messages=[
{"role": "system", "content": """
Act as a cyber security expert with more than 20 years experience of using the STRIDE threat modelling methodology to produce comprehensive threat models for a wide range of applications. Your task is to use the application description provided to you to produce an attack tree in Mermaid syntax. The attack tree should reflect the potential threats for the application based on the details given.
You MUST only respond with the Mermaid code block. See below for a simple example of the required format and syntax for your output.
```mermaid
graph TD
A[Enter Chart Definition] --> B(Preview)
B --> C{{decide}}
C --> D["Keep"]
C --> E["Edit Definition (Edit)"]
E --> B
D --> F["Save Image and Code"]
F --> B
```
IMPORTANT: Round brackets are special characters in Mermaid syntax. If you want to use round brackets inside a node label you MUST wrap the label in double quotes. For example, ["Example Node Label (ENL)"].
"""},
{"role": "user", "content": prompt}
]
)
# Access the 'content' attribute of the 'message' object directly
attack_tree_code = response.choices[0].message.content
# Remove Markdown code block delimiters using regular expression
attack_tree_code = re.sub(r'^```mermaid\s*|\s*```$', '', attack_tree_code, flags=re.MULTILINE)
return attack_tree_code
# Function to get attack tree from Ollama hosted LLM.
def get_attack_tree_ollama(ollama_endpoint, ollama_model, prompt):
url = ollama_endpoint + "/chat"
data = {
"model": ollama_model,
"stream": False,
"messages": [
{
"role": "system",
"content": """
Act as a cyber security expert with more than 20 years experience of using the STRIDE threat modelling methodology to produce comprehensive threat models for a wide range of applications. Your task is to use the application description provided to you to produce an attack tree in Mermaid syntax. The attack tree should reflect the potential threats for the application based on the details given.
You MUST only respond with the Mermaid code block. See below for a simple example of the required format and syntax for your output.
```mermaid
graph TD
A[Enter Chart Definition] --> B(Preview)
B --> C{{decide}}
C --> D["Keep"]
C --> E["Edit Definition (Edit)"]
E --> B
D --> F["Save Image and Code"]
F --> B
```
IMPORTANT: Round brackets are special characters in Mermaid syntax. If you want to use round brackets inside a node label you MUST wrap the label in double quotes. For example, ["Example Node Label (ENL)"].
"""},
{
"role": "user",
"content": prompt
}
]
}
response = requests.post(url, json=data)
outer_json = response.json()
# Access the 'content' attribute of the 'message' dictionary
attack_tree_code = outer_json["message"]["content"]
# Remove Markdown code block delimiters using regular expression
attack_tree_code = re.sub(r'^```mermaid\s*|\s*```$', '', attack_tree_code, flags=re.MULTILINE)
return attack_tree_code