Skip to content
Snippets Groups Projects
Commit e3c808b6 authored by Alexander Heidelbach's avatar Alexander Heidelbach
Browse files

Adjust fitparameter handling

parent 4fce12bd
Branches
No related tags found
No related merge requests found
import copy
from typing import Any, Dict, List, Optional, Union
import zfit
import zfit.core
......@@ -13,7 +14,9 @@ __all__ = ["ModelsReadType", "LimitType", "Model"]
ModelsReadType = Dict[str, Dict[str, Dict[str, Dict[str, Any]]]]
LimitType = Union[List[float], List[List[float]], list]
ModelsOutputType = Union[ModelsReadType, str, LimitType, CovarianceType, float]
ModelsOutputType = Union[
ModelsReadType, str, LimitType, CovarianceType, float, dict[str, float]
]
class Model(BaseModel):
......@@ -23,18 +26,14 @@ class Model(BaseModel):
modelstring: str,
models: ModelsReadType,
modelname: Optional[str] = None,
gof: float = 0.0,
) -> None:
super().__init__(obs)
self.name = modelname if modelname is not None else "Model"
self._modelstring = modelstring
self._models = self._prepare_models(models)
self._models = self._prepare_models(copy.deepcopy(models))
parser = ModelParser(modelstring)
self._node = parser.parse()
self._fraction = 0
self.gof = gof
def _prepare_models(self, models: ModelsReadType) -> Dict[str, Dict[str, Any]]:
tmpmodels = {} # type: Dict[str, Dict[str, Dict[str, Any]]]
......@@ -57,29 +56,53 @@ class Model(BaseModel):
except TypeError:
parameter = {parametername: values}
self._parameters.update(dictToFitparameter(**parameter))
composed_target = None
if isinstance(values, dict) and "composed" in values.keys():
if values["composed"]:
assert (
values["composed"]["parameter"] in self._parameters.keys()
), "Composed parameter not found in parameters"
composed_target = self._parameters[
values["composed"]["parameter"]
]
self._parameters.update(
dictToFitparameter(composed_target=composed_target, **parameter)
)
tmpmodels.update({modelname: modelparameters})
return tmpmodels
def _compute_model(self, node: Node) -> zfit.core.basepdf.BasePDF:
if node.token_type == TokenType.T_MODEL or node.token_type == TokenType.T_EXT:
if node.token_type == TokenType.T_EXT:
assert node.value is not None
assert node.ext_param is not None
parser = ModelParser(node.value)
model = self._compute_model(parser.parse())
ext_param = self.modelparameters[node.ext_param]
zfit_model = model.create_extended(ext_param)
return zfit_model
elif node.token_type == TokenType.T_MODEL:
assert node.value is not None
model = self._models[node.value]
pdf = model["pdf"]
parameter = {} # type: Dict[str, zfit.Parameter]
ext_param: Optional[zfit.Parameter] = None
for parametername, values in model["parameter"].items():
if not isinstance(values, list):
if not self.parameters[values["name"]].extended:
if (
not self.parameters[values["name"]].fraction
and not self.parameters[values["name"]].extended
):
parameter.update(
{parametername: self.modelparameters[values["name"]]}
)
else:
ext_param = self.modelparameters[values["name"]]
else:
if len(values) > 0:
parameter.update(
......@@ -104,8 +127,6 @@ class Model(BaseModel):
)
zfit_model = getattr(zfit.pdf, pdf)(obs=obs, **parameter)
if ext_param is not None:
zfit_model = zfit_model.create_extended(ext_param)
return zfit_model
......@@ -128,35 +149,23 @@ class Model(BaseModel):
)
if not left_result.is_extended:
fraction_infodict = {
f"fraction{self._fraction}": {
"name": f"fraction{self._fraction}",
"value": 0.1,
"lower": 0.0,
"upper": 1.0,
}
}
fraction_dict = dictToFitparameter(**fraction_infodict)
self._parameters.update(
{
f"fraction{self._fraction}": fraction_dict[
f"fraction{self._fraction}"
]
}
)
self._modelparameters.update(
{
f"fraction{self._fraction}": fraction_dict[
f"fraction{self._fraction}"
].to_zfitParameter
}
)
kwargs.update(
{"fracs": self._modelparameters[f"fraction{self._fraction}"]}
assert node.children[0].value is not None
model = self._models[node.children[0].value]
# find the parameter that steers the fraction of the sum
fraction_parameter_name = ""
for modelparameter in model["parameter"].values():
if "fraction" in modelparameter.keys():
if modelparameter["fraction"]:
fraction_parameter_name = modelparameter["name"]
break
assert fraction_parameter_name, "No fraction parameter found"
fraction_parameter = self.parameters.get(
fraction_parameter_name, None
)
assert fraction_parameter, "Fraction parameter not found"
self._fraction += 1
kwargs.update({"fracs": fraction_parameter.to_zfitParameter})
kwargs.update({"pdfs": [left_result, right_result]})
......@@ -191,13 +200,6 @@ class Model(BaseModel):
)
)
for model in models:
for modelparametername, modelparameter in self.modelparameters.items():
if modelparametername in model.modelparameters.keys():
model.update_modelparameter(
name=modelparametername, value=modelparameter.value().numpy()
)
return models
@property
......@@ -254,5 +256,7 @@ class Model(BaseModel):
output.update({"limits": limits})
output.update({"gof": self.gof})
output.update({"significance": self.significance})
output.update({"upperlimit": self.upperlimit})
return output
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment