Skip to content

Commit

Permalink
Merge pull request #38 from Lucs1590/22-add-multi-language-descriptio…
Browse files Browse the repository at this point in the history
…n-support

Add multi language description support
  • Loading branch information
Lucs1590 authored Nov 20, 2024
2 parents 01b0023 + 2c58d58 commit 48af34a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 18 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ numpy==1.26.4
pandas==2.2.3
python-dotenv==1.0.1
questionary==2.0.1
scipy==1.14.1
scipy==1.13.1
tcxreader==0.4.10
tqdm==4.67.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"pandas==2.2.3",
"python-dotenv==1.0.1",
"questionary==2.0.1",
"scipy==1.14.1",
"scipy==1.13.1",
"tcxreader==0.4.10",
"tqdm==4.67.0"
],
Expand Down
40 changes: 26 additions & 14 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def main():
_, tcx_data = validate_tcx_file(file_path)
if ask_llm_analysis():
plan = ask_training_plan()
language = ask_desired_language()
logger.info("Performing LLM analysis")
perform_llm_analysis(tcx_data, sport, plan)
perform_llm_analysis(tcx_data, sport, plan, language)
else:
logger.error("Invalid sport selected")
raise ValueError("Invalid sport selected")
Expand Down Expand Up @@ -194,35 +195,46 @@ def ask_training_plan() -> str:
).ask()


def perform_llm_analysis(data: TCXReader, sport: str, plan: str) -> str:
dataframe = preprocess_trackpoints_data(data)
def ask_desired_language() -> str:
return questionary.text(
"In which language do you want the analysis to be provided? (Default is Portuguese)",
default="Portuguese (Brazil)"
).ask()

prompt_template = """
SYSTEM: You are an AI coach helping athletes optimize and improve their performance.
Based on the provided {sport} training session data, perform the following analysis:

1. Identify key performance metrics.
2. Highlight the athlete's strengths during the session.
3. Pinpoint areas where the athlete can improve.
4. Offer actionable suggestions for enhancing performance in future {sport} sessions.
def perform_llm_analysis(data: TCXReader, sport: str, plan: str, language: str) -> str:
dataframe = preprocess_trackpoints_data(data)

Training session data:
prompt_template = """
SYSTEM: You are an AI performance coach specializing in analyzing athletic performance to help athletes with their trainings.
Using the provided {sport} training session data, analyze the athlete's performance and deliver a detailed analysis and practical advice in {language} language.
Your analysis should include:
1. Key Performance Metrics: Identify and Evaluate the most relevant metrics from the session, understanding the athlete's overall performance
2. Strengths: Highlight the athlete's strongest aspects during the session, supported by specific metrics.
3. Improvement Opportunities: Pinpoint specific areas for growth and improvement.
4. Actionable Suggestions: Provide clear, practical recommendations to help the athlete enhance their performance in future {sport} sessions.
Ensure your response is data-driven, clear, and motivational, helping the athlete make measurable progress.
Training Session Data:
{training_data}
"""

if plan:
prompt_template += "\nTraining plan details: {plan}"
prompt_template += "\n\nTraining Plan Details:\n{plan}"

prompt = PromptTemplate.from_template(prompt_template).format(
sport=sport,
training_data=dataframe.to_csv(index=False),
language=language,
plan=plan
)

openai_llm = ChatOpenAI(
openai_api_key=os.getenv("OPENAI_API_KEY"),
model_name="gpt-4o",
max_tokens=1500,
model_name="gpt-4o-mini",
max_tokens=2000,
temperature=0.6,
max_retries=5
)
Expand Down
18 changes: 16 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_latest_download,
validation,
ask_training_plan,
ask_desired_language,
ask_llm_analysis,
perform_llm_analysis,
preprocess_trackpoints_data,
Expand Down Expand Up @@ -198,6 +199,7 @@ def test_main_invalid_sport(self, mock_indent, mock_validate, mock_format, mock_
mock_validate.assert_not_called()
mock_indent.assert_not_called()

@patch('src.main.ask_desired_language')
@patch('src.main.ask_training_plan')
@patch('src.main.perform_llm_analysis')
@patch('src.main.ask_llm_analysis')
Expand All @@ -211,14 +213,15 @@ def test_main_invalid_sport(self, mock_indent, mock_validate, mock_format, mock_
@patch('src.main.indent_xml_file')
def test_main_bike_sport(self, mock_indent, mock_validate, mock_format, mock_ask_path, mock_download,
mock_ask_id, mock_ask_location, mock_ask_sport, mock_llm_analysis, mock_perform_llm,
mock_training_plan):
mock_training_plan, mock_language):
mock_ask_sport.return_value = "Bike"
mock_ask_location.return_value = "Local"
mock_ask_path.return_value = "assets/bike.tcx"
mock_llm_analysis.return_value = True
mock_validate.return_value = True, "TCX Data"
mock_perform_llm.return_value = "Training Plan"
mock_training_plan.return_value = ""
mock_language.return_value = "Portuguese"

main()

Expand Down Expand Up @@ -306,6 +309,16 @@ def test_ask_training_plan(self):
)
self.assertEqual(result, "")

def test_ask_desired_language(self):
with patch('src.main.questionary.text') as mock_text:
mock_text.return_value.ask.return_value = "Portuguese"
result = ask_desired_language()
mock_text.assert_called_once_with(
'In which language do you want the analysis to be provided? (Default is Portuguese)',
default='Portuguese (Brazil)'
)
self.assertEqual(result, "Portuguese")

def test_ask_llm_analysis(self):
with patch('src.main.questionary.confirm') as mock_confirm:
mock_confirm.return_value.ask.return_value = True
Expand All @@ -323,8 +336,9 @@ def test_perform_llm_analysis(self, mock_chat):
tcx_data = self.running_example_data
sport = "Run"
plan = "Training Plan"
lang = "Portuguese"

result = perform_llm_analysis(tcx_data, sport, plan)
result = perform_llm_analysis(tcx_data, sport, plan, lang)
self.assertEqual(result, "Training Plan")

def test_preprocess_running_trackpoints_data(self):
Expand Down

0 comments on commit 48af34a

Please sign in to comment.