Skip to content

Commit

Permalink
Merge pull request #1 from GameOfPods/v0.0.3
Browse files Browse the repository at this point in the history
V0.0.3
  • Loading branch information
RedRem95 authored Aug 11, 2024
2 parents 20f5bee + 25bc0a9 commit db0c13d
Show file tree
Hide file tree
Showing 10 changed files with 488 additions and 78 deletions.
28 changes: 22 additions & 6 deletions PAT/modules/book/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Tuple, Dict, Any, List, Union, Set, Optional

from PAT.modules import PATModule
from PAT.utils.summerize import LLM as LLMService


class BookModule(PATModule):
Expand All @@ -29,6 +30,7 @@ class BookModule(PATModule):
_SPACY_MODELS = {"en": "en_core_web_trf", "de": "de_core_news_lg"}

_SUMMARIZE_MODEL: Optional[str] = None
_SUMMARIZE_SERVICE: Optional[LLMService] = None

def _chapter_valid(self, chapter_name: str, chapter_counter: Dict[str, int]) -> bool:
if len(self._BOOK_VALID_CHAPTERS[self._book.title]) > 0:
Expand All @@ -54,6 +56,11 @@ def config_keys(cls) -> Dict[str, Dict[str, Any]]:
"type": str, "required": False, "default": None,
"help": f"OpenAI model to use to summarize book content. "
f"If not provided no summarization will be performed. (Default: %(default)s)"
},
"book_summarize_service": {
"type": str, "choices": [x.name for x in LLMService], "required": False,
"default": LLMService.OpenAI.name,
"help": f"Service to use for summarization tasks. Choices: %(choices)s. (Default: %(default)s)"
}
}

Expand All @@ -68,6 +75,13 @@ def load(cls, config: Dict[str, Any]):
cls._LOGGER.info(f"Set valid chapters for '{k}' to {cls._BOOK_VALID_CHAPTERS[k]} "
f"from '{config['chapter_infos']}'")
cls._SUMMARIZE_MODEL = config.get("book_summarize_model", None)
if "book_summarize_service" in config:
try:
cls._SUMMARIZE_SERVICE = [x for x in LLMService if x.name == config["book_summarize_service"]][0]
except IndexError:
cls._LOGGER.error(
f"Could not find Summarization LLM service with name {config['book_summarize_service']}"
)

def __init__(self, file: str):
super().__init__(file)
Expand Down Expand Up @@ -173,14 +187,15 @@ def toRoman(n):
[x.text for x in doc if not any([x.is_space, x.is_punct, x.is_stop])])
chapter_info["entities"] = defaultdict(Counter)

if self._SUMMARIZE_MODEL is not None:
if self._SUMMARIZE_SERVICE is not None and self._SUMMARIZE_MODEL is not None:
try:
from PAT.utils.summerize import summarize, Templates
chapter_info["summary"] = summarize(
text=text,
openai_model=self._SUMMARIZE_MODEL,
language=self._lang,
model=self._SUMMARIZE_MODEL,
lang=self._lang,
template=Templates.LANGCHAIN_DEFAULT,
llm=self._SUMMARIZE_SERVICE,
)
except Exception as ignore:
pass
Expand All @@ -200,7 +215,7 @@ def toRoman(n):
"chapters": chapter_infos
}

if self._SUMMARIZE_MODEL is not None:
if self._SUMMARIZE_SERVICE is not None and self._SUMMARIZE_MODEL is not None:
try:
from PAT.utils.summerize import summarize, Templates
self._LOGGER.info("Creating summary of whole book")
Expand All @@ -209,9 +224,10 @@ def toRoman(n):
)
ret["summary"] = summarize(
text=text_f,
openai_model=self._SUMMARIZE_MODEL,
language=self._lang,
model=self._SUMMARIZE_MODEL,
lang=self._lang,
template=Templates.LANGCHAIN_DEFAULT,
llm=self._SUMMARIZE_SERVICE,
)
except Exception as ignore:
pass
Expand Down
162 changes: 136 additions & 26 deletions PAT/modules/speaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

from PAT.modules import PATModule
from PAT.modules.speaker.diarizators import Diarizators, PyannoteDiarizator
from PAT.modules.speaker.speakermatching import SpeakerMatching
from PAT.modules.speaker.transcriptors import Transcriptor, WordTupleSpeaker
from PAT.utils.punctuation import restore_punctuation
from PAT.utils.summerize import LLM as LLMService


