tinc_build/
lib.rs

1//! The code generator for [`tinc`](https://crates.io/crates/tinc).
2#![cfg_attr(feature = "docs", doc = "## Feature flags")]
3#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
4//! ## Usage
5//!
6//! In your `build.rs`:
7//!
8//! ```rust,no_run
9//! # #[allow(clippy::needless_doctest_main)]
10//! fn main() {
11//!     tinc_build::Config::prost()
12//!         .compile_protos(&["proto/test.proto"], &["proto"])
13//!         .unwrap();
14//! }
15//! ```
16//!
17//! Look at [`Config`] to see different options to configure the generator.
18//!
19//! ## License
20//!
21//! This project is licensed under the MIT or Apache-2.0 license.
22//! You can choose between one of them if you use this work.
23//!
24//! `SPDX-License-Identifier: MIT OR Apache-2.0`
25#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
26#![cfg_attr(docsrs, feature(doc_auto_cfg))]
27#![deny(missing_docs)]
28#![deny(unsafe_code)]
29#![deny(unreachable_pub)]
30#![cfg_attr(not(feature = "prost"), allow(unused_variables, dead_code))]
31
32use std::io::ErrorKind;
33use std::path::Path;
34
35use anyhow::Context;
36use extern_paths::ExternPaths;
37mod codegen;
38mod extern_paths;
39
40#[cfg(feature = "prost")]
41mod prost_explore;
42
43mod types;
44
45/// The mode to use for the generator, currently we only support `prost` codegen.
46#[derive(Debug, Clone, Copy)]
47pub enum Mode {
48    /// Use `prost` to generate the protobuf structures
49    #[cfg(feature = "prost")]
50    Prost,
51}
52
53impl quote::ToTokens for Mode {
54    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
55        match self {
56            #[cfg(feature = "prost")]
57            Mode::Prost => quote::quote!(prost).to_tokens(tokens),
58            #[cfg(not(feature = "prost"))]
59            _ => unreachable!(),
60        }
61    }
62}
63
64#[derive(Default, Debug)]
65struct PathConfigs {
66    btree_maps: Vec<String>,
67    bytes: Vec<String>,
68    boxed: Vec<String>,
69}
70
71/// A config for configuring how tinc builds / generates code.
72#[derive(Debug)]
73pub struct Config {
74    disable_tinc_include: bool,
75    root_module: bool,
76    mode: Mode,
77    paths: PathConfigs,
78    extern_paths: ExternPaths,
79}
80
81impl Config {
82    /// New config with prost mode.
83    #[cfg(feature = "prost")]
84    pub fn prost() -> Self {
85        Self::new(Mode::Prost)
86    }
87
88    /// Make a new config with a given mode.
89    pub fn new(mode: Mode) -> Self {
90        Self {
91            disable_tinc_include: false,
92            mode,
93            paths: PathConfigs::default(),
94            extern_paths: ExternPaths::new(mode),
95            root_module: true,
96        }
97    }
98
99    /// Disable tinc auto-include. By default tinc will add its own
100    /// annotations into the include path of protoc.
101    pub fn disable_tinc_include(&mut self) -> &mut Self {
102        self.disable_tinc_include = true;
103        self
104    }
105
106    /// Disable the root module generation
107    /// which allows for `tinc::include_protos!()` without
108    /// providing a package.
109    pub fn disable_root_module(&mut self) -> &mut Self {
110        self.root_module = false;
111        self
112    }
113
114    /// Specify a path to generate a `BTreeMap` instead of a `HashMap` for proto map.
115    pub fn btree_map(&mut self, path: impl std::fmt::Display) -> &mut Self {
116        self.paths.btree_maps.push(path.to_string());
117        self
118    }
119
120    /// Specify a path to generate `bytes::Bytes` instead of `Vec<u8>` for proto bytes.
121    pub fn bytes(&mut self, path: impl std::fmt::Display) -> &mut Self {
122        self.paths.bytes.push(path.to_string());
123        self
124    }
125
126    /// Specify a path to wrap around a `Box` instead of including it directly into the struct.
127    pub fn boxed(&mut self, path: impl std::fmt::Display) -> &mut Self {
128        self.paths.boxed.push(path.to_string());
129        self
130    }
131
132    /// Compile and generate all the protos with the includes.
133    pub fn compile_protos(&mut self, protos: &[impl AsRef<Path>], includes: &[impl AsRef<Path>]) -> anyhow::Result<()> {
134        match self.mode {
135            #[cfg(feature = "prost")]
136            Mode::Prost => self.compile_protos_prost(protos, includes),
137        }
138    }
139
140    #[cfg(feature = "prost")]
141    fn compile_protos_prost(&mut self, protos: &[impl AsRef<Path>], includes: &[impl AsRef<Path>]) -> anyhow::Result<()> {
142        use std::collections::BTreeMap;
143
144        use codegen::prost_sanatize::to_snake;
145        use codegen::utils::get_common_import_path;
146        use proc_macro2::Span;
147        use prost_reflect::DescriptorPool;
148        use quote::{ToTokens, quote};
149        use syn::parse_quote;
150        use types::{ProtoPath, ProtoTypeRegistry};
151
152        let out_dir_str = std::env::var("OUT_DIR").context("OUT_DIR must be set, typically set by a cargo build script")?;
153        let out_dir = std::path::PathBuf::from(&out_dir_str);
154        let ft_path = out_dir.join("tinc.fd.bin");
155
156        let mut config = prost_build::Config::new();
157        config.file_descriptor_set_path(&ft_path);
158
159        config.btree_map(self.paths.btree_maps.iter());
160        self.paths.boxed.iter().for_each(|path| {
161            config.boxed(path);
162        });
163        config.bytes(self.paths.bytes.iter());
164
165        let mut includes = includes.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
166
167        {
168            let tinc_out = out_dir.join("tinc");
169            std::fs::create_dir_all(&tinc_out).context("failed to create tinc directory")?;
170            std::fs::write(tinc_out.join("annotations.proto"), tinc_pb_prost::TINC_ANNOTATIONS)
171                .context("failed to write tinc_annotations.rs")?;
172            includes.push(Path::new(&out_dir_str));
173            config.protoc_arg(format!("--descriptor_set_in={}", tinc_pb_prost::TINC_ANNOTATIONS_PB_PATH));
174        }
175
176        let fds = config.load_fds(protos, &includes).context("failed to generate tonic fds")?;
177
178        let fds_bytes = std::fs::read(ft_path).context("failed to read tonic fds")?;
179
180        let pool = DescriptorPool::decode(&mut fds_bytes.as_slice()).context("failed to decode tonic fds")?;
181
182        let mut registry = ProtoTypeRegistry::new(self.mode, self.extern_paths.clone());
183
184        config.compile_well_known_types();
185        for (proto, rust) in self.extern_paths.paths() {
186            let proto = if proto.starts_with('.') {
187                proto.to_string()
188            } else {
189                format!(".{proto}")
190            };
191            config.extern_path(proto, rust.to_token_stream().to_string());
192        }
193
194        prost_explore::Extensions::new(&pool)
195            .process(&mut registry)
196            .context("failed to process extensions")?;
197
198        let mut packages = codegen::generate_modules(&registry)?;
199
200        packages.iter_mut().for_each(|(path, package)| {
201            if self.extern_paths.contains(path) {
202                return;
203            }
204
205            package.enum_configs().for_each(|(path, enum_config)| {
206                if self.extern_paths.contains(path) {
207                    return;
208                }
209
210                enum_config.attributes().for_each(|attribute| {
211                    config.enum_attribute(path, attribute.to_token_stream().to_string());
212                });
213                enum_config.variants().for_each(|variant| {
214                    let path = format!("{path}.{variant}");
215                    enum_config.variant_attributes(variant).for_each(|attribute| {
216                        config.field_attribute(&path, attribute.to_token_stream().to_string());
217                    });
218                });
219            });
220
221            package.message_configs().for_each(|(path, message_config)| {
222                if self.extern_paths.contains(path) {
223                    return;
224                }
225
226                message_config.attributes().for_each(|attribute| {
227                    config.message_attribute(path, attribute.to_token_stream().to_string());
228                });
229                message_config.fields().for_each(|field| {
230                    let path = format!("{path}.{field}");
231                    message_config.field_attributes(field).for_each(|attribute| {
232                        config.field_attribute(&path, attribute.to_token_stream().to_string());
233                    });
234                });
235                message_config.oneof_configs().for_each(|(field, oneof_config)| {
236                    let path = format!("{path}.{field}");
237                    oneof_config.attributes().for_each(|attribute| {
238                        // In prost oneofs (container) are treated as enums
239                        config.enum_attribute(&path, attribute.to_token_stream().to_string());
240                    });
241                    oneof_config.fields().for_each(|field| {
242                        let path = format!("{path}.{field}");
243                        oneof_config.field_attributes(field).for_each(|attribute| {
244                            config.field_attribute(&path, attribute.to_token_stream().to_string());
245                        });
246                    });
247                });
248            });
249
250            package.extra_items.extend(package.services.iter().flat_map(|service| {
251                let mut builder = tonic_build::CodeGenBuilder::new();
252
253                builder.emit_package(true).build_transport(true);
254
255                let make_service = |is_client: bool| {
256                    let mut builder = tonic_build::manual::Service::builder()
257                        .name(service.name())
258                        .package(&service.package);
259
260                    if !service.comments.is_empty() {
261                        builder = builder.comment(service.comments.to_string());
262                    }
263
264                    service
265                        .methods
266                        .iter()
267                        .fold(builder, |service_builder, (name, method)| {
268                            let codec_path =
269                                if let Some(Some(codec_path)) = (!is_client).then_some(method.codec_path.as_ref()) {
270                                    let path = get_common_import_path(&service.full_name, codec_path);
271                                    quote!(#path::<::tinc::reexports::tonic::codec::ProstCodec<_, _>>)
272                                } else {
273                                    quote!(::tinc::reexports::tonic::codec::ProstCodec)
274                                };
275
276                            let mut builder = tonic_build::manual::Method::builder()
277                                .input_type(
278                                    registry
279                                        .resolve_rust_path(&service.full_name, method.input.value_type().proto_path())
280                                        .unwrap()
281                                        .to_token_stream()
282                                        .to_string(),
283                                )
284                                .output_type(
285                                    registry
286                                        .resolve_rust_path(&service.full_name, method.output.value_type().proto_path())
287                                        .unwrap()
288                                        .to_token_stream()
289                                        .to_string(),
290                                )
291                                .codec_path(codec_path.to_string())
292                                .name(to_snake(name))
293                                .route_name(name);
294
295                            if method.input.is_stream() {
296                                builder = builder.client_streaming()
297                            }
298
299                            if method.output.is_stream() {
300                                builder = builder.server_streaming();
301                            }
302
303                            if !method.comments.is_empty() {
304                                builder = builder.comment(method.comments.to_string());
305                            }
306
307                            service_builder.method(builder.build())
308                        })
309                        .build()
310                };
311
312                let mut client: syn::ItemMod = syn::parse2(builder.generate_client(&make_service(true), "")).unwrap();
313                client.content.as_mut().unwrap().1.insert(
314                    0,
315                    parse_quote!(
316                        use ::tinc::reexports::tonic;
317                    ),
318                );
319
320                let mut server: syn::ItemMod = syn::parse2(builder.generate_server(&make_service(false), "")).unwrap();
321                server.content.as_mut().unwrap().1.insert(
322                    0,
323                    parse_quote!(
324                        use ::tinc::reexports::tonic;
325                    ),
326                );
327
328                [client.into(), server.into()]
329            }));
330        });
331
332        for package in packages.keys() {
333            match std::fs::remove_file(out_dir.join(format!("{package}.rs"))) {
334                Err(err) if err.kind() != ErrorKind::NotFound => return Err(anyhow::anyhow!(err).context("remove")),
335                _ => {}
336            }
337        }
338
339        config.compile_fds(fds).context("prost compile")?;
340
341        for (package, module) in &mut packages {
342            if self.extern_paths.contains(package) {
343                continue;
344            };
345
346            let path = out_dir.join(format!("{package}.rs"));
347            write_module(&path, std::mem::take(&mut module.extra_items)).with_context(|| package.to_owned())?;
348        }
349
350        #[derive(Default)]
351        struct Module<'a> {
352            proto_path: Option<&'a ProtoPath>,
353            children: BTreeMap<&'a str, Module<'a>>,
354        }
355
356        impl ToTokens for Module<'_> {
357            fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
358                let include = self
359                    .proto_path
360                    .map(|p| p.as_ref())
361                    .map(|path| quote!(include!(concat!(#path, ".rs"));));
362                let children = self.children.iter().map(|(part, child)| {
363                    let ident = syn::Ident::new(&to_snake(part), Span::call_site());
364                    quote! {
365                        pub mod #ident {
366                            #child
367                        }
368                    }
369                });
370                quote! {
371                    #include
372                    #(#children)*
373                }
374                .to_tokens(tokens);
375            }
376        }
377
378        if self.root_module {
379            let mut module = Module::default();
380            for package in packages.keys() {
381                let mut module = &mut module;
382                for part in package.split('.') {
383                    module = module.children.entry(part).or_default();
384                }
385                module.proto_path = Some(package);
386            }
387
388            let file: syn::File = parse_quote!(#module);
389            std::fs::write(out_dir.join("___root_module.rs"), prettyplease::unparse(&file)).context("write root module")?;
390        }
391
392        Ok(())
393    }
394}
395
396fn write_module(path: &std::path::Path, module: Vec<syn::Item>) -> anyhow::Result<()> {
397    let mut file = match std::fs::read_to_string(path) {
398        Ok(content) if !content.is_empty() => syn::parse_file(&content).context("parse")?,
399        Err(err) if err.kind() != ErrorKind::NotFound => return Err(anyhow::anyhow!(err).context("read")),
400        _ => syn::File {
401            attrs: Vec::new(),
402            items: Vec::new(),
403            shebang: None,
404        },
405    };
406
407    file.items.extend(module);
408    std::fs::write(path, prettyplease::unparse(&file)).context("write")?;
409
410    Ok(())
411}