Skip to content

Commit

Permalink
add cache to sklearn
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed Jan 23, 2025
1 parent 9c78791 commit c8dbded
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
17 changes: 10 additions & 7 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
IterableDataset,
IterableDatasetDict,
get_dataset_split_names,
load_dataset_builder,
)
from datasets import load_dataset as hf_load_dataset
from huggingface_hub import HfApi
Expand Down Expand Up @@ -168,7 +167,7 @@ def load_data(self) -> MultiStream:
self.__class__._loader_cache.max_size = settings.loader_cache_size
self.__class__._loader_cache[str(self)] = iterables
if isoftype(iterables, Dict[str, ReusableGenerator]):
return MultiStream.from_generators(iterables)
return MultiStream.from_generators(iterables, copying=True)
return MultiStream.from_iterables(iterables, copying=True)

def process(self) -> MultiStream:
Expand Down Expand Up @@ -476,11 +475,15 @@ def load_iterables(self):
}

def split_generator(self, split: str) -> Generator:
split_data = self.downloader(subset=split)
targets = [split_data["target_names"][t] for t in split_data["target"]]
df = pd.DataFrame([split_data["data"], targets]).T
df.columns = ["data", "target"]
dataset = df.to_dict("records")
dataset = self.__class__._loader_cache.get(str(self) + "_" + split, None)
if dataset is None:
split_data = self.downloader(subset=split)
targets = [split_data["target_names"][t] for t in split_data["target"]]
df = pd.DataFrame([split_data["data"], targets]).T
df.columns = ["data", "target"]
dataset = df.to_dict("records")
self.__class__._loader_cache.max_size = settings.loader_cache_size
self.__class__._loader_cache[str(self) + "_" + split] = dataset
yield from dataset


Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"filename": "src/unitxt/loaders.py",
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
"is_verified": false,
"line_number": 566,
"line_number": 572,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-01-22T20:27:31Z"
"generated_at": "2025-01-23T10:07:40Z"
}

0 comments on commit c8dbded

Please sign in to comment.