Skip to content

Commit

Permalink
ruff: format with 0.9.0 (#1235)
Browse files Browse the repository at this point in the history
This just landed on PyPI; some - but not all - of these changes are
accepted by the older 0.8.0 too.
  • Loading branch information
olupton authored Jan 10, 2025
1 parent 5a74526 commit 9dd32f5
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 46 deletions.
10 changes: 5 additions & 5 deletions .github/container/jax-nccl-test
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ if __name__ == "__main__":
)
args = parser.parse_args()

assert (
args.process_id is None or args.distributed
), "--process-id is only relevant with --distributed"
assert args.process_id is None or args.distributed, (
"--process-id is only relevant with --distributed"
)
if args.distributed:
null_args = {
args.coordinator_address is None,
Expand Down Expand Up @@ -108,7 +108,7 @@ if __name__ == "__main__":
f"Rank {args.process_id} has local rank {local_process_id} and "
f"devices {local_device_ids} from a total of {visible_devices} "
f"visible on this node, {args.process_count} processes and "
f"{args.process_count*args.gpus_per_process} total devices.",
f"{args.process_count * args.gpus_per_process} total devices.",
flush=True,
)
jax.distributed.initialize(
Expand Down Expand Up @@ -209,7 +209,7 @@ if __name__ == "__main__":
if host_timer:
result.block_until_ready()
if jax.process_index() == 0:
print(f"First {op} duration {time.time()-start:.2f}s")
print(f"First {op} duration {time.time() - start:.2f}s")
return result

def device_put_local(x: jax.Array):
Expand Down
10 changes: 5 additions & 5 deletions .github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,11 @@
"# Print out the largest entries adding up to at least this fraction of the total\n",
"threshold = 0.97\n",
"compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time\")\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-9:.2f}s compilation time\")\n",
"for row in compile_summary[\n",
" compile_summary[\"FracNonChild\"].cumsum() <= threshold\n",
"].itertuples():\n",
" print(f\"{row.FracNonChild:6.2%} {row.DurNonChildMs*1e-3:.2f}s {row.Index}\")"
" print(f\"{row.FracNonChild:6.2%} {row.DurNonChildMs * 1e-3:.2f}s {row.Index}\")"
]
},
{
Expand Down Expand Up @@ -585,9 +585,9 @@
"detailed_mask = (compute_duration_rel_stds > var_threshold) & (\n",
" compute_duration_means > mean_threshold\n",
")\n",
"assert (\n",
" detailed_mask.sum() <= detailed_limit\n",
"), f\"Aimed for {detailed_limit} and got {detailed_mask.sum()}\"\n",
"assert detailed_mask.sum() <= detailed_limit, (\n",
" f\"Aimed for {detailed_limit} and got {detailed_mask.sum()}\"\n",
")\n",
"\n",
"fig, axs = plt.subplots(\n",
" ncols=2, width_ratios=[1, 2], figsize=[15, 5], tight_layout=True\n",
Expand Down
43 changes: 21 additions & 22 deletions .github/container/nsys_jax/nsys_jax/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def align_profiler_data_timestamps(
# Error if the communication frame doesn't exist at all, but not if it is empty.
# Calling this on a profile that does not contain any communication should
# gracefully yield empty results.
assert (
frames.communication is not None
), "align_profiler_data_timestamps requires a communication frame"
assert frames.communication is not None, (
"align_profiler_data_timestamps requires a communication frame"
)
if not len(frames.communication):
# Nothing to be done, return an empty result
return frames, {}
Expand All @@ -43,9 +43,9 @@ def align_profiler_data_timestamps(
f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1"
)
return frames, {}
assert (
num_profiled_devices == max_collective_size
), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
assert num_profiled_devices == max_collective_size, (
f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
)
# Find the collectives that will be used
align_df = comm_df[comm_df["CollectiveSize"] == max_collective_size]
# Calculate the collectives' end times
Expand Down Expand Up @@ -190,19 +190,18 @@ def _get_message_size(
) -> tuple[int, str, int, float, float]:
_, inst = module_proto.find_instruction(instruction)
comm_inst = inst.communication_proto()
assert (
comm_inst.opcode
in {
"all-gather-start",
"all-reduce-start",
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"dynamic-slice",
"dynamic-update-slice",
"reduce-scatter",
}
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
assert comm_inst.opcode in {
"all-gather-start",
"all-reduce-start",
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"dynamic-slice",
"dynamic-update-slice",
"reduce-scatter",
}, (
f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
)

def _byte_size(inst) -> int:
size_bits = math.prod(
Expand Down Expand Up @@ -256,9 +255,9 @@ def _byte_size(inst) -> int:
collective_size = iota_group_list.num_devices_per_group
else:
collective_sizes = set(len(group.replica_ids) for group in replica_groups)
assert (
len(collective_sizes) == 1
), f"Heterogeneous collective {comm_inst} could not be interpreted"
assert len(collective_sizes) == 1, (
f"Heterogeneous collective {comm_inst} could not be interpreted"
)
collective_size = next(iter(collective_sizes))
total_msg_size = 0
for operand_id in comm_inst.operand_ids:
Expand Down
4 changes: 3 additions & 1 deletion .github/container/nsys_jax/nsys_jax/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def _calculate_overlap(thunk_df: pd.DataFrame) -> pd.DataFrame:
serial_mask = (
compute_df["ProjStartMs"].array[1:] >= compute_df["ProjEndMs"].array[:-1]
)
assert serial_mask.all(), f"Only {serial_mask.sum()}/{len(serial_mask)} compute kernel pairs failed to overlap on device {device} and program #{program_id}"
assert serial_mask.all(), (
f"Only {serial_mask.sum()}/{len(serial_mask)} compute kernel pairs failed to overlap on device {device} and program #{program_id}"
)
# Update the projected duration of each communication kernel to only
# include the non-overlapped time
for comm_thunk in comm_df.itertuples():
Expand Down
6 changes: 3 additions & 3 deletions .github/container/nsys_jax/nsys_jax/protobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def _visit_computation(computation_id):
if called_inst.opcode in comm_opcodes or _is_offloading_instruction(
called_inst
):
assert (
self._comm_proto is None
), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}"
assert self._comm_proto is None, (
f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}"
)
self._comm_proto = called_inst

for called_id in self._proto.called_computation_ids:
Expand Down
12 changes: 6 additions & 6 deletions .github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def copy_proto_files_to_tmp(
if not osp.isdir(dst_dir):
os.makedirs(dst_dir)
shutil.copy(osp.join(root, proto), osp.join(proto_dir, proto))
print(f"{archive_name}: gathered .proto files in {time.time()-start:.2f}s")
print(f"{archive_name}: gathered .proto files in {time.time() - start:.2f}s")
return proto_dir, proto_files

def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue):
Expand Down Expand Up @@ -369,7 +369,7 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue):
if osp.isdir(full_path) or not osp.exists(full_path):
continue
output_queue.put((ofile, full_path, COMPRESS_NONE))
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
print(f"{archive_name}: post-processing finished in {time.time() - start:.2f}s")

def compress_and_archive(prefix, file, output_queue):
"""
Expand Down Expand Up @@ -403,7 +403,7 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue):
)
for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir):
compress_and_archive(tmp_dir, ofile, output_queue)
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
print(f"{archive_name}: post-processing finished in {time.time() - start:.2f}s")

def save_device_stream_thread_names(tmp_dir, report, output_queue):
"""
Expand Down Expand Up @@ -501,7 +501,7 @@ def table_columns(table_name):
else:
print("WARNING: NOT writing device metadata, no device activity profiled?")
print(
f"{archive_name}: extracted device/thread names in {time.time()-start:.2f}s"
f"{archive_name}: extracted device/thread names in {time.time() - start:.2f}s"
)

def find_pb_files_in_tmp(tmp_dir):
Expand Down Expand Up @@ -553,7 +553,7 @@ def gather_source_files(
continue
assert osp.isabs(src_file), f"{src_file} is not absolute"
output_queue.put(("sources" + src_file, src_file, COMPRESS_DEFLATE))
print(f"{archive_name}: gathered source code in {time.time()-start:.2f}s")
print(f"{archive_name}: gathered source code in {time.time() - start:.2f}s")

def execute_analysis_scripts(mirror_dir, analysis_scripts):
"""
Expand Down Expand Up @@ -631,7 +631,7 @@ def write_output_file(to_process, mirror_dir, analysis_scripts):
for path_in_archive, local_path in analysis_outputs:
archive.write(filename=local_path, arcname=path_in_archive)
os.chmod(archive_name, 0o644)
print(f"{archive_name}: wrote in {time.time()-start:.2f}s")
print(f"{archive_name}: wrote in {time.time() - start:.2f}s")
if exit_code != 0:
print("Exiting due to analysis script errors")
sys.exit(exit_code)
Expand Down
6 changes: 3 additions & 3 deletions .github/container/nsys_jax/nsys_jax/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def analysis_recipe_path(script):
)
if script_file.is_file():
return script_file
assert os.path.exists(
script
), f"{script} does not exist and is not the name of a built-in analysis script"
assert os.path.exists(script), (
f"{script} does not exist and is not the name of a built-in analysis script"
)
return contextlib.nullcontext(pathlib.Path(script))


Expand Down
2 changes: 1 addition & 1 deletion .github/triage/jax_toolbox_triage/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def main():
logger = get_logger(args.output_prefix)
logger.info(
"Verbose output, including stdout/err of triage commands, will be written to "
f'{(args.output_prefix / "debug.log").resolve()}'
f"{(args.output_prefix / 'debug.log').resolve()}"
)
container_url = functools.partial(container_url_base, container=args.container)
container_exists = functools.partial(
Expand Down

0 comments on commit 9dd32f5

Please sign in to comment.