From aea7b46e1bb92dcaff5b7033b383c2e24eda220f Mon Sep 17 00:00:00 2001 From: abrar Date: Wed, 18 Dec 2024 01:45:18 -0700 Subject: [PATCH] Add random guardrail on top of sup_primary --- src/public/viz-guardrails/LineChart.tsx | 69 +++++++++++++++---------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/src/public/viz-guardrails/LineChart.tsx b/src/public/viz-guardrails/LineChart.tsx index da5622145..f4c8437ec 100644 --- a/src/public/viz-guardrails/LineChart.tsx +++ b/src/public/viz-guardrails/LineChart.tsx @@ -6,7 +6,7 @@ /* eslint-disable camelcase */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { useMemo, useState } from 'react'; +import { useMemo, useState, useEffect } from 'react'; import * as d3 from 'd3'; import { Center, Text } from '@mantine/core'; import { XAxis } from './XAxis'; @@ -87,31 +87,42 @@ export function LineChart({ const height = 400 - margin.top - margin.bottom; /// ////////// Setting scales - const { - yMin, yMax, - } = useMemo(() => { - let relevant_selection: string[] = []; - switch (guardrail) { - case 'super_data': - relevant_selection = selection?.concat(controlsSelection) as string[]; - break; - default: - relevant_selection = selection as string[]; - break; + const allCountries = useMemo(() => Array.from(new Set(data.map((val) => val[parameters.cat_var]))), [data, parameters.cat_var]); + + const [randomCountries, setRandomCountries] = useState([]); + const [randomSelectedFlag, setRandomSelectedFlag] = useState(false); + + useEffect(() => { + if (guardrail === 'super_data' && !randomSelectedFlag) { + const unselectedCountries = allCountries.filter( + (country) => !(selection || []).includes(country), + ); + + setRandomCountries(d3.shuffle(unselectedCountries).slice(0, 2)); + setRandomSelectedFlag(true); } - const yData: number[] = data.filter((val) => relevant_selection.includes(val[parameters.cat_var])).map((d) => +d[parameters.y_var]).filter((val) => val !== null) as number[]; - const [yMinSel, yMaxSel] = (dataname === 'clean_stocks' ? (d3.extent(yData) as [number, number]) : ([0, d3.extent(yData)[1]] as [number, number])); - const [lowerq, upperq] = [d3.min(avgData.map((val) => val.lowerq)) as number, d3.max(avgData.map((val) => val.upperq)) as number]; + if (guardrail !== 'super_data') { + setRandomCountries([]); + setRandomSelectedFlag(false); + } + }, [guardrail, selection, allCountries]); - const yMin = (guardrail === 'super_summ' ? d3.min([yMinSel, lowerq]) : yMinSel) as number; - const yMax = (guardrail === 'super_summ' ? d3.max([yMaxSel, upperq]) : yMaxSel) as number; + const { + yMin, yMax, + } = useMemo(() => { + const relevantCountries = selection?.concat(randomCountries) || []; + const yData = data + .filter((val) => relevantCountries.includes(val[parameters.cat_var])) + .map((d) => +d[parameters.y_var]) + .filter((val) => val !== null) as number[]; + const [yMinSel, yMaxSel] = d3.extent(yData) as [number, number]; return { - yMin, - yMax, + yMin: yMinSel || 0, + yMax: yMaxSel || 1, }; - }, [data, selection, guardrail, avgData, controlsSelection, parameters, dataname]); + }, [data, selection, randomCountries, parameters]); const xScale = useMemo(() => { if (range) { @@ -151,16 +162,18 @@ export function LineChart({ return null; } - const lineGenerator = d3.line(); - lineGenerator.x((d: any) => xScale(d3.timeParse('%Y-%m-%d')(d[parameters.x_var]) as Date)); - lineGenerator.y((d: any) => yScale(d[parameters.y_var])); - lineGenerator.curve(d3.curveBasis); - const paths = controlsSelection?.map((x) => ({ - country: x as string, - path: lineGenerator(data.filter((val) => (val[parameters.cat_var] === x))) as string, + const lineGenerator = d3.line() + .x((d: any) => xScale(d3.timeParse('%Y-%m-%d')(d[parameters.x_var]) as Date)) + .y((d: any) => yScale(d[parameters.y_var])) + .curve(d3.curveBasis); + + const paths = randomCountries.map((country) => ({ + country, + path: lineGenerator(data.filter((val) => val[parameters.cat_var] === country)) as string, })); + return paths; - }, [data, xScale, yScale, guardrail, controlsSelection, parameters, dataname]); + }, [data, xScale, yScale, randomCountries, parameters]); const superimposeSummary = useMemo(() => { if (guardrail !== 'super_summ') {