Skip to content

Commit

Permalink
chore: fix OpenAI GPT template
Browse files Browse the repository at this point in the history
  • Loading branch information
Two Dev committed Dec 29, 2024
1 parent 3ba6107 commit 6200dbc
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
2 changes: 1 addition & 1 deletion gemma_template/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
__url__ = "https://github.com/thewebscraping/gemma-template"
__author__ = "Tu Pham"
__author_email__ = "thetwofarm@gmail.com"
__version__ = "0.1.2"
__version__ = "0.1.3"
__license__ = "Apache-2.0"
45 changes: 32 additions & 13 deletions gemma_template/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ class StructureField(BaseTemplate):
"description": [
"Description",
"Introduction",
"Summary",
"Intro",
"Meta Description",
],
"document": ["Article", "Edit Article"],
Expand Down Expand Up @@ -401,6 +399,7 @@ def load_dataset(
min_chars_length: int = 2,
max_chars_length: int = 0,
max_concurrency: int = 4,
is_remove_data: bool = True,
is_close_async_loop: bool = True,
**kwargs,
) -> Union[Dataset, DatasetDict]:
Expand Down Expand Up @@ -436,6 +435,8 @@ def load_dataset(
Maximum character of a word, used to create unigrams, bigrams and trigrams. Default is 0.
max_concurrency (int):
Maximum number of concurrent threads for processing data. Default is 4.
is_remove_data (bool):
True will remove the original data from the dataset, otherwise it will keep the field as `data` in the dataset.
is_close_async_loop (bool):
By default it will close the asyncio event loop every time I finish processing the dataset data.
Although it has handled the `RuntimeError` exception. However, you should set it to False if running on Kaggle Notebooks and Colab.
Expand Down Expand Up @@ -478,6 +479,7 @@ async def create_task(config, hidden_count: int = 0):
min_chars_length=min_chars_length,
max_chars_length=max_chars_length,
excluded_fields=excluded_fields,
is_remove_data=is_remove_data,
)
)
if max_hidden_ratio > 0 and hidden_count < max_hidden_count:
Expand Down Expand Up @@ -836,8 +838,7 @@ def generate_user_prompt(

def generate_model_prompt(
self,
structure_template: Optional[TemplateTypes] = None,
excluded_fields: Optional[Sequence[str]] = (),
structure_template: Optional[TemplateTypes] = "",
bullet_style: Optional[Union[str, Literal["dash", "number"]]] = "dash",
**kwargs,
) -> str:
Expand All @@ -849,7 +850,6 @@ def generate_model_prompt(
Args:
structure_template (Optional[Union[str, Callable]]): A structure template defining the generating structure prompt.
excluded_fields (Sequence[str]): Fields excluded to response. Default is empty sequence.
bullet_style (Optional[str]): Bullet list style start dash or number. Default is dash.
**kwargs: See also `Template.template`.
Expand All @@ -866,11 +866,6 @@ def generate_model_prompt(
""" # noqa: E501

output_document = kwargs.get("output", "")
if excluded_fields:
for excluded_field in excluded_fields:
if excluded_field in kwargs:
kwargs.pop(excluded_field)

if isinstance(structure_template, (str, Callable)):
kwargs["document"] = output_document
if isinstance(structure_template, Callable):
Expand Down Expand Up @@ -916,6 +911,7 @@ def to_text(
language_code=user_kwargs.get("language_code", "auto"),
language=user_kwargs.get("language"),
is_masked=bool(user_kwargs.get("is_masked")),
data=self._get_origin_data(**kwargs),
)

def to_alpaca(
Expand All @@ -941,6 +937,7 @@ def to_alpaca(
language_code=user_kwargs.get("language_code", "auto"),
language=user_kwargs.get("language"),
is_masked=bool(user_kwargs.get("is_masked")),
data=self._get_origin_data(**kwargs),
)

def to_openai(
Expand All @@ -955,8 +952,16 @@ def to_openai(
user_template, instruction_template, structure_template, **kwargs
)
return dict(
human=user_template,
gpt=model_template,
conversations=[
{
"from": "human",
"value": user_template,
},
{
"from": "gpt",
"value": model_template,
},
],
is_instructed=bool(instruction_template is not None),
is_structured=bool(structure_template is not None),
unigrams=user_kwargs.get("unigrams", []) or [],
Expand All @@ -965,6 +970,7 @@ def to_openai(
language_code=user_kwargs.get("language_code", "auto"),
language=user_kwargs.get("language"),
is_masked=bool(user_kwargs.get("is_masked")),
data=self._get_origin_data(**kwargs),
)

def _get_template(
Expand Down Expand Up @@ -1069,10 +1075,14 @@ def _ftm_template(word):
def _formatting_structure_user_fn(
self,
structure_template: str = STRUCTURE_TEMPLATE,
excluded_fields: Sequence[str] = (),
**kwargs,
) -> str:
prompts = []
for _, data in self._get_structure_attrs(**kwargs).items():
for field, data in self._get_structure_attrs(**kwargs).items():
if excluded_fields and field in excluded_fields:
continue

prompts.append(
"{field} {prompt}".format(
field=data["bold_value"], prompt=data["prompt"]
Expand All @@ -1085,6 +1095,7 @@ def _formatting_structure_model_fn(
self,
structure_data: dict,
bullet_style: str = None,
excluded_fields: Sequence[str] = (),
*args,
**kwargs,
) -> str:
Expand All @@ -1097,6 +1108,9 @@ def _formatting_structure_model_fn(
if field not in kwargs:
continue

if excluded_fields and field in excluded_fields:
continue

value = kwargs[field]
if not value:
continue
Expand Down Expand Up @@ -1129,6 +1143,11 @@ def _get_structure_attrs(self, **kwargs):
}
return mapping

def _get_origin_data(self, **kwargs) -> dict:
if not kwargs.get("is_remove_data", True):
return {k: v for k, v in kwargs.items() if hasattr(self, k)}
return {}


gemma_template = Template()
vietnamese_gemma_template = Template(
Expand Down

0 comments on commit 6200dbc

Please sign in to comment.