Преглед изворни кода

search: Extract function for parallel blob iteration

Johannes Hofmann пре 7 година
родитељ
комит
e115c2acd1
1 измењених фајлова са 79 додато и 38 уклоњено
  1. 79
    38
      src/search.rs

+ 79
- 38
src/search.rs Прегледај датотеку

@@ -1,7 +1,7 @@
1
-use scoped_threadpool::Pool;
2 1
 use coord::LatLon;
3
-use osmpbf::{Blob, BlobDecode, BlobReader};
2
+use osmpbf::{Blob, BlobDecode, BlobReader, PrimitiveBlock};
4 3
 use regex::Regex;
4
+use scoped_threadpool::Pool;
5 5
 use std::path::{Path, PathBuf};
6 6
 use std::sync::mpsc::sync_channel;
7 7
 use std::thread;
@@ -55,62 +55,99 @@ pub fn par_search_blocking<P, F>(
55 55
 ) -> Result<(), String>
56 56
 where P: AsRef<Path>,
57 57
       F: Fn(Vec<LatLon>) -> ControlFlow + Send + 'static,
58
+{
59
+    let re = Regex::new(search_pattern)
60
+        .map_err(|e| format!("{}", e))?;
61
+    let re = &re;
62
+
63
+    let search = move |block: &PrimitiveBlock, _: &()| {
64
+        let mut matches = vec![];
65
+
66
+        for node in block.groups().flat_map(|g| g.nodes()) {
67
+            for (_key, val) in node.tags() {
68
+                if re.is_match(val) {
69
+                    let pos = LatLon::new(node.lat(), node.lon());
70
+                    matches.push(pos);
71
+                    break;
72
+                }
73
+            }
74
+        }
75
+
76
+        for node in block.groups().flat_map(|g| g.dense_nodes()) {
77
+            for (_key, val) in node.tags() {
78
+                if re.is_match(val) {
79
+                    let pos = LatLon::new(node.lat(), node.lon());
80
+                    matches.push(pos);
81
+                    break;
82
+                }
83
+            }
84
+        }
85
+
86
+        matches
87
+    };
88
+
89
+    par_iter_blobs(
90
+        pbf_path,
91
+        || {},
92
+        search,
93
+        found_func,
94
+    )
95
+}
96
+
97
+fn par_iter_blobs<P, D, R, IF, CF, RF>(
98
+    pbf_path: P,
99
+    init_func: IF,
100
+    compute_func: CF,
101
+    result_func: RF,
102
+) -> Result<(), String>
103
+where P: AsRef<Path>,
104
+      IF: Fn() -> D,
105
+      CF: Fn(&PrimitiveBlock, &D) -> R + Send + Sync,
106
+      RF: Fn(R) -> ControlFlow + Send + 'static,
107
+      R: Send,
108
+      D: Send,
58 109
 {
59 110
     let num_threads = ::num_cpus::get();
60 111
     let mut pool = Pool::new(num_threads as u32);
61 112
 
62 113
     pool.scoped(|scope| {
63
-        let re = Regex::new(search_pattern)
64
-            .map_err(|e| format!("{}", e))?;
65 114
         let mut reader = BlobReader::from_path(&pbf_path)
66 115
             .map_err(|e| format!("{}", e))?;
67 116
 
68 117
         let mut chans = Vec::with_capacity(num_threads);
69
-        let (result_tx, result_rx) = sync_channel::<(usize, Result<Vec<LatLon>, String>)>(0);
118
+        let (result_tx, result_rx) = sync_channel::<(usize, Result<Option<R>, String>)>(0);
70 119
 
71 120
         for thread_id in 0..num_threads {
72
-            let re = re.clone();
121
+            let thread_data = init_func();
73 122
             let result_tx = result_tx.clone();
74 123
 
75 124
             let (request_tx, request_rx) = sync_channel::<WorkerMessage>(0);
76 125
             chans.push(request_tx);
77 126
 
127
+            let compute = &compute_func;
128
+
78 129
             scope.execute(move || {
79 130
                 for request in request_rx.iter() {
80 131
                     match request {
81 132
                         WorkerMessage::PleaseStop => return,
82 133
                         WorkerMessage::DoBlob(blob) => {
83
-                            let mut matches = vec![];
84
-                            let block = match blob.decode() {
85
-                                Ok(b) => b,
134
+                            match blob.decode() {
135
+                                Ok(BlobDecode::OsmData(block)) => {
136
+                                    let result = compute(&block, &thread_data);
137
+                                    if result_tx.send((thread_id, Ok(Some(result)))).is_err() {
138
+                                        return;
139
+                                    }
140
+                                },
141
+                                //TODO also include other blob types in compute function
142
+                                Ok(_) => {
143
+                                    if result_tx.send((thread_id, Ok(None))).is_err() {
144
+                                        return;
145
+                                    }
146
+                                },
86 147
                                 Err(err) => {
87 148
                                     let _ = result_tx.send((thread_id, Err(format!("{}", err))));
88 149
                                     return;
89
-                                }
90
-                            };
91
-                            if let BlobDecode::OsmData(block) = block {
92
-                                for node in block.groups().flat_map(|g| g.nodes()) {
93
-                                    for (_key, val) in node.tags() {
94
-                                        if re.is_match(val) {
95
-                                            let pos = LatLon::new(node.lat(), node.lon());
96
-                                            matches.push(pos);
97
-                                            break;
98
-                                        }
99
-                                    }
100
-                                }
101
-
102
-                                for node in block.groups().flat_map(|g| g.dense_nodes()) {
103
-                                    for (_key, val) in node.tags() {
104
-                                        if re.is_match(val) {
105
-                                            let pos = LatLon::new(node.lat(), node.lon());
106
-                                            matches.push(pos);
107
-                                            break;
108
-                                        }
109
-                                    }
110
-                                }
111
-                            }
112
-                            if result_tx.send((thread_id, Ok(matches))).is_err() {
113
-                                return;
150
+                                },
114 151
                             }
115 152
                         }
116 153
                     };
@@ -144,10 +181,14 @@ where P: AsRef<Path>,
144 181
         }
145 182
 
146 183
         for (thread_id, matches) in result_rx.iter() {
147
-            let matches = matches?;
148
-
149
-            if found_func(matches) == ControlFlow::Break {
150
-                break;
184
+            match matches {
185
+                Err(err) => return Err(err),
186
+                Ok(Some(matches)) => {
187
+                    if result_func(matches) == ControlFlow::Break {
188
+                        break;
189
+                    }
190
+                },
191
+                _ => {},
151 192
             }
152 193
 
153 194
             match reader.next() {