Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LM Studio Server option for a custom OpenAI compatible endpoint #37

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions attack_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,43 @@ def get_attack_tree_mistral(mistral_api_key, mistral_model, prompt):
F --> B
```

IMPORTANT: Round brackets are special characters in Mermaid syntax. If you want to use round brackets inside a node label you MUST wrap the label in double quotes. For example, ["Example Node Label (ENL)"].
"""},
{"role": "user", "content": prompt}
]
)

# Access the 'content' attribute of the 'message' object directly
attack_tree_code = response.choices[0].message.content

# Remove Markdown code block delimiters using regular expression
attack_tree_code = re.sub(r'^```mermaid\s*|\s*```$', '', attack_tree_code, flags=re.MULTILINE)

return attack_tree_code

# Function to get attack tree from the GPT response.
def get_attack_tree_lmstudio(lmstudio_endpoint, model_name, prompt):
client = OpenAI(base_url=lmstudio_endpoint, api_key="lm-studio")

response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": """
Act as a cyber security expert with more than 20 years experience of using the STRIDE threat modelling methodology to produce comprehensive threat models for a wide range of applications. Your task is to use the application description provided to you to produce an attack tree in Mermaid syntax. The attack tree should reflect the potential threats for the application based on the details given.

You MUST only respond with the Mermaid code block. See below for a simple example of the required format and syntax for your output.

```mermaid
graph TD
A[Enter Chart Definition] --> B(Preview)
B --> C{{decide}}
C --> D["Keep"]
C --> E["Edit Definition (Edit)"]
E --> B
D --> F["Save Image and Code"]
F --> B
```

IMPORTANT: Round brackets are special characters in Mermaid syntax. If you want to use round brackets inside a node label you MUST wrap the label in double quotes. For example, ["Example Node Label (ENL)"].
"""},
{"role": "user", "content": prompt}
Expand Down
23 changes: 22 additions & 1 deletion dread.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,25 @@ def get_dread_assessment_mistral(mistral_api_key, mistral_model, prompt):
# Convert the JSON string in the 'content' field to a Python dictionary
response_content = json.loads(response.choices[0].message.content)

return response_content
return response_content

# Function to get DREAD risk assessment from the GPT response.
def get_dread_assessment_lmstudio(lmstudio_endpoint, model_name, prompt):
client = OpenAI(base_url=lmstudio_endpoint, api_key="lm-studio")
response = client.chat.completions.create(
model=model_name,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
{"role": "user", "content": prompt}
]
)

# Convert the JSON string in the 'content' field to a Python dictionary
try:
dread_assessment = json.loads(response.choices[0].message.content)
except json.JSONDecodeError as e:
st.write(f"JSON decoding error: {e}")
dread_assessment = {}

return dread_assessment
46 changes: 40 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import streamlit as st
import streamlit.components.v1 as components

from threat_model import create_threat_model_prompt, get_threat_model, get_threat_model_azure, get_threat_model_google, get_threat_model_mistral, json_to_markdown, get_image_analysis, create_image_analysis_prompt
from attack_tree import create_attack_tree_prompt, get_attack_tree, get_attack_tree_azure, get_attack_tree_mistral
from mitigations import create_mitigations_prompt, get_mitigations, get_mitigations_azure, get_mitigations_google, get_mitigations_mistral
from test_cases import create_test_cases_prompt, get_test_cases, get_test_cases_azure, get_test_cases_google, get_test_cases_mistral
from dread import create_dread_assessment_prompt, get_dread_assessment, get_dread_assessment_azure, get_dread_assessment_google, get_dread_assessment_mistral, dread_json_to_markdown
from threat_model import create_threat_model_prompt, get_threat_model, get_threat_model_azure, get_threat_model_google, get_threat_model_mistral, get_threat_model_lmstudio, json_to_markdown, get_image_analysis, create_image_analysis_prompt
from attack_tree import create_attack_tree_prompt, get_attack_tree, get_attack_tree_azure, get_attack_tree_mistral, get_attack_tree_lmstudio
from mitigations import create_mitigations_prompt, get_mitigations, get_mitigations_azure, get_mitigations_google, get_mitigations_mistral, get_mitigations_lmstudio
from test_cases import create_test_cases_prompt, get_test_cases, get_test_cases_azure, get_test_cases_google, get_test_cases_mistral, get_test_cases_lmstudio
from dread import create_dread_assessment_prompt, get_dread_assessment, get_dread_assessment_azure, get_dread_assessment_google, get_dread_assessment_mistral, get_dread_assessment_lmstudio, dread_json_to_markdown

# ------------------ Helper Functions ------------------ #

Expand Down Expand Up @@ -63,7 +63,7 @@ def mermaid(code: str, height: int = 500) -> None:
# Add model selection input field to the sidebar
model_provider = st.selectbox(
"Select your preferred model provider:",
["OpenAI API", "Azure OpenAI Service", "Google AI API", "Mistral API"],
["OpenAI API", "Azure OpenAI Service", "Google AI API", "Mistral API", "LMStudio Server"],
key="model_provider",
help="Select the model provider you would like to use. This will determine the models available for selection.",
)
Expand Down Expand Up @@ -167,6 +167,28 @@ def mermaid(code: str, height: int = 500) -> None:
["mistral-large-latest", "mistral-small-latest"],
key="selected_model",
)

if model_provider == "LMStudio Server":
st.markdown(
"""
1. Enter your LM Studio Server IP address 🔑
2. Provide details of the application that you would like to threat model 📝
3. Generate a threat list, attack tree and/or mitigating controls for your application 🚀
"""
)

# Add LM Studio endpoint input field to the sidebar
lmstudio_endpoint = st.text_input(
"Enter your LM Studio endpoint:",
help="In most cases this will be http://localhost:1234/v1/",
)

lmstudio_model = st.text_input(
"Enter your LM Studio model name:",
help="Check the LM Studio Server examples to see which model name is in use (Publisher/Repository)",
)



st.markdown("""---""")

Expand Down Expand Up @@ -374,6 +396,8 @@ def encode_image(uploaded_file):
model_output = get_threat_model_google(google_api_key, google_model, threat_model_prompt)
elif model_provider == "Mistral API":
model_output = get_threat_model_mistral(mistral_api_key, mistral_model, threat_model_prompt)
elif model_provider == "LMStudio Server":
model_output = get_threat_model_lmstudio(lmstudio_endpoint, lmstudio_model, threat_model_prompt)

# Access the threat model and improvement suggestions from the parsed content
threat_model = model_output.get("threat_model", [])
Expand Down Expand Up @@ -445,6 +469,9 @@ def encode_image(uploaded_file):
mermaid_code = get_attack_tree(openai_api_key, selected_model, attack_tree_prompt)
elif model_provider == "Mistral API":
mermaid_code = get_attack_tree_mistral(mistral_api_key, mistral_model, attack_tree_prompt)
elif model_provider == "LMStudio Server":
mermaid_code = get_attack_tree_lmstudio(lmstudio_endpoint, lmstudio_model, attack_tree_prompt)


# Display the generated attack tree code
st.write("Attack Tree Code:")
Expand Down Expand Up @@ -523,6 +550,8 @@ def encode_image(uploaded_file):
mitigations_markdown = get_mitigations_google(google_api_key, google_model, mitigations_prompt)
elif model_provider == "Mistral API":
mitigations_markdown = get_mitigations_mistral(mistral_api_key, mistral_model, mitigations_prompt)
elif model_provider == "LMStudio Server":
mitigations_markdown = get_mitigations_lmstudio(lmstudio_endpoint, lmstudio_model, mitigations_prompt)

# Display the suggested mitigations in Markdown
st.markdown(mitigations_markdown)
Expand Down Expand Up @@ -581,6 +610,9 @@ def encode_image(uploaded_file):
dread_assessment = get_dread_assessment_google(google_api_key, google_model, dread_assessment_prompt)
elif model_provider == "Mistral API":
dread_assessment = get_dread_assessment_mistral(mistral_api_key, mistral_model, dread_assessment_prompt)
elif model_provider == "LMStudio Server":
dread_assessment = get_dread_assessment_lmstudio(lmstudio_endpoint, lmstudio_model, dread_assessment_prompt)

# Save the DREAD assessment to the session state for later use in test cases
st.session_state['dread_assessment'] = dread_assessment
break # Exit the loop if successful
Expand Down Expand Up @@ -644,6 +676,8 @@ def encode_image(uploaded_file):
test_cases_markdown = get_test_cases_google(google_api_key, google_model, test_cases_prompt)
elif model_provider == "Mistral API":
test_cases_markdown = get_test_cases_mistral(mistral_api_key, mistral_model, test_cases_prompt)
elif model_provider == "LMStudio Server":
get_test_cases_lmstudio(lmstudio_endpoint, lmstudio_model, test_cases_prompt)

# Display the suggested mitigations in Markdown
st.markdown(test_cases_markdown)
Expand Down
18 changes: 18 additions & 0 deletions mitigations.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,22 @@ def get_mitigations_mistral(mistral_api_key, mistral_model, prompt):
# Access the content directly as the response will be in text format
mitigations = response.choices[0].message.content

return mitigations


# Function to get mitigations from the GPT response.
def get_mitigations_lmstudio(lmstudio_endpoint, model_name, prompt):
client = OpenAI(base_url=lmstudio_endpoint, api_key="lm-studio")

response = client.chat.completions.create(
model = model_name,
messages=[
{"role": "system", "content": "You are a helpful assistant that provides threat mitigation strategies in Markdown format."},
{"role": "user", "content": prompt}
]
)

# Access the content directly as the response will be in text format
mitigations = response.choices[0].message.content

return mitigations
17 changes: 17 additions & 0 deletions test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,21 @@ def get_test_cases_mistral(mistral_api_key, mistral_model, prompt):
# Access the content directly as the response will be in text format
test_cases = response.choices[0].message.content

return test_cases

# Function to get test cases from the GPT response.
def get_test_cases_lmstudio(lmstudio_endpoint, model_name, prompt):
client = OpenAI(base_url=lmstudio_endpoint, api_key="lm-studio")

response = client.chat.completions.create(
model = model_name,
messages=[
{"role": "system", "content": "You are a helpful assistant that provides Gherkin test cases in Markdown format."},
{"role": "user", "content": prompt}
]
)

# Access the content directly as the response will be in text format
test_cases = response.choices[0].message.content

return test_cases
18 changes: 18 additions & 0 deletions threat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,4 +214,22 @@ def get_threat_model_mistral(mistral_api_key, mistral_model, prompt):
# Convert the JSON string in the 'content' field to a Python dictionary
response_content = json.loads(response.choices[0].message.content)

return response_content

def get_threat_model_lmstudio(lmstudio_endpoint, model_name, prompt):
client = OpenAI(base_url=lmstudio_endpoint, api_key="lm-studio")

response = client.chat.completions.create(
model=model_name,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
{"role": "user", "content": prompt}
],
max_tokens=-1,
)

# Convert the JSON string in the 'content' field to a Python dictionary
response_content = json.loads(response.choices[0].message.content)

return response_content