Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread-based TCP sockets #20

Merged
merged 9 commits into from
Nov 12, 2024
10 changes: 6 additions & 4 deletions athena.mlb
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ local
$(SML_LIB)/mlyacc-lib/mlyacc-lib.mlb
$(SML_LIB)/cml/cml.mlb
$(SML_LIB)/basis/mlton.mlb
$(SML_LIB)/smlnj-lib/INet/inet-lib.mlb
in
compat_11072.sml
base.sig
Expand Down Expand Up @@ -86,11 +87,12 @@ in
topenv_part1.sml
topenv_part2.sml
definition_processor.sml
(**
prolog_solver.sml
**)
repl.sml
athena.sml
mlton_server.sml
sml_c_util.sml
(***
xsb_prolog_solver.sml
***)
athena.sml
mlton_main.sml
end
2 changes: 1 addition & 1 deletion athena.sml
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,5 @@ fun main(arg0,args) =
| [file_name] => M(SOME(file_name),false)
| file_name::(_::_) => M(SOME(file_name),true))
end

end
2 changes: 1 addition & 1 deletion base.sml
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ fun repeat n f =
val (newline,lparen,rparen,lbrack,rbrack,lbrace,rbrace,
blank,comma,period,colon,semi_colon,string_quote) = ("\n","(",")","[","]","{","}"," ",",",".",":",";","\"")

fun mark(s) = (print("\n");repeat 3 (fn _ => print(s));print("\n"))
fun mark(s) = (print("\n");repeat 10 (fn _ => print(s));print("\n"))

fun failLst(messages) = raise FailLst messages

Expand Down
111 changes: 111 additions & 0 deletions client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import sys
import socket
import sys
import json
from utils import *

def extract_response(sock):
chunks = []
while True:
chunk = sock.recv(40960)
if not chunk: # Connection closed by server
break
chunks.append(chunk)
return b''.join(chunks).decode('utf-8')


# Note: send_request_to_server_simple should only be used if the corresponding server uses readAllSimple.

def send_request_to_server_simple(request: str, port=10000) -> str:
server_address = 'localhost'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(4) # Set a 5-second timeout for demonstration
try:
sock.connect((server_address, port))
sock.sendall(request.encode('utf-8'))
sock.sendall(''.encode('utf-8'))
sock.shutdown(socket.SHUT_WR)
return extract_response(sock)
except socket.timeout:
print("Connection timed out.")
return ""
except Exception as e:
print(f"An error occurred: {e}")
return ""

# The default implementation of send_request_to_server encodes the size of the
# Athena payload into the first 4 bytes of the request (thus allowing payloads up to 4GB).
# This must be used in conjunction with readAll on the server side, which first extracts
# the leading 4 bytes of the client's request, transforms those into the integer value N
# they represent, and then reads exactly N bytes from the connection.

def send_request_to_server(request: str, port=10000) -> str:
server_address = 'localhost'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.settimeout(4)
try:
sock.connect((server_address, port))
# Send length as 4-byte integer first
request_bytes = request.encode('utf-8')
length = len(request_bytes)
sock.sendall(length.to_bytes(4, byteorder='big'))
sock.sendall(request_bytes)
sock.shutdown(socket.SHUT_WR)
return extract_response(sock)
except socket.timeout:
print("Connection timed out.")
return ""
except Exception as e:
print(f"An error occurred: {e}")
return ""

def spaces(i):
if i < 1:
return ""
else:
return " " + spaces(i-1)

def checkProof(line):
#
# This function takes a string representing a line from a .jsonl file like gpt_and_english_proofs_230.jsonl
# and checks whether or not the formal proof that can be found in that line is valid. If it is, it returns
# a pair of the form (True,<theorem>), where the string <theorem> is a formula representing the conclusion of the proof.
# If the proof is not valid, then the result is a pair of the form (False,<error message>), where the string <error message>
# provides more detail on where exactly the proof went wrong, and what exactly was wrong (a syntax error, a logic error, and if
# so, the type of error, etc.).
#
D = json.loads(line)
premises = D['premises']
goal = D['goal']
assumes = '\n'.join([spaces(index) + "assume premise-" + str(index+1) + " := " + premises[index] for index in range(len(premises))])
proof = assumes + "\nconclude " + goal + "\n" + D['ndlProof']
athena_response = send_request_to_server(proof)
athena_response_lower = athena_response.lower()
if 'error' in athena_response_lower or 'failed' in athena_response_lower:
return (False,athena_response)
else:
return (True,athena_response)

