From 75a9edd3bd867ce65d3c02575d31b61f394d3c1a Mon Sep 17 00:00:00 2001 From: Dong Yuwei Date: Wed, 1 May 2024 23:50:49 +0800 Subject: [PATCH] [WIP] predict next words --- src/InputController.h | 4 ++ src/InputController.mm | 98 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/src/InputController.h b/src/InputController.h index b4c00d2..bf9489e 100644 --- a/src/InputController.h +++ b/src/InputController.h @@ -5,6 +5,7 @@ #import "ConversionEngine.h" @interface InputController : IMKInputController { + NSMutableString *_sentenceBuffer; NSMutableString *_composedBuffer; NSMutableString *_originalBuffer; NSInteger _insertionIndex; @@ -17,6 +18,9 @@ AnnotationWinController *_annotationWin; } +- (NSMutableString *)sentenceBuffer; +- (void)setSentenceBuffer:(NSString *)string; + - (NSMutableString *)composedBuffer; - (void)setComposedBuffer:(NSString *)string; - (NSMutableString *)originalBuffer; diff --git a/src/InputController.mm b/src/InputController.mm index 6ad6509..21c19b1 100644 --- a/src/InputController.mm +++ b/src/InputController.mm @@ -173,6 +173,7 @@ - (BOOL)onKeyEvent:(NSEvent *)event client:(id)sender { if (hasBufferedText) { [self appendToComposedBuffer:characters]; [self commitCompositionWithoutSpace:sender]; + [self setSentenceBuffer: @""]; return YES; } } @@ -228,6 +229,27 @@ - (void)commitComposition:(id)sender { [sender insertText:text replacementRange:NSMakeRange(NSNotFound, NSNotFound)]; [self reset]; + + NSLog(@"Current Sentence Buffer: %@", self.sentenceBuffer); + if ([self doesSentenceBufferIncludeSpace]) { + [self fetchPredictionsForText:self.sentenceBuffer completion:^(NSDictionary *responseDict, NSArray *bertArray, NSError *error) { + if (error) { + NSLog(@"Error: %@", error.localizedDescription); + } else { + NSLog(@"BERT: %@", bertArray); + dispatch_async(dispatch_get_main_queue(), ^{ + [sharedCandidates setCandidateData:bertArray]; + [sharedCandidates show:kIMKLocateCandidatesBelowHint]; + }); + } + }]; + } + +} + +- (BOOL)doesSentenceBufferIncludeSpace { + NSRange range = [self.sentenceBuffer rangeOfString:@" "]; + return range.location != NSNotFound; } - (void)commitCompositionWithoutSpace:(id)sender { @@ -242,6 +264,66 @@ - (void)commitCompositionWithoutSpace:(id)sender { [self reset]; } +- (NSString *) fetchAPIURL { + NSUserDefaults *defaults = [NSUserDefaults standardUserDefaults]; + NSString *apiURL = [defaults stringForKey:@"NEXT_WORD_PREDICTION_SERVICE_URL"]; + if (apiURL) { + return apiURL; + } else { + return @"http://127.0.0.1:8080/get_end_predictions"; + } +} + +- (void)fetchPredictionsForText:(NSString *)text completion:(void(^)(NSDictionary *responseDict, NSArray *bertArray, NSError *error))completionHandler { + NSString *urlString = [self fetchAPIURL]; + NSURL *url = [NSURL URLWithString:urlString]; + NSMutableURLRequest *request = [NSMutableURLRequest requestWithURL:url]; + request.HTTPMethod = @"POST"; + [request setValue:@"application/json" forHTTPHeaderField:@"Content-Type"]; + + NSDictionary *jsonBody = @{@"input_text": text, @"top_k": @"9"}; + NSError *jsonError; + NSData *jsonData = [NSJSONSerialization dataWithJSONObject:jsonBody options:0 error:&jsonError]; + + if (jsonError) { + completionHandler(nil, nil, jsonError); + return; + } + + request.HTTPBody = jsonData; + + NSURLSession *session = [NSURLSession sharedSession]; + NSURLSessionDataTask *task = [session dataTaskWithRequest:request completionHandler:^(NSData *data, NSURLResponse *response, NSError *error) { + if (error) { + completionHandler(nil, nil, error); + return; + } + + NSError *jsonParsingError; + NSDictionary *responseDict = [NSJSONSerialization JSONObjectWithData:data options:0 error:&jsonParsingError]; + + if (jsonParsingError) { + completionHandler(nil, nil, jsonParsingError); + } else { + NSArray *bertArray = nil; + NSArray *bertCNArray = nil; + + // Parsing the bert string + NSString *bertString = [responseDict objectForKey:@"bert"]; + if (bertString) { + bertArray = [bertString componentsSeparatedByString:@"\n"]; + bertArray = [bertArray filteredArrayUsingPredicate:[NSPredicate predicateWithBlock:^BOOL(id evaluatedObject, NSDictionary *bindings) { + return [evaluatedObject length] > 0 && ![evaluatedObject isEqualToString:@"[UNK]"]; + }]]; + } + + completionHandler(responseDict, bertArray, nil); + } + }]; + + [task resume]; +} + - (void)reset { [self setComposedBuffer:@""]; [self setOriginalBuffer:@""]; @@ -264,6 +346,22 @@ - (NSMutableString *)composedBuffer { - (void)setComposedBuffer:(NSString *)string { NSMutableString *buffer = [self composedBuffer]; + if (string && string.length > 0) { + NSString * sentence = self.sentenceBuffer; + [self setSentenceBuffer: [NSString stringWithFormat:@"%@ %@", sentence, string]]; + } + [buffer setString:string]; +} + +- (NSMutableString *)sentenceBuffer { + if (_sentenceBuffer == nil) { + _sentenceBuffer = [[NSMutableString alloc] init]; + } + return _sentenceBuffer; +} + +- (void)setSentenceBuffer:(NSString *)string { + NSMutableString *buffer = [self sentenceBuffer]; [buffer setString:string]; }