Skip to content

Commit

Permalink
Merge pull request #20 from AthenaFoundation/thread_based_tcp_sockets
Browse files Browse the repository at this point in the history
Thread-based TCP sockets
  • Loading branch information
konstantine4096 authored Nov 12, 2024
2 parents a2d6435 + 212aa71 commit c7b5654
Show file tree
Hide file tree
Showing 14 changed files with 868 additions and 57 deletions.
10 changes: 6 additions & 4 deletions athena.mlb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ local
$(SML_LIB)/mlyacc-lib/mlyacc-lib.mlb
$(SML_LIB)/cml/cml.mlb
$(SML_LIB)/basis/mlton.mlb
$(SML_LIB)/smlnj-lib/INet/inet-lib.mlb
in
compat_11072.sml
base.sig
Expand Down Expand Up @@ -86,11 +87,12 @@ in
topenv_part1.sml
topenv_part2.sml
definition_processor.sml
(**
prolog_solver.sml
**)
repl.sml
athena.sml
mlton_server.sml
sml_c_util.sml
(***
xsb_prolog_solver.sml
***)
athena.sml
mlton_main.sml
end
2 changes: 1 addition & 1 deletion athena.sml
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ fun main(arg0,args) =
| [file_name] => M(SOME(file_name),false)
| file_name::(_::_) => M(SOME(file_name),true))
end

end
2 changes: 1 addition & 1 deletion base.sml
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ fun repeat n f =
val (newline,lparen,rparen,lbrack,rbrack,lbrace,rbrace,
blank,comma,period,colon,semi_colon,string_quote) = ("\n","(",")","[","]","{","}"," ",",",".",":",";","\"")

fun mark(s) = (print("\n");repeat 3 (fn _ => print(s));print("\n"))
fun mark(s) = (print("\n");repeat 10 (fn _ => print(s));print("\n"))

fun failLst(messages) = raise FailLst messages

Expand Down
111 changes: 111 additions & 0 deletions client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import sys
import socket
import sys
import json
from utils import *

def extract_response(sock):
chunks = []
while True:
chunk = sock.recv(40960)
if not chunk: # Connection closed by server
break
chunks.append(chunk)
return b''.join(chunks).decode('utf-8')


# Note: send_request_to_server_simple should only be used if the corresponding server uses readAllSimple.

def send_request_to_server_simple(request: str, port=10000) -> str:
server_address = 'localhost'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(4) # Set a 5-second timeout for demonstration
try:
sock.connect((server_address, port))
sock.sendall(request.encode('utf-8'))
sock.sendall(''.encode('utf-8'))
sock.shutdown(socket.SHUT_WR)
return extract_response(sock)
except socket.timeout:
print("Connection timed out.")
return ""
except Exception as e:
print(f"An error occurred: {e}")
return ""

# The default implementation of send_request_to_server encodes the size of the
# Athena payload into the first 4 bytes of the request (thus allowing payloads up to 4GB).
# This must be used in conjunction with readAll on the server side, which first extracts
# the leading 4 bytes of the client's request, transforms those into the integer value N
# they represent, and then reads exactly N bytes from the connection.

def send_request_to_server(request: str, port=10000) -> str:
server_address = 'localhost'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(4)
try:
sock.connect((server_address, port))
# Send length as 4-byte integer first
request_bytes = request.encode('utf-8')
length = len(request_bytes)
sock.sendall(length.to_bytes(4, byteorder='big'))
sock.sendall(request_bytes)
sock.shutdown(socket.SHUT_WR)
return extract_response(sock)
except socket.timeout:
print("Connection timed out.")
return ""
except Exception as e:
print(f"An error occurred: {e}")
return ""

def spaces(i):
if i < 1:
return ""
else:
return " " + spaces(i-1)