def checkAll(file):
#
# This function works with an entire jsonl file like gpt_and_english_proofs_230.jsonl. It basically
# iterates the single-line function checkProof over all the contents of the jsonl file. The result
# is a list of pairs (boolean_flag,<details>), as produced by checkProof, one for each line in the jsonl file.
#
L = getLines(file)
R = []
send_request_to_server('declare A, B, C, D, E, F, G, H, J, I, K: Boolean')
for i in range(len(L)):
print("About to work on proof #" + str(i))
response = checkProof(L[i])
R.append(response)
return R

# Example use (where file is a path like "../data/gpt_generated_athena_and_english_proofs_raw_data_230.ath"):
# R = checkAll(file)
# This will give all successful/valid proofs:
# T = [r for r in R if r[0]]
# And this will give all incorrect proofs:
# F = [r for r in R if not(r[0])]
1 change: 1 addition & 0 deletions make_athena_binary_for_linux
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mlton @MLton gc-summary -- -output athena -verbose 1 -default-ann 'allowFFI true' -drop-pass deepFlatten -codegen native athena.mlb
1 change: 1 addition & 0 deletions make_athena_binary_for_mac
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mlton @MLton gc-summary -- -output athena -verbose 1 -default-ann 'allowFFI true' -drop-pass deepFlatten athena.mlb
56 changes: 55 additions & 1 deletion mlton_main.sml
Original file line number Diff line number Diff line change
@@ -1 +1,55 @@
val _ = Athena.main("",CommandLine.arguments())
(*============================================================================================================

Assuming that the executable produced by running make_athena_binary_for_linux or make_athena_binary_for_mac
is named "athena", you can run the system in a number of ways:

./athena -> Start the Athena REPL
./athena foo.ath -> Load foo.ath first and then start the Athena REPL
./athena foo.ath quit -> Load foo.ath and then quit
./athena -port <number> -> Start an Athena TCP server running on port <number>
./athena -port <number> -file foo.ath -> Load foo.ath first and then start an Athena TCP server on port <number>

============================================================================================================*)


fun mlton_main(arg0,args) =
let fun M(file_name_option:string option,quit_after) =
(print("\nWelcome to Athena!\n");
print("\nType an expression or deduction at the\nprompt below, ");
print("and press Enter to evaluate it.\n");
print("\nTo exit Athena, type \"quit\" at the prompt\nand press Enter.\n");
if quit_after then
Athena.runWithStarterFileAndQuit(file_name_option)
else
Athena.runWithStarterFile(file_name_option);
OS.Process.success)
(**
val i = initializeXSB()
**)
in
(case args of
[] => M(NONE,false)
| [file_name] => M(SOME(file_name),false)
| [arg_1,arg_2] => if arg_1 = "-port" orelse arg_1 = "--port" then
let val port_num_opt = Int.fromString(arg_2)
in
(case port_num_opt of
SOME(port_num) => (Thread.spawn(fn () => AthenaServer.startServer(port_num,NONE)); Thread.run(); OS.Process.success)
| _ => (print("\nInvalid port number."); OS.Process.failure))
end
else M(SOME(arg_1),true)
| [arg_1,arg_2,arg_3,arg_4] =>
if (arg_1 = "-port" orelse arg_1 = "--port") andalso (arg_3 = "--file" orelse arg_3 = "-file") then
let val port_num_opt = Int.fromString(arg_2)
val starter_file = arg_4
in
(case port_num_opt of
SOME(port_num) => (Thread.spawn(fn () => AthenaServer.startServer(port_num,SOME(starter_file))); Thread.run(); OS.Process.success)
| _ => (print("\nInvalid port number."); OS.Process.failure))
end
else
M(SOME(arg_1),true)
| file_name::(_::_) => M(SOME(file_name),true))
end

