Skip to content

Commit

Permalink
Koissi's critical server fix, shrink image fix, huge logs fix, securi…
Browse files Browse the repository at this point in the history
…ty improvements

Koissi's critical server fix, shrink image fix, huge logs fix, security improvements
  • Loading branch information
KillianLucas authored Jul 29, 2024
2 parents 9124d2c + e0d78fa commit caab8ea
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 120 deletions.
23 changes: 11 additions & 12 deletions interpreter/core/async_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def respond(self, run_code=None):
if "end" in chunk:
print("\n```\n\n------------\n\n", flush=True)
if chunk.get("format") != "active_line":
print(chunk.get("content", ""), end="", flush=True)
if "format" in chunk and "base64" in chunk["format"]:
print("\n[An image was produced]")
else:
print(chunk.get("content", ""), end="", flush=True)

self.output_queue.sync_q.put(chunk)

Expand Down Expand Up @@ -700,17 +703,11 @@ async def chat_completion(request: ChatCompletionRequest):
return router


host = os.getenv(
"HOST", "127.0.0.1"
) # IP address for localhost, used for local testing. To expose to local network, use 0.0.0.0
port = int(os.getenv("PORT", 8000)) # Default port is 8000

# FOR TESTING ONLY
# host = "0.0.0.0"


class Server:
def __init__(self, async_interpreter, host="127.0.0.1", port=8000):
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 8000

def __init__(self, async_interpreter, host=None, port=None):
self.app = FastAPI()
router = create_router(async_interpreter)
self.authenticate = authenticate_function
Expand All @@ -729,7 +726,9 @@ async def validate_api_key(request: Request, call_next):
)

self.app.include_router(router)
self.config = uvicorn.Config(app=self.app, host=host, port=port)
h = host or os.getenv("HOST", Server.DEFAULT_HOST)
p = port or int(os.getenv("PORT", Server.DEFAULT_PORT))
self.config = uvicorn.Config(app=self.app, host=h, port=p)
self.uvicorn_server = uvicorn.Server(self.config)

