1use anyhow::Context;
2use indexmap::IndexMap;
3use openapi::{BodyMethod, GeneratedBody, GeneratedParams, InputGenerator, OutputGenerator};
4use openapiv3_1::HttpMethod;
5use quote::{format_ident, quote};
6use syn::{Ident, parse_quote};
7use tinc_pb_prost::http_endpoint_options;
8
9use super::Package;
10use super::utils::{field_ident_from_str, type_ident_from_str};
11use crate::types::{
12 Comments, ProtoPath, ProtoService, ProtoServiceMethod, ProtoServiceMethodEndpoint, ProtoServiceMethodIo,
13 ProtoTypeRegistry, ProtoValueType,
14};
15
16mod openapi;
17
18struct GeneratedMethod {
19 function_body: proc_macro2::TokenStream,
20 openapi: openapiv3_1::path::PathItem,
21 http_method: Ident,
22 path: String,
23}
24
25impl GeneratedMethod {
26 #[allow(clippy::too_many_arguments)]
27 fn new(
28 name: &str,
29 package: &str,
30 service_name: &str,
31 service: &ProtoService,
32 method: &ProtoServiceMethod,
33 endpoint: &ProtoServiceMethodEndpoint,
34 types: &ProtoTypeRegistry,
35 components: &mut openapiv3_1::Components,
36 ) -> anyhow::Result<GeneratedMethod> {
37 let (http_method_oa, path) = match &endpoint.method {
38 tinc_pb_prost::http_endpoint_options::Method::Get(path) => (openapiv3_1::HttpMethod::Get, path),
39 tinc_pb_prost::http_endpoint_options::Method::Post(path) => (openapiv3_1::HttpMethod::Post, path),
40 tinc_pb_prost::http_endpoint_options::Method::Put(path) => (openapiv3_1::HttpMethod::Put, path),
41 tinc_pb_prost::http_endpoint_options::Method::Delete(path) => (openapiv3_1::HttpMethod::Delete, path),
42 tinc_pb_prost::http_endpoint_options::Method::Patch(path) => (openapiv3_1::HttpMethod::Patch, path),
43 };
44
45 let full_path = match (
46 path.trim_matches('/'),
47 service.options.prefix.as_deref().map(|p| p.trim_matches('/')),
48 ) {
49 ("", Some(prefix)) => format!("/{prefix}"),
50 (path, None | Some("")) => format!("/{path}"),
51 (path, Some(prefix)) => format!("/{prefix}/{path}"),
52 };
53
54 let http_method = quote::format_ident!("{http_method_oa}");
55 let tracker_ident = quote::format_ident!("tracker");
56 let target_ident = quote::format_ident!("target");
57 let state_ident = quote::format_ident!("state");
58 let mut openapi = openapiv3_1::path::Operation::new();
59 let mut generator = InputGenerator::new(
60 types,
61 components,
62 package,
63 method.input.value_type().clone(),
64 tracker_ident.clone(),
65 target_ident.clone(),
66 state_ident.clone(),
67 );
68
69 openapi.tag(service_name);
70
71 let GeneratedParams {
72 tokens: path_tokens,
73 params,
74 } = generator.generate_path_parameter(full_path.trim_end_matches("/"))?;
75 openapi.parameters(params);
76
77 let is_get_or_delete = matches!(http_method_oa, HttpMethod::Get | HttpMethod::Delete);
78 let request = endpoint.request.as_ref().and_then(|req| req.mode.clone()).unwrap_or_else(|| {
79 if is_get_or_delete {
80 http_endpoint_options::request::Mode::Query(http_endpoint_options::request::QueryParams::default())
81 } else {
82 http_endpoint_options::request::Mode::Json(http_endpoint_options::request::JsonBody::default())
83 }
84 });
85
86 let request_tokens = match request {
87 http_endpoint_options::request::Mode::Query(http_endpoint_options::request::QueryParams { field }) => {
88 let GeneratedParams { tokens, params } = generator.generate_query_parameter(field.as_deref())?;
89 openapi.parameters(params);
90 tokens
91 }
92 http_endpoint_options::request::Mode::Binary(http_endpoint_options::request::BinaryBody {
93 field,
94 content_type_accepts,
95 content_type_field,
96 }) => {
97 let GeneratedBody { tokens, body } = generator.generate_body(
98 &method.cel,
99 BodyMethod::Binary(content_type_accepts.as_deref()),
100 field.as_deref(),
101 content_type_field.as_deref(),
102 )?;
103 openapi.request_body = Some(body);
104 tokens
105 }
106 http_endpoint_options::request::Mode::Json(http_endpoint_options::request::JsonBody { field }) => {
107 let GeneratedBody { tokens, body } =
108 generator.generate_body(&method.cel, BodyMethod::Json, field.as_deref(), None)?;
109 openapi.request_body = Some(body);
110 tokens
111 }
112 http_endpoint_options::request::Mode::Text(http_endpoint_options::request::TextBody { field }) => {
113 let GeneratedBody { tokens, body } =
114 generator.generate_body(&method.cel, BodyMethod::Text, field.as_deref(), None)?;
115 openapi.request_body = Some(body);
116 tokens
117 }
118 };
119
120 let input_path = match &method.input {
121 ProtoServiceMethodIo::Single(input) => types.resolve_rust_path(package, input.proto_path()),
122 ProtoServiceMethodIo::Stream(_) => anyhow::bail!("currently streaming is not supported by tinc methods."),
123 };
124
125 let service_method_name = field_ident_from_str(name);
126
127 let response = endpoint
128 .response
129 .as_ref()
130 .and_then(|resp| resp.mode.clone())
131 .unwrap_or_else(
132 || http_endpoint_options::response::Mode::Json(http_endpoint_options::response::Json::default()),
133 );
134
135 let response_ident = quote::format_ident!("response");
136 let builder_ident = quote::format_ident!("builder");
137 let mut generator = OutputGenerator::new(
138 types,
139 components,
140 method.output.value_type().clone(),
141 response_ident.clone(),
142 builder_ident.clone(),
143 );
144
145 let GeneratedBody {
146 body: response,
147 tokens: response_tokens,
148 } = match response {
149 http_endpoint_options::response::Mode::Binary(http_endpoint_options::response::Binary {
150 field,
151 content_type_accepts,
152 content_type_field,
153 }) => generator.generate_body(
154 BodyMethod::Binary(content_type_accepts.as_deref()),
155 field.as_deref(),
156 content_type_field.as_deref(),
157 )?,
158 http_endpoint_options::response::Mode::Json(http_endpoint_options::response::Json { field }) => {
159 generator.generate_body(BodyMethod::Json, field.as_deref(), None)?
160 }
161 http_endpoint_options::response::Mode::Text(http_endpoint_options::response::Text { field }) => {
162 generator.generate_body(BodyMethod::Text, field.as_deref(), None)?
163 }
164 };
165
166 openapi.response("200", response);
167
168 let validate = if matches!(method.input.value_type(), ProtoValueType::Message(_)) {
169 quote! {
170 if let Err(err) = ::tinc::__private::TincValidate::validate_http(&#target_ident, #state_ident, &#tracker_ident) {
171 return err;
172 }
173 }
174 } else {
175 quote!()
176 };
177
178 let function_impl = quote! {
179 let mut #state_ident = ::tinc::__private::TrackerSharedState::default();
180 let mut #tracker_ident = <<#input_path as ::tinc::__private::TrackerFor>::Tracker as ::core::default::Default>::default();
181 let mut #target_ident = <#input_path as ::core::default::Default>::default();
182
183 #path_tokens
184 #request_tokens
185
186 #validate
187
188 let request = ::tinc::reexports::tonic::Request::from_parts(
189 ::tinc::reexports::tonic::metadata::MetadataMap::from_headers(parts.headers),
190 parts.extensions,
191 target,
192 );
193
194 let (metadata, #response_ident, extensions) = match service.inner.#service_method_name(request).await {
195 ::core::result::Result::Ok(response) => response.into_parts(),
196 ::core::result::Result::Err(status) => return ::tinc::__private::handle_tonic_status(&status),
197 };
198
199 let mut response = {
200 let mut #builder_ident = ::tinc::reexports::http::Response::builder();
201 match #response_tokens {
202 ::core::result::Result::Ok(v) => v,
203 ::core::result::Result::Err(err) => return ::tinc::__private::handle_response_build_error(err),
204 }
205 };
206
207 response.headers_mut().extend(metadata.into_headers());
208 *response.extensions_mut() = extensions;
209
210 response
211 };
212
213 Ok(GeneratedMethod {
214 function_body: function_impl,
215 http_method,
216 openapi: openapiv3_1::PathItem::new(http_method_oa, openapi),
217 path: full_path,
218 })
219 }
220
221 pub(crate) fn method_handler(
222 &self,
223 function_name: &Ident,
224 server_module_name: &Ident,
225 service_trait: &Ident,
226 tinc_struct_name: &Ident,
227 ) -> proc_macro2::TokenStream {
228 let function_impl = &self.function_body;
229
230 quote! {
231 #[allow(non_snake_case, unused_mut, dead_code, unused_variables, unused_parens)]
232 async fn #function_name<T>(
233 ::tinc::reexports::axum::extract::State(service): ::tinc::reexports::axum::extract::State<#tinc_struct_name<T>>,
234 request: ::tinc::reexports::axum::extract::Request,
235 ) -> ::tinc::reexports::axum::response::Response
236 where
237 T: super::#server_module_name::#service_trait,
238 {
239 let (mut parts, body) = ::tinc::reexports::axum::RequestExt::with_limited_body(request).into_parts();
240 #function_impl
241 }
242 }
243 }
244
245 pub(crate) fn route(&self, function_name: &Ident) -> proc_macro2::TokenStream {
246 let path = &self.path;
247 let http_method = &self.http_method;
248
249 quote! {
250 .route(#path, ::tinc::reexports::axum::routing::#http_method(#function_name::<T>))
251 }
252 }
253}
254
255#[derive(Debug, Clone, PartialEq)]
256pub(crate) struct ProcessedService {
257 pub full_name: ProtoPath,
258 pub package: ProtoPath,
259 pub comments: Comments,
260 pub openapi: openapiv3_1::OpenApi,
261 pub methods: IndexMap<String, ProcessedServiceMethod>,
262}
263
264impl ProcessedService {
265 pub(crate) fn name(&self) -> &str {
266 self.full_name
267 .strip_prefix(&*self.package)
268 .unwrap_or(&self.full_name)
269 .trim_matches('.')
270 }
271}
272
273#[derive(Debug, Clone, PartialEq)]
274pub(crate) struct ProcessedServiceMethod {
275 pub codec_path: Option<ProtoPath>,
276 pub input: ProtoServiceMethodIo,
277 pub output: ProtoServiceMethodIo,
278 pub comments: Comments,
279}
280
281pub(super) fn handle_service(
282 service: &ProtoService,
283 package: &mut Package,
284 registry: &ProtoTypeRegistry,
285) -> anyhow::Result<()> {
286 let name = service
287 .full_name
288 .strip_prefix(&*service.package)
289 .and_then(|s| s.strip_prefix('.'))
290 .unwrap_or(&*service.full_name);
291
292 let mut components = openapiv3_1::Components::new();
293 let mut paths = openapiv3_1::Paths::builder();
294
295 let snake_name = field_ident_from_str(name);
296 let pascal_name = type_ident_from_str(name);
297
298 let tinc_module_name = quote::format_ident!("{snake_name}_tinc");
299 let server_module_name = quote::format_ident!("{snake_name}_server");
300 let tinc_struct_name = quote::format_ident!("{pascal_name}Tinc");
301
302 let mut method_tokens = Vec::new();
303 let mut route_tokens = Vec::new();
304 let mut method_codecs = Vec::new();
305 let mut methods = IndexMap::new();
306
307 let package_name = format!("{}.{tinc_module_name}", service.package);
308
309 for (method_name, method) in service.methods.iter() {
310 for (idx, endpoint) in method.endpoints.iter().enumerate() {
311 let gen_method = GeneratedMethod::new(
312 method_name,
313 &package_name,
314 name,
315 service,
316 method,
317 endpoint,
318 registry,
319 &mut components,
320 )?;
321 let function_name = quote::format_ident!("{method_name}_{idx}");
322
323 method_tokens.push(gen_method.method_handler(
324 &function_name,
325 &server_module_name,
326 &pascal_name,
327 &tinc_struct_name,
328 ));
329 route_tokens.push(gen_method.route(&function_name));
330 paths = paths.path(gen_method.path, gen_method.openapi);
331 }
332
333 let codec_path = if matches!(method.input.value_type(), ProtoValueType::Message(_)) {
334 let input_path = registry.resolve_rust_path(&package_name, method.input.value_type().proto_path());
335 let output_path = registry.resolve_rust_path(&package_name, method.output.value_type().proto_path());
336 let codec_ident = format_ident!("{method_name}Codec");
337 method_codecs.push(quote! {
338 #[derive(Debug, Clone, Default)]
339 #[doc(hidden)]
340 pub struct #codec_ident<C>(C);
341
342 #[allow(clippy::all, dead_code, unused_imports, unused_variables, unused_parens)]
343 const _: () = {
344 #[derive(Debug, Clone, Default)]
345 pub struct Encoder<E>(E);
346 #[derive(Debug, Clone, Default)]
347 pub struct Decoder<D>(D);
348
349 impl<C> ::tinc::reexports::tonic::codec::Codec for #codec_ident<C>
350 where
351 C: ::tinc::reexports::tonic::codec::Codec<Encode = #output_path, Decode = #input_path>
352 {
353 type Encode = C::Encode;
354 type Decode = C::Decode;
355
356 type Encoder = C::Encoder;
357 type Decoder = Decoder<C::Decoder>;
358
359 fn encoder(&mut self) -> Self::Encoder {
360 ::tinc::reexports::tonic::codec::Codec::encoder(&mut self.0)
361 }
362
363 fn decoder(&mut self) -> Self::Decoder {
364 Decoder(
365 ::tinc::reexports::tonic::codec::Codec::decoder(&mut self.0)
366 )
367 }
368 }
369
370 impl<D> ::tinc::reexports::tonic::codec::Decoder for Decoder<D>
371 where
372 D: ::tinc::reexports::tonic::codec::Decoder<Item = #input_path, Error = ::tinc::reexports::tonic::Status>
373 {
374 type Item = D::Item;
375 type Error = ::tinc::reexports::tonic::Status;
376
377 fn decode(&mut self, buf: &mut ::tinc::reexports::tonic::codec::DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
378 match ::tinc::reexports::tonic::codec::Decoder::decode(&mut self.0, buf) {
379 ::core::result::Result::Ok(::core::option::Option::Some(item)) => {
380 ::tinc::__private::TincValidate::validate_tonic(&item)?;
381 ::core::result::Result::Ok(::core::option::Option::Some(item))
382 },
383 ::core::result::Result::Ok(::core::option::Option::None) => ::core::result::Result::Ok(::core::option::Option::None),
384 ::core::result::Result::Err(err) => ::core::result::Result::Err(err),
385 }
386 }
387
388 fn buffer_settings(&self) -> ::tinc::reexports::tonic::codec::BufferSettings {
389 ::tinc::reexports::tonic::codec::Decoder::buffer_settings(&self.0)
390 }
391 }
392 };
393 });
394 Some(ProtoPath::new(format!("{package_name}.{codec_ident}")))
395 } else {
396 None
397 };
398
399 methods.insert(
400 method_name.clone(),
401 ProcessedServiceMethod {
402 codec_path,
403 input: method.input.clone(),
404 output: method.output.clone(),
405 comments: method.comments.clone(),
406 },
407 );
408 }
409
410 let openapi_tag = openapiv3_1::Tag::builder()
411 .name(name)
412 .description(service.comments.to_string())
413 .build();
414 let openapi = openapiv3_1::OpenApi::builder()
415 .components(components)
416 .paths(paths)
417 .tags(vec![openapi_tag])
418 .build();
419
420 let json_openapi = openapi.to_json().context("invalid openapi schema generation")?;
421
422 package.push_item(parse_quote! {
423 #[allow(clippy::all)]
425 pub mod #tinc_module_name {
426 #![allow(
427 unused_variables,
428 dead_code,
429 missing_docs,
430 clippy::wildcard_imports,
431 clippy::let_unit_value,
432 unused_parens,
433 irrefutable_let_patterns,
434 )]
435
436 pub struct #tinc_struct_name<T> {
438 inner: ::std::sync::Arc<T>,
439 }
440
441 impl<T> #tinc_struct_name<T> {
442 pub fn new(inner: T) -> Self {
444 Self { inner: ::std::sync::Arc::new(inner) }
445 }
446
447 pub fn from_arc(inner: ::std::sync::Arc<T>) -> Self {
449 Self { inner }
450 }
451 }
452
453 impl<T> ::std::clone::Clone for #tinc_struct_name<T> {
454 fn clone(&self) -> Self {
455 Self { inner: ::std::clone::Clone::clone(&self.inner) }
456 }
457 }
458
459 impl<T> ::std::fmt::Debug for #tinc_struct_name<T> {
460 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
461 write!(f, stringify!(#tinc_struct_name))
462 }
463 }
464
465 impl<T> ::tinc::TincService for #tinc_struct_name<T>
466 where
467 T: super::#server_module_name::#pascal_name
468 {
469 fn into_router(self) -> ::tinc::reexports::axum::Router {
470 #(#method_tokens)*
471
472 ::tinc::reexports::axum::Router::new()
473 #(#route_tokens)*
474 .with_state(self)
475 }
476
477 fn openapi_schema_str(&self) -> &'static str {
478 #json_openapi
479 }
480 }
481
482 #(#method_codecs)*
483 }
484 });
485
486 package.services.push(ProcessedService {
487 full_name: service.full_name.clone(),
488 package: service.package.clone(),
489 comments: service.comments.clone(),
490 openapi,
491 methods,
492 });
493
494 Ok(())
495}