val _ = mlton_main("",CommandLine.arguments())
33 changes: 12 additions & 21 deletions repl.sml
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,6 @@ fun printLoadedFiles(loaded_files : (string,bool) HashTable.hash_table) =
end

fun debugPrint(_) = ()

fun exceptionToString(e) =
let fun f(ErrorMsg.ParseError((l,p),str)) = ("\n"^A.posToString({line=l,file=(!Paths.current_file),pos=p})^": Parsing error, "^str^".\n")
| f(A.LexError(str,SOME(pos))) = ("\n"^A.posToString(pos)^": Lexical error, "^str^".\n")
| f(A.LexError(str,NONE)) = ((!Paths.current_file)^": Lexical error at end of file, "^str^".\n")
| f(A.SyntaxError(str,SOME(pos))) = ("\n"^(A.posToString pos)^": Syntax error: "^str^".\n")
| f(A.SyntaxError(str,NONE)) = ("\n"^(!Paths.current_file)^": Syntax error: "^str^".\n")
| f(Semantics.AthenaError(msg)) = ("\n"^msg^"\n")
| f(Semantics.EvalError(x)) = Semantics.makeErrorWithPosInfo(x)
| f(Data.GenEx(str)) = str^"\n"
| f(SemanticValues.GenEx(x as (msg,pos_opt))) = Semantics.makeErrorWithPosInfo(x)
| f(Basic.Fail(str)) = "\n"^str^"\n"
| f(Basic.FailLst(strings)) = "\n"^(Basic.printListStr(strings,fn x => x, "\n"))^"\n"
| f(_) = "\nUnknown error: "^(exnMessage e)
in
f e
end

fun showFreeIds(phr,mod_path) =
let val (new_phrase,vars,fids) = preProcessPhrase(phr,mod_path)
Expand Down Expand Up @@ -140,7 +123,7 @@ in
print(Semantics.summarizeTopCallStack()))
end
| TopEnv.Halt => ()
| _ => print(exceptionToString(e)))
| _ => print(Semantics.exceptionToString(e)))
end

