Source code for spacebench.eval

from collections import defaultdict

import numpy as np

from spacebench.env import SpaceEnv, SpaceDataset


[docs]class DatasetEvaluator: """ Class for evaluating the performance of a causal inference method in a specific SpaceDataset. """ def __init__(self, dataset: SpaceDataset) -> None: self.dataset = dataset self.buffer = defaultdict(list) def eval( self, ate: np.ndarray | None = None, att: np.ndarray | None = None, counterfactuals: np.ndarray | None = None, erf: np.ndarray | None = None, ) -> dict[str, float]: errors = {} cf_true = self.dataset.counterfactuals t = self.dataset.treatment scale = np.std(self.dataset.outcome) if ate is not None: assert self.dataset.has_binary_treatment(), "ATE only valid in binary" ate_true = (cf_true[:, 1] - cf_true[:, 0]).mean() errors["ate_error"] = (ate - ate_true) / scale errors["ate_se"] = np.square(errors["ate_error"]) if att is not None: assert self.dataset.has_binary_treatment(), "ATT only valid in binary" assert np.min(t) == 0.0 and np.max(t) == 1.0 att_true = (cf_true[t == 1, 1] - cf_true[t == 1, 0]).mean() errors["att_error"] = (att - att_true) / scale errors["att_se"] = np.square(errors["att_error"]) if counterfactuals is not None: errors["pehe_curve"] = ((counterfactuals - cf_true) ** 2).mean(0) / scale**2 errors["pehe_av"] = errors["pehe_curve"].mean() if erf is not None: erf_true = self.dataset.erf() errors["erf_error"] = (erf - erf_true) / scale errors["erf_av"] = np.square(errors["erf_error"]).mean() return errors
[docs]class EnvEvaluator: """ Class for evaluating the performance of a causal inference method in a specific SpaceEnv. """ def __init__(self, env: SpaceEnv) -> None: self.env = env self.buffer = defaultdict(list)
[docs] def add( self, dataset: SpaceDataset, ate: np.ndarray | None = None, att: np.ndarray | None = None, counterfactuals: np.ndarray | None = None, erf: np.ndarray | None = None, ) -> None: """ Add a dataset to the buffer. """ evaluator = DatasetEvaluator(dataset) metrics = evaluator.eval( ate=ate, att=att, counterfactuals=counterfactuals, erf=erf, ) for k, v in metrics.items(): self.buffer[k].append(v)
[docs] def summarize(self) -> dict[str, float]: """ Evaluate the error in causal prediction. """ if len(self.buffer) == 0: raise ValueError("Use add first") res = dict() # ate bias and variance if "ate_error" in self.buffer: res["ate_bias"] = np.array(self.buffer["ate_error"]).mean() res["ate_variance"] = np.array(self.buffer["ate_error"]).var() # att bias and variance if "att_error" in self.buffer: res["att_bias"] = np.array(self.buffer["att"]).mean() res["att_variance"] = np.array(self.buffer["att"]).var() # pehe bias and variance if "pehe" in self.buffer: res["pehe"] = np.array(self.buffer["pehe"]).mean(0) # response curve bias and variance if "erf_error" in self.buffer: rc = np.array(self.buffer["erf_error"]) res["erf_bias"] = rc.mean(0) res["erf_variance"] = rc.var(0) return res