Source code for graphframes.examples.belief_propagation
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import math
# Import subpackage examples here explicitly so that
# this module can be run directly with spark-submit.
import graphframes.examples
from graphframes import GraphFrame
from graphframes.lib import AggregateMessages as AM
from pyspark.sql import SparkSession, functions as sqlfunctions, types
__all__ = ['BeliefPropagation']
[docs]class BeliefPropagation(object):
"""Example code for Belief Propagation (BP)
This provides a template for building customized BP algorithms for different types of graphical
models.
This example:
* Ising model on a grid
* Parallel Belief Propagation using colored fields
Ising models are probabilistic graphical models over binary variables
(see :meth:`Graphs.gridIsingModel()`).
Belief Propagation (BP) provides marginal probabilities of the values of the variables
x\ :sub:`i` i.e., P(x\ :sub:`i`) for each i. This allows a user to understand likely values of
variables. See `Wikipedia <https://en.wikipedia.org/wiki/Belief_propagation>`__ for more
information on BP.
We use a batch synchronous BP algorithm, where batches of vertices are updated synchronously.
We follow the mean field update algorithm in Slide 13 of the
`talk slides <http://www.eecs.berkeley.edu/~wainwrig/Talks/A_GraphModel_Tutorial>`__ from:
Wainwright. "Graphical models, message-passing algorithms, and convex optimization."
The batches are chosen according to a coloring. For background on graph colorings for
inference, see for example: Gonzalez et al. "Parallel Gibbs Sampling: From Colored Fields to
Thin Junction Trees." AISTATS, 2011.
The BP algorithm works by:
* Coloring the graph by assigning a color to each vertex such that no neighboring vertices
share the same color.
* In each step of BP, update all vertices of a single color. Alternate colors.
"""
[docs] @classmethod
def runBPwithGraphFrames(cls, g, numIter):
"""Run Belief Propagation using GraphFrame.
This implementation of BP shows how to use GraphFrame's aggregateMessages method.
"""
# choose colors for vertices for BP scheduling
colorG = cls._colorGraph(g)
numColors = colorG.vertices.select('color').distinct().count()
# TODO: handle vertices without any edges
# initialize vertex beliefs at 0.0
gx = GraphFrame(colorG.vertices.withColumn('belief', sqlfunctions.lit(0.0)), colorG.edges)
# run BP for numIter iterations
for iter_ in range(numIter):
# for each color, have that color receive messages from neighbors
for color in range(numColors):
# Send messages to vertices of the current color.
# We may send to source or destination since edges are treated as undirected.
msgForSrc = sqlfunctions.when(
AM.src['color'] == color,
AM.edge['b'] * AM.dst['belief'])
msgForDst = sqlfunctions.when(
AM.dst['color'] == color,
AM.edge['b'] * AM.src['belief'])
# numerically stable sigmoid
logistic = sqlfunctions.udf(cls._sigmoid, returnType=types.DoubleType())
aggregates = gx.aggregateMessages(
sqlfunctions.sum(AM.msg).alias("aggMess"),
sendToSrc=msgForSrc,
sendToDst=msgForDst)
v = gx.vertices
# receive messages and update beliefs for vertices of the current color
newBeliefCol = sqlfunctions.when(
(v['color'] == color) & (aggregates['aggMess'].isNotNull()),
logistic(aggregates['aggMess'] + v['a'])
).otherwise(v['belief']) # keep old beliefs for other colors
newVertices = (v
.join(aggregates, on=(v['id'] == aggregates['id']), how='left_outer')
.drop(aggregates['id']) # drop duplicate ID column (from outer join)
.withColumn('newBelief', newBeliefCol) # compute new beliefs
.drop('aggMess') # drop messages
.drop('belief') # drop old beliefs
.withColumnRenamed('newBelief', 'belief')
)
# cache new vertices using workaround for SPARK-1334
cachedNewVertices = AM.getCachedDataFrame(newVertices)
gx = GraphFrame(cachedNewVertices, gx.edges)
# Drop the "color" column from vertices
return GraphFrame(gx.vertices.drop('color'), gx.edges)
@staticmethod
def _colorGraph(g):
"""Given a GraphFrame, choose colors for each vertex.
No neighboring vertices will share the same color. The number of colors is minimized.
This is written specifically for grid graphs. For non-grid graphs, it should be generalized,
such as by using a greedy coloring scheme.
:param g: Grid graph generated by :meth:`Graphs.gridIsingModel()`
:return: Same graph, but with a new vertex column "color" of type Int (0 or 1)
"""
colorUDF = sqlfunctions.udf(lambda i, j: (i + j) % 2, returnType=types.IntegerType())
v = g.vertices.withColumn('color', colorUDF(sqlfunctions.col('i'), sqlfunctions.col('j')))
return GraphFrame(v, g.edges)
@staticmethod
def _sigmoid(x):
"""Numerically stable sigmoid function 1 / (1 + exp(-x))"""
if not x:
return None
if x >= 0:
z = math.exp(-x)
return 1 / (1 + z)
else:
z = math.exp(x)
return z / (1 + z)
def main():
"""Run the belief propagation algorithm for an example problem."""
# setup spark session
spark = SparkSession.builder.appName("BeliefPropagation example").getOrCreate()
# create graphical model g of size 3 x 3
g = graphframes.examples.Graphs(spark).gridIsingModel(3)
print("Original Ising model:")
g.vertices.show()
g.edges.show()
# run BP for 5 iterations
numIter = 5
results = BeliefPropagation.runBPwithGraphFrames(g, numIter)
# display beliefs
beliefs = results.vertices.select('id', 'belief')
print("Done with BP. Final beliefs after {} iterations:".format(numIter))
beliefs.show()
spark.stop()
if __name__ == '__main__':
main()