Skip to content

Spring MVC 之 @ServletComponentScan

spring boot 中注册 Servlet 的两种方式

使用 @WebServlet 注解

这个是 javaee 的注解,是 servlet3.0 以后提供的。spring boot 会扫描这个注解,并将这个注解注解的类注册到 web 容器中作为一个 servlet。 但是 DispatcherServlet 并不是自定义的 servlet,而是框架提供的 servlet,所以此方法不行。

使用 ServletRegistrationBean

这个 bean 是由 spring boot 提供专门来注册 servlet 的,可以像注册 bean 一样去配置 servlet。例如下面的:DispatcherServletRegistrationBean

@ServletComponentScan

导入 ServletComponentScanRegistrar.class,扫描类中包含 @WebServlet, @WebFilter, @WebListener 注解。

java
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Import(ServletComponentScanRegistrar.class)
public @interface ServletComponentScan {

	@AliasFor("basePackages")
	String[] value() default {};

	@AliasFor("value")
	String[] basePackages() default {};

	Class<?>[] basePackageClasses() default {};
}

ServletComponentScanRegistrar

java
class ServletComponentScanRegistrar implements ImportBeanDefinitionRegistrar {

	private static final String BEAN_NAME = "servletComponentRegisteringPostProcessor";

	@Override
	public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
		Set<String> packagesToScan = getPackagesToScan(importingClassMetadata);
		if (registry.containsBeanDefinition(BEAN_NAME)) {
			updatePostProcessor(registry, packagesToScan);
		}
		else {
            // 添加一个 servletComponentRegisteringPostProcessor 到 Spring 容器
			addPostProcessor(registry, packagesToScan);
		}
	}
	private void addPostProcessor(BeanDefinitionRegistry registry, Set<String> packagesToScan) {
        // 实际上构造函数中放了一个 ServletComponentRegisteringPostProcessor.class
		ServletComponentRegisteringPostProcessorBeanDefinition definition = new ServletComponentRegisteringPostProcessorBeanDefinition(
				packagesToScan);
		registry.registerBeanDefinition(BEAN_NAME, definition);
	}

	private Set<String> getPackagesToScan(AnnotationMetadata metadata) {
		AnnotationAttributes attributes = AnnotationAttributes
				.fromMap(metadata.getAnnotationAttributes(ServletComponentScan.class.getName()));
		String[] basePackages = attributes.getStringArray("basePackages");
		Class<?>[] basePackageClasses = attributes.getClassArray("basePackageClasses");
		Set<String> packagesToScan = new LinkedHashSet<>(Arrays.asList(basePackages));
		for (Class<?> basePackageClass : basePackageClasses) {
			packagesToScan.add(ClassUtils.getPackageName(basePackageClass));
		}
		if (packagesToScan.isEmpty()) {
			packagesToScan.add(ClassUtils.getPackageName(metadata.getClassName()));
		}
		return packagesToScan;
	}

	static final class ServletComponentRegisteringPostProcessorBeanDefinition extends GenericBeanDefinition {

		private Set<String> packageNames = new LinkedHashSet<>();

		ServletComponentRegisteringPostProcessorBeanDefinition(Collection<String> packageNames) {
			setBeanClass(ServletComponentRegisteringPostProcessor.class);
			setRole(BeanDefinition.ROLE_INFRASTRUCTURE);
			addPackageNames(packageNames);
		}
	}
}

ServletComponentRegisteringPostProcessor

可以看出,实际上是一个 BeanFactoryPostProcessor

java
class ServletComponentRegisteringPostProcessor implements BeanFactoryPostProcessor, ApplicationContextAware {

	private static final List<ServletComponentHandler> HANDLERS;

	static {
		List<ServletComponentHandler> servletComponentHandlers = new ArrayList<>();
        // 扫描 @WebServlet
		servletComponentHandlers.add(new WebServletHandler());
        // 扫描 @WebFilter
		servletComponentHandlers.add(new WebFilterHandler());
        // 扫描 @WebListener
		servletComponentHandlers.add(new WebListenerHandler());
		HANDLERS = Collections.unmodifiableList(servletComponentHandlers);
	}

	private final Set<String> packagesToScan;

	private ApplicationContext applicationContext;

	ServletComponentRegisteringPostProcessor(Set<String> packagesToScan) {
		this.packagesToScan = packagesToScan;
	}

	@Override
	public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
		if (isRunningInEmbeddedWebServer()) {
            // 这个类主要用来做扫描工作
			ClassPathScanningCandidateComponentProvider componentProvider = createComponentProvider();
			for (String packageToScan : this.packagesToScan) {
				scanPackage(componentProvider, packageToScan);
			}
		}
	}

	private void scanPackage(ClassPathScanningCandidateComponentProvider componentProvider, String packageToScan) {
		for (BeanDefinition candidate : componentProvider.findCandidateComponents(packageToScan)) {
			if (candidate instanceof AnnotatedBeanDefinition) {
				for (ServletComponentHandler handler : HANDLERS) {
                    // 这里掉到了上面 static 静态代码块中对应类的 doHandle 方法
					handler.handle(((AnnotatedBeanDefinition) candidate),
							(BeanDefinitionRegistry) this.applicationContext);
				}
			}
		}
	}
}

以 WebServletHandler 为例

可以发现,各个类型的组件实例化对象到 Spring 容器中的类:

  • Servlet: ServletRegistrationBean
  • Filter: FilterRegistrationBean
  • Listener: ServletComponentWebListenerRegistrar
java
class WebServletHandler extends ServletComponentHandler {

	WebServletHandler() {
		super(WebServlet.class);
	}

	@Override
	public void doHandle(Map<String, Object> attributes, AnnotatedBeanDefinition beanDefinition,
			BeanDefinitionRegistry registry) {
        // ServletRegistrationBean 这正是我们要找到,通过 SpringBoot 实例化一个 Servlet 对象到 Spring 容器中。
		BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(ServletRegistrationBean.class);
		builder.addPropertyValue("asyncSupported", attributes.get("asyncSupported"));
		builder.addPropertyValue("initParameters", extractInitParameters(attributes));
		builder.addPropertyValue("loadOnStartup", attributes.get("loadOnStartup"));
		String name = determineName(attributes, beanDefinition);
		builder.addPropertyValue("name", name);
		builder.addPropertyValue("servlet", beanDefinition);
		builder.addPropertyValue("urlMappings", extractUrlPatterns(attributes));
		builder.addPropertyValue("multipartConfig", determineMultipartConfig(beanDefinition));
		registry.registerBeanDefinition(name, builder.getBeanDefinition());
	}

	private String determineName(Map<String, Object> attributes, BeanDefinition beanDefinition) {
		return (String) (StringUtils.hasText((String) attributes.get("name")) ? attributes.get("name")
				: beanDefinition.getBeanClassName());
	}

	private MultipartConfigElement determineMultipartConfig(AnnotatedBeanDefinition beanDefinition) {
		Map<String, Object> attributes = beanDefinition.getMetadata()
				.getAnnotationAttributes(MultipartConfig.class.getName());
		if (attributes == null) {
			return null;
		}
		return new MultipartConfigElement((String) attributes.get("location"), (Long) attributes.get("maxFileSize"),
				(Long) attributes.get("maxRequestSize"), (Integer) attributes.get("fileSizeThreshold"));
	}
}