Skip to content

Commit

Permalink
Security improvements (#619)
Browse files Browse the repository at this point in the history
* Improve security

* Fix format
  • Loading branch information
oeway authored Jul 4, 2024
1 parent f728c7d commit 55a407e
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 56 deletions.
111 changes: 60 additions & 51 deletions docs/hypha_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
from urllib.parse import parse_qs


class HyphaDataStore:
def __init__(self):
self.storage = {}
Expand All @@ -12,15 +13,15 @@ def __init__(self):

async def setup(self, server, service_id="data-store", visibility="public"):
self._server = server
self._svc = await server.register_service({
"id": service_id,
"type": "functions",
"config": {
"visibility": visibility,
"require_context": False
self._svc = await server.register_service(
{
"id": service_id,
"type": "functions",
"config": {"visibility": visibility, "require_context": False},
"get": self.http_get,
},
"get": self.http_get,
}, overwrite=True)
overwrite=True,
)

def get_url(self, obj_id: str):
assert self._svc, "Service not initialized, call `setup()`"
Expand All @@ -30,30 +31,30 @@ def get_url(self, obj_id: str):
def put(self, obj_type: str, value: any, name: str, comment: str = ""):
assert self._svc, "Please call `setup()` before using the store"
obj_id = str(uuid.uuid4())
if obj_type == 'file':
if obj_type == "file":
data = value
assert isinstance(data, (str, bytes)), "Value must be a string or bytes"
if isinstance(data, str) and data.startswith("file://"):
# File URL examples:
# Absolute URL: `file:///home/data/myfile.png`
# Relative URL: `file://./myimage.png`, or `file://myimage.png`
with open(data.replace("file://", ""), 'rb') as fil:
with open(data.replace("file://", ""), "rb") as fil:
data = fil.read()
mime_type, _ = mimetypes.guess_type(name)
self.storage[obj_id] = {
'type': obj_type,
'name': name,
'value': data,
'mime_type': mime_type or 'application/octet-stream',
'comment': comment
"type": obj_type,
"name": name,
"value": data,
"mime_type": mime_type or "application/octet-stream",
"comment": comment,
}
else:
self.storage[obj_id] = {
'type': obj_type,
'name': name,
'value': value,
'mime_type': 'application/json',
'comment': comment
"type": obj_type,
"name": name,
"value": value,
"mime_type": "application/json",
"comment": comment,
}
return obj_id

Expand All @@ -63,47 +64,51 @@ def get(self, id: str):
return obj

def http_get(self, scope, context=None):
query_string = scope['query_string']
id = parse_qs(query_string).get('id', [])[0]
query_string = scope["query_string"]
id = parse_qs(query_string).get("id", [])[0]
obj = self.storage.get(id)
if obj is None:
return {'status': 404, 'headers': {}, 'body': "Not found: " + id}
return {"status": 404, "headers": {}, "body": "Not found: " + id}

if obj['type'] == 'file':
data = obj['value']
if obj["type"] == "file":
data = obj["value"]
if isinstance(data, str):
if not os.path.isfile(data):
return {
"status": 404,
'headers': {'Content-Type': 'text/plain'},
"body": "File not found: " + data
"headers": {"Content-Type": "text/plain"},
"body": "File not found: " + data,
}
with open(data, 'rb') as fil:
with open(data, "rb") as fil:
data = fil.read()
headers = {
'Content-Type': obj['mime_type'],
'Content-Length': str(len(obj['value'])),
'Content-Disposition': f'inline; filename="{obj["name"].split("/")[-1]}"'
}

return {
'status': 200,
'headers': headers,
'body': obj['value']
"Content-Type": obj["mime_type"],
"Content-Length": str(len(obj["value"])),
"Content-Disposition": f'inline; filename="{obj["name"].split("/")[-1]}"',
}

return {"status": 200, "headers": headers, "body": obj["value"]}
else:
return {
'status': 200,
'headers': {'Content-Type': 'application/json'},
'body': json.dumps(obj['value'])
"status": 200,
"headers": {"Content-Type": "application/json"},
"body": json.dumps(obj["value"]),
}

def http_list(self, scope, context=None):
query_string = scope.get('query_string', b'')
kws = parse_qs(query_string).get('keyword', [])
query_string = scope.get("query_string", b"")
kws = parse_qs(query_string).get("keyword", [])
keyword = kws[0] if kws else None
result = [value for key, value in self.storage.items() if not keyword or keyword in value['name']]
return {'status': 200, 'headers': {'Content-Type': 'application/json'}, 'body': json.dumps(result)}
result = [
value
for key, value in self.storage.items()
if not keyword or keyword in value["name"]
]
return {
"status": 200,
"headers": {"Content-Type": "application/json"},
"body": json.dumps(result),
}

def remove(self, obj_id: str):
assert self._svc, "Please call `setup()` before using the store"
Expand All @@ -112,8 +117,10 @@ def remove(self, obj_id: str):
return True
raise IndexError("Not found: " + obj_id)


async def test_data_store(server_url="https://ai.imjoy.io"):
from imjoy_rpc.hypha import connect_to_server, login

token = await login({"server_url": server_url})
server = await connect_to_server({"server_url": server_url, "token": token})

Expand All @@ -122,20 +129,22 @@ async def test_data_store(server_url="https://ai.imjoy.io"):
await ds.setup(server)

# Test PUT operation
file_id = ds.put('file', 'file:///home/data.txt', 'data.txt')
binary_id = ds.put('file', b'Some binary content', 'example.bin')
json_id = ds.put('json', {'hello': 'world'}, 'example.json')
file_id = ds.put("file", "file:///home/data.txt", "data.txt")
binary_id = ds.put("file", b"Some binary content", "example.bin")
json_id = ds.put("json", {"hello": "world"}, "example.json")

# Test GET operation
assert ds.get(file_id)['type'] == 'file'
assert ds.get(binary_id)['type'] == 'file'
assert ds.get(json_id)['type'] == 'json'
assert ds.get(file_id)["type"] == "file"
assert ds.get(binary_id)["type"] == "file"
assert ds.get(json_id)["type"] == "json"

# Test GET URL generation
print("URL for getting file", ds.get_url(file_id))
print("URL for getting binary object", ds.get_url(binary_id))
print("URL for getting json object", ds.get_url(json_id))


if __name__ == "__main__":
import asyncio
asyncio.run(test_data_store())

asyncio.run(test_data_store())
6 changes: 6 additions & 0 deletions hypha/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,10 @@ async def save_file(key, content):
"name" in att and "source" in att
), "Attachment should contain `name` and `source`"
if att["source"].startswith("http") and "\n" not in att["source"]:
if not att["source"].startswith("https://"):
raise Exception(
"Only https sources are allowed: " + att["source"]
)
with urlopen(att["source"]) as stream:
output = stream.read()
att["source"] = output
Expand Down Expand Up @@ -395,6 +399,8 @@ async def install(
)

if source.startswith("http"):
if not source.startswith("https://"):
raise Exception("Only https sources are allowed: " + source)
with urlopen(source) as stream:
output = stream.read()
source = output.decode("utf-8")
Expand Down
2 changes: 1 addition & 1 deletion hypha/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
user_info = user_info.model_dump()
else:
user_info = None
scope['context'] = {"user": user_info, "_rkwargs": True}
scope["context"] = {"user": user_info, "_rkwargs": True}
result = await func(scope)
headers = Headers(headers=result.get("headers"))
body = result.get("body")
Expand Down
5 changes: 2 additions & 3 deletions hypha/core/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

if not JWT_SECRET:
logger.warning("JWT_SECRET is not defined")
JWT_SECRET = str(shortuuid.uuid())
JWT_SECRET = shortuuid.ShortUUID().random(length=22)


class AuthError(Exception):
Expand Down Expand Up @@ -135,8 +135,7 @@ def get_rsa_key(kid, refresh=False):
if JWKS is None or refresh:
with urlopen(
f"https://{AUTH0_DOMAIN}/.well-known/jwks.json",
# pylint: disable=protected-access
context=ssl._create_unverified_context(),
context=ssl._create_default_https_context(),
) as jsonurl:
JWKS = json.loads(jsonurl.read())
rsa_key = {}
Expand Down
4 changes: 3 additions & 1 deletion hypha/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ async def disconnect(code):
try:
user_info, ws, cid = parse_reconnection_token(reconnection_token)
except jwt.JWTError as err:
logger.error("Invalid reconnection token: %s", {reconnection_token[:5]})
logger.error(
"Invalid reconnection token: %s", {reconnection_token[:5]}
)
await disconnect(code=status.WS_1003_UNSUPPORTED_DATA)
return
if not await store.disconnected_client_exists(f"{ws}/{cid}"):
Expand Down

0 comments on commit 55a407e

Please sign in to comment.