@property
Expand Down
2 changes: 1 addition & 1 deletion interpreter/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
debug=False,
max_output=2800,
safe_mode="off",
shrink_images=False,
shrink_images=True,
loop=False,
loop_message="""Proceed. You CAN run code on my machine. If the entire task I asked for is done, say exactly 'The task is done.' If you need some specific information (like username or password) say EXACTLY 'Please provide more information.' If it's impossible, say 'The task is impossible.' (If I haven't provided a task, say exactly 'Let me know what you'd like to do next.') Otherwise keep going.""",
loop_breakers=[
Expand Down
152 changes: 55 additions & 97 deletions interpreter/core/llm/utils/convert_to_openai_messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
import io
import json
import sys

from PIL import Image

Expand Down Expand Up @@ -102,16 +103,10 @@ def convert_to_openai_messages(
new_message["role"] = "user"
new_message["content"] = content
elif interpreter.code_output_sender == "assistant":
if "@@@SEND_MESSAGE_AS_USER@@@" in message["content"]:
new_message["role"] = "user"
new_message["content"] = message["content"].replace(
"@@@SEND_MESSAGE_AS_USER@@@", ""
)
else:
new_message["role"] = "assistant"
new_message["content"] = (
"\n```output\n" + message["content"] + "\n```"
)
new_message["role"] = "assistant"
new_message["content"] = (
"\n```output\n" + message["content"] + "\n```"
)

elif message["type"] == "image":
if message.get("format") == "description":
Expand All @@ -129,95 +124,18 @@ def convert_to_openai_messages(
else:
extension = "png"

# Construct the content string
content = f"data:image/{extension};base64,{message['content']}"

if shrink_images:
try:
# Decode the base64 image
img_data = base64.b64decode(message["content"])
img = Image.open(io.BytesIO(img_data))

# Resize the image if it's width is more than 1024
if img.width > 1024:
new_height = int(img.height * 1024 / img.width)
img = img.resize((1024, new_height))

# Convert the image back to base64
buffered = io.BytesIO()
img.save(buffered, format=extension)
img_str = base64.b64encode(buffered.getvalue()).decode(
"utf-8"
)
content = f"data:image/{extension};base64,{img_str}"
except:
# This should be non blocking. It's not required
# print("Failed to shrink image. Proceeding with original image size.")
pass

# Must be less than 5mb
# Calculate the size of the original binary data in bytes
content_size_bytes = len(message["content"]) * 3 / 4

# Convert the size to MB
content_size_mb = content_size_bytes / (1024 * 1024)

# If the content size is greater than 5 MB, resize the image
if content_size_mb > 5:
try:
# Decode the base64 image
img_data = base64.b64decode(message["content"])
img = Image.open(io.BytesIO(img_data))

# Calculate the size of the original binary data in bytes
content_size_bytes = len(img_data)

# Convert the size to MB
content_size_mb = content_size_bytes / (1024 * 1024)

# Run in a loop to make SURE it's less than 5mb
while content_size_mb > 5:
# Calculate the scale factor needed to reduce the image size to 5 MB
scale_factor = (5 / content_size_mb) ** 0.5

# Calculate the new dimensions
new_width = int(img.width * scale_factor)
new_height = int(img.height * scale_factor)

# Resize the image
img = img.resize((new_width, new_height))

# Convert the image back to base64
buffered = io.BytesIO()
img.save(buffered, format=extension)
img_str = base64.b64encode(buffered.getvalue()).decode(
"utf-8"
)

# Set the content
content = f"data:image/{extension};base64,{img_str}"

# Recalculate the size of the content in bytes
content_size_bytes = len(content) * 3 / 4

# Convert the size to MB
content_size_mb = content_size_bytes / (1024 * 1024)
except:
# This should be non blocking. It's not required
# print("Failed to shrink image. Proceeding with original image size.")
pass
encoded_string = message["content"]

elif message["format"] == "path":
# Convert to base64
image_path = message["content"]
file_extension = image_path.split(".")[-1]
extension = image_path.split(".")[-1]

with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode(
"utf-8"
)

content = f"data:image/{file_extension};base64,{encoded_string}"
else:
# Probably would be better to move this to a validation pass
# Near core, through the whole messages object
Expand All @@ -228,17 +146,57 @@ def convert_to_openai_messages(
f"Unrecognized image format: {message['format']}"
)

# Calculate the size of the original binary data in bytes
content_size_bytes = len(content) * 3 / 4
content = f"data:image/{extension};base64,{encoded_string}"

if shrink_images:
# Shrink to less than 5mb

# Calculate size
content_size_bytes = sys.getsizeof(str(content))

# Convert the size to MB
content_size_mb = content_size_bytes / (1024 * 1024)

# If the content size is greater than 5 MB, resize the image
if content_size_mb > 5:
# Decode the base64 image
img_data = base64.b64decode(encoded_string)
img = Image.open(io.BytesIO(img_data))

# Run in a loop to make SURE it's less than 5mb
for _ in range(10):
# Calculate the scale factor needed to reduce the image size to 4.9 MB
scale_factor = (4.9 / content_size_mb) ** 0.5

# Calculate the new dimensions
new_width = int(img.width * scale_factor)
new_height = int(img.height * scale_factor)

# Resize the image
img = img.resize((new_width, new_height))

# Convert the image back to base64
buffered = io.BytesIO()
img.save(buffered, format=extension)
encoded_string = base64.b64encode(
buffered.getvalue()
).decode("utf-8")

# Set the content
content = f"data:image/{extension};base64,{encoded_string}"

# Convert the size to MB
content_size_mb = content_size_bytes / (1024 * 1024)
# Recalculate the size of the content in bytes
content_size_bytes = sys.getsizeof(str(content))

# Print the size of the content in MB
# print(f"File size: {content_size_mb} MB")
# Convert the size to MB
content_size_mb = content_size_bytes / (1024 * 1024)

# Assert that the content size is under 20 MB
assert content_size_mb < 20, "Content size exceeds 20 MB"
if content_size_mb < 5:
break
else:
print(
"Attempted to shrink the image but failed. Sending to the LLM anyway."
)

new_message = {
"role": "user",
Expand Down
27 changes: 20 additions & 7 deletions interpreter/core/respond.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def respond(interpreter):

### RUN THE LLM ###

assert (
len(interpreter.messages) > 0
), "User message was not passed in. You need to pass in at least one message."

if (
interpreter.messages[-1]["type"] != "code"
): # If it is, we should run the code (we do below)
Expand Down Expand Up @@ -292,15 +296,16 @@ def respond(interpreter):
computer_dict = interpreter.computer.to_dict()
if "_hashes" in computer_dict:
computer_dict.pop("_hashes")
if computer_dict:
computer_json = json.dumps(computer_dict)
sync_code = f"""import json\ncomputer.load_dict(json.loads('''{computer_json}'''))"""
interpreter.computer.run("python", sync_code)
if "system_message" in computer_dict:
computer_dict.pop("system_message")
computer_json = json.dumps(computer_dict)
sync_code = f"""import json\ncomputer.load_dict(json.loads('''{computer_json}'''))"""
interpreter.computer.run("python", sync_code)
except Exception as e:
if interpreter.debug:
raise
print(str(e))
print("Continuing...")
print("Failed to sync iComputer with your Computer. Continuing...")

## ↓ CODE IS RUN HERE

Expand All @@ -315,7 +320,15 @@ def respond(interpreter):
# sync up the interpreter's computer with your computer
result = interpreter.computer.run(
"python",
"import json\ncomputer_dict = computer.to_dict()\nif computer_dict:\n if '_hashes' in computer_dict:\n computer_dict.pop('_hashes')\n print(json.dumps(computer_dict))",
"""
import json
computer_dict = computer.to_dict()
if '_hashes' in computer_dict:
computer_dict.pop('_hashes')
if "system_message" in computer_dict:
computer_dict.pop("system_message")
print(json.dumps(computer_dict))
""",
)
result = result[-1]["content"]
interpreter.computer.load_dict(
Expand All @@ -325,7 +338,7 @@ def respond(interpreter):
if interpreter.debug:
raise
print(str(e))
print("Continuing.")
print("Failed to sync your Computer with iComputer. Continuing.")

# yield final "active_line" message, as if to say, no more code is running. unlightlight active lines
# (is this a good idea? is this our responsibility? i think so — we're saying what line of code is running! ...?)
Expand Down
4 changes: 2 additions & 2 deletions interpreter/core/utils/truncate_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
def truncate_output(data, max_output_chars=2800, add_scrollbars=False):
if "@@@DO_NOT_TRUNCATE@@@" in data:
return data
# if "@@@DO_NOT_TRUNCATE@@@" in data:
# return data

needs_truncation = False

Expand Down
1 change: 0 additions & 1 deletion interpreter/terminal_interface/profiles/defaults/os.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

interpreter.os = True
interpreter.llm.supports_vision = True
# interpreter.shrink_images = True # Faster but less accurate

interpreter.llm.model = "gpt-4o"

Expand Down
48 changes: 48 additions & 0 deletions tests/core/test_async_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
from unittest import TestCase, mock

from interpreter.core.async_core import Server, AsyncInterpreter


class TestServerConstruction(TestCase):
"""
Tests to make sure that the underlying server is configured correctly when constructing
the Server object.
"""

def test_host_and_port_defaults(self):
"""
Tests that a Server object takes on the default host and port when
a) no host and port are passed in, and
b) no HOST and PORT are set.
"""
with mock.patch.dict(os.environ, {}):
s = Server(AsyncInterpreter())
self.assertEqual(s.host, Server.DEFAULT_HOST)
self.assertEqual(s.port, Server.DEFAULT_PORT)

def test_host_and_port_passed_in(self):
"""
Tests that a Server object takes on the passed-in host and port when they are passed-in,
ignoring the surrounding HOST and PORT env vars.
"""
host = "the-really-real-host"
port = 2222

with mock.patch.dict(os.environ, {"HOST": "this-is-supes-fake", "PORT": "9876"}):
sboth = Server(AsyncInterpreter(), host, port)
self.assertEqual(sboth.host, host)
self.assertEqual(sboth.port, port)

def test_host_and_port_from_env_1(self):
"""
Tests that the Server object takes on the HOST and PORT env vars as host and port when
nothing has been passed in.
"""
fake_host = "fake_host"
fake_port = 1234

with mock.patch.dict(os.environ, {"HOST": fake_host, "PORT": str(fake_port)}):
s = Server(AsyncInterpreter())
self.assertEqual(s.host, fake_host)
self.assertEqual(s.port, fake_port)

0 comments on commit caab8ea

Please sign in to comment.