/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.yarn.server.resourcemanager.security;

import java.io.IOException;
import java.lang.annotation.Annotation;
import java.net.InetSocketAddress;
import java.security.PrivilegedAction;
import java.security.PrivilegedExceptionAction;
import javax.security.sasl.SaslException;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.RemoteException;
import org.apache.hadoop.ipc.Server;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.KerberosInfo;
import org.apache.hadoop.security.SecurityInfo;
import org.apache.hadoop.security.SecurityUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.security.token.TokenInfo;
import org.apache.hadoop.security.token.TokenSelector;
import org.apache.hadoop.service.AbstractService;
import org.apache.hadoop.yarn.api.ApplicationMasterProtocol;
import org.apache.hadoop.yarn.api.ContainerManagementProtocol;
import org.apache.hadoop.yarn.api.protocolrecords.GetApplicationReportRequest;
import org.apache.hadoop.yarn.api.protocolrecords.GetApplicationReportResponse;
import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
import org.apache.hadoop.yarn.api.protocolrecords.StartContainersRequest;
import org.apache.hadoop.yarn.api.protocolrecords.StartContainersResponse;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ApplicationReport;
import org.apache.hadoop.yarn.event.Dispatcher;
import org.apache.hadoop.yarn.event.DrainDispatcher;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.exceptions.YarnRuntimeException;
import org.apache.hadoop.yarn.security.client.ClientToAMTokenIdentifier;
import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager;
import org.apache.hadoop.yarn.security.client.ClientToAMTokenSelector;
import org.apache.hadoop.yarn.server.resourcemanager.ClientRMService;
import org.apache.hadoop.yarn.server.resourcemanager.MockAM;
import org.apache.hadoop.yarn.server.resourcemanager.MockNM;
import org.apache.hadoop.yarn.server.resourcemanager.MockRMWithCustomAMLauncher;
import org.apache.hadoop.yarn.server.resourcemanager.RMContext;
import org.apache.hadoop.yarn.server.resourcemanager.rmapp.RMApp;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.YarnScheduler;
import org.apache.hadoop.yarn.server.utils.BuilderUtils;
import org.apache.hadoop.yarn.util.ConverterUtils;
import org.apache.hadoop.yarn.util.Records;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Matchers;
import org.mockito.Mockito;

public class TestClientToAMTokens {
    @Test
    public void testClientToAMTokens() throws Exception {
        Configuration conf = new Configuration();
        conf.set("hadoop.security.authentication", "kerberos");
        UserGroupInformation.setConfiguration((Configuration)conf);
        ContainerManagementProtocol containerManager = (ContainerManagementProtocol)Mockito.mock(ContainerManagementProtocol.class);
        StartContainersResponse mockResponse = (StartContainersResponse)Mockito.mock(StartContainersResponse.class);
        Mockito.when((Object)containerManager.startContainers((StartContainersRequest)Matchers.any())).thenReturn((Object)mockResponse);
        final DrainDispatcher dispatcher = new DrainDispatcher();
        MockRMWithCustomAMLauncher rm = new MockRMWithCustomAMLauncher(conf, containerManager){

            @Override
            protected ClientRMService createClientRMService() {
                return new ClientRMService((RMContext)this.rmContext, (YarnScheduler)this.scheduler, this.rmAppManager, this.applicationACLsManager, this.queueACLsManager, this.getRMDTSecretManager());
            }

            protected Dispatcher createDispatcher() {
                return dispatcher;
            }

            protected void doSecureLogin() throws IOException {
            }
        };
        rm.start();
        RMApp app = rm.submitApp(1024);
        MockNM nm1 = rm.registerNode("localhost:1234", 3072);
        nm1.nodeHeartbeat(true);
        dispatcher.await();
        nm1.nodeHeartbeat(true);
        dispatcher.await();
        ApplicationAttemptId appAttempt = app.getCurrentAppAttempt().getAppAttemptId();
        final MockAM mockAM = new MockAM(rm.getRMContext(), (ApplicationMasterProtocol)rm.getApplicationMasterService(), app.getCurrentAppAttempt().getAppAttemptId());
        UserGroupInformation appUgi = UserGroupInformation.createRemoteUser((String)appAttempt.toString());
        RegisterApplicationMasterResponse response = (RegisterApplicationMasterResponse)appUgi.doAs((PrivilegedAction)new PrivilegedAction<RegisterApplicationMasterResponse>(){

            @Override
            public RegisterApplicationMasterResponse run() {
                RegisterApplicationMasterResponse response = null;
                try {
                    response = mockAM.registerAppAttempt();
                }
                catch (Exception e) {
                    junit.framework.Assert.fail((String)"Exception was not expected");
                }
                return response;
            }
        });
        GetApplicationReportRequest request = (GetApplicationReportRequest)Records.newRecord(GetApplicationReportRequest.class);
        request.setApplicationId(app.getApplicationId());
        GetApplicationReportResponse reportResponse = rm.getClientRMService().getApplicationReport(request);
        ApplicationReport appReport = reportResponse.getApplicationReport();
        org.apache.hadoop.yarn.api.records.Token originalClientToAMToken = appReport.getClientToAMToken();
        junit.framework.Assert.assertNotNull((Object)response.getClientToAMTokenMasterKey());
        junit.framework.Assert.assertTrue((response.getClientToAMTokenMasterKey().array().length > 0 ? 1 : 0) != 0);
        ApplicationAttemptId appAttemptId = (ApplicationAttemptId)app.getAppAttempts().keySet().iterator().next();
        junit.framework.Assert.assertNotNull((Object)appAttemptId);
        CustomAM am = new CustomAM(appAttemptId, response.getClientToAMTokenMasterKey().array());
        am.init(conf);
        am.start();
        SecurityUtil.setSecurityInfoProviders((SecurityInfo[])new SecurityInfo[]{new CustomSecurityInfo()});
        try {
            CustomProtocol client = (CustomProtocol)RPC.getProxy(CustomProtocol.class, (long)1L, (InetSocketAddress)am.address, (Configuration)conf);
            client.ping();
            Assert.fail((String)"Access by unauthenticated user should fail!!");
        }
        catch (Exception e) {
            junit.framework.Assert.assertFalse((boolean)am.pinged);
        }
        Token token = ConverterUtils.convertFromYarn((org.apache.hadoop.yarn.api.records.Token)originalClientToAMToken, (InetSocketAddress)am.address);
        this.verifyTokenWithTamperedID(conf, am, (Token<ClientToAMTokenIdentifier>)token);
        this.verifyTokenWithTamperedUserName(conf, am, (Token<ClientToAMTokenIdentifier>)token);
        this.verifyValidToken(conf, am, (Token<ClientToAMTokenIdentifier>)token);
    }

    private void verifyTokenWithTamperedID(Configuration conf, CustomAM am, Token<ClientToAMTokenIdentifier> token) throws IOException {
        UserGroupInformation ugi = UserGroupInformation.createRemoteUser((String)"me");
        ClientToAMTokenIdentifier maliciousID = new ClientToAMTokenIdentifier(BuilderUtils.newApplicationAttemptId((ApplicationId)BuilderUtils.newApplicationId((long)am.appAttemptId.getApplicationId().getClusterTimestamp(), (int)42), (int)43), UserGroupInformation.getCurrentUser().getShortUserName());
        this.verifyTamperedToken(conf, am, token, ugi, maliciousID);
    }

    private void verifyTokenWithTamperedUserName(Configuration conf, CustomAM am, Token<ClientToAMTokenIdentifier> token) throws IOException {
        UserGroupInformation ugi = UserGroupInformation.createRemoteUser((String)"me");
        ClientToAMTokenIdentifier maliciousID = new ClientToAMTokenIdentifier(am.appAttemptId, "evilOrc");
        this.verifyTamperedToken(conf, am, token, ugi, maliciousID);
    }

    private void verifyTamperedToken(final Configuration conf, final CustomAM am, Token<ClientToAMTokenIdentifier> token, UserGroupInformation ugi, ClientToAMTokenIdentifier maliciousID) {
        Token maliciousToken = new Token(maliciousID.getBytes(), token.getPassword(), token.getKind(), token.getService());
        ugi.addToken(maliciousToken);
        try {
            ugi.doAs((PrivilegedExceptionAction)new PrivilegedExceptionAction<Void>(){

                @Override
                public Void run() throws Exception {
                    try {
                        CustomProtocol client = (CustomProtocol)RPC.getProxy(CustomProtocol.class, (long)1L, (InetSocketAddress)am.address, (Configuration)conf);
                        client.ping();
                        Assert.fail((String)"Connection initiation with illegally modified tokens is expected to fail.");
                        return null;
                    }
                    catch (YarnException ex) {
                        Assert.fail((String)"Cannot get a YARN remote exception as it will indicate RPC success");
                        throw ex;
                    }
                }
            });
        }
        catch (Exception e2) {
            junit.framework.Assert.assertEquals((String)RemoteException.class.getName(), (String)e2.getClass().getName());
            IOException e2 = ((RemoteException)e2).unwrapRemoteException();
            junit.framework.Assert.assertEquals((String)SaslException.class.getCanonicalName(), (String)e2.getClass().getCanonicalName());
            junit.framework.Assert.assertTrue((boolean)e2.getMessage().contains("DIGEST-MD5: digest response format violation. Mismatched response."));
            junit.framework.Assert.assertFalse((boolean)am.pinged);
        }
    }

    private void verifyValidToken(final Configuration conf, final CustomAM am, Token<ClientToAMTokenIdentifier> token) throws IOException, InterruptedException {
        UserGroupInformation ugi = UserGroupInformation.createRemoteUser((String)"me");
        ugi.addToken(token);
        ugi.doAs((PrivilegedExceptionAction)new PrivilegedExceptionAction<Void>(){

            @Override
            public Void run() throws Exception {
                CustomProtocol client = (CustomProtocol)RPC.getProxy(CustomProtocol.class, (long)1L, (InetSocketAddress)am.address, (Configuration)conf);
                client.ping();
                junit.framework.Assert.assertTrue((boolean)am.pinged);
                return null;
            }
        });
    }

    private static class CustomAM
    extends AbstractService
    implements CustomProtocol {
        private final ApplicationAttemptId appAttemptId;
        private final byte[] secretKey;
        private InetSocketAddress address;
        private boolean pinged = false;

        public CustomAM(ApplicationAttemptId appId, byte[] secretKey) {
            super("CustomAM");
            this.appAttemptId = appId;
            this.secretKey = secretKey;
        }

        @Override
        public void ping() throws YarnException, IOException {
            this.pinged = true;
        }

        protected void serviceStart() throws Exception {
            RPC.Server server;
            Configuration conf = this.getConfig();
            try {
                server = new RPC.Builder(conf).setProtocol(CustomProtocol.class).setNumHandlers(1).setSecretManager((SecretManager)new ClientToAMTokenSecretManager(this.appAttemptId, this.secretKey)).setInstance((Object)this).build();
            }
            catch (Exception e) {
                throw new YarnRuntimeException((Throwable)e);
            }
            server.start();
            this.address = NetUtils.getConnectAddress((Server)server);
            super.serviceStart();
        }
    }

    private static class CustomSecurityInfo
    extends SecurityInfo {
        private CustomSecurityInfo() {
        }

        public TokenInfo getTokenInfo(Class<?> protocol, Configuration conf) {
            return new TokenInfo(){

                public Class<? extends Annotation> annotationType() {
                    return null;
                }

                public Class<? extends TokenSelector<? extends TokenIdentifier>> value() {
                    return ClientToAMTokenSelector.class;
                }
            };
        }

        public KerberosInfo getKerberosInfo(Class<?> protocol, Configuration conf) {
            return null;
        }
    }

    private static interface CustomProtocol {
        public static final long versionID = 1L;

        public void ping() throws YarnException, IOException;
    }
}

