refreshing causal model visualization

parent bce9a225
......@@ -95,6 +95,50 @@ class CausalModel:
elif iy in self.__causalBN.parents(ix):
self.eraseCausalArc(iy, ix)
def toDot(self):
res = "digraph {"
# latent variables
if gum.config['causal', 'show_latent_names'] == 'True':
shap = "ellipse"
else:
shap = "point"
res += '''
node [fillcolor="{}",
fontcolor="{}",
style=filled,shape={}];
''' .format(gum.config['causal', "default_node_bgcolor"],
gum.config['causal', "default_node_fgcolor"], shap)
res += "\n"
for n in self.nodes():
if n in self.latentVariablesIds():
res += ' "' + self.names()[n] + '";' + "\n"
# not latent variables
res += '''
node [fillcolor="{}",
fontcolor="{}",
style=filled,shape="ellipse"];
''' .format(gum.config['causal', "default_node_bgcolor"],
gum.config['causal', "default_node_fgcolor"])
res += "\n"
for n in self.nodes():
if n not in self.latentVariablesIds():
res += ' "' + self.names()[n] + '";' + "\n"
for a, b in self.arcs():
res += " "+self.names()[a] + "->" + self.names()[b]
if a in self.latentVariablesIds() or b in self.latentVariablesIds():
res += ' [style="dashed"];'
else:
res += ' [color="black:black"];'
res+="\n"
res += "\n};"
return res
def causalBN(self) -> gum.BayesNet:
"""
:return: the causal Bayesian network
......
......@@ -38,47 +38,14 @@ def getCausalModel(cm: csl.CausalModel, size=None) -> str:
:param vals:
:return:
"""
if size is None:
size = gum.config['causal', "default_graph_size"]
graph = dot.Dot(graph_type='digraph')
for n in cm.nodes():
if n not in cm.latentVariablesIds():
bgcol = gum.config['causal', "default_node_bgcolor"]
fgcol = gum.config['causal', "default_node_fgcolor"]
shap = "ellipse"
else:
bgcol = gum.config['causal', "default_latent_bgcolor"]
fgcol = gum.config['causal', "default_latent_fgcolor"]
if gum.config['causal', 'show_latent_names'] == 'True':
shap = "ellipse"
else:
shap = "point"
graph.add_node(dot.Node(cm.names()[n],
shape=shap,
style="filled",
fillcolor=bgcol,
fontcolor=fgcol))
for a, b in cm.arcs():
if a in cm.latentVariablesIds():
graph.add_edge(dot.Edge(cm.names()[a], cm.names()[b],style="dashed"))
else:
graph.add_edge(dot.Edge(cm.names()[a], cm.names()[b], color='"black:black"'))
graph.set_size(size)
return IPython.display.SVG(graph.create_svg()).data
return gnb.getDot(cm.toDot())
def showCausalModel(cm: csl.CausalModel, size: str = "4"):
"""
Shows a graphviz svg representation of the causal DAG ``d``
"""
html = getCausalModel(cm, size)
IPython.display.display(IPython.display.HTML(
"<div align='center'>" + html + "</div>"))
gnb.showDot(cm.toDot())
def getCausalImpact(model: csl.CausalModel, on: Union[str, NameSet], doing: Union[str, NameSet],
......@@ -118,5 +85,6 @@ def showCausalImpact(model: csl.CausalModel, on: Union[str, NameSet], doing: Uni
IPython.display.display(IPython.display.HTML(html))
csl.CausalModel._repr_html_ = lambda self: getCausalModel(self, size="10")
csl.CausalModel._repr_html_ = lambda self: gnb.getDot(
self.toDot(), size=gum.config['causal', 'default_graph_size'])
csl.CausalFormula._repr_html_ = lambda self: f"$${self.toLatex()}$$"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment