/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.crypto;

import java.io.FileDescriptor;
import java.io.FileInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.security.GeneralSecurityException;
import java.util.EnumSet;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;

import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.fs.ByteBufferReadable;
import org.apache.hadoop.fs.CanSetDropBehind;
import org.apache.hadoop.fs.CanSetReadahead;
import org.apache.hadoop.fs.HasEnhancedByteBufferAccess;
import org.apache.hadoop.fs.HasFileDescriptor;
import org.apache.hadoop.fs.PositionedReadable;
import org.apache.hadoop.fs.ReadOption;
import org.apache.hadoop.fs.Seekable;
import org.apache.hadoop.io.ByteBufferPool;

import com.google.common.base.Preconditions;

/**
 * CryptoInputStream decrypts data. It is not thread-safe. AES CTR mode is
 * required in order to ensure that the plain text and cipher text have a 1:1
 * mapping. The decryption is buffer based. The key points of the decryption
 * are (1) calculating the counter and (2) padding through stream position:
 * <p/>
 * counter = base + pos/(algorithm blocksize); 
 * padding = pos%(algorithm blocksize); 
 * <p/>
 * The underlying stream offset is maintained as state.
 */
@InterfaceAudience.Private
@InterfaceStability.Evolving
public class CryptoInputStream extends FilterInputStream implements 
    Seekable, PositionedReadable, ByteBufferReadable, HasFileDescriptor, 
    CanSetDropBehind, CanSetReadahead, HasEnhancedByteBufferAccess, 
    ReadableByteChannel {
  private static final byte[] oneByteBuf = new byte[1];
  private final CryptoCodec codec;
  private final Decryptor decryptor;
  private final int bufferSize;
  
  /**
   * Input data buffer. The data starts at inBuffer.position() and ends at 
   * to inBuffer.limit().
   */
  private ByteBuffer inBuffer;
  
  /**
   * The decrypted data buffer. The data starts at outBuffer.position() and 
   * ends at outBuffer.limit();
   */
  private ByteBuffer outBuffer;
  private long streamOffset = 0; // Underlying stream offset.
  
  /**
   * Whether the underlying stream supports 
   * {@link org.apache.hadoop.fs.ByteBufferReadable}
   */
  private Boolean usingByteBufferRead = null;
  
  /**
   * Padding = pos%(algorithm blocksize); Padding is put into {@link #inBuffer} 
   * before any other data goes in. The purpose of padding is to put the input 
   * data at proper position.
   */
  private byte padding;
  private boolean closed;
  private final byte[] key;
  private final byte[] initIV;
  private byte[] iv;
  private final boolean isByteBufferReadable;
  private final boolean isReadableByteChannel;
  
  /** DirectBuffer pool */
  private final Queue<ByteBuffer> bufferPool = 
      new ConcurrentLinkedQueue<ByteBuffer>();
  /** Decryptor pool */
  private final Queue<Decryptor> decryptorPool = 
      new ConcurrentLinkedQueue<Decryptor>();
  
  public CryptoInputStream(InputStream in, CryptoCodec codec, 
      int bufferSize, byte[] key, byte[] iv) throws IOException {
    this(in, codec, bufferSize, key, iv, 
        CryptoStreamUtils.getInputStreamOffset(in));
  }
  
  public CryptoInputStream(InputStream in, CryptoCodec codec,
      int bufferSize, byte[] key, byte[] iv, long streamOffset) throws IOException {
    super(in);
    this.bufferSize = CryptoStreamUtils.checkBufferSize(codec, bufferSize);
    this.codec = codec;
    this.key = key.clone();
    this.initIV = iv.clone();
    this.iv = iv.clone();
    this.streamOffset = streamOffset;
    isByteBufferReadable = in instanceof ByteBufferReadable;
    isReadableByteChannel = in instanceof ReadableByteChannel;
    inBuffer = ByteBuffer.allocateDirect(this.bufferSize);
    outBuffer = ByteBuffer.allocateDirect(this.bufferSize);
    decryptor = getDecryptor();
    resetStreamOffset(streamOffset);
  }
  
  public CryptoInputStream(InputStream in, CryptoCodec codec,
      byte[] key, byte[] iv) throws IOException {
    this(in, codec, CryptoStreamUtils.getBufferSize(codec.getConf()), key, iv);
  }
  
  public InputStream getWrappedStream() {
    return in;
  }
  
  /**
   * Decryption is buffer based.
   * If there is data in {@link #outBuffer}, then read it out of this buffer.
   * If there is no data in {@link #outBuffer}, then read more from the 
   * underlying stream and do the decryption.
   * @param b the buffer into which the decrypted data is read.
   * @param off the buffer offset.
   * @param len the maximum number of decrypted data bytes to read.
   * @return int the total number of decrypted data bytes read into the buffer.
   * @throws IOException
   */
  @Override
  public int read(byte[] b, int off, int len) throws IOException {
    checkStream();
    if (b == null) {
      throw new NullPointerException();
    } else if (off < 0 || len < 0 || len > b.length - off) {
      throw new IndexOutOfBoundsException();
    } else if (len == 0) {
      return 0;
    }
    
    final int remaining = outBuffer.remaining();
    if (remaining > 0) {
      int n = Math.min(len, remaining);
      outBuffer.get(b, off, n);
      return n;
    } else {
      int n = 0;
      
      /*
       * Check whether the underlying stream is {@link ByteBufferReadable},
       * it can avoid bytes copy.
       */
      if (usingByteBufferRead == null) {
        if (isByteBufferReadable || isReadableByteChannel) {
          try {
            n = isByteBufferReadable ? 
                ((ByteBufferReadable) in).read(inBuffer) : 
                  ((ReadableByteChannel) in).read(inBuffer);
            usingByteBufferRead = Boolean.TRUE;
          } catch (UnsupportedOperationException e) {
            usingByteBufferRead = Boolean.FALSE;
          }
        } else {
          usingByteBufferRead = Boolean.FALSE;
        }
        if (!usingByteBufferRead) {
          n = readFromUnderlyingStream(inBuffer);
        }
      } else {
        if (usingByteBufferRead) {
          n = isByteBufferReadable ? ((ByteBufferReadable) in).read(inBuffer) : 
                ((ReadableByteChannel) in).read(inBuffer);
        } else {
          n = readFromUnderlyingStream(inBuffer);
        }
      }
      if (n <= 0) {
        return n;
      }
      
      streamOffset += n; // Read n bytes
      decrypt(decryptor, inBuffer, outBuffer, padding);
      padding = afterDecryption(decryptor, inBuffer, streamOffset, iv);
      n = Math.min(len, outBuffer.remaining());
      outBuffer.get(b, off, n);
      return n;
    }
  }
  
  /** Read data from underlying stream. */
  private int readFromUnderlyingStream(ByteBuffer inBuffer) throws IOException {
    final int toRead = inBuffer.remaining();
    final byte[] tmp = getTmpBuf();
    final int n = in.read(tmp, 0, toRead);
    if (n > 0) {
      inBuffer.put(tmp, 0, n);
    }
    return n;
  }
  
  private byte[] tmpBuf;
  private byte[] getTmpBuf() {
    if (tmpBuf == null) {
      tmpBuf = new byte[bufferSize];
    }
    return tmpBuf;
  }
  
  /**
   * Do the decryption using inBuffer as input and outBuffer as output.
   * Upon return, inBuffer is cleared; the decrypted data starts at 
   * outBuffer.position() and ends at outBuffer.limit();
   */
  private void decrypt(Decryptor decryptor, ByteBuffer inBuffer, 
      ByteBuffer outBuffer, byte padding) throws IOException {
    Preconditions.checkState(inBuffer.position() >= padding);
    if(inBuffer.position() == padding) {
      // There is no real data in inBuffer.
      return;
    }
    inBuffer.flip();
    outBuffer.clear();
    decryptor.decrypt(inBuffer, outBuffer);
    inBuffer.clear();
    outBuffer.flip();
    if (padding > 0) {
      /*
       * The plain text and cipher text have a 1:1 mapping, they start at the 
       * same position.
       */
      outBuffer.position(padding);
    }
  }
  
  /**
   * This method is executed immediately after decryption. Check whether 
   * decryptor should be updated and recalculate padding if needed. 
   */
  private byte afterDecryption(Decryptor decryptor, ByteBuffer inBuffer, 
      long position, byte[] iv) throws IOException {
    byte padding = 0;
    if (decryptor.isContextReset()) {
      /*
       * This code is generally not executed since the decryptor usually 
       * maintains decryption context (e.g. the counter) internally. However, 
       * some implementations can't maintain context so a re-init is necessary 
       * after each decryption call.
       */
      updateDecryptor(decryptor, position, iv);
      padding = getPadding(position);
      inBuffer.position(padding);
    }
    return padding;
  }
  
  private long getCounter(long position) {
    return position / codec.getCipherSuite().getAlgorithmBlockSize();
  }
  
  private byte getPadding(long position) {
    return (byte)(position % codec.getCipherSuite().getAlgorithmBlockSize());
  }
  
  /** Calculate the counter and iv, update the decryptor. */
  private void updateDecryptor(Decryptor decryptor, long position, byte[] iv) 
      throws IOException {
    final long counter = getCounter(position);
    codec.calculateIV(initIV, counter, iv);
    decryptor.init(key, iv);
  }
  
  /**
   * Reset the underlying stream offset; clear {@link #inBuffer} and 
   * {@link #outBuffer}. This Typically happens during {@link #seek(long)} 
   * or {@link #skip(long)}.
   */
  private void resetStreamOffset(long offset) throws IOException {
    streamOffset = offset;
    inBuffer.clear();
    outBuffer.clear();
    outBuffer.limit(0);
    updateDecryptor(decryptor, offset, iv);
    padding = getPadding(offset);
    inBuffer.position(padding); // Set proper position for input data.
  }
  
  @Override
  public void close() throws IOException {
    if (closed) {
      return;
    }
    
    super.close();
    freeBuffers();
    closed = true;
  }
  
  /** Positioned read. It is thread-safe */
  @Override
  public int read(long position, byte[] buffer, int offset, int length)
      throws IOException {
    checkStream();
    try {
      final int n = ((PositionedReadable) in).read(position, buffer, offset, 
          length);
      if (n > 0) {
        // This operation does not change the current offset of the file
        decrypt(position, buffer, offset, n);
      }
      
      return n;
    } catch (ClassCastException e) {
      throw new UnsupportedOperationException("This stream does not support " +
          "positioned read.");
    }
  }
  
  /**
   * Decrypt length bytes in buffer starting at offset. Output is also put 
   * into buffer starting at offset. It is thread-safe.
   */
  private void decrypt(long position, byte[] buffer, int offset, int length) 
      throws IOException {
    ByteBuffer inBuffer = getBuffer();
    ByteBuffer outBuffer = getBuffer();
    Decryptor decryptor = null;
    try {
      decryptor = getDecryptor();
      byte[] iv = initIV.clone();
      updateDecryptor(decryptor, position, iv);
      byte padding = getPadding(position);
      inBuffer.position(padding); // Set proper position for input data.
      
      int n = 0;
      while (n < length) {
        int toDecrypt = Math.min(length - n, inBuffer.remaining());
        inBuffer.put(buffer, offset + n, toDecrypt);
        // Do decryption
        decrypt(decryptor, inBuffer, outBuffer, padding);
        
        outBuffer.get(buffer, offset + n, toDecrypt);
        n += toDecrypt;
        padding = afterDecryption(decryptor, inBuffer, position + n, iv);
      }
    } finally {
      returnBuffer(inBuffer);
      returnBuffer(outBuffer);
      returnDecryptor(decryptor);
    }
  }
  
  /** Positioned read fully. It is thread-safe */
  @Override
  public void readFully(long position, byte[] buffer, int offset, int length)
      throws IOException {
    checkStream();
    try {
      ((PositionedReadable) in).readFully(position, buffer, offset, length);
      if (length > 0) {
        // This operation does not change the current offset of the file
        decrypt(position, buffer, offset, length);
      }
    } catch (ClassCastException e) {
      throw new UnsupportedOperationException("This stream does not support " +
          "positioned readFully.");
    }
  }

  @Override
  public void readFully(long position, byte[] buffer) throws IOException {
    readFully(position, buffer, 0, buffer.length);
  }

  /** Seek to a position. */
  @Override
  public void seek(long pos) throws IOException {
    Preconditions.checkArgument(pos >= 0, "Cannot seek to negative offset.");
    checkStream();
    try {
      /*
       * If data of target pos in the underlying stream has already been read 
       * and decrypted in outBuffer, we just need to re-position outBuffer.
       */
      if (pos <= streamOffset && pos >= (streamOffset - outBuffer.remaining())) {
        int forward = (int) (pos - (streamOffset - outBuffer.remaining()));
        if (forward > 0) {
          outBuffer.position(outBuffer.position() + forward);
        }
      } else {
        ((Seekable) in).seek(pos);
        resetStreamOffset(pos);
      }
    } catch (ClassCastException e) {
      throw new UnsupportedOperationException("This stream does not support " +
          "seek.");
    }
  }
  
  /** Skip n bytes */
  @Override
  public long skip(long n) throws IOException {
    Preconditions.checkArgument(n >= 0, "Negative skip length.");
    checkStream();
    
    if (n == 0) {
      return 0;
    } else if (n <= outBuffer.remaining()) {
      int pos = outBuffer.position() + (int) n;
      outBuffer.position(pos);
      return n;
    } else {
      /*
       * Subtract outBuffer.remaining() to see how many bytes we need to 
       * skip in the underlying stream. Add outBuffer.remaining() to the 
       * actual number of skipped bytes in the underlying stream to get the 
       * number of skipped bytes from the user's point of view.
       */
      n -= outBuffer.remaining();
      long skipped = in.skip(n);
      if (skipped < 0) {
        skipped = 0;
      }
      long pos = streamOffset + skipped;
      skipped += outBuffer.remaining();
      resetStreamOffset(pos);
      return skipped;
    }
  }

  /** Get underlying stream position. */
  @Override
  public long getPos() throws IOException {
    checkStream();
    // Equals: ((Seekable) in).getPos() - outBuffer.remaining()
    return streamOffset - outBuffer.remaining();
  }
  
  /** ByteBuffer read. */
  @Override
  public int read(ByteBuffer buf) throws IOException {
    checkStream();
    if (isByteBufferReadable || isReadableByteChannel) {
      final int unread = outBuffer.remaining();
      if (unread > 0) { // Have unread decrypted data in buffer.
        int toRead = buf.remaining();
        if (toRead <= unread) {
          final int limit = outBuffer.limit();
          outBuffer.limit(outBuffer.position() + toRead);
          buf.put(outBuffer);
          outBuffer.limit(limit);
          return toRead;
        } else {
          buf.put(outBuffer);
        }
      }
      
      final int pos = buf.position();
      final int n = isByteBufferReadable ? ((ByteBufferReadable) in).read(buf) : 
            ((ReadableByteChannel) in).read(buf);
      if (n > 0) {
        streamOffset += n; // Read n bytes
        decrypt(buf, n, pos);
      }
      
      if (n >= 0) {
        return unread + n;
      } else {
        if (unread == 0) {
          return -1;
        } else {
          return unread;
        }
      }
    } else {
      int n = 0;
      if (buf.hasArray()) {
        n = read(buf.array(), buf.position(), buf.remaining());
        if (n > 0) {
          buf.position(buf.position() + n);
        }
      } else {
        byte[] tmp = new byte[buf.remaining()];
        n = read(tmp);
        if (n > 0) {
          buf.put(tmp, 0, n);
        }
      }
      return n;
    }
  }
  
  /**
   * Decrypt all data in buf: total n bytes from given start position.
   * Output is also buf and same start position.
   * buf.position() and buf.limit() should be unchanged after decryption.
   */
  private void decrypt(ByteBuffer buf, int n, int start) 
      throws IOException {
    final int pos = buf.position();
    final int limit = buf.limit();
    int len = 0;
    while (len < n) {
      buf.position(start + len);
      buf.limit(start + len + Math.min(n - len, inBuffer.remaining()));
      inBuffer.put(buf);
      // Do decryption
      try {
        decrypt(decryptor, inBuffer, outBuffer, padding);
        buf.position(start + len);
        buf.limit(limit);
        len += outBuffer.remaining();
        buf.put(outBuffer);
      } finally {
        padding = afterDecryption(decryptor, inBuffer, streamOffset - (n - len), iv);
      }
    }
    buf.position(pos);
  }
  
  @Override
  public int available() throws IOException {
    checkStream();
    
    return in.available() + outBuffer.remaining();
  }

  @Override
  public boolean markSupported() {
    return false;
  }
  
  @Override
  public void mark(int readLimit) {
  }
  
  @Override
  public void reset() throws IOException {
    throw new IOException("Mark/reset not supported");
  }

  @Override
  public boolean seekToNewSource(long targetPos) throws IOException {
    Preconditions.checkArgument(targetPos >= 0, 
        "Cannot seek to negative offset.");
    checkStream();
    try {
      boolean result = ((Seekable) in).seekToNewSource(targetPos);
      resetStreamOffset(targetPos);
      return result;
    } catch (ClassCastException e) {
      throw new UnsupportedOperationException("This stream does not support " +
          "seekToNewSource.");
    }
  }

  @Override
  public ByteBuffer read(ByteBufferPool bufferPool, int maxLength,
      EnumSet<ReadOption> opts) throws IOException,
      UnsupportedOperationException {
    checkStream();
    try {
      if (outBuffer.remaining() > 0) {
        // Have some decrypted data unread, need to reset.
        ((Seekable) in).seek(getPos());
        resetStreamOffset(getPos());
      }
      final ByteBuffer buffer = ((HasEnhancedByteBufferAccess) in).
          read(bufferPool, maxLength, opts);
      if (buffer != null) {
        final int n = buffer.remaining();
        if (n > 0) {
          streamOffset += buffer.remaining(); // Read n bytes
          final int pos = buffer.position();
          decrypt(buffer, n, pos);
        }
      }
      return buffer;
    } catch (ClassCastException e) {
      throw new UnsupportedOperationException("This stream does not support " + 
          "enhanced byte buffer access.");
    }
  }

  @Override
  public void releaseBuffer(ByteBuffer buffer) {
    try {
      ((HasEnhancedByteBufferAccess) in).releaseBuffer(buffer);
    } catch (ClassCastException e) {
      throw new UnsupportedOperationException("This stream does not support " + 
          "release buffer.");
    }
  }

  @Override
  public void setReadahead(Long readahead) throws IOException,
      UnsupportedOperationException {
    try {
      ((CanSetReadahead) in).setReadahead(readahead);
    } catch (ClassCastException e) {
      throw new UnsupportedOperationException("This stream does not support " +
          "setting the readahead caching strategy.");
    }
  }

  @Override
  public void setDropBehind(Boolean dropCache) throws IOException,
      UnsupportedOperationException {
    try {
      ((CanSetDropBehind) in).setDropBehind(dropCache);
    } catch (ClassCastException e) {
      throw new UnsupportedOperationException("This stream does not " +
          "support setting the drop-behind caching setting.");
    }
  }

  @Override
  public FileDescriptor getFileDescriptor() throws IOException {
    if (in instanceof HasFileDescriptor) {
      return ((HasFileDescriptor) in).getFileDescriptor();
    } else if (in instanceof FileInputStream) {
      return ((FileInputStream) in).getFD();
    } else {
      return null;
    }
  }
  
  @Override
  public int read() throws IOException {
    return (read(oneByteBuf, 0, 1) == -1) ? -1 : (oneByteBuf[0] & 0xff);
  }
  
  private void checkStream() throws IOException {
    if (closed) {
      throw new IOException("Stream closed");
    }
  }
  
  /** Get direct buffer from pool */
  private ByteBuffer getBuffer() {
    ByteBuffer buffer = bufferPool.poll();
    if (buffer == null) {
      buffer = ByteBuffer.allocateDirect(bufferSize);
    }
    
    return buffer;
  }
  
  /** Return direct buffer to pool */
  private void returnBuffer(ByteBuffer buf) {
    if (buf != null) {
      buf.clear();
      bufferPool.add(buf);
    }
  }
  
  /** Forcibly free the direct buffers. */
  private void freeBuffers() {
    CryptoStreamUtils.freeDB(inBuffer);
    CryptoStreamUtils.freeDB(outBuffer);
    cleanBufferPool();
  }
  
  /** Clean direct buffer pool */
  private void cleanBufferPool() {
    ByteBuffer buf;
    while ((buf = bufferPool.poll()) != null) {
      CryptoStreamUtils.freeDB(buf);
    }
  }
  
  /** Get decryptor from pool */
  private Decryptor getDecryptor() throws IOException {
    Decryptor decryptor = decryptorPool.poll();
    if (decryptor == null) {
      try {
        decryptor = codec.createDecryptor();
      } catch (GeneralSecurityException e) {
        throw new IOException(e);
      }
    }
    
    return decryptor;
  }
  
  /** Return decryptor to pool */
  private void returnDecryptor(Decryptor decryptor) {
    if (decryptor != null) {
      decryptorPool.add(decryptor);
    }
  }

  @Override
  public boolean isOpen() {
    return !closed;
  }
}