def checkProof(line):
#
# This function takes a string representing a line from a .jsonl file like gpt_and_english_proofs_230.jsonl
# and checks whether or not the formal proof that can be found in that line is valid. If it is, it returns
# a pair of the form (True,<theorem>), where the string <theorem> is a formula representing the conclusion of the proof.
# If the proof is not valid, then the result is a pair of the form (False,<error message>), where the string <error message>
# provides more detail on where exactly the proof went wrong, and what exactly was wrong (a syntax error, a logic error, and if
# so, the type of error, etc.).
#
D = json.loads(line)
premises = D['premises']
goal = D['goal']
assumes = '\n'.join([spaces(index) + "assume premise-" + str(index+1) + " := " + premises[index] for index in range(len(premises))])
proof = assumes + "\nconclude " + goal + "\n" + D['ndlProof']
athena_response = send_request_to_server(proof)
athena_response_lower = athena_response.lower()
if 'error' in athena_response_lower or 'failed' in athena_response_lower:
return (False,athena_response)
else:
return (True,athena_response)

def checkAll(file):
#
# This function works with an entire jsonl file like gpt_and_english_proofs_230.jsonl. It basically
# iterates the single-line function checkProof over all the contents of the jsonl file. The result
# is a list of pairs (boolean_flag,<details>), as produced by checkProof, one for each line in the jsonl file.
#
L = getLines(file)
R = []
send_request_to_server('declare A, B, C, D, E, F, G, H, J, I, K: Boolean')
for i in range(len(L)):
print("About to work on proof #" + str(i))
response = checkProof(L[i])
R.append(response)
return R

# Example use (where file is a path like "../data/gpt_generated_athena_and_english_proofs_raw_data_230.ath"):
# R = checkAll(file)
# This will give all successful/valid proofs:
# T = [r for r in R if r[0]]
# And this will give all incorrect proofs:
# F = [r for r in R if not(r[0])]
1 change: 1 addition & 0 deletions make_athena_binary_for_linux
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mlton @MLton gc-summary -- -output athena -verbose 1 -default-ann 'allowFFI true' -drop-pass deepFlatten -codegen native athena.mlb
1 change: 1 addition & 0 deletions make_athena_binary_for_mac
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mlton @MLton gc-summary -- -output athena -verbose 1 -default-ann 'allowFFI true' -drop-pass deepFlatten athena.mlb
56 changes: 55 additions & 1 deletion mlton_main.sml
Original file line number Diff line number Diff line change
@@ -1 +1,55 @@
val _ = Athena.main("",CommandLine.arguments())
(*============================================================================================================
Assuming that the executable produced by running make_athena_binary_for_linux or make_athena_binary_for_mac
is named "athena", you can run the system in a number of ways:
./athena -> Start the Athena REPL
./athena foo.ath -> Load foo.ath first and then start the Athena REPL
./athena foo.ath quit -> Load foo.ath and then quit
./athena -port <number> -> Start an Athena TCP server running on port <number>
./athena -port <number> -file foo.ath -> Load foo.ath first and then start an Athena TCP server on port <number>
============================================================================================================*)


fun mlton_main(arg0,args) =
let fun M(file_name_option:string option,quit_after) =
(print("\nWelcome to Athena!\n");
print("\nType an expression or deduction at the\nprompt below, ");
print("and press Enter to evaluate it.\n");
print("\nTo exit Athena, type \"quit\" at the prompt\nand press Enter.\n");
if quit_after then
Athena.runWithStarterFileAndQuit(file_name_option)
else
Athena.runWithStarterFile(file_name_option);
OS.Process.success)
(**
val i = initializeXSB()
**)
in
(case args of
[] => M(NONE,false)
| [file_name] => M(SOME(file_name),false)
| [arg_1,arg_2] => if arg_1 = "-port" orelse arg_1 = "--port" then
let val port_num_opt = Int.fromString(arg_2)
in
(case port_num_opt of
SOME(port_num) => (Thread.spawn(fn () => AthenaServer.startServer(port_num,NONE)); Thread.run(); OS.Process.success)
| _ => (print("\nInvalid port number."); OS.Process.failure))
end
else M(SOME(arg_1),true)
| [arg_1,arg_2,arg_3,arg_4] =>
if (arg_1 = "-port" orelse arg_1 = "--port") andalso (arg_3 = "--file" orelse arg_3 = "-file") then
let val port_num_opt = Int.fromString(arg_2)
val starter_file = arg_4
in
(case port_num_opt of
SOME(port_num) => (Thread.spawn(fn () => AthenaServer.startServer(port_num,SOME(starter_file))); Thread.run(); OS.Process.success)
| _ => (print("\nInvalid port number."); OS.Process.failure))
end
else
M(SOME(arg_1),true)
| file_name::(_::_) => M(SOME(file_name),true))
end

