-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Compute Binary Model Solutions and Plot (#7)
- Loading branch information
Showing
11 changed files
with
620 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Combine all model solutions into a single DataFrame.""" | ||
|
||
from pathlib import Path | ||
from typing import Annotated | ||
|
||
import pandas as pd # type: ignore[import-untyped] | ||
import pytask | ||
from pytask import Product | ||
|
||
from thesis.config import BLD | ||
from thesis.pyvmte.task_solve_simple_model_sharp import ID_TO_KWARGS | ||
|
||
paths_to_res = [ID_TO_KWARGS[key].path_to_data for key in ID_TO_KWARGS] | ||
|
||
|
||
@pytask.mark.wip | ||
def task_combine_model_solutions_simple_model( | ||
path_to_combined: Annotated[Path, Product] = BLD | ||
/ "data" | ||
/ "solutions" | ||
/ "solutions_simple_model_combined.pkl", | ||
paths_to_res: list[Path] = paths_to_res, | ||
) -> pd.DataFrame: | ||
"""Combine model solutions into a single DataFrame.""" | ||
dfs_single = [pd.read_pickle(path) for path in paths_to_res] | ||
|
||
df_combined = pd.concat(dfs_single, ignore_index=True) | ||
|
||
df_combined["b_late"] = df_combined["y1_c"] - df_combined["y0_c"] | ||
|
||
df_combined.to_pickle(path_to_combined) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
"""Plot simple model by LATE for paper for different restrictions.""" | ||
|
||
from pathlib import Path | ||
from typing import Annotated, NamedTuple | ||
|
||
import pandas as pd # type: ignore[import-untyped] | ||
import plotly.graph_objects as go # type: ignore[import-untyped] | ||
import pytask | ||
from pytask import Product, task | ||
|
||
from thesis.config import BLD | ||
|
||
|
||
class _Arguments(NamedTuple): | ||
idestimands: str | ||
bfunc_type: str | ||
constraint: str | ||
path_to_plot: Annotated[Path, Product] | ||
|
||
|
||
_shape_constr_to_plot = ("decreasing", "decreasing") | ||
_mte_monotone_to_plot = "decreasing" | ||
_monotone_response_to_plot = "positive" | ||
|
||
_constr_vals = { | ||
"shape_constraints": "_".join(_shape_constr_to_plot), | ||
"mte_monotone": _mte_monotone_to_plot, | ||
"monotone_response": _monotone_response_to_plot, | ||
} | ||
|
||
_constr_subtitle = { | ||
"shape_constraints": ( | ||
f"{_constr_vals['shape_constraints'].split('_')[0].capitalize()}" | ||
" MTR Functions" | ||
), | ||
"mte_monotone": ( | ||
f"{_constr_vals['mte_monotone'].capitalize()} Marginal Treatment Effect" | ||
), | ||
"monotone_response": ( | ||
f"{_constr_vals['monotone_response'].capitalize()} Treatment Response" | ||
), | ||
None: "None", | ||
} | ||
|
||
bfunc_types_to_plot = ["constant", "bernstein"] | ||
idestimands_to_plot = ["late", "sharp"] | ||
constraints_to_plot = [None, "shape_constraints", "mte_monotone", "monotone_response"] | ||
|
||
|
||
ID_TO_KWARGS = { | ||
f"{idestimands}_{bfunc_type}_{constraint}": _Arguments( | ||
idestimands=idestimands, | ||
bfunc_type=bfunc_type, | ||
constraint=constraint, # type: ignore[arg-type] | ||
path_to_plot=BLD | ||
/ "figures" | ||
/ "solutions" | ||
/ f"simple_model_by_late_{idestimands}_{bfunc_type}_{constraint}.png", | ||
) | ||
for idestimands in idestimands_to_plot | ||
for bfunc_type in bfunc_types_to_plot | ||
for constraint in constraints_to_plot | ||
} | ||
|
||
for id_, kwargs in ID_TO_KWARGS.items(): | ||
|
||
@pytask.mark.wip | ||
@task(id=id_, kwargs=kwargs) # type: ignore[arg-type] | ||
def task_plot_simple_model_by_late( | ||
idestimands: str, | ||
bfunc_type: str, | ||
constraint: str, | ||
path_to_plot: Annotated[Path, Product], | ||
path_to_combined: Path = BLD | ||
/ "data" | ||
/ "solutions" | ||
/ "solutions_simple_model_combined.pkl", | ||
) -> None: | ||
"""Plot simple model by LATE for different restrictions.""" | ||
df_combined = pd.read_pickle(path_to_combined) | ||
|
||
fig = go.Figure() | ||
|
||
bfunc_type_to_color = {"constant": "blue", "bernstein": "red"} | ||
bound_to_dash = {"upper_bound": "solid", "lower_bound": "solid"} | ||
|
||
for bfunc_type in ["constant", "bernstein"]: | ||
df_plot = df_combined[df_combined["bfunc_type"] == bfunc_type] | ||
|
||
if constraint is not None: | ||
df_plot = df_plot[df_plot["constraint_type"] == constraint] | ||
df_plot = df_plot[df_plot["constraint_val"] == _constr_vals[constraint]] | ||
else: | ||
df_plot = df_plot[df_plot["constraint_type"] == "none"] | ||
|
||
df_plot = df_plot[df_plot["idestimands"] == idestimands] | ||
|
||
_k_bernstein = df_plot["k_bernstein"].unique() | ||
|
||
assert len(_k_bernstein) == 1 | ||
|
||
_legend_title_by_bfunc = { | ||
"constant": "Constant", | ||
"bernstein": ( | ||
f"Bernstein, Degree {int(_k_bernstein[0])}" | ||
if bfunc_type == "bernstein" | ||
else None | ||
), | ||
} | ||
|
||
for bound in ["upper_bound", "lower_bound"]: | ||
fig.add_trace( | ||
go.Scatter( | ||
x=df_plot["b_late"], | ||
y=df_plot[bound], | ||
mode="lines", | ||
name=f"{bound.split('_')[0].capitalize()} Bound", | ||
legendgroup=bfunc_type, | ||
legendgrouptitle={"text": _legend_title_by_bfunc[bfunc_type]}, | ||
line={ | ||
"color": bfunc_type_to_color[bfunc_type], | ||
"dash": bound_to_dash[bound], | ||
}, | ||
), | ||
) | ||
|
||
_subtitle = ( | ||
f" <br><sup> Identified Estimands: {idestimands.capitalize()} </sup>" | ||
f" <br><sup> Shape constraints: {_constr_subtitle[constraint]} </sup>" | ||
) | ||
|
||
fig.update_layout( | ||
title="Bounds on Target LATE(0.4, 0.8) for Binary-IV Model" + _subtitle, | ||
xaxis_title="Identified LATE", | ||
yaxis_title="Bounds", | ||
) | ||
|
||
# Add note with num_gridpoints | ||
_num_gridpoints = df_plot["num_gridpoints"].unique() | ||
assert len(_num_gridpoints) == 1 | ||
_num_gridpoints = _num_gridpoints[0] | ||
|
||
fig.add_annotation( | ||
text=f"Number of gridpoints: {_num_gridpoints}", | ||
showarrow=False, | ||
xref="paper", | ||
yref="paper", | ||
x=0.99, | ||
y=0.01, | ||
) | ||
|
||
fig.write_image(path_to_plot) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.