tinc_build/codegen/
service.rs

1use anyhow::Context;
2use indexmap::IndexMap;
3use openapi::{BodyMethod, GeneratedBody, GeneratedParams, InputGenerator, OutputGenerator};
4use openapiv3_1::HttpMethod;
5use quote::{format_ident, quote};
6use syn::{Ident, parse_quote};
7use tinc_pb_prost::http_endpoint_options;
8
9use super::Package;
10use super::utils::{field_ident_from_str, type_ident_from_str};
11use crate::types::{
12    Comments, ProtoPath, ProtoService, ProtoServiceMethod, ProtoServiceMethodEndpoint, ProtoServiceMethodIo,
13    ProtoTypeRegistry, ProtoValueType,
14};
15
16mod openapi;
17
18struct GeneratedMethod {
19    function_body: proc_macro2::TokenStream,
20    openapi: openapiv3_1::path::PathItem,
21    http_method: Ident,
22    path: String,
23}
24
25impl GeneratedMethod {
26    #[allow(clippy::too_many_arguments)]
27    fn new(
28        name: &str,
29        package: &str,
30        service_name: &str,
31        service: &ProtoService,
32        method: &ProtoServiceMethod,
33        endpoint: &ProtoServiceMethodEndpoint,
34        types: &ProtoTypeRegistry,
35        components: &mut openapiv3_1::Components,
36    ) -> anyhow::Result<GeneratedMethod> {
37        let (http_method_oa, path) = match &endpoint.method {
38            tinc_pb_prost::http_endpoint_options::Method::Get(path) => (openapiv3_1::HttpMethod::Get, path),
39            tinc_pb_prost::http_endpoint_options::Method::Post(path) => (openapiv3_1::HttpMethod::Post, path),
40            tinc_pb_prost::http_endpoint_options::Method::Put(path) => (openapiv3_1::HttpMethod::Put, path),
41            tinc_pb_prost::http_endpoint_options::Method::Delete(path) => (openapiv3_1::HttpMethod::Delete, path),
42            tinc_pb_prost::http_endpoint_options::Method::Patch(path) => (openapiv3_1::HttpMethod::Patch, path),
43        };
44
45        let full_path = match (
46            path.trim_matches('/'),
47            service.options.prefix.as_deref().map(|p| p.trim_matches('/')),
48        ) {
49            ("", Some(prefix)) => format!("/{prefix}"),
50            (path, None | Some("")) => format!("/{path}"),
51            (path, Some(prefix)) => format!("/{prefix}/{path}"),
52        };
53
54        let http_method = quote::format_ident!("{http_method_oa}");
55        let tracker_ident = quote::format_ident!("tracker");
56        let target_ident = quote::format_ident!("target");
57        let state_ident = quote::format_ident!("state");
58        let mut openapi = openapiv3_1::path::Operation::new();
59        let mut generator = InputGenerator::new(
60            types,
61            components,
62            package,
63            method.input.value_type().clone(),
64            tracker_ident.clone(),
65            target_ident.clone(),
66            state_ident.clone(),
67        );
68
69        openapi.tag(service_name);
70
71        let GeneratedParams {
72            tokens: path_tokens,
73            params,
74        } = generator.generate_path_parameter(full_path.trim_end_matches("/"))?;
75        openapi.parameters(params);
76
77        let is_get_or_delete = matches!(http_method_oa, HttpMethod::Get | HttpMethod::Delete);
78        let request = endpoint.request.as_ref().and_then(|req| req.mode.clone()).unwrap_or_else(|| {
79            if is_get_or_delete {
80                http_endpoint_options::request::Mode::Query(http_endpoint_options::request::QueryParams::default())
81            } else {
82                http_endpoint_options::request::Mode::Json(http_endpoint_options::request::JsonBody::default())
83            }
84        });
85
86        let request_tokens = match request {
87            http_endpoint_options::request::Mode::Query(http_endpoint_options::request::QueryParams { field }) => {
88                let GeneratedParams { tokens, params } = generator.generate_query_parameter(field.as_deref())?;
89                openapi.parameters(params);
90                tokens
91            }
92            http_endpoint_options::request::Mode::Binary(http_endpoint_options::request::BinaryBody {
93                field,
94                content_type_accepts,
95                content_type_field,
96            }) => {
97                let GeneratedBody { tokens, body } = generator.generate_body(
98                    &method.cel,
99                    BodyMethod::Binary(content_type_accepts.as_deref()),
100                    field.as_deref(),
101                    content_type_field.as_deref(),
102                )?;
103                openapi.request_body = Some(body);
104                tokens
105            }
106            http_endpoint_options::request::Mode::Json(http_endpoint_options::request::JsonBody { field }) => {
107                let GeneratedBody { tokens, body } =
108                    generator.generate_body(&method.cel, BodyMethod::Json, field.as_deref(), None)?;
109                openapi.request_body = Some(body);
110                tokens
111            }
112            http_endpoint_options::request::Mode::Text(http_endpoint_options::request::TextBody { field }) => {
113                let GeneratedBody { tokens, body } =
114                    generator.generate_body(&method.cel, BodyMethod::Text, field.as_deref(), None)?;
115                openapi.request_body = Some(body);
116                tokens
117            }
118        };
119
120        let input_path = match &method.input {
121            ProtoServiceMethodIo::Single(input) => types.resolve_rust_path(package, input.proto_path()),
122            ProtoServiceMethodIo::Stream(_) => anyhow::bail!("currently streaming is not supported by tinc methods."),
123        };
124
125        let service_method_name = field_ident_from_str(name);
126
127        let response = endpoint
128            .response
129            .as_ref()
130            .and_then(|resp| resp.mode.clone())
131            .unwrap_or_else(
132                || http_endpoint_options::response::Mode::Json(http_endpoint_options::response::Json::default()),
133            );
134
135        let response_ident = quote::format_ident!("response");
136        let builder_ident = quote::format_ident!("builder");
137        let mut generator = OutputGenerator::new(
138            types,
139            components,
140            method.output.value_type().clone(),
141            response_ident.clone(),
142            builder_ident.clone(),
143        );
144
145        let GeneratedBody {
146            body: response,
147            tokens: response_tokens,
148        } = match response {
149            http_endpoint_options::response::Mode::Binary(http_endpoint_options::response::Binary {
150                field,
151                content_type_accepts,
152                content_type_field,
153            }) => generator.generate_body(
154                BodyMethod::Binary(content_type_accepts.as_deref()),
155                field.as_deref(),
156                content_type_field.as_deref(),
157            )?,
158            http_endpoint_options::response::Mode::Json(http_endpoint_options::response::Json { field }) => {
159                generator.generate_body(BodyMethod::Json, field.as_deref(), None)?
160            }
161            http_endpoint_options::response::Mode::Text(http_endpoint_options::response::Text { field }) => {
162                generator.generate_body(BodyMethod::Text, field.as_deref(), None)?
163            }
164        };
165
166        openapi.response("200", response);
167
168        let validate = if matches!(method.input.value_type(), ProtoValueType::Message(_)) {
169            quote! {
170                if let Err(err) = ::tinc::__private::TincValidate::validate_http(&#target_ident, #state_ident, &#tracker_ident) {
171                    return err;
172                }
173            }
174        } else {
175            quote!()
176        };
177
178        let function_impl = quote! {
179            let mut #state_ident = ::tinc::__private::TrackerSharedState::default();
180            let mut #tracker_ident = <<#input_path as ::tinc::__private::TrackerFor>::Tracker as ::core::default::Default>::default();
181            let mut #target_ident = <#input_path as ::core::default::Default>::default();
182
183            #path_tokens
184            #request_tokens
185
186            #validate
187
188            let request = ::tinc::reexports::tonic::Request::from_parts(
189                ::tinc::reexports::tonic::metadata::MetadataMap::from_headers(parts.headers),
190                parts.extensions,
191                target,
192            );
193
194            let (metadata, #response_ident, extensions) = match service.inner.#service_method_name(request).await {
195                ::core::result::Result::Ok(response) => response.into_parts(),
196                ::core::result::Result::Err(status) => return ::tinc::__private::handle_tonic_status(&status),
197            };
198
199            let mut response = {
200                let mut #builder_ident = ::tinc::reexports::http::Response::builder();
201                match #response_tokens {
202                    ::core::result::Result::Ok(v) => v,
203                    ::core::result::Result::Err(err) => return ::tinc::__private::handle_response_build_error(err),
204                }
205            };
206
207            response.headers_mut().extend(metadata.into_headers());
208            *response.extensions_mut() = extensions;
209
210            response
211        };
212
213        Ok(GeneratedMethod {
214            function_body: function_impl,
215            http_method,
216            openapi: openapiv3_1::PathItem::new(http_method_oa, openapi),
217            path: full_path,
218        })
219    }
220
221    pub(crate) fn method_handler(
222        &self,
223        function_name: &Ident,
224        server_module_name: &Ident,
225        service_trait: &Ident,
226        tinc_struct_name: &Ident,
227    ) -> proc_macro2::TokenStream {
228        let function_impl = &self.function_body;
229
230        quote! {
231            #[allow(non_snake_case, unused_mut, dead_code, unused_variables, unused_parens)]
232            async fn #function_name<T>(
233                ::tinc::reexports::axum::extract::State(service): ::tinc::reexports::axum::extract::State<#tinc_struct_name<T>>,
234                request: ::tinc::reexports::axum::extract::Request,
235            ) -> ::tinc::reexports::axum::response::Response
236            where
237                T: super::#server_module_name::#service_trait,
238            {
239                let (mut parts, body) = ::tinc::reexports::axum::RequestExt::with_limited_body(request).into_parts();
240                #function_impl
241            }
242        }
243    }
244
245    pub(crate) fn route(&self, function_name: &Ident) -> proc_macro2::TokenStream {
246        let path = &self.path;
247        let http_method = &self.http_method;
248
249        quote! {
250            .route(#path, ::tinc::reexports::axum::routing::#http_method(#function_name::<T>))
251        }
252    }
253}
254
255#[derive(Debug, Clone, PartialEq)]
256pub(crate) struct ProcessedService {
257    pub full_name: ProtoPath,
258    pub package: ProtoPath,
259    pub comments: Comments,
260    pub openapi: openapiv3_1::OpenApi,
261    pub methods: IndexMap<String, ProcessedServiceMethod>,
262}
263
264impl ProcessedService {
265    pub(crate) fn name(&self) -> &str {
266        self.full_name
267            .strip_prefix(&*self.package)
268            .unwrap_or(&self.full_name)
269            .trim_matches('.')
270    }
271}
272
273#[derive(Debug, Clone, PartialEq)]
274pub(crate) struct ProcessedServiceMethod {
275    pub codec_path: Option<ProtoPath>,
276    pub input: ProtoServiceMethodIo,
277    pub output: ProtoServiceMethodIo,
278    pub comments: Comments,
279}
280
281pub(super) fn handle_service(
282    service: &ProtoService,
283    package: &mut Package,
284    registry: &ProtoTypeRegistry,
285) -> anyhow::Result<()> {
286    let name = service
287        .full_name
288        .strip_prefix(&*service.package)
289        .and_then(|s| s.strip_prefix('.'))
290        .unwrap_or(&*service.full_name);
291
292    let mut components = openapiv3_1::Components::new();
293    let mut paths = openapiv3_1::Paths::builder();
294
295    let snake_name = field_ident_from_str(name);
296    let pascal_name = type_ident_from_str(name);
297
298    let tinc_module_name = quote::format_ident!("{snake_name}_tinc");
299    let server_module_name = quote::format_ident!("{snake_name}_server");
300    let tinc_struct_name = quote::format_ident!("{pascal_name}Tinc");
301
302    let mut method_tokens = Vec::new();
303    let mut route_tokens = Vec::new();
304    let mut method_codecs = Vec::new();
305    let mut methods = IndexMap::new();
306
307    let package_name = format!("{}.{tinc_module_name}", service.package);
308
309    for (method_name, method) in service.methods.iter() {
310        for (idx, endpoint) in method.endpoints.iter().enumerate() {
311            let gen_method = GeneratedMethod::new(
312                method_name,
313                &package_name,
314                name,
315                service,
316                method,
317                endpoint,
318                registry,
319                &mut components,
320            )?;
321            let function_name = quote::format_ident!("{method_name}_{idx}");
322
323            method_tokens.push(gen_method.method_handler(
324                &function_name,
325                &server_module_name,
326                &pascal_name,
327                &tinc_struct_name,
328            ));
329            route_tokens.push(gen_method.route(&function_name));
330            paths = paths.path(gen_method.path, gen_method.openapi);
331        }
332
333        let codec_path = if matches!(method.input.value_type(), ProtoValueType::Message(_)) {
334            let input_path = registry.resolve_rust_path(&package_name, method.input.value_type().proto_path());
335            let output_path = registry.resolve_rust_path(&package_name, method.output.value_type().proto_path());
336            let codec_ident = format_ident!("{method_name}Codec");
337            method_codecs.push(quote! {
338                #[derive(Debug, Clone, Default)]
339                #[doc(hidden)]
340                pub struct #codec_ident<C>(C);
341
342                #[allow(clippy::all, dead_code, unused_imports, unused_variables, unused_parens)]
343                const _: () = {
344                    #[derive(Debug, Clone, Default)]
345                    pub struct Encoder<E>(E);
346                    #[derive(Debug, Clone, Default)]
347                    pub struct Decoder<D>(D);
348
349                    impl<C> ::tinc::reexports::tonic::codec::Codec for #codec_ident<C>
350                    where
351                        C: ::tinc::reexports::tonic::codec::Codec<Encode = #output_path, Decode = #input_path>
352                    {
353                        type Encode = C::Encode;
354                        type Decode = C::Decode;
355
356                        type Encoder = C::Encoder;
357                        type Decoder = Decoder<C::Decoder>;
358
359                        fn encoder(&mut self) -> Self::Encoder {
360                            ::tinc::reexports::tonic::codec::Codec::encoder(&mut self.0)
361                        }
362
363                        fn decoder(&mut self) -> Self::Decoder {
364                            Decoder(
365                                ::tinc::reexports::tonic::codec::Codec::decoder(&mut self.0)
366                            )
367                        }
368                    }
369
370                    impl<D> ::tinc::reexports::tonic::codec::Decoder for Decoder<D>
371                    where
372                        D: ::tinc::reexports::tonic::codec::Decoder<Item = #input_path, Error = ::tinc::reexports::tonic::Status>
373                    {
374                        type Item = D::Item;
375                        type Error = ::tinc::reexports::tonic::Status;
376
377                        fn decode(&mut self, buf: &mut ::tinc::reexports::tonic::codec::DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
378                            match ::tinc::reexports::tonic::codec::Decoder::decode(&mut self.0, buf) {
379                                ::core::result::Result::Ok(::core::option::Option::Some(item)) => {
380                                    ::tinc::__private::TincValidate::validate_tonic(&item)?;
381                                    ::core::result::Result::Ok(::core::option::Option::Some(item))
382                                },
383                                ::core::result::Result::Ok(::core::option::Option::None) => ::core::result::Result::Ok(::core::option::Option::None),
384                                ::core::result::Result::Err(err) => ::core::result::Result::Err(err),
385                            }
386                        }
387
388                        fn buffer_settings(&self) -> ::tinc::reexports::tonic::codec::BufferSettings {
389                            ::tinc::reexports::tonic::codec::Decoder::buffer_settings(&self.0)
390                        }
391                    }
392                };
393            });
394            Some(ProtoPath::new(format!("{package_name}.{codec_ident}")))
395        } else {
396            None
397        };
398
399        methods.insert(
400            method_name.clone(),
401            ProcessedServiceMethod {
402                codec_path,
403                input: method.input.clone(),
404                output: method.output.clone(),
405                comments: method.comments.clone(),
406            },
407        );
408    }
409
410    let openapi_tag = openapiv3_1::Tag::builder()
411        .name(name)
412        .description(service.comments.to_string())
413        .build();
414    let openapi = openapiv3_1::OpenApi::builder()
415        .components(components)
416        .paths(paths)
417        .tags(vec![openapi_tag])
418        .build();
419
420    let json_openapi = openapi.to_json().context("invalid openapi schema generation")?;
421
422    package.push_item(parse_quote! {
423        /// This module was automatically generated by `tinc`.
424        #[allow(clippy::all)]
425        pub mod #tinc_module_name {
426            #![allow(
427                unused_variables,
428                dead_code,
429                missing_docs,
430                clippy::wildcard_imports,
431                clippy::let_unit_value,
432                unused_parens,
433                irrefutable_let_patterns,
434            )]
435
436            /// A tinc service struct that exports gRPC routes via an axum router.
437            pub struct #tinc_struct_name<T> {
438                inner: ::std::sync::Arc<T>,
439            }
440
441            impl<T> #tinc_struct_name<T> {
442                /// Create a new tinc service struct from a service implementation.
443                pub fn new(inner: T) -> Self {
444                    Self { inner: ::std::sync::Arc::new(inner) }
445                }
446
447                /// Create a new tinc service struct from an existing `Arc`.
448                pub fn from_arc(inner: ::std::sync::Arc<T>) -> Self {
449                    Self { inner }
450                }
451            }
452
453            impl<T> ::std::clone::Clone for #tinc_struct_name<T> {
454                fn clone(&self) -> Self {
455                    Self { inner: ::std::clone::Clone::clone(&self.inner) }
456                }
457            }
458
459            impl<T> ::std::fmt::Debug for #tinc_struct_name<T> {
460                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
461                    write!(f, stringify!(#tinc_struct_name))
462                }
463            }
464
465            impl<T> ::tinc::TincService for #tinc_struct_name<T>
466            where
467                T: super::#server_module_name::#pascal_name
468            {
469                fn into_router(self) -> ::tinc::reexports::axum::Router {
470                    #(#method_tokens)*
471
472                    ::tinc::reexports::axum::Router::new()
473                        #(#route_tokens)*
474                        .with_state(self)
475                }
476
477                fn openapi_schema_str(&self) -> &'static str {
478                    #json_openapi
479                }
480            }
481
482            #(#method_codecs)*
483        }
484    });
485
486    package.services.push(ProcessedService {
487        full_name: service.full_name.clone(),
488        package: service.package.clone(),
489        comments: service.comments.clone(),
490        openapi,
491        methods,
492    });
493
494    Ok(())
495}