fun pathToString(path) = if null(path) then "[]" else Basic.printListStr(path,Symbol.name,".")
Expand Down Expand Up @@ -337,7 +320,7 @@ and processModuleExtension(module:A.module_entry as {module_name,module_contents
val _ = returned_env := SV.valEnv({val_map=val_map1',mod_map=Symbol.enter(mod_map1,mod_sym,new_module)})
in
eval_env := SV.valEnv({val_map=val_map2,mod_map=Symbol.enter(mod_map2,mod_sym,new_module)})
end handle ex => (error_msg := exceptionToString(ex);
end handle ex => (error_msg := Semantics.exceptionToString(ex);
Paths.open_mod_paths := starting_open_mod_paths_val;
Paths.open_mod_directives := starting_open_mod_directives_val;
eval_env := starting_eval_env;
Expand Down Expand Up @@ -408,7 +391,7 @@ and processModule(module:A.module_entry as {module_name,module_contents,module_f
else ()
in
()
end) handle ex => (error_msg := exceptionToString(ex);
end) handle ex => (error_msg := Semantics.exceptionToString(ex);
Paths.open_mod_paths := starting_open_mod_paths_val;
Paths.open_mod_directives := starting_open_mod_directives_val;
eval_env := starting_eval_env;
Expand Down Expand Up @@ -795,7 +778,7 @@ fun getInputAndProcess() =
((Parse.parse_from_stream istream),true,"")
handle e => let val _ = Parse.setLinePos(1,0)
in
([],false,exceptionToString(e))
([],false,Semantics.exceptionToString(e))
end
in
if ok_input then
Expand Down Expand Up @@ -830,6 +813,7 @@ fun escape(str) =
loop(L,[])
end

(**** qqq PROCESSSTRING DEFINITION ****)
fun processString(cmd,mod_path,env,eval_env) =
let val stream = TextIO.openString (cmd)
val inputs = Parse.parse_from_stream(stream)
Expand All @@ -840,6 +824,13 @@ fun processString(cmd,mod_path,env,eval_env) =

val _ = (Semantics.processString := processString)

fun processAlreadyParsedInputs(inputs,mod_path,env,eval_env) =
let val _ = List.app (fn i => (processInput(i,mod_path,env, Semantics.top_val_env, N.top_level_name,top_loaded_files_ht))) inputs
in ()
end

val _ = (Semantics.processAlreadParsedInputsRef := processAlreadyParsedInputs)

fun makeLibFileName(home,fname) = List.foldl (fn (dir, path) => Paths.makePath(path, dir)) home ["lib", "basic", fname]

val (athena_home_option,athena_home) = (Paths.athena_home_option,Paths.athena_home)
Expand Down
19 changes: 19 additions & 0 deletions semantics.sml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ val evaluateStringFlexible:((string * value_environment ref) -> value) ref = ref

val processString:((string * (Symbol.symbol list) * value_environment ref * value_environment ref) -> unit) ref = ref (fn _ => ())

val processAlreadParsedInputsRef :((A.user_input list * (Symbol.symbol list) * value_environment ref * value_environment ref) -> unit) ref = ref (fn _ => ())

val (ABaseInsert,ABaseAugment) = (ABase.insert,ABase.augment)

fun putValIntoAB(propVal(P),ab) = ABase.insert(P,ab)
Expand Down Expand Up @@ -2203,6 +2205,23 @@ fun liftArg(arg_val,expected_arity,pos_opt) =
end
else evError(wrongArgKindExpectationOnly(termLCType,arg_val),pos_opt)
| _ => evError(wrongArgKindExpectationOnly(termLCType,arg_val),pos_opt))

fun exceptionToString(e) =
let fun f(ErrorMsg.ParseError((l,p),str)) = ("\n"^A.posToString({line=l,file=(!Paths.current_file),pos=p})^": Parsing error, "^str^".\n")
| f(A.LexError(str,SOME(pos))) = ("\n"^A.posToString(pos)^": Lexical error, "^str^".\n")
| f(A.LexError(str,NONE)) = ((!Paths.current_file)^": Lexical error at end of file, "^str^".\n")
| f(A.SyntaxError(str,SOME(pos))) = ("\n"^(A.posToString pos)^": Syntax error: "^str^".\n")
| f(A.SyntaxError(str,NONE)) = ("\n"^(!Paths.current_file)^": Syntax error: "^str^".\n")
| f(AthenaError(msg)) = ("\n"^msg^"\n")
| f(EvalError(x)) = makeErrorWithPosInfo(x)
| f(Data.GenEx(str)) = str^"\n"
| f(GenEx(x as (msg,pos_opt))) = makeErrorWithPosInfo(x)
| f(Basic.Fail(str)) = "\n"^str^"\n"
| f(Basic.FailLst(strings)) = "\n"^(Basic.printListStr(strings,fn x => x, "\n"))^"\n"
| f(_) = "\nUnknown error: "^(exnMessage e)
in
f e
end

val (lp,rp,space) = ("(",")"," ")

Expand Down
7 changes: 5 additions & 2 deletions server.sml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

A function (Server.makeServerFun) that creates an Athena TCP server, which
can be hit by any TCP client (written in any language), local or remote.
Note: This is a single-threaded server and its use is not recommended.
For a robust Athena server, use MLton to generate an executable and then
start a TCP server on the desired port as described in mlton_main.sml.

=======================================================================*)

Expand All @@ -25,7 +28,7 @@ fun makeServerFun([termVal(t),cv],env,ab) =
".\nThe procedure must take a string and return a string."))
end
fun runServerFun([termVal(pt)],env,ab) =
let val serverFun = Socket.makeServer(input_buffer_size,processString)
let val serverFun = SocketImp.makeServer(input_buffer_size,processString)
val port = (case AthTerm.getNumVal(pt) of
SOME(A.int_num(p,_),_) => p
| _ => primError("A port number (numeral) was expected here"))
Expand All @@ -44,4 +47,4 @@ fun makeServerFun([termVal(t),cv],env,ab) =
| makeServerFun(vals,env,ab) =
primError(wrongArgNumber(N.makeServerFun_name,length(vals),2))

end
end
2 changes: 1 addition & 1 deletion sockets.sml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ to implement an Athena server that can be hit by arbitrary TCP clients

=======================================================================*)

structure Socket = struct
structure SocketImp = struct

open TextIO

Expand Down
Loading
Loading