val _ = mlton_main("",CommandLine.arguments())
33 changes: 12 additions & 21 deletions repl.sml
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,6 @@ fun printLoadedFiles(loaded_files : (string,bool) HashTable.hash_table) =
end

fun debugPrint(_) = ()

fun exceptionToString(e) =
let fun f(ErrorMsg.ParseError((l,p),str)) = ("\n"^A.posToString({line=l,file=(!Paths.current_file),pos=p})^": Parsing error, "^str^".\n")
| f(A.LexError(str,SOME(pos))) = ("\n"^A.posToString(pos)^": Lexical error, "^str^".\n")
| f(A.LexError(str,NONE)) = ((!Paths.current_file)^": Lexical error at end of file, "^str^".\n")
| f(A.SyntaxError(str,SOME(pos))) = ("\n"^(A.posToString pos)^": Syntax error: "^str^".\n")
| f(A.SyntaxError(str,NONE)) = ("\n"^(!Paths.current_file)^": Syntax error: "^str^".\n")
| f(Semantics.AthenaError(msg)) = ("\n"^msg^"\n")
| f(Semantics.EvalError(x)) = Semantics.makeErrorWithPosInfo(x)
| f(Data.GenEx(str)) = str^"\n"
| f(SemanticValues.GenEx(x as (msg,pos_opt))) = Semantics.makeErrorWithPosInfo(x)
| f(Basic.Fail(str)) = "\n"^str^"\n"
| f(Basic.FailLst(strings)) = "\n"^(Basic.printListStr(strings,fn x => x, "\n"))^"\n"
| f(_) = "\nUnknown error: "^(exnMessage e)
in
f e
end

fun showFreeIds(phr,mod_path) =
let val (new_phrase,vars,fids) = preProcessPhrase(phr,mod_path)
Expand Down Expand Up @@ -140,7 +123,7 @@ in
print(Semantics.summarizeTopCallStack()))
end
| TopEnv.Halt => ()
| _ => print(exceptionToString(e)))
| _ => print(Semantics.exceptionToString(e)))
end

