tinc_build/codegen/service/
mod.rs

1use anyhow::Context;
2use indexmap::IndexMap;
3use openapi::{BodyMethod, GeneratedBody, GeneratedParams, InputGenerator, OutputGenerator};
4use openapiv3_1::HttpMethod;
5use quote::{format_ident, quote};
6use syn::{Ident, parse_quote};
7use tinc_pb_prost::http_endpoint_options;
8
9use super::Package;
10use super::utils::{field_ident_from_str, type_ident_from_str};
11use crate::types::{
12    Comments, ProtoPath, ProtoService, ProtoServiceMethod, ProtoServiceMethodEndpoint, ProtoServiceMethodIo,
13    ProtoTypeRegistry, ProtoValueType,
14};
15
16mod openapi;
17
18struct GeneratedMethod {
19    function_body: proc_macro2::TokenStream,
20    openapi: openapiv3_1::path::PathItem,
21    http_method: Ident,
22    path: String,
23}
24
25impl GeneratedMethod {
26    #[allow(clippy::too_many_arguments)]
27    fn new(
28        name: &str,
29        package: &str,
30        service: &ProtoService,
31        method: &ProtoServiceMethod,
32        endpoint: &ProtoServiceMethodEndpoint,
33        types: &ProtoTypeRegistry,
34        components: &mut openapiv3_1::Components,
35    ) -> anyhow::Result<GeneratedMethod> {
36        let (http_method_oa, path) = match &endpoint.method {
37            tinc_pb_prost::http_endpoint_options::Method::Get(path) => (openapiv3_1::HttpMethod::Get, path),
38            tinc_pb_prost::http_endpoint_options::Method::Post(path) => (openapiv3_1::HttpMethod::Post, path),
39            tinc_pb_prost::http_endpoint_options::Method::Put(path) => (openapiv3_1::HttpMethod::Put, path),
40            tinc_pb_prost::http_endpoint_options::Method::Delete(path) => (openapiv3_1::HttpMethod::Delete, path),
41            tinc_pb_prost::http_endpoint_options::Method::Patch(path) => (openapiv3_1::HttpMethod::Patch, path),
42        };
43
44        let trimmed_path = path.trim_start_matches('/');
45        let full_path = if let Some(prefix) = &service.options.prefix {
46            format!("/{}/{}", prefix.trim_end_matches('/'), trimmed_path)
47        } else {
48            format!("/{trimmed_path}")
49        };
50
51        let http_method = quote::format_ident!("{http_method_oa}");
52        let tracker_ident = quote::format_ident!("tracker");
53        let target_ident = quote::format_ident!("target");
54        let state_ident = quote::format_ident!("state");
55        let mut openapi = openapiv3_1::path::Operation::new();
56        let mut generator = InputGenerator::new(
57            types,
58            components,
59            package,
60            method.input.value_type().clone(),
61            tracker_ident.clone(),
62            target_ident.clone(),
63            state_ident.clone(),
64        );
65
66        let GeneratedParams {
67            tokens: path_tokens,
68            params,
69        } = generator.generate_path_parameter(&full_path)?;
70        openapi.parameters(params);
71
72        let is_get_or_delete = matches!(http_method_oa, HttpMethod::Get | HttpMethod::Delete);
73        let request = endpoint.request.as_ref().and_then(|req| req.mode.clone()).unwrap_or_else(|| {
74            if is_get_or_delete {
75                http_endpoint_options::request::Mode::Query(http_endpoint_options::request::QueryParams::default())
76            } else {
77                http_endpoint_options::request::Mode::Json(http_endpoint_options::request::JsonBody::default())
78            }
79        });
80
81        let request_tokens = match request {
82            http_endpoint_options::request::Mode::Query(http_endpoint_options::request::QueryParams { field }) => {
83                let GeneratedParams { tokens, params } = generator.generate_query_parameter(field.as_deref())?;
84                openapi.parameters(params);
85                tokens
86            }
87            http_endpoint_options::request::Mode::Binary(http_endpoint_options::request::BinaryBody {
88                field,
89                content_type_accepts,
90                content_type_field,
91            }) => {
92                let GeneratedBody { tokens, body } = generator.generate_body(
93                    &method.cel,
94                    BodyMethod::Binary(content_type_accepts.as_deref()),
95                    field.as_deref(),
96                    content_type_field.as_deref(),
97                )?;
98                openapi.request_body = Some(body);
99                tokens
100            }
101            http_endpoint_options::request::Mode::Json(http_endpoint_options::request::JsonBody { field }) => {
102                let GeneratedBody { tokens, body } =
103                    generator.generate_body(&method.cel, BodyMethod::Json, field.as_deref(), None)?;
104                openapi.request_body = Some(body);
105                tokens
106            }
107            http_endpoint_options::request::Mode::Text(http_endpoint_options::request::TextBody { field }) => {
108                let GeneratedBody { tokens, body } =
109                    generator.generate_body(&method.cel, BodyMethod::Text, field.as_deref(), None)?;
110                openapi.request_body = Some(body);
111                tokens
112            }
113        };
114
115        let input_path = match &method.input {
116            ProtoServiceMethodIo::Single(input) => types.resolve_rust_path(package, input.proto_path()),
117            ProtoServiceMethodIo::Stream(_) => anyhow::bail!("currently streaming is not supported by tinc methods."),
118        };
119
120        let service_method_name = field_ident_from_str(name);
121
122        let response = endpoint
123            .response
124            .as_ref()
125            .and_then(|resp| resp.mode.clone())
126            .unwrap_or_else(
127                || http_endpoint_options::response::Mode::Json(http_endpoint_options::response::Json::default()),
128            );
129
130        let response_ident = quote::format_ident!("response");
131        let builder_ident = quote::format_ident!("builder");
132        let mut generator = OutputGenerator::new(
133            types,
134            components,
135            method.output.value_type().clone(),
136            response_ident.clone(),
137            builder_ident.clone(),
138        );
139
140        let GeneratedBody {
141            body: response,
142            tokens: response_tokens,
143        } = match response {
144            http_endpoint_options::response::Mode::Binary(http_endpoint_options::response::Binary {
145                field,
146                content_type_accepts,
147                content_type_field,
148            }) => generator.generate_body(
149                BodyMethod::Binary(content_type_accepts.as_deref()),
150                field.as_deref(),
151                content_type_field.as_deref(),
152            )?,
153            http_endpoint_options::response::Mode::Json(http_endpoint_options::response::Json { field }) => {
154                generator.generate_body(BodyMethod::Json, field.as_deref(), None)?
155            }
156            http_endpoint_options::response::Mode::Text(http_endpoint_options::response::Text { field }) => {
157                generator.generate_body(BodyMethod::Text, field.as_deref(), None)?
158            }
159        };
160
161        openapi.response("200", response);
162
163        let validate = if matches!(method.input.value_type(), ProtoValueType::Message(_)) {
164            quote! {
165                if let Err(err) = ::tinc::__private::TincValidate::validate_http(&#target_ident, #state_ident, &#tracker_ident) {
166                    return err;
167                }
168            }
169        } else {
170            quote!()
171        };
172
173        let function_impl = quote! {
174            let mut #state_ident = ::tinc::__private::TrackerSharedState::default();
175            let mut #tracker_ident = <<#input_path as ::tinc::__private::TrackerFor>::Tracker as ::core::default::Default>::default();
176            let mut #target_ident = <#input_path as ::core::default::Default>::default();
177
178            #path_tokens
179            #request_tokens
180
181            #validate
182
183            let request = ::tinc::reexports::tonic::Request::from_parts(
184                ::tinc::reexports::tonic::metadata::MetadataMap::from_headers(parts.headers),
185                parts.extensions,
186                target,
187            );
188
189            let (metadata, #response_ident, extensions) = match service.inner.#service_method_name(request).await {
190                ::core::result::Result::Ok(response) => response.into_parts(),
191                ::core::result::Result::Err(status) => return ::tinc::__private::handle_tonic_status(&status),
192            };
193
194            let mut response = {
195                let mut #builder_ident = ::tinc::reexports::http::Response::builder();
196                match #response_tokens {
197                    ::core::result::Result::Ok(v) => v,
198                    ::core::result::Result::Err(err) => return ::tinc::__private::handle_response_build_error(err),
199                }
200            };
201
202            response.headers_mut().extend(metadata.into_headers());
203            *response.extensions_mut() = extensions;
204
205            response
206        };
207
208        Ok(GeneratedMethod {
209            function_body: function_impl,
210            http_method,
211            openapi: openapiv3_1::PathItem::new(http_method_oa, openapi),
212            path: full_path,
213        })
214    }
215
216    pub(crate) fn method_handler(
217        &self,
218        function_name: &Ident,
219        server_module_name: &Ident,
220        service_trait: &Ident,
221        tinc_struct_name: &Ident,
222    ) -> proc_macro2::TokenStream {
223        let function_impl = &self.function_body;
224
225        quote! {
226            #[allow(non_snake_case, unused_mut, dead_code, unused_variables, unused_parens)]
227            async fn #function_name<T>(
228                ::tinc::reexports::axum::extract::State(service): ::tinc::reexports::axum::extract::State<#tinc_struct_name<T>>,
229                request: ::tinc::reexports::axum::extract::Request,
230            ) -> ::tinc::reexports::axum::response::Response
231            where
232                T: super::#server_module_name::#service_trait,
233            {
234                let (mut parts, body) = ::tinc::reexports::axum::RequestExt::with_limited_body(request).into_parts();
235                #function_impl
236            }
237        }
238    }
239
240    pub(crate) fn route(&self, function_name: &Ident) -> proc_macro2::TokenStream {
241        let path = &self.path;
242        let http_method = &self.http_method;
243
244        quote! {
245            .route(#path, ::tinc::reexports::axum::routing::#http_method(#function_name::<T>))
246        }
247    }
248}
249
250#[derive(Debug, Clone, PartialEq)]
251pub(crate) struct ProcessedService {
252    pub full_name: ProtoPath,
253    pub package: ProtoPath,
254    pub comments: Comments,
255    pub openapi: openapiv3_1::OpenApi,
256    pub methods: IndexMap<String, ProcessedServiceMethod>,
257}
258
259impl ProcessedService {
260    pub(crate) fn name(&self) -> &str {
261        self.full_name
262            .strip_prefix(&*self.package)
263            .unwrap_or(&self.full_name)
264            .trim_matches('.')
265    }
266}
267
268#[derive(Debug, Clone, PartialEq)]
269pub(crate) struct ProcessedServiceMethod {
270    pub codec_path: Option<ProtoPath>,
271    pub input: ProtoServiceMethodIo,
272    pub output: ProtoServiceMethodIo,
273    pub comments: Comments,
274}
275
276pub(super) fn handle_service(
277    service: &ProtoService,
278    package: &mut Package,
279    registry: &ProtoTypeRegistry,
280) -> anyhow::Result<()> {
281    let name = service
282        .full_name
283        .strip_prefix(&*service.package)
284        .and_then(|s| s.strip_prefix('.'))
285        .unwrap_or(&*service.full_name);
286
287    let mut components = openapiv3_1::Components::new();
288    let mut paths = openapiv3_1::Paths::builder();
289
290    let snake_name = field_ident_from_str(name);
291    let pascal_name = type_ident_from_str(name);
292
293    let tinc_module_name = quote::format_ident!("{snake_name}_tinc");
294    let server_module_name = quote::format_ident!("{snake_name}_server");
295    let tinc_struct_name = quote::format_ident!("{pascal_name}Tinc");
296
297    let mut method_tokens = Vec::new();
298    let mut route_tokens = Vec::new();
299    let mut method_codecs = Vec::new();
300    let mut methods = IndexMap::new();
301
302    let package_name = format!("{}.{tinc_module_name}", service.package);
303
304    for (name, method) in service.methods.iter() {
305        for (idx, endpoint) in method.endpoints.iter().enumerate() {
306            let gen_method =
307                GeneratedMethod::new(name, &package_name, service, method, endpoint, registry, &mut components)?;
308            let function_name = quote::format_ident!("{name}_{idx}");
309
310            method_tokens.push(gen_method.method_handler(
311                &function_name,
312                &server_module_name,
313                &pascal_name,
314                &tinc_struct_name,
315            ));
316            route_tokens.push(gen_method.route(&function_name));
317            paths = paths.path(gen_method.path, gen_method.openapi);
318        }
319
320        let codec_path = if matches!(method.input.value_type(), ProtoValueType::Message(_)) {
321            let input_path = registry.resolve_rust_path(&package_name, method.input.value_type().proto_path());
322            let output_path = registry.resolve_rust_path(&package_name, method.output.value_type().proto_path());
323            let codec_ident = format_ident!("{name}Codec");
324            method_codecs.push(quote! {
325                #[derive(Debug, Clone, Default)]
326                #[doc(hidden)]
327                pub struct #codec_ident<C>(C);
328
329                #[allow(clippy::all, dead_code, unused_imports, unused_variables, unused_parens)]
330                const _: () = {
331                    #[derive(Debug, Clone, Default)]
332                    pub struct Encoder<E>(E);
333                    #[derive(Debug, Clone, Default)]
334                    pub struct Decoder<D>(D);
335
336                    impl<C> ::tinc::reexports::tonic::codec::Codec for #codec_ident<C>
337                    where
338                        C: ::tinc::reexports::tonic::codec::Codec<Encode = #output_path, Decode = #input_path>
339                    {
340                        type Encode = C::Encode;
341                        type Decode = C::Decode;
342
343                        type Encoder = C::Encoder;
344                        type Decoder = Decoder<C::Decoder>;
345
346                        fn encoder(&mut self) -> Self::Encoder {
347                            ::tinc::reexports::tonic::codec::Codec::encoder(&mut self.0)
348                        }
349
350                        fn decoder(&mut self) -> Self::Decoder {
351                            Decoder(
352                                ::tinc::reexports::tonic::codec::Codec::decoder(&mut self.0)
353                            )
354                        }
355                    }
356
357                    impl<D> ::tinc::reexports::tonic::codec::Decoder for Decoder<D>
358                    where
359                        D: ::tinc::reexports::tonic::codec::Decoder<Item = #input_path, Error = ::tinc::reexports::tonic::Status>
360                    {
361                        type Item = D::Item;
362                        type Error = ::tinc::reexports::tonic::Status;
363
364                        fn decode(&mut self, buf: &mut ::tinc::reexports::tonic::codec::DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
365                            match ::tinc::reexports::tonic::codec::Decoder::decode(&mut self.0, buf) {
366                                ::core::result::Result::Ok(::core::option::Option::Some(item)) => {
367                                    ::tinc::__private::TincValidate::validate_tonic(&item)?;
368                                    ::core::result::Result::Ok(::core::option::Option::Some(item))
369                                },
370                                ::core::result::Result::Ok(::core::option::Option::None) => ::core::result::Result::Ok(::core::option::Option::None),
371                                ::core::result::Result::Err(err) => ::core::result::Result::Err(err),
372                            }
373                        }
374
375                        fn buffer_settings(&self) -> ::tinc::reexports::tonic::codec::BufferSettings {
376                            ::tinc::reexports::tonic::codec::Decoder::buffer_settings(&self.0)
377                        }
378                    }
379                };
380            });
381            Some(ProtoPath::new(format!("{package_name}.{codec_ident}")))
382        } else {
383            None
384        };
385
386        methods.insert(
387            name.clone(),
388            ProcessedServiceMethod {
389                codec_path,
390                input: method.input.clone(),
391                output: method.output.clone(),
392                comments: method.comments.clone(),
393            },
394        );
395    }
396
397    let openapi = openapiv3_1::OpenApi::builder().components(components).paths(paths).build();
398
399    let json_openapi = openapi.to_json().context("invalid openapi schema generation")?;
400
401    package.push_item(parse_quote! {
402        /// This module was automatically generated by `tinc`.
403        pub mod #tinc_module_name {
404            #![allow(
405                unused_variables,
406                dead_code,
407                missing_docs,
408                clippy::wildcard_imports,
409                clippy::let_unit_value,
410                unused_parens,
411                irrefutable_let_patterns,
412            )]
413
414            /// A tinc service struct that exports gRPC routes via an axum router.
415            pub struct #tinc_struct_name<T> {
416                inner: ::std::sync::Arc<T>,
417            }
418
419            impl<T> #tinc_struct_name<T> {
420                /// Create a new tinc service struct from a service implementation.
421                pub fn new(inner: T) -> Self {
422                    Self { inner: ::std::sync::Arc::new(inner) }
423                }
424
425                /// Create a new tinc service struct from an existing `Arc`.
426                pub fn from_arc(inner: ::std::sync::Arc<T>) -> Self {
427                    Self { inner }
428                }
429            }
430
431            impl<T> ::std::clone::Clone for #tinc_struct_name<T> {
432                fn clone(&self) -> Self {
433                    Self { inner: ::std::clone::Clone::clone(&self.inner) }
434                }
435            }
436
437            impl<T> ::std::fmt::Debug for #tinc_struct_name<T> {
438                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
439                    write!(f, stringify!(#tinc_struct_name))
440                }
441            }
442
443            impl<T> ::tinc::TincService for #tinc_struct_name<T>
444            where
445                T: super::#server_module_name::#pascal_name
446            {
447                fn into_router(self) -> ::tinc::reexports::axum::Router {
448                    #(#method_tokens)*
449
450                    ::tinc::reexports::axum::Router::new()
451                        #(#route_tokens)*
452                        .with_state(self)
453                }
454
455                fn openapi_schema_str(&self) -> &'static str {
456                    #json_openapi
457                }
458            }
459
460            #(#method_codecs)*
461        }
462    });
463
464    package.services.push(ProcessedService {
465        full_name: service.full_name.clone(),
466        package: service.package.clone(),
467        comments: service.comments.clone(),
468        openapi,
469        methods,
470    });
471
472    Ok(())
473}