A simple map viewer

search.rs 7.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. use coord::LatLonDeg;
  2. use osmpbf::{Blob, BlobDecode, BlobReader, PrimitiveBlock};
  3. use regex::Regex;
  4. use scoped_threadpool::Pool;
  5. use std::collections::hash_set::HashSet;
  6. use std::path::{Path, PathBuf};
  7. use std::sync::mpsc::sync_channel;
  8. use std::thread;
  9. #[derive(Debug, Eq, PartialEq)]
  10. pub enum ControlFlow {
  11. Continue,
  12. Break,
  13. }
  14. impl<T, E> From<Result<T, E>> for ControlFlow
  15. {
  16. fn from(result: Result<T, E>) -> Self {
  17. match result {
  18. Ok(_) => ControlFlow::Continue,
  19. Err(_) => ControlFlow::Break,
  20. }
  21. }
  22. }
  23. enum WorkerMessage {
  24. PleaseStop,
  25. DoBlob(Box<Blob>),
  26. }
  27. pub fn par_search<P, F, G>(
  28. pbf_path: P,
  29. search_pattern: &str,
  30. found_func: F,
  31. finished_func: G,
  32. ) -> Result<thread::JoinHandle<()>, String>
  33. where P: AsRef<Path>,
  34. F: Fn(Vec<LatLonDeg>) -> ControlFlow + Send + 'static,
  35. G: Fn(Result<(), String>) + Send + 'static,
  36. {
  37. let pbf_path = PathBuf::from(pbf_path.as_ref());
  38. let search_pattern = search_pattern.to_string();
  39. let handle = thread::spawn(move|| {
  40. let res = par_search_blocking(pbf_path, &search_pattern, found_func);
  41. finished_func(res);
  42. });
  43. Ok(handle)
  44. }
  45. pub fn par_search_blocking<P, F>(
  46. pbf_path: P,
  47. search_pattern: &str,
  48. found_func: F,
  49. ) -> Result<(), String>
  50. where P: AsRef<Path>,
  51. F: Fn(Vec<LatLonDeg>) -> ControlFlow + Send + 'static,
  52. {
  53. let re = Regex::new(search_pattern)
  54. .map_err(|e| format!("{}", e))?;
  55. let re = &re;
  56. let first_pass = move |block: &PrimitiveBlock, _: &()| {
  57. let mut matches = vec![];
  58. let mut way_node_ids = vec![];
  59. for node in block.groups().flat_map(|g| g.nodes()) {
  60. for (_key, val) in node.tags() {
  61. if re.is_match(val) {
  62. let pos = LatLonDeg::new(node.lat(), node.lon());
  63. matches.push(pos);
  64. break;
  65. }
  66. }
  67. }
  68. for node in block.groups().flat_map(|g| g.dense_nodes()) {
  69. for (_key, val) in node.tags() {
  70. if re.is_match(val) {
  71. let pos = LatLonDeg::new(node.lat(), node.lon());
  72. matches.push(pos);
  73. break;
  74. }
  75. }
  76. }
  77. for way in block.groups().flat_map(|g| g.ways()) {
  78. for (_key, val) in way.tags() {
  79. if re.is_match(val) && !way.refs_slice().is_empty() {
  80. //TODO take middle node, not first one
  81. way_node_ids.push(way.refs_slice()[0]);
  82. break;
  83. }
  84. }
  85. }
  86. (matches, way_node_ids)
  87. };
  88. let mut way_node_ids: HashSet<i64> = HashSet::new();
  89. par_iter_blobs(
  90. &pbf_path,
  91. || {},
  92. first_pass,
  93. |(matches, node_ids)| {
  94. way_node_ids.extend(&node_ids);
  95. found_func(matches)
  96. },
  97. )?;
  98. let way_node_ids = &way_node_ids;
  99. let second_pass = move |block: &PrimitiveBlock, _: &()| {
  100. let mut matches = vec![];
  101. for node in block.groups().flat_map(|g| g.nodes()) {
  102. if way_node_ids.contains(&node.id()) {
  103. let pos = LatLonDeg::new(node.lat(), node.lon());
  104. matches.push(pos);
  105. }
  106. }
  107. for node in block.groups().flat_map(|g| g.dense_nodes()) {
  108. if way_node_ids.contains(&node.id) {
  109. let pos = LatLonDeg::new(node.lat(), node.lon());
  110. matches.push(pos);
  111. }
  112. }
  113. matches
  114. };
  115. par_iter_blobs(
  116. &pbf_path,
  117. || {},
  118. second_pass,
  119. found_func,
  120. )
  121. }
  122. fn par_iter_blobs<P, D, R, IF, CF, RF>(
  123. pbf_path: P,
  124. init_func: IF,
  125. compute_func: CF,
  126. mut result_func: RF,
  127. ) -> Result<(), String>
  128. where P: AsRef<Path>,
  129. IF: Fn() -> D,
  130. CF: Fn(&PrimitiveBlock, &D) -> R + Send + Sync,
  131. RF: FnMut(R) -> ControlFlow,
  132. R: Send,
  133. D: Send,
  134. {
  135. let num_threads = ::num_cpus::get();
  136. let mut pool = Pool::new(num_threads as u32);
  137. pool.scoped(|scope| {
  138. let mut reader = BlobReader::from_path(&pbf_path)
  139. .map_err(|e| format!("{}", e))?;
  140. let mut chans = Vec::with_capacity(num_threads);
  141. let (result_tx, result_rx) = sync_channel::<(usize, Result<Option<R>, String>)>(0);
  142. for thread_id in 0..num_threads {
  143. let thread_data = init_func();
  144. let result_tx = result_tx.clone();
  145. let (request_tx, request_rx) = sync_channel::<WorkerMessage>(0);
  146. chans.push(request_tx);
  147. let compute = &compute_func;
  148. scope.execute(move || {
  149. for request in request_rx.iter() {
  150. match request {
  151. WorkerMessage::PleaseStop => return,
  152. WorkerMessage::DoBlob(blob) => {
  153. match blob.decode() {
  154. Ok(BlobDecode::OsmData(block)) => {
  155. let result = compute(&block, &thread_data);
  156. if result_tx.send((thread_id, Ok(Some(result)))).is_err() {
  157. return;
  158. }
  159. },
  160. //TODO also include other blob types in compute function
  161. Ok(_) => {
  162. if result_tx.send((thread_id, Ok(None))).is_err() {
  163. return;
  164. }
  165. },
  166. Err(err) => {
  167. let _ = result_tx.send((thread_id, Err(format!("{}", err))));
  168. return;
  169. },
  170. }
  171. }
  172. };
  173. }
  174. });
  175. }
  176. let mut stopped_threads = 0;
  177. // send initial message to each worker thread
  178. for channel in &chans {
  179. match reader.next() {
  180. Some(Ok(blob)) => {
  181. channel.send(WorkerMessage::DoBlob(Box::new(blob)))
  182. .map_err(|e| format!("{}", e))?;
  183. },
  184. Some(Err(err)) => {
  185. return Err(format!("{}", err));
  186. },
  187. None => {
  188. channel.send(WorkerMessage::PleaseStop)
  189. .map_err(|e| format!("{}", e))?;
  190. stopped_threads += 1;
  191. },
  192. }
  193. }
  194. if stopped_threads == num_threads {
  195. return Ok(());
  196. }
  197. for (thread_id, matches) in result_rx.iter() {
  198. match matches {
  199. Err(err) => return Err(err),
  200. Ok(Some(matches)) => {
  201. if result_func(matches) == ControlFlow::Break {
  202. break;
  203. }
  204. },
  205. _ => {},
  206. }
  207. match reader.next() {
  208. Some(Ok(blob)) => {
  209. chans[thread_id].send(WorkerMessage::DoBlob(Box::new(blob)))
  210. .map_err(|e| format!("{}", e))?;
  211. },
  212. Some(Err(err)) => {
  213. return Err(format!("{}", err));
  214. },
  215. None => {
  216. chans[thread_id].send(WorkerMessage::PleaseStop)
  217. .map_err(|e| format!("{}", e))?;
  218. stopped_threads += 1;
  219. if stopped_threads == num_threads {
  220. break;
  221. }
  222. }
  223. }
  224. }
  225. Ok(())
  226. })
  227. }