fun pathToString(path) = if null(path) then "[]" else Basic.printListStr(path,Symbol.name,".")
Expand Down Expand Up @@ -337,7 +320,7 @@ and processModuleExtension(module:A.module_entry as {module_name,module_contents
val _ = returned_env := SV.valEnv({val_map=val_map1',mod_map=Symbol.enter(mod_map1,mod_sym,new_module)})
in
eval_env := SV.valEnv({val_map=val_map2,mod_map=Symbol.enter(mod_map2,mod_sym,new_module)})
end handle ex => (error_msg := exceptionToString(ex);
end handle ex => (error_msg := Semantics.exceptionToString(ex);
Paths.open_mod_paths := starting_open_mod_paths_val;
Paths.open_mod_directives := starting_open_mod_directives_val;
eval_env := starting_eval_env;
Expand Down Expand Up @@ -408,7 +391,7 @@ and processModule(module:A.module_entry as {module_name,module_contents,module_f
else ()
in
()
end) handle ex => (error_msg := exceptionToString(ex);
end) handle ex => (error_msg := Semantics.exceptionToString(ex);
Paths.open_mod_paths := starting_open_mod_paths_val;
Paths.open_mod_directives := starting_open_mod_directives_val;
eval_env := starting_eval_env;
Expand Down Expand Up @@ -795,7 +778,7 @@ fun getInputAndProcess() =
((Parse.parse_from_stream istream),true,"")
handle e => let val _ = Parse.setLinePos(1,0)
in
([],false,exceptionToString(e))
([],false,Semantics.exceptionToString(e))
end
in
if ok_input then
Expand Down Expand Up @@ -830,6 +813,7 @@ fun escape(str) =
loop(L,[])
end

(**** qqq PROCESSSTRING DEFINITION ****)
fun processString(cmd,mod_path,env,eval_env) =
let val stream = TextIO.openString (cmd)
val inputs = Parse.parse_from_stream(stream)
Expand All @@ -840,6 +824,13 @@ fun processString(cmd,mod_path,env,eval_env) =

val _ = (Semantics.processString := processString)

fun processAlreadyParsedInputs(inputs,mod_path,env,eval_env) =
let val _ = List.app (fn i => (processInput(i,mod_path,env, Semantics.top_val_env, N.top_level_name,top_loaded_files_ht))) inputs
in ()
end

val _ = (Semantics.processAlreadParsedInputsRef := processAlreadyParsedInputs)

fun makeLibFileName(home,fname) = List.foldl (fn (dir, path) => Paths.makePath(path, dir)) home ["lib", "basic", fname]

val (athena_home_option,athena_home) = (Paths.athena_home_option,Paths.athena_home)
Expand Down
19 changes: 19 additions & 0 deletions semantics.sml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ val evaluateStringFlexible:((string * value_environment ref) -> value) ref = ref

val processString:((string * (Symbol.symbol list) * value_environment ref * value_environment ref) -> unit) ref = ref (fn _ => ())

val processAlreadParsedInputsRef :((A.user_input list * (Symbol.symbol list) * value_environment ref * value_environment ref) -> unit) ref = ref (fn _ => ())

val (ABaseInsert,ABaseAugment) = (ABase.insert,ABase.augment)

fun putValIntoAB(propVal(P),ab) = ABase.insert(P,ab)
Expand Down Expand Up @@ -2203,6 +2205,23 @@ fun liftArg(arg_val,expected_arity,pos_opt) =
end
else evError(wrongArgKindExpectationOnly(termLCType,arg_val),pos_opt)
| _ => evError(wrongArgKindExpectationOnly(termLCType,arg_val),pos_opt))

fun exceptionToString(e) =
let fun f(ErrorMsg.ParseError((l,p),str)) = ("\n"^A.posToString({line=l,file=(!Paths.current_file),pos=p})^": Parsing error, "^str^".\n")
| f(A.LexError(str,SOME(pos))) = ("\n"^A.posToString(pos)^": Lexical error, "^str^".\n")
| f(A.LexError(str,NONE)) = ((!Paths.current_file)^": Lexical error at end of file, "^str^".\n")
| f(A.SyntaxError(str,SOME(pos))) = ("\n"^(A.posToString pos)^": Syntax error: "^str^".\n")
| f(A.SyntaxError(str,NONE)) = ("\n"^(!Paths.current_file)^": Syntax error: "^str^".\n")
| f(AthenaError(msg)) = ("\n"^msg^"\n")
| f(EvalError(x)) = makeErrorWithPosInfo(x)
| f(Data.GenEx(str)) = str^"\n"
| f(GenEx(x as (msg,pos_opt))) = makeErrorWithPosInfo(x)
| f(Basic.Fail(str)) = "\n"^str^"\n"
| f(Basic.FailLst(strings)) = "\n"^(Basic.printListStr(strings,fn x => x, "\n"))^"\n"
| f(_) = "\nUnknown error: "^(exnMessage e)
in
f e
end

val (lp,rp,space) = ("(",")"," ")

Expand Down
7 changes: 5 additions & 2 deletions server.sml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
A function (Server.makeServerFun) that creates an Athena TCP server, which
can be hit by any TCP client (written in any language), local or remote.
Note: This is a single-threaded server and its use is not recommended.
For a robust Athena server, use MLton to generate an executable and then
start a TCP server on the desired port as described in mlton_main.sml.
=======================================================================*)

Expand All @@ -25,7 +28,7 @@ fun makeServerFun([termVal(t),cv],env,ab) =
".\nThe procedure must take a string and return a string."))
end
fun runServerFun([termVal(pt)],env,ab) =
let val serverFun = Socket.makeServer(input_buffer_size,processString)
let val serverFun = SocketImp.makeServer(input_buffer_size,processString)
val port = (case AthTerm.getNumVal(pt) of
SOME(A.int_num(p,_),_) => p
| _ => primError("A port number (numeral) was expected here"))
Expand All @@ -44,4 +47,4 @@ fun makeServerFun([termVal(t),cv],env,ab) =
| makeServerFun(vals,env,ab) =
primError(wrongArgNumber(N.makeServerFun_name,length(vals),2))

end
end
2 changes: 1 addition & 1 deletion sockets.sml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ to implement an Athena server that can be hit by arbitrary TCP clients
=======================================================================*)

structure Socket = struct
structure SocketImp = struct

open TextIO

Expand Down
Loading

0 comments on commit c7b5654

Please sign in to comment.