class SpeakerDetectionModule(PATModule):
Expand All @@ -45,6 +47,8 @@ class SpeakerDetectionModule(PATModule):
_TRANSCRIPTOR_MODEL: Optional[str] = None
_TRANSCRIPTOR_LANGUAGE: Optional[str] = None
_SUMMARIZE_MODEL: Optional[str] = None
_SPEAKER_MATCHER: Optional[SpeakerMatching] = None
_SUMMARIZE_SERVICE: Optional[LLMService] = None

def __init__(self, file: str):
super().__init__(file)
Expand Down Expand Up @@ -106,8 +110,12 @@ def process(self) -> Union[Tuple[Dict[str, Any], List[Tuple[str, bytes]]], Dict[
)
))
i += 1
self._LOGGER.info(
f"Splitting audio in {len(muc)} {max_clip_len} long snippets for demucs to save on ram usage"
)
else:
muc = [(self.file, speach_file)]
self._LOGGER.info(f"Audio clip is shorter than {max_clip_len}. Will not split audio for demucs")

os.makedirs(os.path.join(self._TEMP_DIR, "demucs"), exist_ok=True)
for f, t in muc:
Expand Down Expand Up @@ -146,10 +154,10 @@ def process(self) -> Union[Tuple[Dict[str, Any], List[Tuple[str, bytes]]], Dict[
else:
from uuid import uuid4
from collections import defaultdict
from rttm_manager import export_rttm
from rttm_manager import export_rttm, RTTM
segments_by_speaker = defaultdict(list)

rttm_list = self.__class__._DIARIZER.diarize(audio_file=speach_file, temp_dir=self._TEMP_DIR)
rttm_list: List[RTTM] = self.__class__._DIARIZER.diarize(audio_file=speach_file, temp_dir=self._TEMP_DIR)
for rttm in rttm_list:
label = rttm.speaker_name
segment_start, segment_duration = rttm.turn_onset, rttm.turn_duration
Expand All @@ -176,6 +184,73 @@ def process(self) -> Union[Tuple[Dict[str, Any], List[Tuple[str, bytes]]], Dict[
ret[1].append(("speaker.rttm", f.read().encode("utf-8")))
os.remove(_tmp_rttm_file)

all_speakers: Dict[str, Optional[Tuple[str, float]]] = {x.speaker_name: None for x in rttm_list}

def speaker_str(_s):
return f'{all_speakers[_s][0]}' if _s in all_speakers else _s

if self._SPEAKER_MATCHER is not None:
from pydub import AudioSegment
from io import BytesIO
audio = AudioSegment.from_file(speach_file)
for s in all_speakers.keys():
speaker_times: List[Tuple[Optional[float], Optional[float]]] = []

for rttm in (x for x in rttm_list if x.speaker_name == s):
st, en = rttm.turn_onset, rttm.turn_onset + rttm.turn_duration

def overlap(_x: RTTM):
_x1, _x2 = _x.turn_onset, _x.turn_onset + _x.turn_duration
return st <= _x1 <= en or st <= _x2 <= en or _x1 <= st <= _x2 or _x1 <= en <= _x2

if not any(overlap(_x=x) for x in rttm_list if x.speaker_name != s):
speaker_times.append((st, en))
speaker_times = [x for x in speaker_times if not any(y is None for y in x)]
if sum(abs(e - s) for s, e in speaker_times) > 10:
break

if len(speaker_times) > 0:
speaker_audio = audio[speaker_times[0][0] * 1000:speaker_times[0][1] * 1000]
for speaker_time in speaker_times[1:]:
speaker_audio = speaker_audio + audio[speaker_time[0] * 1000:speaker_time[1] * 1000]
with BytesIO() as output:
speaker_audio.export(output, format="mp3")
output.seek(0)
pred_speaker = self._SPEAKER_MATCHER.match_speaker(audio=output)
if pred_speaker:
all_speakers[s] = pred_speaker

self._LOGGER.info(", ".join(f"{k} might be {v1} ({v2:.4f})" for k, (v1, v2) in all_speakers.items()))
speaker_confidence_t = float(os.environ.get("SPEAKER_CONFIDENCE_VALUE", 0.5))
if speaker_confidence_t > 0:
self._LOGGER.info(
f"Removing all speaker mappings that have confidence lower than {speaker_confidence_t}"
)
all_speakers = {x: y if y[1] > speaker_confidence_t else None for x, y in all_speakers.items()}

if len(set(all_speakers.values())) != len(all_speakers):
from collections import Counter
self._LOGGER.error("Two or more speakers matched to the same sample. This should not happen")
c = Counter(all_speakers.values())
for k, v in all_speakers.items():
if c[v] >= 2:
all_speakers[k] = f"{v[0]}-{c[v]}", v[1]
c[v] -= 1

all_speakers_keys = all_speakers.keys()
all_speakers = {x: y for x, y in all_speakers.items() if y is not None}

self._LOGGER.info(f"Final mapping: {', '.join(f'{s} -> {speaker_str(s)}' for s in all_speakers_keys)}")

ret[0]["speaker_mapping"] = {s: speaker_str(s) for s in all_speakers_keys}

if "speaker_counts" in ret[0]:
ret[0]["speaker_counts_o"] = ret[0]["speaker_counts"]
ret[0]["speaker_counts"] = {speaker_str(k): v for k, v in ret[0]["speaker_counts_o"].items()}
if "speaker_durations" in ret[0]:
ret[0]["speaker_durations_o"] = ret[0]["speaker_durations"]
ret[0]["speaker_durations"] = {speaker_str(k): v for k, v in ret[0]["speaker_durations_o"].items()}

if self.__class__._TRANSCRIPTOR is None:
self.__class__._LOGGER.error("No transcriptor selected. Please make sure to meet requirements")
else:
Expand All @@ -186,58 +261,62 @@ def process(self) -> Union[Tuple[Dict[str, Any], List[Tuple[str, bytes]]], Dict[
)
self._LOGGER.info(f"Successful transcription of audio. Transcribed as language {lang}")
ret[0]["language"] = lang

self._LOGGER.info("Creating naive transcript without using diarization")
transcript: str = " ".join(x.word.strip() for x in words if x.word is not None and len(x.word.strip()) > 0)
self._LOGGER.info(f"Fixing punctuation in transcript")
transcript: str = restore_punctuation(text=transcript)
ret[1].append(("transcript.txt", transcript.encode("utf-8")))

if rttm_list is not None and len(rttm_list) > 0:
from rttm_manager import RTTM
from PAT.modules.speaker.util import best_word_speaker_match
rttm_list: List[RTTM]
self._LOGGER.info("Diarization and Transcription ready. Matching conversation to speaker data")

speaker_idx = 0
word_speaker: List[WordTupleSpeaker] = []
for word in tqdm(words, unit="word", leave=False, desc="Matching words to speakers"):
speaker = best_word_speaker_match(word=word, speakers=rttm_list)[0]
word_speaker.append(
WordTupleSpeaker.from_word_tuple(word=word, speaker=speaker.speaker_name)
)
self._LOGGER.info("Building transcription")
# TODO: Match Person to speaker label
transcript: List[Tuple[str, List[str]]] = []
transcript: List[Tuple[WordTupleSpeaker, List[str]]] = []
for ws in word_speaker:
if len(transcript) <= 0 or transcript[-1][0] != ws.speaker:
transcript.append((ws.speaker, []))
transcript[-1][1].append(ws.word)
if len(transcript) <= 0 or transcript[-1][0].speaker != ws.speaker or abs(
transcript[-1][0].end - ws.start) > 10:
transcript.append((ws, []))
transcript[-1][-1].append(ws.word)

final_transcript = "\n".join(
f"{s}: {restore_punctuation(' '.join(t))}" for s, t in
tqdm(transcript, unit="sentence", leave=False, desc="Fixing punctuations in transcript")
f"{timedelta(seconds=s.start)} {speaker_str(s.speaker)}: {restore_punctuation(' '.join(t))}"
for s, t in tqdm(transcript, unit="section", leave=False, desc="Fixing punctuations in transcript")
)
ret[1].append(("transcript.txt", final_transcript.encode("utf-8")))
ret[1].append(("transcript_speaker.txt", final_transcript.encode("utf-8")))

else:
self._LOGGER.info("Diarization not done. Creating single transcript without speaker separation")
transcript: str = " ".join(
x.word.strip() for x in words if x.word is not None and len(x.word.strip()) > 0)
self._LOGGER.info(f"Fixing punctuation in transcript")
transcript: str = restore_punctuation(text=transcript)
ret[1].append(("transcript.txt", transcript.encode("utf-8")))
self._LOGGER.info("Diarization not done. Using naive transcript as final transcript")
final_transcript = transcript

# TODO: Maybe use LLM to fix transcription?

if self._SUMMARIZE_MODEL is not None:
if self._SUMMARIZE_SERVICE is not None and self._SUMMARIZE_MODEL is not None:
try:
from PAT.utils.summerize import summarize, Templates
self._LOGGER.info("Creating summary of final transcript")
summ = summarize(
text=final_transcript,
openai_model=self._SUMMARIZE_MODEL,
language=lang,
template=Templates.Podcast
model=self._SUMMARIZE_MODEL,
lang=lang,
template=Templates.Podcast,
llm=self._SUMMARIZE_SERVICE,
)
if summ is not None:
ret[1].append(("summarization.txt", summ.encode("utf-8")))
except Exception:
self._LOGGER.warning("Failed to summarize transcript")
except Exception as e:
self._LOGGER.error(f"Failed to summarize transcript: {e}")
else:
self._LOGGER.info("No model for summarization set. Will not summarize")
self._LOGGER.info("No service/model for summarization set. Will not summarize")

return ret

Expand All @@ -258,6 +337,20 @@ def load(cls, config: Dict[str, Any]):
cls._TRANSCRIPTOR = config.get("transcriptor", None)
cls._TRANSCRIPTOR_MODEL = config.get("transcriptor_model", None)
cls._SUMMARIZE_MODEL = config.get("transcription_summarize_model", None)
cls._SUMMARIZE_SERVICE = config.get("transcription_summarize_service", LLMService.OpenAI)
if "transcription_summarize_service" in config:
try:
cls._SUMMARIZE_SERVICE = [
x for x in LLMService if x.name == config["transcription_summarize_service"]
][0]
except IndexError:
cls._LOGGER.error(
f"Could not find Summarization LLM service with name {config['transcription_summarize_service']}"
)
if "speaker_samples" in config and config["speaker_samples"] is not None and len(config["speaker_samples"]) > 0:
cls._SPEAKER_MATCHER = SpeakerMatching(
{os.path.splitext(os.path.basename(x))[0]: x for x in config["speaker_samples"]}
)

@classmethod
def config_keys(cls) -> Dict[str, Dict[str, Any]]:
Expand All @@ -275,13 +368,25 @@ def config_keys(cls) -> Dict[str, Dict[str, Any]]:
ret["diarizer"] = {
"type": lambda x: cls._AVAILABLE_DIARIZERS[x], "required": False,
"default": sorted(cls._AVAILABLE_DIARIZERS)[0], "choices": cls._AVAILABLE_DIARIZERS,
"help": f"Diarization method to use. Options: {', '.join(cls._AVAILABLE_DIARIZERS)}. (Default: %(default)s)"
"help": f"Diarization method to use. Options: {', '.join(cls._AVAILABLE_DIARIZERS)}. "
f"(Default: %(default)s)"
}
if SpeakerMatching.is_available():
def _valid_file(x):
import os
if os.path.exists(x) and os.path.isfile(x):
return x

ret["speaker_samples"] = {
"type": _valid_file, "required": False, "default": None, "nargs": "+",
"help": "Provide sample audio files to match the diarized speakers against. "
"Should be formatted as /path/to/file/<name_of_speaker>.ext"
}
if len(cls._AVAILABLE_TRANSCRIPTORS) > 0:
ret["transcriptor"] = {
"type": lambda x: cls._AVAILABLE_TRANSCRIPTORS[x], "required": False,
"default": sorted(cls._AVAILABLE_TRANSCRIPTORS)[0], "choices": cls._AVAILABLE_TRANSCRIPTORS,
"help": f"Transcription method to use. Options: {', '.join(cls._AVAILABLE_TRANSCRIPTORS)}. (Default: %(default)s)"
"help": f"Transcription method to use. Choices: %(choices)s. (Default: %(default)s)"
}
ret["transcriptor_model"] = {
"type": str, "required": False, "default": "medium",
Expand All @@ -291,6 +396,11 @@ def config_keys(cls) -> Dict[str, Dict[str, Any]]:
"type": str, "required": False, "default": None,
"help": f"Transcription language. Omit for autodetect. (Default: %(default)s)"
}
ret["transcription_summarize_service"] = {
"type": str, "choices": [x.name for x in LLMService], "required": False,
"default": LLMService.OpenAI.name,
"help": f"Service to use for summarization tasks. Choices: %(choices)s. (Default: %(default)s)"
}
ret["transcription_summarize_model"] = {
"type": str, "required": False, "default": None,
"help": f"OpenAI model to use to summarize transcription. "
Expand Down
Loading

0 comments on commit db0c13d

Please sign in to comment.