Alternative python multiprocessing usage patterns avoiding global state propagation?

This (extremely simplified example) works fine (Python 2.6.6, Debian Squeeze):

from multiprocessing import Pool import numpy as np src=None def process(row): return np.sum(src[row]) def main(): global src src=np.ones((100,100)) pool=Pool(processes=16) rows=pool.map(process,range(100)) print rows if __name__ == "__main__": main() 

however, after many years of learning the global state of bad !!! all my instincts tell me that I really really would like to write something closer:

 from multiprocessing import Pool import numpy as np def main(): src=np.ones((100,100)) def process(row): return np.sum(src[row]) pool=Pool(processes=16) rows=pool.map(process,range(100)) print rows if __name__ == "__main__": main() 

but of course it doesn't work (freezes, unable to pickle something).

The example here is trivial, but by the time you add a few "processes", and each of them depends on several additional inputs ... well, all this is a bit like something written in Basic 30 years ago. Trying to use classes to at least aggregate state with related functions seems like an obvious solution, but doesn't seem so simple in practice.

Is there a recommended template or style for using multiprocessing.pool that will avoid spreading global state to support every function I want for a parallel map?

How do experienced "multiprocessor pros" handle this?

Refresh . Please note: I'm really interested in handling much larger arrays, so the variations on the above that pickle src for each call / iteration are not as good as the ones that fork it in the pool workflows.

+7
source share
1 answer

You can always pass a called object like this, then the object can contain a general state:

 from multiprocessing import Pool import numpy as np class RowProcessor(object): def __init__(self, src): self.__src = src def __call__(self, row): return np.sum(self.__src[row]) def main(): src=np.ones((100,100)) p = RowProcessor(src) pool=Pool(processes=16) rows = pool.map(p, range(100)) print rows if __name__ == "__main__": main() 
+5
source

All Articles